TensorFlow 2.0: Functions, not Sessions.

doc link

Design Proposal

Basic idea: Python functions as Graphs

Where tf.function is a decorator that “defines a TensorFlow function”. A “TensorFlow function” defines a computation as a graph of TensorFlow operations, with named arguments and explicit return values.

1
2
3
4
5
6
7
8
9
10
11
12
import tensorflow as tf

@tf.function
def compute_z1(x, y):
return tf.add(x, y)

@tf.function
def compute_z0(x):
return compute_z1(x, tf.square(x))

z0 = compute_z0(2.)
z1 = compute_z1(2., 2.)

Having the Python function correspond to what the runtime will execute reduces conceptual complexity in translating between the two domains.

Referencing state: Variables, tables etc.

A function decorated Python function encapsulates a graph and its execution. The Python function may reference stateful objects (i.e., state backed by DT_RESOURCE tensors in the runtime, e.g., tf.Variable) by referencing the corresponding Python object, and these will be captured as implicit inputs to the function.

Comparing TensorFlow code today with how we propose it looks in 2.x:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# TF 1.x
W = tf.Variable(
tf.glorot_uniform_initializer()(
(10, 10)))
b = tf.Variable(tf.zeros(10))
c = tf.Variable(0)

x = tf.placeholder(tf.float32)
ctr = c.assign_add(1)
with tf.control_dependencies([ctr]):
y = tf.matmul(x, W) + b
init =
tf.global_variables_initializer()

with tf.Session() as sess:
sess.run(init)
print(sess.run(y,
feed_dict={x: make_input_value()}))
assert int(sess.run(c)) == 1

# TF 2.0
W = tf.Variable(
tf.glorot_uniform_initializer()(
(10, 10)))
b = tf.Variable(tf.zeros(10))
c = tf.Variable(0)

@tf.function
def f(x):
c.assign_add(1)
return tf.matmul(x, W) + b

print(f(make_input_value())
assert int(c) == 1

Worthy of note here - in TensorFlow 1.x, the memory underlying the variables W and b in the runtime lives for the lifetime of the Session - unrelated to the lifetime of the Python objects. In 2.x, the lifetime of the Python objects and the runtime state are tied together.

Control dependencies

In TensorFlow graphs today, control dependencies are sometimes needed to ensure correct evaluation order.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# TF 1.x
v = tf.Variable(1.0)
init_op = tf.global_variables_initializer()
assign_op = v.assign(2.0)
read = v.read_value()

with tf.Session() as sess:
sess.run(init_op)
val = sess.run(read)
print(val) # Will print 1.0, the assign is ignored
val = sess.run([read, assign_op])[0]
print(val) # Non-deterministically prints 1.0 or 2.0,

# TF2.0
v = tf.Variable(1.0)
@tf.function
def f():
v.assign(2.0)
return v.read_value()

print(f()) # Always prints 2.0.

Note that the intention here is to avoid observable differences from program order. For example:

1
2
3
4
5
6
7
8
a = tf.Variable(1.0)
b = tf.Variable(1.0)
@tf.function
def f():
a.assign(2.0)
b.assign(3.0)
return a + b
print(f())

Will always print 5.0 since the assignments will occur before the read. However, there is no guaranteed ordering between the assignment of a and b (as any difference in that is not observable).

Functions that create state

1
2
3
4
5
6
7
8
9
10
11
v = None

@tf.function
def f(x):
global v
if v is None:
v = tf.Variable(1.0)
return tf.cast(x, tf.float32) + v

f(tf.constant(1, dtype=tf.float32)) # Creates the variable, returns 2.0
f(tf.constant(2, dtype=tf.int32)) # Reuses the variable, returns 3.0
  1. State (like tf.Variable objects) are only created the first time the function f is called. If any variables are created in the first execution of f, then @tf.function will trace f again the second time it is invoked in order to record the behavior that will be used from then on.
  2. The caller must make sure that any variable referenced by the function still exists whenever the function is evaluated. @tf.function itself will keep only weak references to these created variables.

Trace Caches

Since new graphs are traced when new input signatures are encountered, a function can encapsulate multiple graphs. For example, considering the following, there are two graphs created here:

1
2
3
4
5
6
@tf.function
def f(x):
return tf.square(x)

f(tf.constant(1, dtype=tf.int32))
f(tf.constant(1.0, dtype=tf.float32))

Note the use of tf.constant to ensure that the argument is a Tensor. If the argument were a Python value, then additional graphs will be traced for each such value. For example, the following two calls will result in two additional graphs being traced:

1
2
f(1.0)
f(2.0)

Where arguments are not Tensors, the “value” of the argument is used to compute the trace_cache_key. For example:

1
2
3
4
5
6
@tf.function
def f(x, use_multiply):
return tf.multiply(x, x) if use_multiply else tf.square(x)

f(tf.constant(2.0), True)
f(tf.constant(2.0), False)

will result in 2 graphs being created, since the two calls result in two different cache keys because the value of the Python object (the second argument) changes between the two.

Note that the “type” of Tensor inputs to the function also incorporates the shape. For example:

1
2
3
4
5
6
7
8
9
@tf.function
def f(x):
return tf.add(x, 1.)

f(tf.constant([2.0]))
f(tf.constant([2.0, 3.0]))
f(tf.constant([[2.0]]))
f(tf.constant([3.0]))
f(tf.constant([4.0, 5.0]))

will result in 3 graphs being created.

The trace_cache_key also incorporates the “context” in which the call was made. For example:

1
2
3
4
5
6
7
@tf.function
def f(x): return tf.add(x, 1.)

with tf.device("/device:CPU:0"):
f(tf.constant(2.0))
with tf.device("/device:GPU:0"):
f(tf.constant(2.0))

Will create 2 graphs.

CAUTION: Too many traces

CAUTION: Mutable non-Tensor arguments

The trace_cache_key includes the Python object for non-Tensor arguments. Mutations of these arguments might not be detected. For example:

1
2
3
4
5
6
7
8
9
10
11
12
# non-Tensor object
class Params(object):
multiply = True

p = Params()
@tf.function
def f(x, y):
return tf.multiply(x, 2.) if y.multiply else tf.add(x, 2.)

f(3., p) # Returns 6.0
p.multiply = False
f(3., p) # Mutations to `p` may not trigger a retrace, so might still return 6.0

Input Signatures

An “input signature” can be explicitly specified to control the trace_cache_key computation based on the type and shape of Tensor (and list of Tensor) arguments to f.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
@tf.function(input_signature=((tf.float32, [None]))
def f(x):
return tf.add(x, 1.)

f(tf.constant([2.0])) # Returns [3.0]
f(tf.constant([2.0, 3.0])) # Matches the input signature as [None]
# matches the actual shape [2]
f(tf.constant([[2.0]])) # Raises an error as the arguments don't match the
# input signature.
f(tf.constant([2], dtype=tf.int32)) # Raises an error as the dtype of the argument
# does not match the input signature

# f is backed by a single Graph since the input signature specification allowed
# for the same graph to be used when the input shape is (1,) or (2,).
  1. For a Tensor argument, it specifies a (dtype, shape pattern).
    • (tf.float32, [None]) means the argument must be a float32 vector (with any number of elements).
    • (tf.int32, []) means that the argument must be an int32 scalar.
  2. For a list of Tensor objects, it specifies an optional list length and the signature for elements in the list (i.e., the dtype and shape pattern for all elements in the list).
  3. For non-Tensor arguments: tf.PYTHON_VALUE

You can use the tf.TRACE_ON_NEW_VALUE to release the restriction of dtype:

1
2
3
4
5
6
@tf.function(input_signature=((tf.TRACE_ON_NEW_VALUE, [None]))
def f(x):
return tf.square(x)

f(tf.constant([2.0])) # Returns 4.0
f(tf.constant([2, 2], dtype=tf.int32) # Returns [4, 4] after tracing a new graph

Classes

If a member function of a class does not create variables, it may be decorated with @tf.function and it will work:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class AnyShapeModel(object):
def __init__(self):
self.v = None

@tf.function
def increment(self, amount):
if self.v is None:
self.v = tf.Variable(tf.zeros_like(amount))
self.v.assign_add(amount)

model1 = AnyShapeModel()
model1.increment(tf.constant(3))
assert int(model1.v) == 3
model1.increment(tf.constant(4))
assert int(model1.v) == 7
model2 = AnyShapeModel()
model2.increment(tf.constant([4, 5]))
assert model2.v.numpy() == [4, 5]

The semantics here are that each new instance is allowed to create variables in each @tf.function once.

function-ing Python control flow

If the function has data-dependent control flow then though the function will execute fine with eager execution enabled, function decorating it will fail. For example:

1
2
3
4
5
6
7
8
9
10
11
12
def f(x, y):
if tf.equal(y, 0.0):
return y
return x / y

x = tf.constant(2.0)
y = tf.constant(2.0)

f(x, y) # Will be 1.0

df = tf.function(f)
df(x, y) # Will raise an error complaining about the data-dependent control flow

To fix this, one would have to use the graph construction APIs for control flow (tf.cond, tf.while_loop):

1
2
3
4
5
6
7
8
9
10
def f(x, y):
return tf.cond(tf.equal(y, 0.0), lambda: y, lambda: x/y)

x = tf.constant(2.0)
y = tf.constant(2.0)

f(x, y) # Will be 1.0

df = tf.function(f)
df(x, y) # Will be 1.0

This situation can be improved with the help of autograph to allow expression of control flow in Python.

1
2
df = tf.function(autograph=True)(f)
f(x, y) # Will be 1.0