forked from tensorflow/tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdevice_lib.i
More file actions
88 lines (68 loc) · 2.61 KB
/
device_lib.i
File metadata and controls
88 lines (68 loc) · 2.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%include "tensorflow/python/platform/base.i"
%{
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/public/session_options.h"
%}
%typemap(in, numinputs=0) const tensorflow::SessionOptions& options (
tensorflow::SessionOptions temp) {
$1 = &temp;
}
%typemap(in, numinputs=0) std::vector<tensorflow::Device*>* devices (
std::vector<tensorflow::Device*> temp) {
$1 = &temp;
}
// Handle string input into AddDevices
%typemap(in, numinputs=0) const string& name_prefix (
string temp) {
// Always pass an empty name_prefix.
$1 = &temp;
}
%typemap(argout) std::vector<tensorflow::Device*>* devices {
std::vector< std::unique_ptr<tensorflow::Device> > safe_devices;
for (auto* device : *$1) safe_devices.emplace_back(device);
auto temp_string_list = tensorflow::make_safe(PyList_New(0));
if (!temp_string_list) {
SWIG_fail;
}
for (const auto& device : safe_devices) {
const tensorflow::DeviceAttributes& attr = device->attributes();
string attr_serialized;
if (!attr.SerializeToString(&attr_serialized)) {
PyErr_SetString(PyExc_RuntimeError,
"Unable to serialize DeviceAttributes");
SWIG_fail;
}
tensorflow::Safe_PyObjectPtr safe_attr_string = tensorflow::make_safe(
%#if PY_MAJOR_VERSION < 3
PyString_FromStringAndSize(
%#else
PyBytes_FromStringAndSize(
%#endif
reinterpret_cast<const char*>(
attr_serialized.data()), attr_serialized.size()));
if (PyList_Append(temp_string_list.get(), safe_attr_string.get()) == -1) {
SWIG_fail;
}
}
$result = temp_string_list.release();
}
%ignoreall
%unignore tensorflow;
%unignore tensorflow::DeviceFactory;
%unignore tensorflow::DeviceFactory::AddDevices;
%include "tensorflow/core/common_runtime/device_factory.h"
%unignoreall
%newobject tensorflow::SessionOptions;