Skip to content
Merged
Prev Previous commit
Next Next commit
rename
  • Loading branch information
xadupre committed Nov 13, 2023
commit ad82c19ec08d43448c063202727df8b0db20f707
8 changes: 4 additions & 4 deletions _unittests/ut_light_api/test_translate_classic.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_check_code(self):
outputs.append(make_tensor_value_info("Y", TensorProto.FLOAT, shape=[]))
graph = make_graph(
nodes,
"noname",
"onename",
inputs,
outputs,
initializers,
Expand Down Expand Up @@ -77,7 +77,7 @@ def test_exp(self):
outputs.append(make_tensor_value_info('Y', TensorProto.FLOAT, shape=[]))
graph = make_graph(
nodes,
'noname',
'light_api',
inputs,
outputs,
initializers,
Expand Down Expand Up @@ -161,7 +161,7 @@ def test_transpose(self):
outputs.append(make_tensor_value_info('Y', TensorProto.FLOAT, shape=[]))
graph = make_graph(
nodes,
'noname',
'light_api',
inputs,
outputs,
initializers,
Expand Down Expand Up @@ -223,7 +223,7 @@ def test_topk_reverse(self):
outputs.append(make_tensor_value_info('Indices', TensorProto.FLOAT, shape=[]))
graph = make_graph(
nodes,
'noname',
'light_api',
inputs,
outputs,
initializers,
Expand Down
3 changes: 2 additions & 1 deletion onnx_array_api/light_api/inner_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,11 @@ def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
return lines

def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
name = kwargs.get("name", "noname")
lines = [
"graph = make_graph(",
" nodes,",
" 'noname',",
f" {name!r},",
" inputs,",
" outputs,",
" initializers,",
Expand Down
7 changes: 6 additions & 1 deletion onnx_array_api/light_api/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,16 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
),
)
)
if isinstance(self.proto_, (GraphProto, FunctionProto)):
name = self.proto_.name
else:
name = self.proto_.graph.name
rows.extend(
self.emitter(
EventType.END_FUNCTION
if isinstance(self.proto_, FunctionProto)
else EventType.END_GRAPH
else EventType.END_GRAPH,
name=name,
)
)

Expand Down