-
Notifications
You must be signed in to change notification settings - Fork 75.3k
Expand file tree
/
Copy pathmodule_wrapper.py
More file actions
216 lines (182 loc) · 7.31 KB
/
module_wrapper.py
File metadata and controls
216 lines (182 loc) · 7.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides wrapper for TensorFlow modules."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import importlib
import sys
import types
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
from tensorflow.python.util import tf_stack
from tensorflow.tools.compatibility import all_renames_v2
_PER_MODULE_WARNING_LIMIT = 1
def get_rename_v2(name):
if name not in all_renames_v2.symbol_renames:
return None
return all_renames_v2.symbol_renames[name]
def _call_location():
# We want to get stack frame 3 frames up from current frame,
# i.e. above __getattr__, _tfmw_add_deprecation_warning,
# and _call_location calls.
stack = tf_stack.extract_stack_file_and_line(max_length=4)
if not stack: # should never happen as we're in a function
return 'UNKNOWN'
frame = stack[0]
return '{}:{}'.format(frame.file, frame.line)
def contains_deprecation_decorator(decorators):
return any(
d.decorator_name == 'deprecated' for d in decorators)
def has_deprecation_decorator(symbol):
"""Checks if given object has a deprecation decorator.
We check if deprecation decorator is in decorators as well as
whether symbol is a class whose __init__ method has a deprecation
decorator.
Args:
symbol: Python object.
Returns:
True if symbol has deprecation decorator.
"""
decorators, symbol = tf_decorator.unwrap(symbol)
if contains_deprecation_decorator(decorators):
return True
if tf_inspect.isfunction(symbol):
return False
if not tf_inspect.isclass(symbol):
return False
if not hasattr(symbol, '__init__'):
return False
init_decorators, _ = tf_decorator.unwrap(symbol.__init__)
return contains_deprecation_decorator(init_decorators)
class TFModuleWrapper(types.ModuleType):
"""Wrapper for TF modules to support deprecation messages and lazyloading."""
def __init__( # pylint: disable=super-on-old-class
self,
wrapped,
module_name,
public_apis=None,
deprecation=True,
has_lite=False): # pylint: enable=super-on-old-class
super(TFModuleWrapper, self).__init__(wrapped.__name__)
self.__dict__.update(wrapped.__dict__)
# Prefix all local attributes with _tfmw_ so that we can
# handle them differently in attribute access methods.
self._tfmw_wrapped_module = wrapped
self._tfmw_module_name = module_name
self._tfmw_public_apis = public_apis
self._tfmw_print_deprecation_warnings = deprecation
self._tfmw_has_lite = has_lite
# Set __all__ so that import * work for lazy loaded modules
if self._tfmw_public_apis:
self._tfmw_wrapped_module.__all__ = list(self._tfmw_public_apis.keys())
self.__all__ = list(self._tfmw_public_apis.keys())
else:
if hasattr(self._tfmw_wrapped_module, '__all__'):
self.__all__ = self._tfmw_wrapped_module.__all__
else:
self._tfmw_wrapped_module.__all__ = [
attr for attr in dir(self._tfmw_wrapped_module)
if not attr.startswith('_')
]
self.__all__ = self._tfmw_wrapped_module.__all__
# names we already checked for deprecation
self._tfmw_deprecated_checked = set()
self._tfmw_warning_count = 0
def _tfmw_add_deprecation_warning(self, name, attr):
"""Print deprecation warning for attr with given name if necessary."""
if (self._tfmw_warning_count < _PER_MODULE_WARNING_LIMIT and
name not in self._tfmw_deprecated_checked):
self._tfmw_deprecated_checked.add(name)
if self._tfmw_module_name:
full_name = 'tf.%s.%s' % (self._tfmw_module_name, name)
else:
full_name = 'tf.%s' % name
rename = get_rename_v2(full_name)
if rename and not has_deprecation_decorator(attr):
call_location = _call_location()
# skip locations in Python source
if not call_location.startswith('<'):
logging.warning(
'From %s: The name %s is deprecated. Please use %s instead.\n',
_call_location(), full_name, rename)
self._tfmw_warning_count += 1
def _tfmw_import_module(self, name):
symbol_loc_info = self._tfmw_public_apis[name]
if symbol_loc_info[0]:
module = importlib.import_module(symbol_loc_info[0])
attr = getattr(module, symbol_loc_info[1])
else:
attr = importlib.import_module(symbol_loc_info[1])
setattr(self._tfmw_wrapped_module, name, attr)
self.__dict__[name] = attr
return attr
def __getattribute__(self, name): # pylint: disable=super-on-old-class
# Workaround to make sure we do not import from tensorflow/lite/__init__.py
if name == 'lite':
if self._tfmw_has_lite:
attr = self._tfmw_import_module(name)
setattr(self._tfmw_wrapped_module, 'lite', attr)
return attr
attr = super(TFModuleWrapper, self).__getattribute__(name)
if name.startswith('__') or name.startswith('_tfmw_'):
return attr
if self._tfmw_print_deprecation_warnings:
self._tfmw_add_deprecation_warning(name, attr)
return attr
def __getattr__(self, name):
try:
attr = getattr(self._tfmw_wrapped_module, name)
except AttributeError as e:
if not self._tfmw_public_apis:
raise
if name not in self._tfmw_public_apis:
raise
attr = self._tfmw_import_module(name)
if self._tfmw_print_deprecation_warnings:
self._tfmw_add_deprecation_warning(name, attr)
return attr
def __setattr__(self, arg, val): # pylint: disable=super-on-old-class
if not arg.startswith('_tfmw_'):
setattr(self._tfmw_wrapped_module, arg, val)
self.__dict__[arg] = val
if arg not in self.__all__ and arg != '__all__':
self.__all__.append(arg)
super(TFModuleWrapper, self).__setattr__(arg, val)
def __dir__(self):
if self._tfmw_public_apis:
return list(
set(self._tfmw_public_apis.keys()).union(
set([
attr for attr in dir(self._tfmw_wrapped_module)
if not attr.startswith('_')
])))
else:
return dir(self._tfmw_wrapped_module)
def __delattr__(self, name): # pylint: disable=super-on-old-class
if name.startswith('_tfmw_'):
super(TFModuleWrapper, self).__delattr__(name)
else:
delattr(self._tfmw_wrapped_module, name)
def __repr__(self):
return self._tfmw_wrapped_module.__repr__()
def __getstate__(self):
return self.__name__
def __setstate__(self, d):
# pylint: disable=protected-access
self.__init__(sys.modules[d]._tfmw_wrapped_module,
sys.modules[d]._tfmw_module_name)
# pylint: enable=protected-access