TensorFlow 2.0: Functions, not Sessions.
Design Proposal
Basic idea: Python functions as Graphs
Where tf.function is a decorator that “defines a TensorFlow function”. A “TensorFlow function” defines a computation as a graph of TensorFlow operations, with named arguments and explicit return values.
1 | import tensorflow as tf |
Having the Python function correspond to what the runtime will execute reduces conceptual complexity in translating between the two domains.
Referencing state: Variables, tables etc.
A function decorated Python function encapsulates a graph and its execution. The Python function may reference stateful objects (i.e., state backed by DT_RESOURCE tensors in the runtime, e.g., tf.Variable) by referencing the corresponding Python object, and these will be captured as implicit inputs to the function.
Comparing TensorFlow code today with how we propose it looks in 2.x:
1 | # TF 1.x |
Worthy of note here - in TensorFlow 1.x, the memory underlying the variables W and b in the runtime lives for the lifetime of the Session - unrelated to the lifetime of the Python objects. In 2.x, the lifetime of the Python objects and the runtime state are tied together.
Control dependencies
In TensorFlow graphs today, control dependencies are sometimes needed to ensure correct evaluation order.
1 | # TF 1.x |
Note that the intention here is to avoid observable differences from program order. For example:
1 | a = tf.Variable(1.0) |
Will always print 5.0 since the assignments will occur before the read. However, there is no guaranteed ordering between the assignment of a and b (as any difference in that is not observable).
Functions that create state
1 | v = None |
- State (like tf.Variable objects) are only created the first time the function f is called. If any variables are created in the first execution of f, then @tf.function will trace f again the second time it is invoked in order to record the behavior that will be used from then on.
- The caller must make sure that any variable referenced by the function still exists whenever the function is evaluated. @tf.function itself will keep only weak references to these created variables.
Trace Caches
Since new graphs are traced when new input signatures are encountered, a function can encapsulate multiple graphs. For example, considering the following, there are two graphs created here:
1 |
|
Note the use of tf.constant to ensure that the argument is a Tensor. If the argument were a Python value, then additional graphs will be traced for each such value. For example, the following two calls will result in two additional graphs being traced:
1 | f(1.0) |
Where arguments are not Tensors, the “value” of the argument is used to compute the trace_cache_key. For example:
1 |
|
will result in 2 graphs being created, since the two calls result in two different cache keys because the value of the Python object (the second argument) changes between the two.
Note that the “type” of Tensor inputs to the function also incorporates the shape. For example:
1 |
|
will result in 3 graphs being created.
The trace_cache_key also incorporates the “context” in which the call was made. For example:
1 |
|
Will create 2 graphs.
CAUTION: Too many traces
CAUTION: Mutable non-Tensor arguments
The trace_cache_key includes the Python object for non-Tensor arguments. Mutations of these arguments might not be detected. For example:
1 | # non-Tensor object |
Input Signatures
An “input signature” can be explicitly specified to control the trace_cache_key computation based on the type and shape of Tensor (and list of Tensor) arguments to f.
1 |
- For a Tensor argument, it specifies a (dtype, shape pattern).
- (tf.float32, [None]) means the argument must be a float32 vector (with any number of elements).
- (tf.int32, []) means that the argument must be an int32 scalar.
- For a list of Tensor objects, it specifies an optional list length and the signature for elements in the list (i.e., the dtype and shape pattern for all elements in the list).
- For non-Tensor arguments: tf.PYTHON_VALUE
You can use the tf.TRACE_ON_NEW_VALUE to release the restriction of dtype:
1 |
Classes
If a member function of a class does not create variables, it may be decorated with @tf.function and it will work:
1 | class AnyShapeModel(object): |
The semantics here are that each new instance is allowed to create variables in each @tf.function once.
function-ing Python control flow
If the function has data-dependent control flow then though the function will execute fine with eager execution enabled, function decorating it will fail. For example:
1 | def f(x, y): |
To fix this, one would have to use the graph construction APIs for control flow (tf.cond, tf.while_loop):
1 | def f(x, y): |
This situation can be improved with the help of autograph to allow expression of control flow in Python.
1 | df = tf.function(autograph=True)(f) |