forked from tensorflow/tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathwrap_function.py
More file actions
94 lines (77 loc) · 3.03 KB
/
wrap_function.py
File metadata and controls
94 lines (77 loc) · 3.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=unidiomatic-typecheck
"""Prototype decorator for defining legacy-graph-mode functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import function
from tensorflow.python.framework import func_graph
from tensorflow.python.ops import variable_scope
class VariableHolder(object):
"""Holds variables for a python function."""
def __init__(self, fn):
self._fn = fn
self._variables = []
def variable_creator_scope(self, next_creator, **kwargs):
v = next_creator(**kwargs)
self._variables.append(v)
return v
def __call__(self, *args, **kwargs):
with variable_scope.variable_creator_scope(self.variable_creator_scope):
return self._fn(*args, **kwargs)
def wrap_function(fn, signature, name=None):
"""Wraps the TF 1.x function fn into a graph function.
The python function `fn` will be called once with symbolic arguments specified
in the `signature`, traced, and turned into a graph function. Any variables
created by `fn` will be owned by the object returned by `wrap_function`. The
resulting graph function can be called with tensors which match the
signature.
```python
def f(x, do_add):
v = tf.Variable(5.0)
if do_add:
op = v.assign_add(x)
else:
op = v.assign_sub(x)
with tf.control_dependencies([op]):
return v.read_value()
f_add = tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), True])
assert float(f_add(1.0)) == 6.0
assert float(f_add(1.0)) == 7.0
# Can call tf.compat.v1.wrap_function again to get a new trace, a new set
# of variables, and possibly different non-template arguments.
f_sub= tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), False])
assert float(f_sub(1.0)) == 4.0
assert float(f_sub(1.0)) == 3.0
```
Args:
fn: python function to be wrapped
signature: the placeholder and python arguments to be passed to the
wrapped function
name: Optional. The name of the function.
Returns:
the wrapped graph function.
"""
holder = VariableHolder(fn)
fn = function.Function(
func_graph.func_graph_from_py_func(
name,
holder,
args=None, kwargs=None, signature=signature,
add_control_dependencies=False),
signature=signature)
fn._variable_holder = holder
return fn