Skip to content

Commit a750307

Browse files
committed
Generator ops with unique_ptr #54
1 parent b27bdb7 commit a750307

2 files changed

Lines changed: 5026 additions & 5896 deletions

File tree

include/cppflow/ops_generator/generator.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -78,38 +78,38 @@ def code(self):
7878
'string' : '''
7979
std::vector<std::size_t> {0}_sizes; {0}_sizes.reserve({0}.size());
8080
std::transform({0}.begin(), {0}.end(), std::back_inserter({0}_sizes), [](const auto& s) {{ return s.size();}});
81-
TFE_OpSetAttrStringList(op, "{orig:}", reinterpret_cast<const void *const *>({0}.data()), {0}_sizes.data(), {0}.size());
81+
TFE_OpSetAttrStringList(op.get(), "{orig:}", reinterpret_cast<const void *const *>({0}.data()), {0}_sizes.data(), {0}.size());
8282
''',
83-
'int' : 'TFE_OpSetAttrIntList(op, "{orig:}", {0}.data(), {0}.size());',
84-
'float' : 'TFE_OpSetAttrFloatList(op, "{orig:}", {0}.data(), {0}.size());',
85-
'bool' : 'TFE_OpSetAttrBoolList(op, "{orig:}", std::vector<unsigned char>({0}.begin(), {0}.end()).data(), {0}.size());',
86-
'type' : 'TFE_OpSetAttrTypeList(op, "{orig:}", reinterpret_cast<const enum TF_DataType *>({0}.data()), {0}.size());',
83+
'int' : 'TFE_OpSetAttrIntList(op.get(), "{orig:}", {0}.data(), {0}.size());',
84+
'float' : 'TFE_OpSetAttrFloatList(op.get(), "{orig:}", {0}.data(), {0}.size());',
85+
'bool' : 'TFE_OpSetAttrBoolList(op.get(), "{orig:}", std::vector<unsigned char>({0}.begin(), {0}.end()).data(), {0}.size());',
86+
'type' : 'TFE_OpSetAttrTypeList(op.get(), "{orig:}", reinterpret_cast<const enum TF_DataType *>({0}.data()), {0}.size());',
8787
'shape' : '''
8888
std::vector<const int64_t*> {0}_values; {0}_values.reserve({0}.size());
8989
std::vector<int> {0}_ndims; {0}_ndims.reserve({0}.size());
9090
std::transform({0}.begin(), {0}.end(), std::back_inserter({0}_values), [](const auto& v) {{ return v.data();}});
9191
std::transform({0}.begin(), {0}.end(), std::back_inserter({0}_ndims), [](const auto& v) {{ return v.size();}});
92-
TFE_OpSetAttrShapeList(op, "{orig:}", {0}_values.data(), {0}_ndims.data(), {0}.size(), context::get_status());
92+
TFE_OpSetAttrShapeList(op.get(), "{orig:}", {0}_values.data(), {0}_ndims.data(), {0}.size(), context::get_status());
9393
status_check(context::get_status());
9494
''',
9595
}[self.type].format(self.name.replace('template', 'template_arg'), orig=self.name)).replace('\n', '\n ')
9696

9797
else:
9898
return textwrap.dedent({
9999
'shape' : '''
100-
TFE_OpSetAttrShape(op, "{orig:}", {0}.data(), {0}.size(), context::get_status());
100+
TFE_OpSetAttrShape(op.get(), "{orig:}", {0}.data(), {0}.size(), context::get_status());
101101
status_check(context::get_status());
102102
''',
103-
'int' : 'TFE_OpSetAttrInt(op, "{orig:}", {0});',
104-
'float' : 'TFE_OpSetAttrFloat(op, "{orig:}", {0});',
105-
'string': 'TFE_OpSetAttrString(op, "{orig:}", (void*) {0}.c_str(), {0}.size());',
106-
'type' : 'TFE_OpSetAttrType(op, "{orig:}", {0});',
107-
'bool' : 'TFE_OpSetAttrBool(op, "{orig:}", (unsigned char){0});',
103+
'int' : 'TFE_OpSetAttrInt(op.get(), "{orig:}", {0});',
104+
'float' : 'TFE_OpSetAttrFloat(op.get(), "{orig:}", {0});',
105+
'string': 'TFE_OpSetAttrString(op.get(), "{orig:}", (void*) {0}.c_str(), {0}.size());',
106+
'type' : 'TFE_OpSetAttrType(op.get(), "{orig:}", {0});',
107+
'bool' : 'TFE_OpSetAttrBool(op.get(), "{orig:}", (unsigned char){0});',
108108
'tensor': '''
109-
TFE_OpSetAttrTensor(op, "{orig:}", {0}.tf_tensor.get(), context::get_status());
109+
TFE_OpSetAttrTensor(op.get(), "{orig:}", {0}.tf_tensor.get(), context::get_status());
110110
status_check(context::get_status());
111111
''',
112-
'n_attr': 'TFE_OpSetAttrInt(op, "{orig:}", {n_attr:}.size());'
112+
'n_attr': 'TFE_OpSetAttrInt(op.get(), "{orig:}", {n_attr:}.size());'
113113

114114
}[self.type].format(self.name.replace('template', 'template_arg'), orig=self.name, n_attr=self.number_attr)).replace('\n', '\n ')
115115

@@ -145,7 +145,7 @@ def code(self):
145145
{} {}({}{}) {{
146146
147147
// Define Op
148-
auto op = TFE_NewOp(context::get_context(), "{}", context::get_status());
148+
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(TFE_NewOp(context::get_context(), "{}", context::get_status()), &TFE_DeleteOp);
149149
status_check(context::get_status());
150150
151151
// Required input arguments
@@ -157,23 +157,22 @@ def code(self):
157157
// Execute Op
158158
int num_outputs_op = 1;
159159
TFE_TensorHandle* res[1] = {{nullptr}};
160-
TFE_Execute(op, res, &num_outputs_op, context::get_status());
160+
TFE_Execute(op.get(), res, &num_outputs_op, context::get_status());
161161
status_check(context::get_status());
162-
TFE_DeleteOp(op);
163162
return tensor(res[0]);
164163
}}
165164
''')
166165

167166
# Add single input template
168167
add_inputs = textwrap.dedent('''
169-
TFE_OpAddInput(op, {}.tfe_handle.get(), context::get_status());
168+
TFE_OpAddInput(op.get(), {}.tfe_handle.get(), context::get_status());
170169
status_check(context::get_status());
171170
''').replace('\n', '\n ')
172171

173172
add_inputs_list = textwrap.dedent('''
174173
std::vector<TFE_TensorHandle*> {0}_handles; {0}_handles.reserve({0}.size());
175174
std::transform({0}.begin(), {0}.end(), std::back_inserter({0}_handles), [](const auto& t) {{ return t.tfe_handle.get();}});
176-
TFE_OpAddInputList(op, {0}_handles.data(), {0}.size(), context::get_status());
175+
TFE_OpAddInputList(op.get(), {0}_handles.data(), {0}.size(), context::get_status());
177176
status_check(context::get_status());
178177
''').replace('\n', '\n ')
179178

0 commit comments

Comments
 (0)