A Primer on TensorFlow 2.0

This post is also available as a Python notebook.

From September 2017 to October 2018, I worked on TensorFlow 2.0 alongside many engineers. In this post, I’ll explain what TensorFlow 2.0 is and how it differs from TensorFlow 1.x. Towards the end, I’ll briefly compare TensorFlow 2.0 to PyTorch 1.0. This post represents my own views; it does not represent the views of Google, my former employer.

TensorFlow (TF) 2.0 is a significant, backwards-incompatible update to TF’s execution model and API.

Execution model. In TF 2.0, all operations execute imperatively by default. Graphs and the graph runtime are both abstracted away by a just-in-time tracer that translates Python functions executing TF operations into executable graph functions. This means in TF 2.0, there is no Session, and no global graph state. The tracer is exposed as a Python decorator, tf.function. This decorator is for advanced users. Using it is completely optional.

API. TF 2.0 makes tf.keras the high-level API for constructing and training neural networks. But you don’t have to use Keras if you don’t want to. You can instead use lower-level operations and automatic differentiation directly.

To follow along with the code examples in this post, install the TF 2.0 alpha.

pip install tensorflow==2.0.0-alpha0
import tensorflow as tf


  1. Why TF 2.0?
  2. Imperative execution
  3. State
  4. Automatic differentiation
  5. Keras
  6. Graph functions
  7. Comparison to other Python libraries
  8. Domain-specific languages for machine learning

I. Why TF 2.0?

TF 2.0 largely exists to make TF easier to use, for newcomers and researchers alike.

TF 1.x requires metaprogramming

TF 1.x was designed to train extremely large, static neural networks. Representing a model as a dataflow graph and separating its specification from its execution simplifies training at scale, which explains why TF 1.x uses Python as a declarative metaprogramming tool for graphs.

But most people don’t need to train Google-scale models, and most people find metaprogramming difficult. Constructing a TF 1.x graph is like writing assembly code, and this abstraction is so low-level that it is hard to produce anything but the simplest differentiable programs using it. Programs that have data-dependent control flow are particularly hard to express as graphs.

Metaprogramming is (often) unnecessary

It is possible to implement automatic differentiation by tracing computations while they are executed, without static graphs; Chainer, PyTorch, and autograd do exactly that. These libraries are substantially easier to use than TF 1.x, since imperative programming is so much more natural than declarative programming. Moreover, when training models with large operations on a single machine, these graph-free libraries are competitive with TF 1.x performance. For these reasons, TF 2.0 privileges imperative execution.

Graphs are still sometimes useful, for distribution, serialization, code generation, deployment, and (sometimes) performance. That’s why TF 2.0 provides the just-in-time tracer tf.function, which transparently converts Python functions into functions backed by graphs. This tracer also rewrites tensor-dependent Python control flow to TF control flow, and it automatically adds control dependencies to order reads and writes to TF state. This means that constructing graphs via tf.function is much easier than constructing TF 1.x graphs manually.

Multi-stage programming

The ability to create polymorphic graph functions via tf.function at runtime makes TF 2.0 similar to a multi-stage programming language.

For TF 2.0, I recommend the following multi-stage workflow. Start by implementing your program in imperative mode. Once you’re satisfied that your program is correct, measure its performance. If the performance is unsatisfactory, analyze your program using cProfile or a comparable tool to find bottlenecks consisting of TF operations. Next, refactor the bottlenecks into Python functions, and stage these functions in graphs with tf.function.

A multi-stage workflow for TF 2.0.

If you mostly use TF 2.0 to train large deep models, you probably won’t need to analyze or stage your programs. If on the other hand you write programs that execute lots of small operations, like MCMC samplers or reinforcement learning algorithms, you’ll likely find this workflow useful. In such cases, the Python overhead incurred by executing operations eagerly actually matters.

II. Imperative execution

In TF 2.0, all operations are executed imperatively, or “eagerly”, by default. If you’ve used NumPy or PyTorch, TF 2.0 will feel familiar. For example, the following line of code will immediately construct two tensors backed by numerical tensors and then execute the add operation.

tf.constant([1., 2.]) + tf.constant([3., 4.])
<tf.Tensor: id=1440, shape=(2,), dtype=int32, numpy=array([4, 6], dtype=float32)>

Contrast the above code snippet to its verbose, awkward TF 1.x equivalent:

# TF 1.X code
x = tf.placeholder(tf.float32, shape=[2])
y = tf.placeholder(tf.float32, shape=[2])
value = x + y

with tf.Session() as sess:
  print(sess.run(value, feed_dict={x: [1., 2.], y: [3., 4.]}))

In TF 2.0, there are no placeholders, no sessions, and no feed dicts. Because operations are executed immediately, you can use (and differentiate through) if statements and for loops (no more tf.cond or tf.while_loop). You can also use whatever Python data structures you like, and debug your programs with print statements and pdb.

If TF detects that a GPU is available, it will automatically run operations on the GPU when possible. The target device can also be controlled explicitly.

if tf.test.is_gpu_available():
  with tf.device('gpu:0'):
    tf.constant([1., 2.]) + tf.constant([3., 4.])

III. State

Using tf.Variable objects in TensorFlow required wrangling global collections of graph state, with confusing APIs like tf.get_variable, tf.variable_scope, and tf.initializers.global_variables. TF 2.0 does away with global collections and their associated APIs. If you need a tf.Variable in TF 2.0, then you just construct and initialize it directly:

tf.Variable(tf.random.normal([3, 5]))
<tf.Variable 'Variable:0' shape=(3, 5) dtype=float32, numpy=
array([[ 0.13141578, -0.18558209,  1.2412338 , -0.5886968 , -0.9191646 ],
       [ 1.186105  , -0.45135704,  0.57979995,  0.12573312, -0.7697861 ],
       [ 0.28296474,  1.2735683 , -0.08385598,  0.59388596, -0.2402552 ]],

IV. Automatic differentiation

TF 2.0 implements reverse-mode automatic differentiation (also known as backpropagation), using a trace-based mechanism. This trace, or tape, is exposed as a context manager, tf.GradientTape. The watch method designates a Tensor as something that we’ll need to differentiate with respect to later. Notice that by tracing the computation of dy_dx under the first tape, we’re able to compute d2y_dx2.

x = tf.constant(3.0)
with tf.GradientTape() as t1:
  with tf.GradientTape() as t2:
    y = x * x
  dy_dx = t2.gradient(y, x)
d2y_dx2 = t1.gradient(dy_dx, x)
<tf.Tensor: id=62, shape=(), dtype=float32, numpy=6.0>
<tf.Tensor: id=68, shape=(), dtype=float32, numpy=2.0>

tf.Variable objects are watched automatically by tapes.

x = tf.Variable(3.0)
with tf.GradientTape() as t1:
  with tf.GradientTape() as t2:
    y = x * x
  dy_dx = t2.gradient(y, x)
d2y_dx2 = t1.gradient(dy_dx, x)

V. Keras

TF 1.x is notorious for having many mutually incompatible high-level APIs for neural networks. TF 2.0 has just one high-level API: tf.keras, which essentially implements the Keras API but is customized for TF. Several standard layers for neural networks are available in the tf.keras.layers namespace.

Keras layers can be composed via tf.keras.Sequential() to obtain an object representing their composition. For example, the below code trains a toy CNN on MNIST. (Of course, MNIST can be solved by much simpler methods, like least squares.)

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
input_shape=[28, 28, 1]

max_pool = tf.keras.layers.MaxPooling2D(
      (2, 2), (2, 2), padding='same', data_format=data_format)

model = tf.keras.Sequential([
    input_shape=[28, 28]),
    padding='same', data_format=data_format,
  tf.keras.layers.Conv2D(64, 5,
    padding='same', data_format=data_format,
  tf.keras.layers.Dense(1024, activation=tf.nn.relu),
  tf.keras.layers.Dense(10, activation=tf.nn.softmax)
model.fit(x_train, y_train, epochs=1)
60000/60000 [==============================] - 238s 4ms/sample - loss: 0.3417 - accuracy: 0.9495

Alternatively, the same model could have been written as a subclass of tf.keras.Model.

class ConvNet(tf.keras.Model):
  def __init__(self, input_shape, data_format):
    super(ConvNet, self).__init__()
    self.reshape = tf.keras.layers.Reshape(
      target_shape=input_shape, input_shape=[28, 28])
    self.conv1 = tf.keras.layers.Conv2D(32,5,
      padding='same', data_format=data_format,
    self.pool = tf.keras.layers.MaxPooling2D(
      (2, 2), (2, 2), padding='same', data_format=data_format)
    self.conv2 = tf.keras.layers.Conv2D(64, 5,
      padding='same', data_format=data_format,
    self.flt = tf.keras.layers.Flatten()
    self.d1 = tf.keras.layers.Dense(1024, activation=tf.nn.relu)
    self.dropout = tf.keras.layers.Dropout(0.3)
    self.d2 = tf.keras.layers.Dense(10, activation=tf.nn.softmax)

  def call(self, x):
    x = self.reshape(x)
    x = self.conv1(x)
    x = self.pool(x)
    x = self.conv2(x)
    x = self.pool(x)
    x = self.flt(x)
    x = self.d1(x)
    x = self.dropout(x)
    return self.d2(x)

If you don’t want to use tf.keras, you can use low-level APIs like tf.reshape, tf.nn.conv2d, tf.nn.max_pool, tf.nn.dropout, and tf.matmul directly.

VI. Graph functions

For advanced users who need graphs, TF 2.0 provides tf.function, a just-in-time tracer that converts Python functions that execute TensorFlow operations into graph functions. A graph function is a TF graph with named inputs and outputs. Graph functions are executed by a C++ runtime that automatically partitions graphs across devices, and it parallelizes and optimizes them before execution.

Calling a graph function is syntactically equivalent to calling a Python function. Here’s a very simple example.

def add(tensor):
  return tensor + tensor + tensor
# Executes as a dataflow graph
add(tf.ones([2, 2])) 
<tf.Tensor: id=1487, shape=(2, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.]], dtype=float32)>

The add function is also polymorphic in the data types and shapes of its Tensor arguments (and the run-time values of the non-Tensor arguments), even though TF graphs are not.

add(tf.ones([2, 2], dtype=tf.uint8)) 
<tf.Tensor: id=1499, shape=(2, 2), dtype=uint8, numpy=
array([[3, 3],
       [3, 3]], dtype=uint8)>

Every time a graph function is called, its “input signature” is analyzed. If the input signature doesn’t match an input signature it has seen before, it re-traces the Python function and constructs another concrete graph function. (In programming languages terms, this is like multiple dispatch or lightweight modular staging.) This means that for one Python function, many concrete graph functions might be constructed. This also means that every call that triggers a trace will be slow, but subsequent calls with the same input signature will be much faster.

Lexical closure, state, and control dependencies

Graph functions support lexically closing over tf.Tensor and tf.Variable objects. You can mutate tf.Variable objects inside a graph function, and tf.function will automatically add the control dependencies needed to ensure that your reads and writes happen in program-order.

a = tf.Variable(1.0)
b = tf.Variable(1.0)

def f(x, y):
  a.assign(y * b)
  b.assign_add(x * a)
  return a + b

f(tf.constant(1.0), tf.constant(2.0))
<tf.Tensor: id=1569, shape=(), dtype=float32, numpy=5.0>
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=3.0>

Python control flow

tf.function automatically rewrites Python control flow that depends on tf.Tensor data into graph control flow, using autograph. This means that you no longer need to use constructs like tf.cond and tf.while_loop. For example, if we were to translate the following function into a graph function via tf.function, autograph would convert the for loop into a tf.while_loop, because it depends on tf.range(100), which is a tf.Tensor.

def matmul_many(tensor):
  accum = tensor
  for _ in tf.range(100):  # will be converted by autograph
    accum = tf.matmul(accum, tensor)
  return accum

It’s important to note that if tf.range(100) were replaced with range(100), then the loop would be unrolled, meaning that a graph with 100 matmul operations would be generated.

You can inspect the code that autograph generates on your behalf.

from __future__ import print_function

def tf__matmul_many(tensor):
    with ag__.function_scope('matmul_many'):
      do_return = False
      retval_ = None
      accum = tensor

      def loop_body(loop_vars, accum_1):
        with ag__.function_scope('loop_body'):
          _ = loop_vars
          accum_1 = ag__.converted_call('matmul', tf, ag__.ConversionOptions(recursive=True, verbose=0, strip_decorators=(ag__.convert, ag__.do_not_convert, ag__.converted_call), force_conversion=False, optional_features=ag__.Feature.ALL, internal_convert_user_code=True), (accum_1, tensor), {})
          return accum_1,
      accum, = ag__.for_stmt(ag__.converted_call('range', tf, ag__.ConversionOptions(recursive=True, verbose=0, strip_decorators=(ag__.convert, ag__.do_not_convert, ag__.converted_call), force_conversion=False, optional_features=ag__.Feature.ALL, internal_convert_user_code=True), (100,), {}), None, loop_body, (accum,))
      do_return = True
      retval_ = accum
      return retval_

tf__matmul_many.autograph_info__ = {}


Graph functions can provide significant speed-ups for programs that execute many small TF operations. For these programs, the Python overhead incurred executing an operation imperatively outstrips the time spent running the operations. As an example, let’s benchmark the matmul_many function imperatively and as a graph function.

graph_fn = tf.function(matmul_many)

Here’s the imperative (Python) performance.


matmul_many(tf.ones([2, 2]))
100 loops, best of 3: 13.5 ms per loop

The first call to graph_fn is slow, since this is when the graph function is generated.


graph_fn(tf.ones([2, 2]))
CPU times: user 158 ms, sys: 2.02 ms, total: 160 ms
Wall time: 159 ms
<tf.Tensor: id=1530126, shape=(2, 2), dtype=float32, numpy=
array([[1., 1.],
       [1., 1.]], dtype=float32)>

But subsequent calls are an order of magnitude faster than imperatively executing matmul_many.


graph_fn(tf.ones([2, 2]))
1000 loops, best of 3: 1.97 ms per loop

VII. Comparison to other Python libraries

There are many libraries for machine learning. Out of all of them, PyTorch 1.0 is the one that’s most similar to TF 2.0. Both TF 2.0 and PyTorch 1.0 execute imperatively by default, and both provide ways to transform Python functions into graph-backed functions (compare tf.function and torch.jit). The PyTorch JIT tracer, torch.jit.trace, doesn’t implement the multiple-dispatch semantics that tf.function does, and it also doesn’t rewrite the AST. On the other hand, TorchScript lets you use Python control flow, but unlike tf.function, it doesn’t let you mix in arbitrary Python code that parametrizes the construction of your graph. That means that in comparison to tf.function, TorchScript makes it harder for you to shoot yourself in the foot, while potentially limiting your creative expression.

So should you use TF 2.0, or PyTorch 1.0? It depends. Because TF 2.0 is in alpha, it still has some kinks, and its imperative performance still needs work. But you can probably count on TF 2.0 becoming stable sometime this year. If you’re in industry, TensorFlow has TFX for production pipelines, TFLite for deploying to mobile, and TensorFlow.js for the web. PyTorch recently made a commitment to production; since then, they’ve added C++ inference and deployment solutions for several cloud providers. For research, I’ve found that TF 2.0 and PyTorch 1.0 are sufficiently similar that I’m comfortable using either one, and my choice of framework depends on my collaborators.

The multi-stage approach of TF 2.0 is similar to what’s done in JAX. JAX is great if you want a functional programming model that looks exactly like NumPy, but with automatic differentiation and GPU support; this is, in fact, what many researchers want. If you don’t like functional programming, JAX won’t be a good fit.

VIII. Domain-specific languages for machine learning

TF 2.0 and PyTorch 1.0 are very unusual libraries. It has been observed that these libraries resemble domain-specific languages (DSLs) for automatic-differentiation and machine learning, embedded in Python (see also our paper on TF Eager, TF 2.0’s precursor). What TF 2.0 and PyTorch 1.0 accomplish in Python is impressive, but they’re pushing the language to its limits.

There is now significant work underway to embed ML DSLs in languages that are more amenable to compilation than Python, like Swift (DLVM, Swift for TensorFlow, MLIR), and Julia (Flux, Zygote). So while TF 2.0 and PyTorch 1.0 are great libraries, do stay tuned: over the next year (or two, or three?), the ecosystem of programming languages for machine learning will continue to evolve rapidly.

Leave a Reply

Your email address will not be published. Required fields are marked *