-
Notifications
You must be signed in to change notification settings - Fork 75.3k
Expand file tree
/
Copy pathfunc_graph.py
More file actions
1186 lines (1040 loc) · 48.6 KB
/
func_graph.py
File metadata and controls
1186 lines (1040 loc) · 48.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""FuncGraph and related functionality."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections as py_collections
import itertools
import weakref
import numpy as np
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.eager import context
from tensorflow.python.eager import execute
from tensorflow.python.eager import tape
from tensorflow.python.eager.graph_only_ops import graph_placeholder
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import type_spec
from tensorflow.python.framework.auto_control_deps import AutomaticControlDependencies
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import compat
from tensorflow.python.util import memory
from tensorflow.python.util import nest
from tensorflow.python.util import object_identity
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_decorator
from tensorflow.python.util.lazy_loader import LazyLoader
# This is to avoid a circular dependency:
# function -> func_graph
function = LazyLoader("function", globals(),
"tensorflow.python.eager.function")
def_function = LazyLoader(
"def_function", globals(),
"tensorflow.python.eager.def_function")
WHITELIST_COLLECTIONS = [
ops.GraphKeys.GLOBAL_VARIABLES,
ops.GraphKeys.LOCAL_VARIABLES,
ops.GraphKeys.TRAINABLE_VARIABLES,
variable_scope._VARSTORE_KEY, # pylint: disable=protected-access
variable_scope._VARSCOPESTORE_KEY # pylint: disable=protected-access
]
_EAGER_CONST_THRESHOLD = 128
class UnknownArgument(object):
"""Signifies an argument which is not currently handled."""
pass
def convert_structure_to_signature(structure, arg_names=None):
"""Convert a potentially nested structure to a signature.
Args:
structure: Structure to convert, where top level collection is a list or a
tuple.
arg_names: Optional list of arguments that has equal number of elements as
`structure` and is used for naming corresponding TensorSpecs.
Returns:
Identical structure that has TensorSpec objects instead of Tensors and
UknownArgument instead of any unsupported types.
"""
def encode_arg(arg, path):
"""A representation for this argument, for converting into signatures."""
if isinstance(arg, ops.Tensor):
user_specified_name = None
try:
user_specified_name = compat.as_str(
arg.op.get_attr("_user_specified_name"))
except ValueError:
pass
if path and user_specified_name and user_specified_name != path[0]:
# The user has explicitly named the argument differently than the name
# of the function argument.
name = user_specified_name
else:
name = "/".join([str(p) for p in path])
return tensor_spec.TensorSpec(arg.shape, arg.dtype, name)
if isinstance(arg, composite_tensor.CompositeTensor):
# TODO(b/133606651) Do we need to inject arg_name?
return arg._type_spec # pylint: disable=protected-access
if isinstance(arg, (
int,
float,
bool,
type(None),
dtypes.DType,
tensor_spec.TensorSpec,
type_spec.TypeSpec,
)):
return arg
return UnknownArgument()
# We are using the flattened paths to name the TensorSpecs. We need an
# explicit name for them downstream.
flattened = nest.flatten_with_tuple_paths(structure)
if arg_names:
if len(arg_names) != len(structure):
raise ValueError(
"Passed in arg_names don't match actual signature (%s)." % arg_names)
# Replace all top-level names with their actual arg_names. If a path before
# was "(2,'a',1)", it will become "(arg_names[2],'a',1)".
flattened = [
((arg_names[path[0]],) + path[1:], arg) for path, arg in flattened
]
mapped = [encode_arg(arg, path) for path, arg in flattened]
return nest.pack_sequence_as(structure, mapped)
class FuncGraph(ops.Graph):
"""Graph representing a function body.
Attributes:
name: The name of the function.
inputs: Placeholder tensors representing the inputs to this function. The
tensors are in this FuncGraph. This represents "regular" inputs as well as
captured inputs (i.e. the values of self.captures), with the regular
inputs coming first.
outputs: Tensors that will be returned by this function. The tensors are in
this FuncGraph.
control_outputs: Operations that must be executed before the function
represented by this graph can be said to have been executed.
structured_input_signature: A tuple of (args, kwargs), which are both
possibly-nested python objects that were received by this function. Note
that these structures might contain Python `None`s.
structured_outputs: A possibly-nested python object which will be returned
by this function. The Tensors in this structure are the same as those of
self.outputs. Note that this structure might contain Python `None`s.
variables: Variables that should be watched during function execution.
outer_graph: The graph this function is defined in. May be another FuncGraph
or the global default Graph.
captures: Maps external tensor -> internal tensor (i.e. input placeholder).
The entries are in the order they were captured.
control_captures: Set of external ops on which this graph has a control
dependency.
seed: The graph-level random seed.
capture_by_value: If True, the func graph will capture Variables by value
instead of reference.
"""
def __init__(self, name, collections=None, capture_by_value=None):
"""Construct a new FuncGraph.
The graph will inherit its graph key, collections, seed, and distribution
strategy stack from the current context or graph.
Args:
name: the name of the function.
collections: a dictionary of collections this FuncGraph should start
with. If not specified (None), the FuncGraph will read (but not write
to) the outer graph's collections that are not whitelisted, and both
read and write to the outer graph's collections that are whitelisted.
The current whitelisted collections are the global variables, the
local variables, and the trainable variables.
Defaults to None.
capture_by_value: An optional boolean. If True, the func graph will
capture Variables by value instead of reference. By default inherit
from outer graphs, and failing that will default to False.
"""
super(FuncGraph, self).__init__()
self.name = name
self.inputs = []
self.outputs = []
self.control_outputs = []
self.control_captures = set()
self.structured_input_signature = None
self.structured_outputs = None
self._weak_variables = []
self._watched_variables = object_identity.ObjectIdentityWeakSet()
self.outer_graph = ops.get_default_graph()
self._captures = py_collections.OrderedDict()
# If not None, records the names of output args of this function. Used to
# preserve the output names in the signature of a serialized+deserialized
# function. Private at the moment mostly because it's often out of date.
self._output_names = None
# Maps arbitrary key -> (closure, nest of placeholders), where at function
# call time the value of closure() will be used to feed the nest of
# placeholders.
self._deferred_captures = py_collections.OrderedDict()
# Inherit capture-by-value from outer graph.
if capture_by_value is not None:
self.capture_by_value = capture_by_value
elif self.outer_graph is not None and isinstance(
self.outer_graph, FuncGraph):
self.capture_by_value = self.outer_graph.capture_by_value
else:
self.capture_by_value = False
self._building_function = True
# Map from resource tensor name to last op (in program order) which uses
# this tensor. Used to enforce that execution order matches program order
# for resource tensors.
self._last_op_using_resource_tensor = {}
graph = self.outer_graph
if context.executing_eagerly():
self.seed = context.global_seed()
# [for tf-data user migration from TF1.0 to 2.0] seed_used keep track of
# any None op_seed for random_op in the function, in which case we end up
# using function seed, which could be unintended behavior for the op.
self._seed_used = False
else:
self.seed = graph.seed
self._seed_used = False
# TODO(allenl): Figure out if we can remove colocation stack
# specialization (currently used in cond_v2), here and in the cache key.
self._colocation_stack = graph._colocation_stack.copy() # pylint: disable=protected-access
if collections is None:
for collection_name in graph.get_all_collection_keys():
if collection_name not in WHITELIST_COLLECTIONS:
self._collections[collection_name] = graph.get_collection(
collection_name)
for collection_name in WHITELIST_COLLECTIONS:
self._collections[collection_name] = graph.get_collection_ref(
collection_name)
else:
self._collections = collections
# Keep track of whether this FuncGraph is exportable to SavedModel. Use
# `graph.mark_as_unsaveable(reason)` to mark this FuncGraph and any
# dependent functions as unsaveable.
self._saveable = True
self._saving_errors = set()
def __str__(self):
return "FuncGraph(name=%s, id=%s)" % (self.name, id(self))
def watch_variable(self, v):
"""Marks the variable v as accessed while building this graph."""
while self is not None and isinstance(self, FuncGraph):
self._watched_variables.add(v)
self = self.outer_graph
def capture_call_time_value(self, closure, spec, key=None):
"""Creates a placeholder which at call time has the value closure().
Useful, for example, to respect TensorFlow context managers, which are often
dynamically scoped.
Args:
closure: function which takes no arguments, to be evaluated at function
call time, returning a nest of tensors compatible with `spec`.
spec: nest of TypeSpec for the value to capture.
key: optional. If not None, multiple calls to lazy_capture with the same
key in the same graph will return the same placeholder, and the
first closure will be used at function call time.
Returns:
Nest of placeholders which, at function call time, will be fed with the
result of calling closure().
Raises:
ValueError: at function call time, if the return value of closure() is
not compatible with `spec`.
"""
if key is None:
key = object()
if key not in self._deferred_captures:
def convert_to_placeholder(s):
if not isinstance(s, tensor_spec.TensorSpec):
raise TypeError(
"Expected a nest of `TypeSpec` objects, found %s of type %s." %
(s, type(s)))
return array_ops.placeholder(dtype=s.dtype, shape=s.shape)
placeholder = nest.map_structure(
convert_to_placeholder, spec, expand_composites=True)
def wrapped_closure():
ret_nest = closure()
nest.assert_same_structure(spec, ret_nest, expand_composites=True)
# This uses the tensor dtype defined in `spec` when converting values
# in `ret_nest` to tensors.
# pylint: disable=protected-access
y = nest.map_structure(lambda s, r: s._to_components(r), spec, ret_nest,
expand_composites=False)
# pylint: enable=protected-access
return nest.flatten(y, expand_composites=True)
self._deferred_captures[key] = (wrapped_closure, placeholder)
return self._deferred_captures[key][1]
def control_dependencies(self, control_inputs):
"""Handles control dependencies.
FuncGraph wraps Graph's control_dependencies logic by first filtering out
any external tensors / operations and storing them in the graph's
control_captures member. Any consumers of this function graph must then
decide how to handle the control captures.
Args:
control_inputs: A list of `Operation` or `Tensor` objects which
must be executed or computed before running the operations
defined in the context. Can also be `None` to clear the control
dependencies.
Returns:
A context manager that specifies control dependencies for all
operations constructed within the context.
Raises:
TypeError: If `control_inputs` is not a list of `Operation` or
`Tensor` objects.
"""
if control_inputs is None:
return super(FuncGraph, self).control_dependencies(control_inputs)
filtered_control_inputs = []
for c in control_inputs:
# Check for _UnreadVariable
if (isinstance(c, ops.IndexedSlices) or
(hasattr(c, "_handle") and hasattr(c, "op"))):
c = c.op
graph_element = ops._as_graph_element(c) # pylint: disable=protected-access
if graph_element is None:
graph_element = c
if graph_element is not None and getattr(
graph_element, "graph", None) is not self:
self.control_captures.add(graph_element)
else:
filtered_control_inputs.append(graph_element)
return super(FuncGraph, self).control_dependencies(filtered_control_inputs)
def as_default(self):
outer_cm = super(FuncGraph, self).as_default()
@tf_contextlib.contextmanager
def inner_cm():
"""Context manager for copying distribute.Strategy scope information."""
graph = ops.get_default_graph()
# pylint: disable=protected-access
# TODO(b/112906995, nareshmodi): distribution strategy depends on
# inheriting this stack from the default graph even in eager mode. Maybe
# it should be part of the eager context? This would also allow us to
# remove a get_default_graph() call from the function cache lookup.
old_strategy_stack = self._distribution_strategy_stack
self._distribution_strategy_stack = list(
graph._distribution_strategy_stack)
# We ignore device placements from any outer scopes while tracing the
# function when possible, to avoid hard-coding them in the function
# graph. "Default" placements come from the PartitionedCallOp's placement,
# so that the same trace of the Python function may be placed on several
# different devices and saved functions may be placed on new devices when
# restored.
old_device_stack = self._device_function_stack
if context.executing_eagerly():
if self._distribution_strategy_stack:
self._device_function_stack = self._device_function_stack.copy()
self._add_device_to_stack(context.context().device_name)
else:
if (self._distribution_strategy_stack
or device_stack_has_callable(graph._device_function_stack)):
# Hard-code devices from device functions in the function body
self._device_function_stack = graph._device_function_stack.copy()
old_creator_stack = self._variable_creator_stack
self._variable_creator_stack = graph._variable_creator_stack
# Inherit the graph key, since this is used for matching variables in
# optimizers.
old_graph_key = self._graph_key
self._graph_key = graph._graph_key
# Inherit the auto_cast_variable_read_dtype, since this should not change
# inside a function.
old_auto_cast_var_read_dtype = self._auto_cast_variable_read_dtype
self._auto_cast_variable_read_dtype = graph._auto_cast_variable_read_dtype
# pylint: enable=protected-access
with outer_cm as g:
try:
yield g
finally:
self._distribution_strategy_stack = old_strategy_stack
self._device_function_stack = old_device_stack
self._variable_creator_stack = old_creator_stack
self._graph_key = old_graph_key
self._auto_cast_variable_read_dtype = old_auto_cast_var_read_dtype
return inner_cm()
@property
def output_types(self):
return [t.dtype for t in self.outputs]
@property
def output_shapes(self):
return [t.shape for t in self.outputs]
@property
def variables(self):
"""A list of variables accessed by this FuncGraph.
Note that functions keep only weak references to variables. Calling the
function after a variable it accesses has been deleted is an error.
Yields:
Strong references to variables accessed by this FuncGraph.
"""
for weak_v in self._weak_variables:
v = weak_v()
if v is None:
raise AssertionError(
"Called a function referencing variables which have been deleted. "
"This likely means that function-local variables were created and "
"not referenced elsewhere in the program. This is generally a "
"mistake; consider storing variables in an object attribute on "
"first call.")
yield v
@variables.setter
def variables(self, var_list):
self._weak_variables = [weakref.ref(v) for v in var_list]
def _capture_by_value(
self,
op_type,
inputs,
dtypes, # pylint: disable=redefined-outer-name
input_types=None,
name=None,
attrs=None,
op_def=None,
compute_device=True):
# When capturing by value, do the read outside
reverse_captures = dict((id(v), k) for k, v in self.captures)
uncaptured_inputs = [reverse_captures.get(id(t), t) for t in inputs]
with ops.init_scope():
if context.executing_eagerly():
attr_list = ("dtype", int(attrs["dtype"].type))
value, = execute.execute(
compat.as_bytes(op_type), 1, uncaptured_inputs, attr_list,
context.context())
else:
op = ops.get_default_graph()._create_op_internal( # pylint: disable=protected-access
op_type,
uncaptured_inputs,
dtypes,
input_types,
name,
attrs,
op_def,
compute_device)
value = op.outputs[0]
captured_value = self.capture(value)
return captured_value.op
def create_op(
self,
op_type,
inputs,
dtypes=None, # pylint: disable=redefined-outer-name
input_types=None,
name=None,
attrs=None,
op_def=None,
compute_shapes=True,
compute_device=True):
"""Like Graph.create_op, except handles external input tensors.
This overload adds functionality to create_op to "capture" any external
input tensors, i.e. tensors from the eager context or outer function graphs
if this is a nested function. See `capture` for more information.
Args:
op_type: The `Operation` type to create. This corresponds to the
`OpDef.name` field for the proto that defines the operation.
inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
dtypes: (Optional) A list of `DType` objects that will be the types of the
tensors that the operation produces.
input_types: (Optional.) A list of `DType`s that will be the types of
the tensors that the operation consumes. By default, uses the base
`DType` of each input in `inputs`. Operations that expect
reference-typed inputs must specify `input_types` explicitly.
name: (Optional.) A string name for the operation. If not specified, a
name is generated based on `op_type`.
attrs: (Optional.) A dictionary where the key is the attribute name (a
string) and the value is the respective `attr` attribute of the
`NodeDef` proto that will represent the operation (an `AttrValue`
proto).
op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
the operation will have.
compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always
computed).
compute_device: (Optional.) If True, device functions will be executed
to compute the device property of the Operation.
Returns:
An `Operation` object.
"""
del compute_shapes
if self.capture_by_value and op_type in ["ReadVariableOp",
"ResourceGather"]:
return self._capture_by_value(op_type, inputs, dtypes, input_types, name,
attrs, op_def, compute_device)
# This capturing logic interacts poorly with control flow contexts which
# want to replace inputs of ops far too late in the process. This can lead
# the context to get confused and try to create an Enter for an Enter. We
# can detect this here and skip the additional Enter which can confuse loop
# validation logic.
if op_type == "Enter" and inputs[0].op.type == "Enter":
if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s:
return inputs[0].op
# Calling AddValue on the control flow contexts to force creation of the
# backward accumulators in the original graph before we create placeholders
# to capture the inputs.
ctxt = ops.get_default_graph()._control_flow_context # pylint: disable=protected-access
for i, inp in enumerate(inputs):
# TPU Estimator defines a control flow context with no AddValue method.
if ctxt is not None and hasattr(ctxt, "AddValue"):
inp = ctxt.AddValue(inp)
inp = self.capture(inp)
inputs[i] = inp
return super(FuncGraph, self)._create_op_internal( # pylint: disable=protected-access
op_type, inputs, dtypes, input_types, name, attrs, op_def,
compute_device)
def capture(self, tensor, name=None):
"""Captures `tensor` if it's external to this graph.
If `tensor` is from a different graph, returns a placeholder for it.
`tensor` and the placeholder will appear in self.captures, and the
placeholder will appear in self.inputs. Multiple calls to this method with
the same `tensor` argument will return the same placeholder. If `tensor` is
from this graph, returns `tensor`.
Args:
tensor: Tensor. May be from this FuncGraph or a different graph.
name: Optional name if a placeholder is created.
Returns:
Tensor from this FuncGraph.
Raises:
InaccessibleTensorError: if any tensors are accessed in a manner that
bypasses the mechanisms required for the data dependencies to be correctly
wired.
"""
# Note: _forward_func_graph is currently only set when building the gradient
# graph graph of a defun call. If the backwards graph tries to capture
# tensors those will be captured first in the forward graph. This
# makes sure that any tensor needed by a custom_gradient is correctly
# captured.
if (getattr(tensor, "graph", None) is not self and
hasattr(self, "_forward_func_graph") and
isinstance(self._forward_func_graph, FuncGraph)):
tensor = self._forward_func_graph.capture(tensor)
if isinstance(tensor, ops.EagerTensor):
if name is None:
name = str(ops.uid())
# Small EagerTensors are captured with Const ops
if (tensor.dtype in dtypes.TF_VALUE_DTYPES and
np.prod(tensor.shape) <= _EAGER_CONST_THRESHOLD):
return self.capture_eager_tensor(tensor, name)
# Large EagerTensors and resources are captured with Placeholder ops
return self._capture_helper(tensor, name)
if tensor.graph is not self:
if name is None:
name = tensor.op.name
inner_graph = tensor.graph
while inner_graph is not None and isinstance(inner_graph, FuncGraph):
if inner_graph is self:
raise errors.InaccessibleTensorError(
"The tensor '%s' cannot be accessed here: it is defined"
" in another function or code block. Use return values,"
" explicit Python locals or TensorFlow collections to access"
" it. Defined in: %s; accessed from: %s.\n"
% (tensor, tensor.graph, self))
inner_graph = inner_graph.outer_graph
return self._capture_helper(tensor, name)
return tensor
def _capture_helper(self, tensor, name):
capture = self._captures.get(ops.tensor_id(tensor))
if capture is None:
placeholder = _create_substitute_placeholder(
tensor, name=name, dtype=tensor.dtype)
self.add_capture(tensor, placeholder)
else:
placeholder = capture[1]
tape.record_operation("captured_value", [placeholder], [tensor],
lambda x: [x])
return placeholder
@property
def captures(self):
"""Order list of tuples containing external and internal captures."""
return self._captures.values()
def add_capture(self, tensor, placeholder):
"""Capture a specific tensor and utilize the provided placeholder.
Args:
tensor: Tensor to captures.
placeholder: Provided placeholder for the tensor.
"""
self._captures[ops.tensor_id(tensor)] = (tensor, placeholder)
self.inputs.append(placeholder)
def reset_captures(self, capture_list):
"""Set the captures with the provided list of captures & placeholder."""
self._captures = py_collections.OrderedDict()
for tensor, placeholder in capture_list:
self._captures[ops.tensor_id(tensor)] = (tensor, placeholder)
def pop_capture(self, tensor):
"""Remove the capture and return the generated placeholder."""
capture = self._captures.pop(ops.tensor_id(tensor), None)
if capture is None:
return None
return capture[1]
def clear_captures(self):
# TODO(b/115366440): Delete this method when a custom OrderedDict is added.
# Clearing captures using clear() leaves some cycles around.
while self._captures:
self._captures.popitem()
memory.dismantle_ordered_dict(self._captures)
while self._deferred_captures:
self._deferred_captures.popitem()
memory.dismantle_ordered_dict(self._deferred_captures)
def capture_distributed_variable(self, variable, placeholder):
"""Add given distributed variable to captures with given placeholder."""
self._captures[ops.tensor_id(variable)] = (variable, placeholder)
tape.record_operation("captured_value", [placeholder], [variable],
lambda x: [x])
def capture_eager_tensor(self, tensor, name):
capture = self._captures.get(ops.tensor_id(tensor))
if capture is None:
# We clear all control dependencies and place the Const op on the same
# device as the source tensor. The device placement may be relaxed at
# a later date.
with ops.control_dependencies(None), self.device(tensor.device):
graph_const = constant_op.constant(tensor.numpy(), dtype=tensor.dtype,
shape=tensor.shape, name=name)
self.add_capture(tensor, graph_const)
else:
graph_const = capture[1]
tape.record_operation("captured_value", [graph_const], [tensor],
lambda x: [x])
return graph_const
@property
def external_captures(self):
"""External tensors captured by this function."""
return [c[0] for c in self._captures.values()]
@property
def internal_captures(self):
"""Placeholders in this function corresponding captured tensors."""
return [c[1] for c in self._captures.values()]
@property
def deferred_external_captures(self):
"""Ordered nest of tensors whose placeholders will be fed at call time."""
return [c[0] for c in self._deferred_captures.values()]
@property
def deferred_internal_captures(self):
"""List of nest of placeholders which at call time will be fed."""
return [c[1] for c in self._deferred_captures.values()]
@property
def variable_captures(self):
"""Map of tensor ids of variable handles to variables which are captured."""
return {
ops.tensor_id(self._captures[ops.tensor_id(v.handle)][1]): v
for v in self.variables
if ops.tensor_id(v.handle) in self._captures
}
def mark_as_unsaveable(self, error_message):
"""Marks this FuncGraph as unsaveable.
Any attempts to export this FuncGraph will raise an error with the specified
message.
Args:
error_message: List or string containing the error message to be raised
when saving this FuncGraph to SavedModel.
"""
self._saveable = False
if isinstance(error_message, str):
error_message = [error_message]
self._saving_errors.update(error_message)
@property
def saveable(self):
"""Returns whether this FuncGraph is saveable."""
return self._saveable
@property
def saving_errors(self):
"""Returns set of errors preventing this FuncGraph from being saved."""
return self._saving_errors
def func_graph_from_py_func(name,
python_func,
args,
kwargs,
signature=None,
func_graph=None,
autograph=False,
autograph_options=None,
add_control_dependencies=True,
arg_names=None,
op_return_value=None,
collections=None,
capture_by_value=None,
override_flat_arg_shapes=None):
"""Returns a `FuncGraph` generated from `python_func`.
Args:
name: an identifier for the function.
python_func: the Python function to trace.
args: the positional args with which the Python function should be called;
ignored if a signature is provided.
kwargs: the keyword args with which the Python function should be called;
ignored if a signature is provided.
signature: a possibly nested sequence of `TensorSpecs` specifying the shapes
and dtypes of the arguments. When a signature is provided, `args` and
`kwargs` are ignored, and `python_func` is traced with Tensors conforming
to `signature`. If `None`, the shapes and dtypes are inferred from the
inputs.
func_graph: Optional. An instance of FuncGraph. If provided, we will use
this graph else a new one is built and returned.
autograph: whether to use autograph to compile `python_func`.
See https://www.tensorflow.org/guide/autograph for more information.
autograph_options: additional knobs to control when `autograph=True`.
See https://www.tensorflow.org/guide/autograph for more information.
add_control_dependencies: If True, automatically adds control dependencies
to ensure program order matches execution order and stateful ops always
execute.
arg_names: Optional list of argument names, used to give input placeholders
recognizable names.
op_return_value: Optional. A Tensor. If set and `python_func` returns
Operations, those return values will be replaced with this value. If not
set, returning an Operation triggers an error.
collections: a dictionary of collections this FuncGraph should start
with. If not specified (None), the FuncGraph will read (but not write to)
the outer graph's collections that are not whitelisted, and both
read and write to the outer graph's collections that are whitelisted.
The current whitelisted collections are the global variables, the
local variables, and the trainable variables.
Defaults to None.
capture_by_value: An optional boolean. If True, the func graph will capture
Variables by value instead of reference. By default inherit from outer
graphs, and failing that will default to False.
override_flat_arg_shapes: An optional list of instances that are either
`None` or `TensorShape`. The length must match that of
`nest.flatten((args, kwargs), expand_composites=True)`. The entries
containing value `None` must match entries in flattened arguments
containing non-tensors, while entries containing a `TensorShape` must
match entries in the flattened arguments containing tensors.
Returns:
A FuncGraph.
Raises:
TypeError: If any of `python_func`'s return values is neither `None` nor a
`Tensor`.
ValueError: If both `signature` and `override_flat_arg_shapes` are
passed in.
"""
if op_return_value is not None:
assert isinstance(op_return_value, ops.Tensor), op_return_value
if func_graph is None:
func_graph = FuncGraph(name, collections=collections,
capture_by_value=capture_by_value)
assert isinstance(func_graph, FuncGraph)
if add_control_dependencies:
control_manager = AutomaticControlDependencies()
else:
control_manager = ops.NullContextmanager()
with func_graph.as_default(), control_manager as a:
current_scope = variable_scope.get_variable_scope()
default_use_recource = current_scope.use_resource
current_scope.set_use_resource(True)
if signature is not None and override_flat_arg_shapes is not None:
raise ValueError(
"Passed both signature and override_flat_arg_shapes: %s and %s."
% (signature, override_flat_arg_shapes))
if signature is not None:
args = signature
kwargs = {}
# Creates and names placeholders for all arguments.
if override_flat_arg_shapes is not None:
flat_args = nest.flatten(args, expand_composites=True)
arg_shapes = override_flat_arg_shapes[:len(flat_args)]
kwarg_shapes = override_flat_arg_shapes[len(flat_args):]
else:
arg_shapes = None
kwarg_shapes = None
func_args = _get_defun_inputs_from_args(
args, arg_names, flat_shapes=arg_shapes)
func_kwargs = _get_defun_inputs_from_kwargs(
kwargs, flat_shapes=kwarg_shapes)
# Convert all Tensors into TensorSpecs before saving the structured inputs.
# If storing pure concrete functions that are not called through polymorphic
# functions, we don't have access to FunctionSpec, so we need to call the
# TensorSpecs by their `arg_names` for later binding.
func_graph.structured_input_signature = (
convert_structure_to_signature(func_args, arg_names),
convert_structure_to_signature(func_kwargs))
flat_func_args = nest.flatten(func_args, expand_composites=True)
flat_func_kwargs = nest.flatten(func_kwargs, expand_composites=True)
# Temporarily set inputs to allow graph building code to inspect
# them. Reassigned below.
func_graph.inputs = [arg for arg in flat_func_args + flat_func_kwargs
if isinstance(arg, ops.Tensor)]
# Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
# Variables to help check whether mutation happens in calling the function
# Copy the recursive list, tuple and map structure, but not base objects
func_args_before = nest.pack_sequence_as(func_args, flat_func_args,
expand_composites=True)
func_kwargs_before = nest.pack_sequence_as(
func_kwargs, flat_func_kwargs, expand_composites=True)
def convert(x):
"""Converts a function output to a Tensor."""
if x is None:
return None
if op_return_value is not None and isinstance(x, ops.Operation):
# TODO(b/79881896): we currently can't capture external control deps, so
# this won't work if x needs to be captured (i.e. if python_func returns
# captured Operations).
with ops.control_dependencies([x]):
x = array_ops.identity(op_return_value)
elif not isinstance(x, tensor_array_ops.TensorArray):
try:
x = ops.convert_to_tensor_or_composite(x)
except (ValueError, TypeError):
raise TypeError(
"To be compatible with tf.contrib.eager.defun, Python functions "
"must return zero or more Tensors; in compilation of %s, found "
"return value of type %s, which is not a Tensor." %
(str(python_func), type(x)))
if add_control_dependencies:
x = a.mark_as_return(x)
return x
try:
if autograph:
from tensorflow.python import autograph # pylint: disable=g-import-not-at-top
_, original_func = tf_decorator.unwrap(python_func)
def wrapper(*args, **kwargs):
"""Calls a converted version of original_func."""
# TODO(mdan): Push this block higher in tf.function's call stack.
try:
return autograph.converted_call(
original_func,
autograph.ConversionOptions(
recursive=True,
optional_features=autograph_options,
user_requested=True,
), args, kwargs)
except Exception as e: # pylint:disable=broad-except
if hasattr(e, "ag_error_metadata"):
raise e.ag_error_metadata.to_exception(e)
else:
raise
# Wrapping around a decorator allows checks like tf_inspect.getargspec
# to be accurate.
converted_func = tf_decorator.make_decorator(original_func, wrapper)
python_func = tf_decorator.rewrap(python_func, original_func,
converted_func)
func_outputs = python_func(*func_args, **func_kwargs)
# invariant: `func_outputs` contains only Tensors, CompositeTensors,
# TensorArrays and `None`s.
func_outputs = nest.map_structure(convert, func_outputs,
expand_composites=True)
check_mutation(func_args_before, func_args)
check_mutation(func_kwargs_before, func_kwargs)
finally:
current_scope.set_use_resource(default_use_recource)
# Variables in `func_args`, `func_kwargs` should be explicit inputs
# to the function, not captured inputs.
graph_variables = list(func_graph._watched_variables) # pylint: disable=protected-access
arg_variables = object_identity.ObjectIdentitySet()
inputs = []
for arg in (nest.flatten(func_args, expand_composites=True) +
nest.flatten(func_kwargs, expand_composites=True)):
if isinstance(arg, resource_variable_ops.BaseResourceVariable):
# Even if an argument variable was not used in the function, we've
# already manually captured the resource Tensor when creating argument
# placeholders.
resource_placeholder = func_graph.pop_capture(arg.handle)
if resource_placeholder is None:
continue
arg_variables.add(arg)
inputs.append(resource_placeholder)
elif isinstance(arg, ops.Tensor):
inputs.append(arg)
variables = [v for v in graph_variables if v not in arg_variables]
func_graph.inputs = (
inputs + func_graph.internal_captures + nest.flatten(
func_graph.deferred_internal_captures, expand_composites=True))
func_graph.structured_outputs = func_outputs
# Returning a closed-over tensor does not trigger convert_to_tensor.
func_graph.outputs.extend(
func_graph.capture(x)
for x in flatten(func_graph.structured_outputs)
if x is not None)
func_graph.variables = variables
if add_control_dependencies:
func_graph.control_outputs.extend(control_manager.ops_which_must_run)
return func_graph
def maybe_captured(tensor):
"""If t is a captured value placeholder, returns the original captured value.
Args:
tensor: Tensor.
Returns:
A tensor, potentially from a different Graph/FuncGraph.
"""
if (not isinstance(tensor, ops.EagerTensor) and
tensor.op.graph.building_function and tensor.op.type == "Placeholder"):
for input_t, placeholder_t in tensor.op.graph.captures:
if tensor == placeholder_t:
return maybe_captured(input_t)
# pylint: enable=protected-access
return tensor
def device_stack_has_callable(device_stack):
"""Checks whether a device stack contains a callable."""
return any(callable(spec._device_name_or_function) # pylint: disable=protected-access
for spec in device_stack.peek_objs())
def check_mutation(n1, n2):
"""Check if two list of arguments are exactly the same."""
errmsg = ("Function to be traced should not modify structure of input "
"arguments. Check if your function has list and dictionary "
"operations that alter input arguments, "
"such as `list.pop`, `list.append`")
try:
nest.assert_same_structure(n1, n2, expand_composites=True)
except ValueError:
raise ValueError(errmsg)
for arg1, arg2 in zip(nest.flatten(n1, expand_composites=True),
nest.flatten(n2, expand_composites=True)):