Skip to content

Commit b27022d

Browse files
kevemantensorflower-gardener
authored andcommitted
Fix GetOpList and GetPythonWrappers SWIG wrappers for Python 3.
- Return the result from GetOpList as uninterpreted bytes object. - Write a input typemap for GetPythonWrappers to receive python 'bytes' object and convert to const char* pointer and length. Change: 117258253
1 parent 725e968 commit b27022d

5 files changed

Lines changed: 22 additions & 11 deletions

File tree

tensorflow/python/client/tf_session.i

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -237,11 +237,7 @@ tensorflow::ImportNumpy();
237237
// is not expected to be NULL-terminated, and TF_Buffer.length does not count
238238
// the terminator.
239239
%typemap(out) TF_Buffer (TF_GetOpList,TF_GetBuffer) {
240-
%#if PY_MAJOR_VERSION < 3
241-
$result = PyString_FromStringAndSize(
242-
%#else
243-
$result = PyUnicode_FromStringAndSize(
244-
%#endif
240+
$result = PyBytes_FromStringAndSize(
245241
reinterpret_cast<const char*>($1.data), $1.length);
246242
}
247243

tensorflow/python/framework/load_library.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,6 @@ def load_op_library(library_filename):
4343
Pass "library_filename" to a platform-specific mechanism for dynamically
4444
loading a library. The rules for determining the exact location of the
4545
library are platform-specific and are not documented here.
46-
Expects the symbols "RegisterOps", "RegisterKernels", and "GetOpList", to be
47-
defined in the library.
4846
4947
Args:
5048
library_filename: Path to the plugin.
@@ -78,7 +76,7 @@ def load_op_library(library_filename):
7876
op_list_str = py_tf.TF_GetOpList(lib_handle)
7977
op_list = op_def_pb2.OpList()
8078
op_list.ParseFromString(compat.as_bytes(op_list_str))
81-
wrappers = py_tf.GetPythonWrappers(op_list_str, len(op_list_str))
79+
wrappers = py_tf.GetPythonWrappers(op_list_str)
8280

8381
# Get a unique name for the module.
8482
module_name = hashlib.md5(wrappers).hexdigest()

tensorflow/python/framework/python_op_gen.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -693,8 +693,8 @@ string GetAllPythonOps(const char* hidden, bool require_shapes) {
693693
return GetPythonOps(ops, hidden, require_shapes);
694694
}
695695

696-
string GetPythonWrappers(const char* buf, size_t len) {
697-
string op_list_str(buf, len);
696+
string GetPythonWrappers(const char* op_wrapper_buf, size_t op_wrapper_len) {
697+
string op_list_str(op_wrapper_buf, op_wrapper_len);
698698
OpList ops;
699699
ops.ParseFromString(op_list_str);
700700
return GetPythonOps(ops, "", false);

tensorflow/python/framework/python_op_gen.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ string GetPythonOps(const OpList& ops, const string& hidden_ops,
3434
// Get the python wrappers for a list of ops in a OpList.
3535
// buf should be a pointer to a buffer containing the binary encoded OpList
3636
// proto, and len should be the length of that buffer.
37-
string GetPythonWrappers(const char* buf, size_t len);
37+
string GetPythonWrappers(const char* op_wrapper_buf, size_t op_wrapper_len);
3838

3939
} // namespace tensorflow
4040

tensorflow/python/framework/python_op_gen.i

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,23 @@ limitations under the License.
1919
#include "tensorflow/python/framework/python_op_gen.h"
2020
%}
2121

22+
// Input typemap for GetPythonWrappers.
23+
// Accepts a python object of 'bytes' type, and converts it to
24+
// a const char* pointer and size_t length. The default typemap
25+
// going from python bytes to const char* tries to decode the
26+
// contents from utf-8 to unicode for Python version >= 3, but
27+
// we want the bytes to be uninterpreted.
28+
%typemap(in) (const char* op_wrapper_buf, size_t op_wrapper_len) {
29+
char* c_string;
30+
Py_ssize_t py_size;
31+
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
32+
SWIG_fail;
33+
}
34+
$1 = c_string;
35+
$2 = static_cast<size_t>(py_size);
36+
}
37+
38+
2239
%ignoreall;
2340
%unignore tensorflow::GetPythonWrappers;
2441
%include "tensorflow/python/framework/python_op_gen.h"

0 commit comments

Comments
 (0)