Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add function to translate functions
  • Loading branch information
xadupre committed Jan 5, 2024
commit e71097107fc63c4a4c77fcd7abe8a03637324149
Binary file not shown.
1 change: 0 additions & 1 deletion _unittests/ut_light_api/test_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,5 +220,4 @@ def test_aionnxml(self):


if __name__ == "__main__":
TestTranslate().test_export_if()
unittest.main(verbosity=2)
10 changes: 9 additions & 1 deletion _unittests/ut_light_api/test_translate_classic.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,6 @@ def test_aionnxml(self):
.to_onnx()
)
code = translate(onx, api="onnx")
print(code)
expected = dedent(
"""
opset_imports = [
Expand Down Expand Up @@ -318,6 +317,15 @@ def test_aionnxml(self):
self.maxDiff = None
self.assertEqual(expected, code)

def test_remove_nodes(self):
path = os.path.join(
os.path.dirname(__file__), "_data", "custom_ops_type_inference_fails_0.onnx"
)
onx = load(path)
text = translate(onx, api="onnx")
with open("debug_test_remove_nodes.py", "w") as f:
f.write(text)


if __name__ == "__main__":
# TestLightApi().test_topk()
Expand Down
50 changes: 49 additions & 1 deletion onnx_array_api/light_api/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ class EventType(IntEnum):
END_FUNCTION = 8
INITIALIZER = 9
SPARSE_INITIALIZER = 10
FUNCTION_INPUT = 11
FUNCTION_OUTPUT = 12
FUNCTION_ATTRIBUTES = 13

@classmethod
def to_str(cls, self) -> str:
Expand Down Expand Up @@ -63,6 +66,21 @@ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]:
if event == EventType.END_GRAPH:
return self._emit_end_graph(**kwargs)

if event == EventType.BEGIN_FUNCTION:
return self._emit_begin_function(**kwargs)

if event == EventType.END_FUNCTION:
return self._emit_end_function(**kwargs)

if event == EventType.FUNCTION_INPUT:
return self._emit_function_input(**kwargs)

if event == EventType.FUNCTION_OUTPUT:
return self._emit_function_output(**kwargs)

if event == EventType.FUNCTION_ATTRIBUTES:
return self._emit_function_attributes(**kwargs)

raise ValueError(f"Unexpected event {EventType.to_str(event)}.")

def render_attribute_value(self, value: Any) -> Tuple[List[str], str]:
Expand Down Expand Up @@ -104,11 +122,21 @@ def render_attribute_value(self, value: Any) -> Tuple[List[str], str]:
srows = ".".join(rows[:-1])
return [], f"g().{srows}"

if isinstance(value, tuple) and len(value) == 2 and value[1] is None:
# in a function, an attribute receiving a value from an attribute
v = value[0]
name = v.name
ref = v.ref_attr_name
dt = v.type
return [], f"(name={name!r}, ref_attr_name={ref!r}, dt={dt})"


raise ValueError(
f"Unable to render an attribute {type(v)}, "
f"attribute type={value[0].type}, "
f"dtype={getattr(v, 'dtype', '-')}, "
f"shape={getattr(v, 'shape', '-')}, {value}."
f"shape={getattr(v, 'shape', '-')}, type(value)={type(value)}, "
f"value={value!r}."
)

def join(self, rows: List[str], single_line: bool = False) -> str:
Expand Down Expand Up @@ -161,6 +189,26 @@ def _emit_sparse_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
)

def _emit_begin_function(self, **kwargs: Dict[str, Any]) -> List[str]:
raise NotImplementedError(
f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
)

def _emit_function_input(self, **kwargs: Dict[str, Any]) -> List[str]:
raise NotImplementedError(
f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
)

def _emit_function_output(self, **kwargs: Dict[str, Any]) -> List[str]:
raise NotImplementedError(
f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
)

def _emit_function_attributes(self, **kwargs: Dict[str, Any]) -> List[str]:
raise NotImplementedError(
f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
)


class Emitter(BaseEmitter):
"""
Expand Down
41 changes: 41 additions & 0 deletions onnx_array_api/light_api/inner_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,44 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
lines[-1] = lines[-1][:-1]
lines.extend([" )", ")"])
return before_lines + lines

def _emit_begin_function(self, **kwargs: Dict[str, Any]) -> List[str]:
lines = [
"",
f"name_f = {kwargs['name']!r}",
f"domain_f = {kwargs['domain']!r}",
"nodes = []",
"inputs = []",
"outputs = []",
"atts = []",
]
return lines

def _emit_function_input(self, **kwargs: Dict[str, Any]) -> List[str]:
return [f"inputs.append({kwargs['name']!r})"]

def _emit_function_output(self, **kwargs: Dict[str, Any]) -> List[str]:
return [f"outputs.append({kwargs['name']!r})"]

def _emit_function_attributes(self, **kwargs: Dict[str, Any]) -> List[str]:
atts = kwargs["attributes"]
if isinstance(atts, list) and all(map(lambda t: isinstance(t, str), atts)):
return [f"atts.extend({atts!r})"]
raise NotImplementedError(f"Unable to process function attributes {atts!r}.")

def _emit_end_function(self, **kwargs: Dict[str, Any]) -> List[str]:
lines = [
"functions.append(",
" make_function(",
" domain, ",
" name, ",
" inputs, ",
" outputs, ",
" nodes, ",
" attributes=atts, ",
" opset_imports=opset_imports,",
" )",
")",
]
return lines

30 changes: 20 additions & 10 deletions onnx_array_api/light_api/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
nodes = self.proto_.graph.node
initializers = self.proto_.graph.initializer
sparse_initializers = self.proto_.graph.sparse_initializer
attributes = []
elif isinstance(self.proto_, (FunctionProto, GraphProto)):
inputs = self.proto_.input
outputs = self.proto_.output
Expand All @@ -48,19 +49,19 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
else:
initializers = []
sparse_initializers = []
attributes = (
self.proto_.attribute if hasattr(self.proto_, "attribute") else []
)
else:
raise ValueError(f"Unexpected type {type(self.proto_)} for proto.")

if sparse_initializers:
raise NotImplementedError("Sparse initializer not supported yet.")

rows.extend(
self.emitter(
EventType.BEGIN_FUNCTION
if isinstance(self.proto_, FunctionProto)
else EventType.BEGIN_GRAPH
)
)
if isinstance(self.proto_, FunctionProto):
rows.extend(self.emitter(EventType.BEGIN_FUNCTION, name=self.proto_.name, domain=self.proto_.domain))
else:
rows.extend(self.emitter(EventType.BEGIN_GRAPH))

for i in initializers:
rows.extend(
Expand All @@ -71,7 +72,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:

for i in inputs:
if isinstance(i, str):
rows.extend(self.emitter(EventType.INPUT, name=i))
rows.extend(self.emitter(EventType.FUNCTION_INPUT, name=i))
else:
rows.extend(
self.emitter(
Expand All @@ -85,6 +86,11 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
)
)

if attributes:
rows.extend(
self.emitter(EventType.FUNCTION_ATTRIBUTES, attributes=list(attributes))
)

for node in nodes:
atts = self.extract_attributes(node)
rows.extend(
Expand All @@ -100,7 +106,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:

for o in outputs:
if isinstance(o, str):
rows.extend(self.emitter(EventType.INPUT, name=o))
rows.extend(self.emitter(EventType.FUNCTION_OUTPUT, name=o))
else:
rows.extend(
self.emitter(
Expand All @@ -127,7 +133,11 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
)

if isinstance(self.proto_, ModelProto) and len(self.proto_.functions) > 0:
raise NotImplementedError("Local functions are not yet implemented.")
for fu in self.proto_.functions:

cl = self.__class__(fu, self.emitter)
text = cl.export(False, single_line=False)
rows.extend(text)

rows.extend(self.emitter(EventType.TO_ONNX))
if as_str:
Expand Down