forked from lutzroeder/netron
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcaffe2-script.py
More file actions
executable file
·200 lines (182 loc) · 7.03 KB
/
caffe2-script.py
File metadata and controls
executable file
·200 lines (182 loc) · 7.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
from __future__ import unicode_literals
from __future__ import print_function
import io
import json
import logging
import pydoc
import os
import re
import sys
def get_support_level(dir):
if 'caffe2/caffe2/operators' in dir:
return 'core'
if 'contrib' in dir.split('/'):
return 'contribution'
if 'experiments' in dir.split('/'):
return 'experimental'
return 'default'
def update_argument_type(type):
if type == 'int':
return 'int'
elif type == '[int]' or type == 'int[]':
return 'int[]'
elif type == 'float':
return 'float'
elif type == 'string':
return 'string'
elif type == 'List(string)':
return 'string[]'
elif type == 'bool':
return 'bool'
raise Exception('Unknown argument type ' + str(type))
def update_argument_default(value, type):
if type == 'int':
return int(value)
elif type == 'float':
return float(value.rstrip('~'))
elif type == 'bool':
if value == 'True':
return True
if value == 'False':
return False
elif type == 'string':
return value.strip('\"')
raise Exception('Unknown argument type ' + str(type))
def update_argument(schema, arg):
if not 'attributes' in schema:
schema['attributes'] = []
attribute = None
for current_attribute in schema['attributes']:
if 'name' in current_attribute and current_attribute['name'] == arg.name:
attribute = current_attribute
break
if not attribute:
attribute = {}
attribute['name'] = arg.name
schema['attributes'].append(attribute)
description = arg.description.strip()
if description.startswith('*('):
index = description.find(')*')
properties = []
if index != -1:
properties = description[2:index].split(';')
description = description[index+2:].lstrip()
else:
index = description.index(')')
properties = description[2:index].split(';')
description = description[index+1:].lstrip()
if len(properties) == 1 and properties[0].find(',') != -1:
properties = properties[0].split(',')
for property in properties:
parts = property.split(':')
name = parts[0].strip()
if name == 'type':
type = parts[1].strip()
if type == 'primitive' or type == 'int | Tuple(int)' or type == '[]' or type == 'TensorProto_DataType' or type == 'Tuple(int)':
continue
attribute['type'] = update_argument_type(type)
elif name == 'default':
if 'type' in attribute:
type = attribute['type']
default = parts[1].strip()
if default == '2, possible values':
default = '2'
if type == 'float' and default == '\'NCHW\'':
continue
if type == 'int[]':
continue
attribute['default'] = update_argument_default(default, type)
elif name == 'optional':
attribute['option'] = 'optional'
elif name == 'must be > 1.0' or name == 'default=\'NCHW\'' or name == 'type depends on dtype' or name == 'Required=True':
continue
elif name == 'List(string)':
attribute['type'] = 'string[]'
else:
raise Exception('Unknown property ' + str(parts[0].strip()))
attribute['description'] = description
if not arg.required:
attribute['option'] = 'optional'
return
def update_input(schema, input_desc):
input_name = input_desc[0]
description = input_desc[1]
if not 'inputs' in schema:
schema['inputs'] = []
input_arg = None
for current_input in schema['inputs']:
if 'name' in current_input and current_input['name'] == input_name:
input_arg = current_input
break
if not input_arg:
input_arg = {}
input_arg['name'] = input_name
schema['inputs'].append(input_arg)
input_arg['description'] = description
if len(input_desc) > 2:
return
def update_output(operator_name, schema, output_desc):
output_name = output_desc[0]
description = output_desc[1]
if not 'outputs' in schema:
schema['outputs'] = []
output_arg = None
for current_output in schema['outputs']:
if 'name' in current_output and current_output['name'] == output_name:
output_arg = current_output
break
if not output_arg:
output_arg = {}
output_arg['name'] = output_name
schema['outputs'].append(output_arg)
if (operator_name == 'Int8Conv' or operator_name == 'Int8AveragePool') and output_name == 'Y':
if 'description' in output_arg:
del output_arg['description']
else:
output_arg['description'] = description
if len(output_desc) > 2:
return
class Caffe2Filter(logging.Filter):
def filter(self, record):
return record.getMessage().startswith('WARNING:root:This caffe2 python run does not have GPU support.')
def metadata():
logging.getLogger('').addFilter(Caffe2Filter())
import caffe2.python.core
json_file = '../src/caffe2-metadata.json'
json_data = open(json_file).read()
json_root = json.loads(json_data)
schema_map = {}
for entry in json_root:
operator_name = entry['name']
schema = entry['schema']
schema_map[operator_name] = schema
for operator_name in caffe2.python.core._GetRegisteredOperators():
op_schema = caffe2.python.workspace.C.OpSchema.get(operator_name)
if op_schema:
if operator_name in schema_map:
schema = schema_map[operator_name]
else:
schema = {}
schema_map[operator_name] = { 'name': operator_name, 'schema': schema }
schema['description'] = op_schema.doc
for arg in op_schema.args:
update_argument(schema, arg)
for input_desc in op_schema.input_desc:
update_input(schema, input_desc)
skip_operator_output_map = { 'Int8ConvRelu': True, 'Int8Conv': True, 'Int8AveragePoolRelu': True, 'Int8AveragePool': True, 'Int8MaxPool': True }
if not operator_name in skip_operator_output_map:
for output_desc in op_schema.output_desc:
update_output(operator_name, schema, output_desc)
schema['support_level'] = get_support_level(os.path.dirname(op_schema.file))
with io.open(json_file, 'w', newline='') as fout:
json_data = json.dumps(json_root, sort_keys=True, indent=2)
for line in json_data.splitlines():
line = line.rstrip()
if sys.version_info[0] < 3:
line = unicode(line)
fout.write(line)
fout.write('\n')
if __name__ == '__main__':
command_table = { 'metadata': metadata }
command = sys.argv[1];
command_table[command]()