Status | Accepted |
---|---|
Author(s) | ashankar@google.com, joshl@google.com |
Sponsor | apassos@google.com |
Updated | 2018-10-02 |
This document presents a proposal to make TensorFlow be more "Pythonic" in 2.0. In five bullet points, the proposal is to:
- Encourage the encapsulation of graph computation as Python functions
(where the graph is executed when the function is invoked, instead of viaSession
) - Align "state" in the TensorFlow runtime (e.g., resource tensors like those that back
tf.Variable
objects) with state in the Python program (e.g., Python objects corresponding to the runtime state with lifetimes attached to each other). - Make it easy to export these encapsulations to a
GraphDef
+Checkpoint and/orSavedModel
. - Enable eager execution by default.
- Provide a path for incorporating existing code that uses the 1.x APIs to construct TensorFlows graphs as functions in TensorFlow 2.x programs.
This document liberally employs the use of sample code to describe the end-user effect of proposed changes.
(We say "encourage" instead of "require" since removing the Session API from the Python frontend within a year may be an unrealistic aspiration. Particularly given the use in Estimators and the use of MonitoredSession and hooks. The Session
API may have to stick around in tf.compat.v1
.)
Today, the TensorFlow graph defines the union of all computation that the author of the graph may be interested in. The actual computation to execute is defined by the arguments to tf.Session.run
. Once this subgraph is defined, the runtime can optimize and execute. For example, consider the following:
import tensorflow as tf
x = tf.placeholder(tf.float32)
y = tf.square(x)
z = tf.add(x, y)
sess = tf.Session()
z0 = sess.run([z], feed_dict={x: 2.}) # 6.0
z1 = sess.run([z], feed_dict={x: 2., y: 2.}) # 4.0
Though there is one tf.Graph
object the user is interacting with (tf.get_default_graph()
), the two sess.run
calls are executing different programs (indeed the runtime ends up with two separate Graph
objects in C++, one for each program), equivalent to:
def compute_z0(x):
return tf.add(x, tf.square(x))
def compute_z1(x, y):
return tf.add(x, y)
The core proposal of this document is the alignment between computation expressed in Python and the computation executed by the runtime. Instead of defining a graph and then selecting the subgraph to execute at sess.run()
time, the exact computation of interest is encapsulated in a Python callable. For example, the program above that uses sess.run()
to compute z0
and z1
can be written as:
import tensorflow as tf
@tf.function
def compute_z1(x, y):
return tf.add(x, y)
@tf.function
def compute_z0(x):
return compute_z1(x, tf.square(x))
z0 = compute_z0(2.)
z1 = compute_z1(2., 2.)
Where tf.function
is a decorator that "defines a TensorFlow function". A "TensorFlow function" defines a computation as a graph of TensorFlow operations, with named arguments and explicit return values. Users define the function they want TensorFlow to "accelerate" as a Python function and integrate it into their Python program like any other Python function call.
Having the Python function correspond to what the runtime will execute reduces conceptual complexity in translating between the two domains. It also affords an opportunity to provide more helpful stacktraces on errors. More advanced features available today (e.g., carving sub-graphs, feeding intermediate values) will still be possible (discussed later), though most users should not need to think in terms of graphs, feeds, and fetches. The constructed graph also provides a natural point for accelerators/acceleration libraries (NVIDIA TensorRT, Google Cloud TPUs etc.) to hook in for rewrites.
function
constructs a TensorFlow graph by "tracing" the TensorFlow operations executed by the Python function. Specifically:
f
is a Python function that returns zero or moreTensor
sfunction(f)
is a Python function that returns a Python callable,F
- When
F
is invoked it:- Potentially casts inputs to tensors if an input signature was specified, see the "Input Signatures" section below.
- Determines a "trace_cache_key" (based on the types and/or values of the arguments).
- Every time a new trace_cache_key is encountered, it invokes
f
to create a TensorFlow graph,G
. If the trace_cache_key has been seen before, it looks upG
from a cache. - It executes the graph defined by
G,
feeding each argument as a value of the corresponding node in the graph, and returns a tuple ofTensor
s (or list ofTensor
s).
A function
decorated Python function encapsulates a graph and its execution. The Python function may reference stateful objects (i.e., state backed by DT_RESOURCE
tensors in the runtime, e.g., tf.Variable
) by referencing the corresponding Python object, and these will be captured as implicit inputs to the function.
Comparing TensorFlow code today with how we propose it looks in 2.x:
TensorFlow 1.x | 2.0 |
W = tf.Variable( tf.glorot_uniform_initializer()( (10, 10))) b = tf.Variable(tf.zeros(10)) c = tf.Variable(0) x = tf.placeholder(tf.float32) ctr = c.assign_add(1) with tf.control_dependencies([ctr]): y = tf.matmul(x, W) + b init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) print(sess.run(y, feed_dict={x: make_input_value()})) assert int(sess.run(c)) == 1 |
W = tf.Variable( tf.glorot_uniform_initializer()( (10, 10))) b = tf.Variable(tf.zeros(10)) c = tf.Variable(0) @tf.function def f(x): c.assign_add(1) return tf.matmul(x, W) + b print(f(make_input_value()) assert int(c) == 1 |
Worthy of note here - in TensorFlow 1.x, the memory underlying the variables W
and b
in the runtime lives for the lifetime of the Session
- unrelated to the lifetime of the Python objects. In 2.x, the lifetime of the Python objects and the runtime state are tied together.
In TensorFlow graphs today, control dependencies are sometimes needed to ensure correct evaluation order. For example, consider the following:
v = tf.Variable(1.0)
init_op = tf.global_variables_initializer()
assign_op = v.assign(2.0)
read = v.read_value()
with tf.Session() as sess:
sess.run(init_op)
val = sess.run(read)
print(val) # Will print 1.0, the assign is ignored
val = sess.run([read, assign_op])[0]
print(val) # Non-deterministically prints 1.0 or 2.0,
The output here is not deterministic, since val
may evaluate to either 1.0 or 2.0 depending on whether the runtime happened to execute assign_op
before read
or not. tf.control_dependencies
is a mechanism provided to add annotations at graph construction time to influence graph execution. The TensorFlow user, a Python programmer, is thus forced to think about two execution models - TensorFlow graphs and the Python interpreter. To eliminate this cognitive load, function
will automatically insert control dependencies to ensure that (1) operations that produce or consume a given DT_RESOURCE
tensor and (2) operations that are marked stateful (REGISTER_OP(...).SetIsStateful()
) follow graph construction order. Thus:
v = tf.Variable(1.0)
@tf.function
def f():
v.assign(2.0)
return v.read_value()
print(f()) # Always prints 2.0.
Note that the intention here is to avoid observable differences from program order. For example:
a = tf.Variable(1.0)
b = tf.Variable(1.0)
@tf.function
def f():
a.assign(2.0)
b.assign(3.0)
return a + b
print(f())
Will always print 5.0 since the assignments will occur before the read. However, there is no guaranteed ordering between the assignment of a
and b
(as any difference in that is not observable).
A preview of this implemented in tf.contrib.eager.defun
today (using AutomaticControlDependencies
).
In the above code, no tf.Variable
objects are created inside a tf.function
decorated function. This makes it clear that the code will have the same semantics once wrapped.
Note that if the function naturally creates state only on the first trace, all is well:
v = None
@tf.function
def f(x):
global v
if v is None:
v = tf.Variable(1.0)
return tf.cast(x, tf.float32) + v
f(tf.constant(1, dtype=tf.float32)) # Creates the variable, returns 2.0
f(tf.constant(2, dtype=tf.int32)) # Reuses the variable, returns 3.0
To support this function
imposes some requirements on the decorated function:
- State (like
tf.Variable
objects) are only created the first time the functionf
is called.
How that is accomplished is left up to the implementation off
.
If any variables are created in the first execution off
, then@tf.function
will tracef
again the second time it is invoked in order to record the behavior that will be used from then on. No variables may be created during that second trace, or any other trace after that (due to different dtypes, shapes, or non-tensor arguments). - The caller must make sure that any variable referenced by the function still exists whenever the function is evaluated.
@tf.function
itself will keep only weak references to these created variables. Thus, if the referenced state does not exist when the decorated function is invoked, an exception will be raised.
In the future we may want to allow for function local tf.Variable
s, which are created and destroyed each time the decorated function is invoked.
Every argument to a function
decorated Python function (F
) must be either:
- A
Tensor
object (NumPyndarray
s are converted to the equivalentTensor
), or - A list of
Tensor
objects, or - An arbitrary Python value.
(There seems to be some interest expressed in supporting structured inputs using nest.flatten
and nest.pack_sequence_as
. This will be considered as follow-up work.)
Every time F
is invoked in the Python program, a trace_cache_key
is computed as a function of:
- The element datatype and shape of every
Tensor
argument - The length of the list, and (dtype, shape) of every element in the list of
Tensor
argument - The concrete value of non-
Tensor
(and list ofTensor
) Python object arguments - The "context" in which
F
is invoked (e.g., the device prescribed by thetf.device()
scope in whichF
is invoked).
This key is used to determine if a new graph needs to be created or if a previously created graph can be invoked.
Since new graphs are traced when new input signatures are encountered, a function
can encapsulate multiple graphs. For example, consider the following:
@tf.function
def f(x):
return tf.square(x)
f(tf.constant(1, dtype=tf.int32))
f(tf.constant(1.0, dtype=tf.float32))
There are two graphs created here - one which corresponds to the Square
operation applied to DT_INT32
tensors, and one with the Square
operation applied to DT_FLOAT32
tensors. The object returned by function
encapsulates multiple graphs (lazily generated based on the type and shape of input arguments), multiplexing between them in __call__
.
Note the use of tf.constant
to ensure that the argument is a Tensor
. If the argument were a Python value, then additional graphs will be traced for each such value. For example, the following two calls will result in two additional graphs being traced:
f(1.0)
f(2.0)
Where arguments are not Tensor
s, the "value" of the argument is used to compute the trace_cache_key
. For example:
@tf.function
def f(x, use_multiply):
return tf.multiply(x, x) if use_multiply else tf.square(x)
f(tf.constant(2.0), True)
f(tf.constant(2.0), False)
will result in 2 graphs being created, since the two calls result in two different cache keys because the value of the Python object (the second argument) changes between the two.
Note that the "type" of Tensor
inputs to the function also incorporates the shape. For example:
@tf.function
def f(x): return tf.add(x, 1.)
f(tf.constant([2.0]))
f(tf.constant([2.0, 3.0]))
f(tf.constant([[2.0]]))
f(tf.constant([3.0]))
f(tf.constant([4.0, 5.0]))
will result in 3 graphs being created:
- One for when the first argument is a
tf.float32
vector with 1 element - One for when the first argument is a
tf.float32
vector with 2 elements - One for when the first argument is a
tf.float32
1x1 matrix
The trace_cache_key also incorporates the "context" in which the call was made. For example:
@tf.function
def f(x): return tf.add(x, 1.)
with tf.device("/device:CPU:0"):
f(tf.constant(2.0))
with tf.device("/device:GPU:0"):
f(tf.constant(2.0))
Will create 2 graphs, one where the operations are pinned to the CPU device and one where they are pinned to the GPU device.
Since new traces are generated on demand, the object returned by function
may hold on to more resources than the user may realize. Possible mitigations:
- Garbage collect the graphs when the weak reference to any component of the
trace_cache_key
is no longer alive. - Use input signatures to prevent unnecessary retraces (see "Input Signatures" section below)
- Raise / log an error when the ratio of calls to traces is greater than some threshold (e.g., if every 2 calls to a
function
decorated function generates a new graph).
The trace_cache_key includes the Python object for non-Tensor
arguments. Mutations of these arguments might not be detected. For example:
class Params(object):
multiply = True
p = Params()
@tf.function
def f(x, y):
return tf.multiply(x, 2.) if y.multiply else tf.add(x, 2.)
f(3., p) # Returns 6.0
p.multiply = False
f(3., p) # Mutations to `p` may not trigger a retrace, so might still return 6.0
Tracing the decorated function to create a new graph on each input shape is a conservative choice. Often the same graph suffices for Tensor
s of multiple shapes. As a trivial example, consider:
@tf.function
def f(x): return tf.add(x, 1.)
f(tf.constant(1.0)) # Scalar argument
f(tf.constant([1.0, 2.0])) # Vector argument
f(tf.constant([[3.0]])) # Matrix
This snippet would result in 3 graphs being traced. An "input signature" can be explicitly specified to control the trace_cache_key
computation based on the type and shape of Tensor
(and list of Tensor
) arguments to f
.
For example:
@tf.function(input_signature=((tf.float32, [None]))
def f(x): return tf.add(x, 1.)
f(tf.constant([2.0])) # Returns [3.0]
f(tf.constant([2.0, 3.0])) # Matches the input signature as [None]
# matches the actual shape [2]
f(tf.constant([[2.0]])) # Raises an error as the arguments don't match the
# input signature.
f(tf.constant([2], dtype=tf.int32)) # Raises an error as the dtype of the argument
# does not match the input signature
# f is backed by a single Graph since the input signature specification allowed
# for the same graph to be used when the input shape is (1,) or (2,).
An "input signature" specifies a pattern for each of the arguments that may be accepted by the function
-decorated function. Specifically:
- For a
Tensor
argument, it specifies a (dtype, shape pattern).
For example:(tf.float32, [None])
means the argument must be a float32 vector (with any number of elements).(tf.int32, [])
means that the argument must be an int32 scalar.
In this case, non-Tensor
Python values provided at call time are automatically converted (usingtf.convert_to_tensor
) to aTensor
matching this signature.
- For a list of
Tensor
objects, it specifies an optional list length and the signature for elements in the list (i.e., the dtype and shape pattern for all elements in the list). - For non-
Tensor
arguments:tf.PYTHON_VALUE
When an input signature is specified, new graphs are traced only when the value of the Python argument or the context in which the function is invoked changes. If this is considered to be too restrictive, a possible future extension would be to annotate signature of an argument so that new traces can be created. For example:
@tf.function(input_signature=((tf.TRACE_ON_NEW_VALUE, [None]))
def f(x): return tf.square(x)
f(tf.constant([2.0])) # Returns 4.0
f(tf.constant([2, 2], dtype=tf.int32) # Returns [4, 4] after tracing a new graph
We've introduced a single new symbol: function
that consumes a Python function and returns a callable Python object. The precise API of the object is being iterated on in go/tf-2.0-function-api, but at a high level it will have methods to:
- List out all captured state (
tf.Variable
objects, otherDT_RESOURCE
tensors used by the computation and provided as implicit inputs). - Access the
tf.Graph
that corresponds to the graph executed by the__call__
method of the object. - Execute the function with custom
RunOptions
and retrieveRunMetadata
.
If a member function of a class does not create variables, it may be decorated with @tf.function
and it will work:
class ScalarModel(object):
def __init__(self):
self.v = tf.Variable(0)
@tf.function
def increment(self, amount):
self.v.assign_add(amount)
model1 = ScalarModel()
model1.increment(tf.constant(3))
assert int(model1.v) == 3
model1.increment(tf.constant(4))
assert int(model1.v) == 7
model2 = ScalarModel()
model2.increment(tf.constant(5))
assert int(model2.v) == 5
This works since increment()
has self
as a non-tensor argument, and a new trace will be created for each value of self
. However, if variables are created in a method, we want to allow a new set of variables for every instantiation of self
.
class AnyShapeModel(object):
def __init__(self):
self.v = None
@tf.function
def increment(self, amount):
if self.v is None:
self.v = tf.Variable(tf.zeros_like(amount))
self.v.assign_add(amount)
model1 = AnyShapeModel()
model1.increment(tf.constant(3))
assert int(model1.v) == 3
model1.increment(tf.constant(4))
assert int(model1.v) == 7
model2 = AnyShapeModel()
model2.increment(tf.constant([4, 5]))
assert model2.v.numpy() == [4, 5]
The semantics here are that each new instance is allowed to create variables in each @tf.function
once.
In addition, as long as all variable creation/initialization happens while we are tracing, we should be able to support exporting the initialization graph when exporting a SavedModel
or MetaGraphDef
.
The definition of tf.function
above is careful to check that invoking a decorated Python function would have the same behavior as invoking an undecorated function. This is to guard against it being passed code from TensorFlow v1.x that expects to only be called once (and relies on things like graph collections to track which variables are created), for example:
def f(x, do_add):
v = tf.Variable(5.0)
if do_add:
v.assign_add(x)
else:
v.assign_sub(x)
return v
For this case, we use a different API, tf.compat.v1.wrap_function
, that treats any created variables as static local state:
f_add = tf.compat.v1.wrap_function(f, tf.TensorSpec(tf.float32, ()), True)
assert float(f_add(1.0)) == 6.0
assert float(f_add(1.0)) == 7.0
# Can call tf.compat.v1.wrap_function again to get a new trace, a new set
# of variables, and possibly different non-template arguments.
f_sub = tf.compat.v1.wrap_function(f, tf.TensorSpec(tf.float32, ()), False)
assert float(f_sub(1.0)) == 4.0
assert float(f_sub(1.0)) == 3.0
Note these differences from tf.function
:
- Only ever traces
f()
once (per call totf.compat.v1.wrap_function
). - The complete input tensor signature (via
tf.TensorSpec
calls) and the values of all non-tensor arguments must be specified when wrapping the function. Note: we may want atf.tensor_like(x)
convenience function that returnstf.TensorSpec(x.dtype, x.shape)
. - Will include extra TF v1.x compatibility features like collections, and access v1.x APIs like
tf.compat.v1.get_variable()
- Will not automatically insert control dependencies to maintain program order across stateful operations/state accesses.
- May only use a function or Python constant to initialize variables, no tensors. This is a technical limitation, required by the fact that we need some way of disentangling the initializers for variables from the other operations from the function.
- Keeps strong references to variables created in f, weak references to other variables accessed by f. This is to match the v1.x graph behavior that variables have the lifetime of the graph they are created, and can generally be accessed through graph collections. Some common patterns of writing v1.x code don't leave any references to those variables around. Keeping references to those variables extends their lifetime to match that of the object returned by
tf.compat.v1.wrap_function
. - Typically won't be used as a decorator. Calling
tf.compat.v1.wrap_function
takes some arguments, traces the function, and creates an object with state. The lifetime of the return value should be tracked explicitly by saving it in a variable.
Treating state (like tf.Variable
) as static local does mean that the behavior of a tf.compat.v1.wrap_function
-decorated Python function differs from that of an undecorated one. In the above example, f(1.0, True)
will always return 6.0 (as a scalar Tensor
), while each call to f_add(1.0)
will return a different value. We propose this separate tf.compat.v1.wrap_function
endpoint specifically to make it easy to migrate TensorFlow 1.x libraries to the TensorFlow 2.0. The behavior of 2.0 tf.function
is restricted to cases where we can say that the behavior will match.
We recognize that code written for TensorFlow 1.x commonly does not encapsulate state in Python objects, instead adding to hidden (graph-)global collections. We will support code that accesses collections inside a tf.compat.v1.wrap_function
, though those collections will be local to a single trace.
With the tf.compat.v1.wrap_function
proposed above, most graph construction library functions written against TensorFlow 1.x can be incorporated into TensorFlow 2.x programs.
def f(x):
W = tf.compat.v1.get_variable(name="weight", shape=[10, 10])
b = tf.compat.v1.get_variable(name="bias", shape=[10],
initializer=tf.zeros_initializer())
c = tf.Variable(0, dtype=tf.int32, name="counter")
with tf.control_dependencies([c.assign_add(1)]):
return tf.matmul(x, W) + b
f = tf.compat.v1.wrap_function(f, tf.placeholder(tf.float32, None))
print(f(make_input_value()))
assert len(f.variables) == 3
assert f.variables[0].name == "weight"
In this case, the object returned by tf.compat.v1.wrap_function
owns the state created within f
, and the __call__
method on it invokes the corresponding computation.
Long story short, tf.compat.v1.wrap_function
helps in incorporating graph construction code written against TensorFlow 1.x into TensorFlow 2.x programs. wrap_function
constructs the same object as a function
decorated function, which provides the conceptual equivalent of graph construction and Session.run
.
So far we've only considered Python programs. One of the key features of TensorFlow is the ability to integrate models created (and possibly trained) in a Python program into an application written in another programming language and/or platform (e.g., servers, mobile phones, self-driving cars). This ability will of course remain, with a smoother path to exporting models.
In TensorFlow 1.x, "saving a model" could mean one of three things:
- Saving parameter values, but not the computation:
A "checkpoint" containing the values of all model parameters.
(tf.train.Saver
/tf.train.Checkpoint
)
Restoring this model required that the restoring program duplicate the Python code to construct the graph with the same model parameters. - Saving the computation graph, but not the parameter values:
The computation is represented by aGraphDef
that can be exported by calls totf.Graph.as_graph_def()
, ortf.train.export_meta_graph()
, and reconstructed by calls totf.import_graph_def()
/tf.train.import_meta_graph()
.
Note that the parameter (tf.Variable
) values are not saved, but their initializers are. - Saving both the computation and the parameter values:
The two packaged together in a SavedModel.
At a high level, the SavedModel format packages theMetaGraphDef
, checkpoint, and a signature (names of input and output tensors).
(tf.saved_model.simple_save
/tf.saved_model.builder.SavedModelBuilder
)
This is the format preferred for exporting for serving via TensorFlow Serving or to other languages (e.g.,SavedModelBundle.load()
in Java,LoadSavedModel
in Go)
The objects created by function
encapsulate (1) the computation expressed as a GraphDef
, (2) the state used by it. Thus, these objects are naturally suited for import/export in any of the above formats, using something like the following:
TensorFlow 1.x | 2.x |
Save only the parameters, not the computation | |
W = tf.get_variable( "weights", shape=[10, 10]) # Presumably the train_op is # a little fancier train_op = W.assign_add(1.) saver = tf.train.Saver() with tf.Session() as sess: sess.run(W.initializer) sess.run(train_op) saver.save(sess, "/tmp/checkpoint/") with tf.Session() as sess: saver.restore(sess, "/tmp/checkpoint/") sess.run(train_op) |
W = tf.Variable( tf.glorot_uniform_initializer()( (10, 10))) @tf.function def train(): W.assign_add(1.) train() ckpt = tf.train.Checkpoint(W=W) ckpt.save("/tmp/checkpoint") ckpt.restore("/tmp/checkpoint") |
Exporting/Importing GraphDefs
|
|
W = tf.get_variable("weights", shape=[10, 10]) x = tf.placeholder( tf.float32, shape=(None, 10))) y = tf.matmul(x, W) graph = tf.get_default_graph() graph_def = graph.as_graph_def() with open("/tmp/graph.pb", "w") as f: f.write( graph_def.SerializeToString()) tf.reset_default_graph() graph_def = tf.GraphDef() with open("/tmp/graph.pbtxt") as f: graph_def.ParseFromString(f.read()) tf.import_graph_def(graph_def) |
W = tf.Variable( tf.glorot_uniform_initializer()( (10, 10))) @tf.function def f(x): return tf.matmul(x, W) # Retrieve the object corresponding to # a particular input signature: graph = f.graph_function( (tf.float32, (None, 10)).graph graph_def = graph.as_graph_def() with open("/tmp/graph.pb", "w") as f: f.write(graph_def.SerializeToString()) \ |
Exporting/Importing SavedModels | |
def save_model(): W = tf.get_variable("weights", shape=[10, 10]) x = tf.placeholder( tf.float32, shape=(None, 10)) y = tf.matmul(x, W) with tf.Session() as sess: sess.run( tf.global_variables_initializer()) tf.saved_model.simple_save( sess, "/tmp/model", inputs={"x": x}, outputs={"y": y}) def load_model(): sess = tf.Session() with sess.as_default(): inputs, outputs = tf.saved_model.simple_load(sess, "/tmp/model") return inputs, outputs, sess |
To be worked on but something along the lines of:
class Model(tf.train.Checkpointable): def __init__(self): self.W = tf.Variable(...) @tf.function def f(self, x): return tf.matmul(x, self.W) m = Model() tf.saved_model.export("/tmp/model", m) m = tf.saved_model.import("/tmp/model") |
One reservation expressed by TensorFlow graph/session enthusiasts today is that the ability to write generic analysis/inspection tooling on graphs, precluding the need to understand or modify the Python code that constructed the graph, is important to them. To put it differently, some find it easier to navigate the GraphDef
program than navigating the Python program. \
This ability will be maintained. function
-decorated Python functions have an associated graph, and new functions can be created by specifying the sub-graph of interest. For example:
TensorFlow 1.x | TensorFlow 2.x |
Carving out a subgraph | |
def build_graph(): x = tf.placeholder(tf.float32) y = tf.square(x) z = tf.square(y) with tf.Session() as sess: build_graph() sess.run("Square_1:0", feed_dict={"Square:0": 2.0}) # 4.0 |
@tf.function def f(x): return tf.square(tf.square(x)) # tf.Graph corresponding to "x" # being a float32 tensor with unknown # shape graph = f.graph_function( (tf.float32, None)).graph f2 = tf.NewGraphFunction( graph, inputs=["Square:0"], outputs=["Square_1:0"]) # The above may optionally take a # "prune" argument to allow for # pruning stateful operations in # `graph` that are not in the path # from inputs to outputs. f2(2.0) # 4.0 |
Extending a graph | |
def build_graph(): x = tf.placeholder(tf.float32) y = tf.square(x) return y y = build_graph() z = tf.square(y) with tf.Session() as sess: # Line below will return 16.0 sess.run(z, feed_dict={"Placeholder:0": 2.0)) |
@tf.function def f(x): return tf.square(x) @tf.function def g(x): return tf.square(f(x)) g(2.0) # 16.0 |
At the lowest level of the API, distributed execution continues to work with tf.device
annotations, where the device name can reference remote devices as well, just like they do today.
The DistributionStrategy
API, typically aimed at synchronous training will continue to be the method of choice (where the API can be used inside a function
). Other APIs such as go/tf-replicator will also be usable.
The author realizes that this section can do with more detail. However, to keep this document more focused, these details will be discussed separately. In particular, usage of MonitoredSession
and session hooks today needs additional thought.
function
decorates a graph construction function and transparently recreates graphs if needed. However, this does mean that if the function has data-dependent control flow then though the function will execute fine with eager execution enabled, function
decorating it will fail. For example:
def f(x, y):
if tf.equal(y, 0.0):
return y
return x / y
x = tf.constant(2.0)
y = tf.constant(2.0)
f(x, y) # Will be 1.0
df = tf.function(f)
df(x, y) # Will raise an error complaining about the data-dependent control flow
To fix this, one would have to use the graph construction APIs for control flow (tf.cond
, tf.while_loop
):
def f(x, y):
return tf.cond(tf.equal(y, 0.0), lambda: y, lambda: x/y)
x = tf.constant(2.0)
y = tf.constant(2.0)
f(x, y) # Will be 1.0
df = tf.function(f)
df(x, y) # Will be 1.0
This situation can be improved with the help of autograph to allow expression of control flow in Python. Whether autograph will be enabled by default or not is still under debate, but the option will be there as a flag on function. For example:
df = tf.function(autograph=True)(f)
f(x, y) # Will be 1.0
The summary writing operations (tb.summary.scalar
, tb.summary.image
etc.) can be naturally placed in the graph by using them in a function
-decorated function. These operations require two "external" inputs - the summary writer resource and the condition, that will be picked up from the context (e.g., tb.summary.create_file_writer
and tb.summary.record_summary_every_n_global_steps
). When defining the graph, these inputs are converted to placeholders, which are then resolved at function invocation time. Thus, something like this:
writer = tf.contrib.summary.create_file_writer('/tmp/test')
with writer.as_default(), tf.contrib.summary.always_record_summaries():
f()
with writer.as_default(), tf.contrib.summary.never_record_summaries():
f()
Will write one summary to writer
whether f
is defined as:
def f():
tb.summary.scalar("loss", compute_loss())
Or
f = tf.contrib.eager.defun(f)
(NOTE: As of August 2018, this is not the case, but it will be. See b/112269952).
Note that the runtime is free to prune away the summary writing operations when the function is invoked in a context where there is no summary writer resource or the condition is false.
So far this proposal has dealt with the encapsulation of TensorFlow graphs in Python functions with the intention of making it easier to integrate TensorFlow-accelerated computation in Python programs.
Additionally, this proposal suggests enabling eager execution by default in TensorFlow 2.0. Keeping function
in mind, this basically means:
- Inside the context of defining a TensorFlow function (i.e., within a
function
decorated function)tf.Tensor
objects created refer to symbolic tensors. - Outside this context,
tf.Tensor
objects created are backed by concrete values and TensorFlow API. The underlying memory of the tensor can be backed by any device (i.e., CPU/GPU) and is not restricted to host-memory (like numpy arrays).
See the docstring for tf.contrib.eager.defun - the evolving playground for the implementation of the proposal in this document. The basic takeaway is that:
- For users that embrace symbolic tensors and graphs, continue doing so with your code placed inside a
function
decorated Python function. - We believe most users (new ones in particular) will find it more convenient to deal with
Tensor
objects backed by concrete values and then selectively "compiling" portions of their Python program into TensorFlow graphs rather than being exposed to graph metaprogramming in Python upfront. In spirit, this is similar to Swift4TensorFlow with the obvious glaring difference that graph program extraction here is manually specified (with thefunction
decoration).
NOTE: In TensorFlow 1.x, eager execution is enabled by tf.enable_eager_execution()
. Once invoked, all public API endpoints that consume or produce symbolic Tensor
objects begin to produce and consume Tensor
objects that are backed by a concrete value. See the "Research and Experimentation" section at www.tensorflow.org/tutorials for an introduction.
- This change only applies to the TensorFlow Python frontend
- TensorFlow.js is already "eager by default".
- Switf4TensorFlow has similar design goals, doing away with the define-then-run style of TensorFlow graphs.
- Most other language bindings (Java, C++, Go, others) are mostly targeting deployment of defined models in applications. While an imperative style might help simplify model development and training in these languages, doing so is explicitly out of scope for TensorFlow 2.0. The notion of graphs and sessions will remain in them, as well as in the stable C API. In these APIs, the lifetime of program state (like variables) will continue to be tied to the lifetime of the
Session
.
- Users of
Estimator
will see no change- Canned Estimators are black boxes that create and train models. Enabling eager execution will have no effect on their usage. This is true today.
- The
model_fn
of a regular (non-canned)Estimator
will remain as a graph construction function.
- SavedModel will continue to be the format encouraged for exporting trained models
- Crudely speaking, a SavedModel encapsulates a Graph, a checkpoint of variable values, and some metadata like signature information (names of input and output tensors).
- A path will be provided to easily export models in this format (e.g., via
tf.keras.Model.save()
). There may be instances where converting the Python code to a graph is not trivial (e.g., it uses the subset of Python that autograph does not support), in which case, exporting to a SavedModel (and thus a Graph) will fail.
How state (DT_RESOURCE
tensors) created inside a function
should be handled is actively being debated. Options include:
- "Lifting" state out as a static local function variable
- Mimic the undecorated code - creating and destroying variables on each call.
tf.contrib.eager.function
today treats state as function-static variables, which allows for code like:
def f(x):
v = tf.Variable(1, dtype=x.dtype)
v.assign_add(x)
return v
df = tf.contrib.eager.defun(f)
# tf.function(f) proposed in this document will raise an exception on first use
x = tf.constant(1, dtype=tf.float32))
print(df(x)) # 2.0
print(df(x)) # 3.0
However, the one major issue with this approach is that it behaves differently from how an undecorated function would:
print(f(1.0), df(1.0)) # 2.0, 2.0
print(f(1.0), df(1.0)) # 2.0, 3.0
To be conservative, we propose some restrictions on function
, such as:
- State is created only once, i.e.,
function
will fail if callingf
a second time results in new state being created. function
decorated functions can only produceTensor
return values.- If you want to convert TF v1.x code like
f
above, you may usetf.compat.v1.wrap_function
which guarantees it will only tracef
once.
Another option would be to match typical Python functions, where state is created and destroyed during the call to the function. So:
def f(x):
v = tf.Variable(1.0)
v.assign_add(x)
return v
df = tf.function(f)
assert f(1.0) == df(1.0) # Both will be 2.0
assert f(1.0) == df(1.0) # Still 2.0, since 'v' would be recreated.
This seems like an avenue definitely worth pursuing, but requires careful consideration of some additional design points such as escape analysis of return values (e.g. the lifetime of tf.Variable
objects that are returned from a decorated function).
For now, we propose that function
continue with the restricted abilities proposed in this document and a "maintain Python semantics" decorator be investigated independently.
- Naming:
tf.compat.v1.wrap_function
ortf.compat.v1.defun
ortf.compat.v1.function
ortf.compat.v1.wrap_graph_as_function
?
- Signatures in Python 3? (From ngc92)
- Supporting structured inputs:
As proposed, arguments tofunction
must be eitherTensor
objects, or objects that can be converted to aTensor
(tf.convert_to_tensor
), or opaque Python objects.
Perhaps we can support nested structures ofTensor
s (usingnest.flatten
andnest.pack_sequence_as
), or even arbitrary Python objects?
If this is supported, then specifying aninput_signature
may become cumbersome, but perhaps we can have afunction(infer_signature_from_first_call=True)
to make that easier.
\