Skip to content

Commit 8e50419

Browse files
Charles Nicholsontensorflower-gardener
authored andcommitted
Introduce TFDecorator, a base class for Python TensorFlow decorators. Provides basic introspection and "unwrap" services, allowing tooling code to fully 'understand' the wrapped object.
Change: 1538540
1 parent c3bf39b commit 8e50419

57 files changed

Lines changed: 1354 additions & 335 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

tensorflow/contrib/distributions/python/ops/distribution.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
import abc
2222
import contextlib
23-
import inspect
2423
import types
2524

2625
import numpy as np
@@ -33,6 +32,7 @@
3332
from tensorflow.python.framework import tensor_util
3433
from tensorflow.python.ops import array_ops
3534
from tensorflow.python.ops import math_ops
35+
from tensorflow.python.util import tf_inspect
3636

3737

3838
_DISTRIBUTION_PUBLIC_METHOD_WRAPPERS = [
@@ -154,12 +154,12 @@ def __new__(mcs, classname, baseclasses, attrs):
154154
if class_special_attr_value is None:
155155
# No _special method available, no need to update the docstring.
156156
continue
157-
class_special_attr_docstring = inspect.getdoc(class_special_attr_value)
157+
class_special_attr_docstring = tf_inspect.getdoc(class_special_attr_value)
158158
if not class_special_attr_docstring:
159159
# No docstring to append.
160160
continue
161161
class_attr_value = _copy_fn(base_attr_value)
162-
class_attr_docstring = inspect.getdoc(base_attr_value)
162+
class_attr_docstring = tf_inspect.getdoc(base_attr_value)
163163
if class_attr_docstring is None:
164164
raise ValueError(
165165
"Expected base class fn to contain a docstring: %s.%s"

tensorflow/contrib/distributions/python/ops/kullback_leibler.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,20 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21-
import inspect
22-
2321
from tensorflow.python.framework import ops
2422
from tensorflow.python.ops import array_ops
2523
from tensorflow.python.ops import control_flow_ops
2624
from tensorflow.python.ops import math_ops
25+
from tensorflow.python.util import tf_inspect
2726

2827

2928
_DIVERGENCES = {}
3029

3130

3231
def _registered_kl(type_a, type_b):
3332
"""Get the KL function registered for classes a and b."""
34-
hierarchy_a = inspect.getmro(type_a)
35-
hierarchy_b = inspect.getmro(type_b)
33+
hierarchy_a = tf_inspect.getmro(type_a)
34+
hierarchy_b = tf_inspect.getmro(type_b)
3635
dist_to_children = None
3736
kl_fn = None
3837
for mro_to_a, parent_a in enumerate(hierarchy_a):

tensorflow/contrib/framework/python/ops/arg_scope.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,9 @@ def conv2d(*args, **kwargs)
6161
from __future__ import absolute_import
6262
from __future__ import division
6363
from __future__ import print_function
64-
import contextlib
65-
import functools
64+
65+
from tensorflow.python.util import tf_contextlib
66+
from tensorflow.python.util import tf_decorator
6667

6768
__all__ = ['arg_scope',
6869
'add_arg_scope',
@@ -106,7 +107,7 @@ def _add_op(op):
106107
_DECORATED_OPS[key_op] = _kwarg_names(op)
107108

108109

109-
@contextlib.contextmanager
110+
@tf_contextlib.contextmanager
110111
def arg_scope(list_ops_or_scope, **kwargs):
111112
"""Stores the default arguments for the given set of list_ops.
112113
@@ -170,7 +171,6 @@ def add_arg_scope(func):
170171
Returns:
171172
A tuple with the decorated function func_with_args().
172173
"""
173-
@functools.wraps(func)
174174
def func_with_args(*args, **kwargs):
175175
current_scope = _current_arg_scope()
176176
current_args = kwargs
@@ -181,8 +181,7 @@ def func_with_args(*args, **kwargs):
181181
return func(*args, **current_args)
182182
_add_op(func)
183183
setattr(func_with_args, '_key_op', _key_op(func))
184-
setattr(func_with_args, '__doc__', func.__doc__)
185-
return func_with_args
184+
return tf_decorator.make_decorator(func, func_with_args)
186185

187186

188187
def has_arg_scope(func):

tensorflow/contrib/keras/python/keras/backend_test.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,11 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21-
import inspect
22-
2321
import numpy as np
2422

2523
from tensorflow.contrib.keras.python import keras
2624
from tensorflow.python.platform import test
25+
from tensorflow.python.util import tf_inspect
2726

2827

2928
def compare_single_input_op_to_numpy(keras_op,
@@ -207,7 +206,7 @@ def test_reduction_ops(self):
207206
compare_single_input_op_to_numpy(keras_op, np_op, input_shape=(4, 7, 5),
208207
keras_kwargs={'axis': -1},
209208
np_kwargs={'axis': -1})
210-
if 'keepdims' in inspect.getargspec(keras_op).args:
209+
if 'keepdims' in tf_inspect.getargspec(keras_op).args:
211210
compare_single_input_op_to_numpy(keras_op, np_op,
212211
input_shape=(4, 7, 5),
213212
keras_kwargs={'axis': 1,

tensorflow/contrib/keras/python/keras/engine/topology.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from __future__ import print_function
2121

2222
import copy
23-
import inspect
2423
import json
2524
import os
2625
import re
@@ -35,6 +34,7 @@
3534
from tensorflow.contrib.keras.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
3635
from tensorflow.contrib.keras.python.keras.utils.layer_utils import print_summary as print_layer_summary
3736
from tensorflow.python.framework import tensor_shape
37+
from tensorflow.python.util import tf_inspect
3838

3939

4040
# pylint: disable=g-import-not-at-top
@@ -584,7 +584,7 @@ def __call__(self, inputs, **kwargs):
584584
user_kwargs = copy.copy(kwargs)
585585
if not _is_all_none(previous_mask):
586586
# The previous layer generated a mask.
587-
if 'mask' in inspect.getargspec(self.call).args:
587+
if 'mask' in tf_inspect.getargspec(self.call).args:
588588
if 'mask' not in kwargs:
589589
# If mask is explicitly passed to __call__,
590590
# we should override the default mask.
@@ -2166,7 +2166,7 @@ def run_internal_graph(self, inputs, masks=None):
21662166
kwargs = {}
21672167
if len(computed_data) == 1:
21682168
computed_tensor, computed_mask = computed_data[0]
2169-
if 'mask' in inspect.getargspec(layer.call).args:
2169+
if 'mask' in tf_inspect.getargspec(layer.call).args:
21702170
if 'mask' not in kwargs:
21712171
kwargs['mask'] = computed_mask
21722172
output_tensors = _to_list(layer.call(computed_tensor, **kwargs))
@@ -2177,7 +2177,7 @@ def run_internal_graph(self, inputs, masks=None):
21772177
else:
21782178
computed_tensors = [x[0] for x in computed_data]
21792179
computed_masks = [x[1] for x in computed_data]
2180-
if 'mask' in inspect.getargspec(layer.call).args:
2180+
if 'mask' in tf_inspect.getargspec(layer.call).args:
21812181
if 'mask' not in kwargs:
21822182
kwargs['mask'] = computed_masks
21832183
output_tensors = _to_list(layer.call(computed_tensors, **kwargs))

tensorflow/contrib/keras/python/keras/layers/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from __future__ import print_function
2020

2121
import copy
22-
import inspect
2322
import types as python_types
2423

2524
import numpy as np
@@ -35,6 +34,7 @@
3534
from tensorflow.contrib.keras.python.keras.utils.generic_utils import func_dump
3635
from tensorflow.contrib.keras.python.keras.utils.generic_utils import func_load
3736
from tensorflow.python.framework import tensor_shape
37+
from tensorflow.python.util import tf_inspect
3838

3939

4040
class Masking(Layer):
@@ -595,7 +595,7 @@ def __init__(self, function, mask=None, arguments=None, **kwargs):
595595

596596
def call(self, inputs, mask=None):
597597
arguments = self.arguments
598-
arg_spec = inspect.getargspec(self.function)
598+
arg_spec = tf_inspect.getargspec(self.function)
599599
if 'mask' in arg_spec.args:
600600
arguments['mask'] = mask
601601
return self.function(inputs, **arguments)

tensorflow/contrib/keras/python/keras/layers/wrappers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@
2020
from __future__ import print_function
2121

2222
import copy
23-
import inspect
2423

2524
from tensorflow.contrib.keras.python.keras import backend as K
2625
from tensorflow.contrib.keras.python.keras.engine import InputSpec
2726
from tensorflow.contrib.keras.python.keras.engine import Layer
2827
from tensorflow.python.framework import tensor_shape
28+
from tensorflow.python.util import tf_inspect
2929

3030

3131
class Wrapper(Layer):
@@ -284,7 +284,7 @@ def _compute_output_shape(self, input_shape):
284284

285285
def call(self, inputs, training=None, mask=None):
286286
kwargs = {}
287-
func_args = inspect.getargspec(self.layer.call).args
287+
func_args = tf_inspect.getargspec(self.layer.call).args
288288
if 'training' in func_args:
289289
kwargs['training'] = training
290290
if 'mask' in func_args:

tensorflow/contrib/keras/python/keras/testing_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,10 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21-
import inspect
22-
2321
import numpy as np
2422

2523
from tensorflow.contrib.keras.python import keras
24+
from tensorflow.python.util import tf_inspect
2625

2726

2827
def get_test_data(train_samples,
@@ -98,7 +97,7 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None,
9897
layer.set_weights(weights)
9998

10099
# test and instantiation from weights
101-
if 'weights' in inspect.getargspec(layer_cls.__init__):
100+
if 'weights' in tf_inspect.getargspec(layer_cls.__init__):
102101
kwargs['weights'] = weights
103102
layer = layer_cls(**kwargs)
104103

tensorflow/contrib/keras/python/keras/utils/generic_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20-
import inspect
2120
import marshal
2221
import sys
2322
import time
@@ -26,6 +25,8 @@
2625
import numpy as np
2726
import six
2827

28+
from tensorflow.python.util import tf_decorator
29+
from tensorflow.python.util import tf_inspect
2930

3031
_GLOBAL_CUSTOM_OBJECTS = {}
3132

@@ -116,6 +117,7 @@ def get_custom_objects():
116117

117118

118119
def serialize_keras_object(instance):
120+
_, instance = tf_decorator.unwrap(instance)
119121
if instance is None:
120122
return None
121123
if hasattr(instance, 'get_config'):
@@ -149,7 +151,7 @@ def deserialize_keras_object(identifier,
149151
if cls is None:
150152
raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
151153
if hasattr(cls, 'from_config'):
152-
arg_spec = inspect.getargspec(cls.from_config)
154+
arg_spec = tf_inspect.getargspec(cls.from_config)
153155
if 'custom_objects' in arg_spec.args:
154156
custom_objects = custom_objects or {}
155157
return cls.from_config(

tensorflow/contrib/keras/python/keras/wrappers/scikit_learn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@
1919
from __future__ import print_function
2020

2121
import copy
22-
import inspect
2322
import types
2423

2524
import numpy as np
2625

2726
from tensorflow.contrib.keras.python.keras.models import Sequential
2827
from tensorflow.contrib.keras.python.keras.utils.np_utils import to_categorical
28+
from tensorflow.python.util import tf_inspect
2929

3030

3131
class BaseWrapper(object):
@@ -97,7 +97,7 @@ def check_params(self, params):
9797

9898
legal_params = []
9999
for fn in legal_params_fns:
100-
legal_params += inspect.getargspec(fn)[0]
100+
legal_params += tf_inspect.getargspec(fn)[0]
101101
legal_params = set(legal_params)
102102

103103
for params_name in params:
@@ -182,7 +182,7 @@ def filter_sk_params(self, fn, override=None):
182182
"""
183183
override = override or {}
184184
res = {}
185-
fn_args = inspect.getargspec(fn)[0]
185+
fn_args = tf_inspect.getargspec(fn)[0]
186186
for name, value in self.sk_params.items():
187187
if name in fn_args:
188188
res.update({name: value})

0 commit comments

Comments
 (0)