@@ -78,38 +78,38 @@ def code(self):
7878 'string' : '''
7979 std::vector<std::size_t> {0}_sizes; {0}_sizes.reserve({0}.size());
8080 std::transform({0}.begin(), {0}.end(), std::back_inserter({0}_sizes), [](const auto& s) {{ return s.size();}});
81- TFE_OpSetAttrStringList(op, "{orig:}", reinterpret_cast<const void *const *>({0}.data()), {0}_sizes.data(), {0}.size());
81+ TFE_OpSetAttrStringList(op.get() , "{orig:}", reinterpret_cast<const void *const *>({0}.data()), {0}_sizes.data(), {0}.size());
8282 ''' ,
83- 'int' : 'TFE_OpSetAttrIntList(op, "{orig:}", {0}.data(), {0}.size());' ,
84- 'float' : 'TFE_OpSetAttrFloatList(op, "{orig:}", {0}.data(), {0}.size());' ,
85- 'bool' : 'TFE_OpSetAttrBoolList(op, "{orig:}", std::vector<unsigned char>({0}.begin(), {0}.end()).data(), {0}.size());' ,
86- 'type' : 'TFE_OpSetAttrTypeList(op, "{orig:}", reinterpret_cast<const enum TF_DataType *>({0}.data()), {0}.size());' ,
83+ 'int' : 'TFE_OpSetAttrIntList(op.get() , "{orig:}", {0}.data(), {0}.size());' ,
84+ 'float' : 'TFE_OpSetAttrFloatList(op.get() , "{orig:}", {0}.data(), {0}.size());' ,
85+ 'bool' : 'TFE_OpSetAttrBoolList(op.get() , "{orig:}", std::vector<unsigned char>({0}.begin(), {0}.end()).data(), {0}.size());' ,
86+ 'type' : 'TFE_OpSetAttrTypeList(op.get() , "{orig:}", reinterpret_cast<const enum TF_DataType *>({0}.data()), {0}.size());' ,
8787 'shape' : '''
8888 std::vector<const int64_t*> {0}_values; {0}_values.reserve({0}.size());
8989 std::vector<int> {0}_ndims; {0}_ndims.reserve({0}.size());
9090 std::transform({0}.begin(), {0}.end(), std::back_inserter({0}_values), [](const auto& v) {{ return v.data();}});
9191 std::transform({0}.begin(), {0}.end(), std::back_inserter({0}_ndims), [](const auto& v) {{ return v.size();}});
92- TFE_OpSetAttrShapeList(op, "{orig:}", {0}_values.data(), {0}_ndims.data(), {0}.size(), context::get_status());
92+ TFE_OpSetAttrShapeList(op.get() , "{orig:}", {0}_values.data(), {0}_ndims.data(), {0}.size(), context::get_status());
9393 status_check(context::get_status());
9494 ''' ,
9595 }[self .type ].format (self .name .replace ('template' , 'template_arg' ), orig = self .name )).replace ('\n ' , '\n ' )
9696
9797 else :
9898 return textwrap .dedent ({
9999 'shape' : '''
100- TFE_OpSetAttrShape(op, "{orig:}", {0}.data(), {0}.size(), context::get_status());
100+ TFE_OpSetAttrShape(op.get() , "{orig:}", {0}.data(), {0}.size(), context::get_status());
101101 status_check(context::get_status());
102102 ''' ,
103- 'int' : 'TFE_OpSetAttrInt(op, "{orig:}", {0});' ,
104- 'float' : 'TFE_OpSetAttrFloat(op, "{orig:}", {0});' ,
105- 'string' : 'TFE_OpSetAttrString(op, "{orig:}", (void*) {0}.c_str(), {0}.size());' ,
106- 'type' : 'TFE_OpSetAttrType(op, "{orig:}", {0});' ,
107- 'bool' : 'TFE_OpSetAttrBool(op, "{orig:}", (unsigned char){0});' ,
103+ 'int' : 'TFE_OpSetAttrInt(op.get() , "{orig:}", {0});' ,
104+ 'float' : 'TFE_OpSetAttrFloat(op.get() , "{orig:}", {0});' ,
105+ 'string' : 'TFE_OpSetAttrString(op.get() , "{orig:}", (void*) {0}.c_str(), {0}.size());' ,
106+ 'type' : 'TFE_OpSetAttrType(op.get() , "{orig:}", {0});' ,
107+ 'bool' : 'TFE_OpSetAttrBool(op.get() , "{orig:}", (unsigned char){0});' ,
108108 'tensor' : '''
109- TFE_OpSetAttrTensor(op, "{orig:}", {0}.tf_tensor.get(), context::get_status());
109+ TFE_OpSetAttrTensor(op.get() , "{orig:}", {0}.tf_tensor.get(), context::get_status());
110110 status_check(context::get_status());
111111 ''' ,
112- 'n_attr' : 'TFE_OpSetAttrInt(op, "{orig:}", {n_attr:}.size());'
112+ 'n_attr' : 'TFE_OpSetAttrInt(op.get() , "{orig:}", {n_attr:}.size());'
113113
114114 }[self .type ].format (self .name .replace ('template' , 'template_arg' ), orig = self .name , n_attr = self .number_attr )).replace ('\n ' , '\n ' )
115115
@@ -145,7 +145,7 @@ def code(self):
145145 {} {}({}{}) {{
146146
147147 // Define Op
148- auto op = TFE_NewOp(context::get_context(), "{}", context::get_status());
148+ std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op( TFE_NewOp(context::get_context(), "{}", context::get_status()), &TFE_DeleteOp );
149149 status_check(context::get_status());
150150
151151 // Required input arguments
@@ -157,23 +157,22 @@ def code(self):
157157 // Execute Op
158158 int num_outputs_op = 1;
159159 TFE_TensorHandle* res[1] = {{nullptr}};
160- TFE_Execute(op, res, &num_outputs_op, context::get_status());
160+ TFE_Execute(op.get() , res, &num_outputs_op, context::get_status());
161161 status_check(context::get_status());
162- TFE_DeleteOp(op);
163162 return tensor(res[0]);
164163 }}
165164 ''' )
166165
167166 # Add single input template
168167 add_inputs = textwrap .dedent ('''
169- TFE_OpAddInput(op, {}.tfe_handle.get(), context::get_status());
168+ TFE_OpAddInput(op.get() , {}.tfe_handle.get(), context::get_status());
170169 status_check(context::get_status());
171170 ''' ).replace ('\n ' , '\n ' )
172171
173172 add_inputs_list = textwrap .dedent ('''
174173 std::vector<TFE_TensorHandle*> {0}_handles; {0}_handles.reserve({0}.size());
175174 std::transform({0}.begin(), {0}.end(), std::back_inserter({0}_handles), [](const auto& t) {{ return t.tfe_handle.get();}});
176- TFE_OpAddInputList(op, {0}_handles.data(), {0}.size(), context::get_status());
175+ TFE_OpAddInputList(op.get() , {0}_handles.data(), {0}.size(), context::get_status());
177176 status_check(context::get_status());
178177 ''' ).replace ('\n ' , '\n ' )
179178
0 commit comments