Skip to content

Commit 79e33a0

Browse files
authored
add signature def to model.json (tensorflow#2326)
* add signature def to model.json * address comments * fix test
1 parent fbf536d commit 79e33a0

5 files changed

Lines changed: 80 additions & 3 deletions

File tree

tfjs-converter/python/tensorflowjs/converters/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
GENERATED_BY_KEY = 'generatedBy'
3535
CONVERTED_BY_KEY = 'convertedBy'
3636

37+
SIGNATURE_KEY = 'signature'
38+
3739
# Model formats.
3840
KERAS_SAVED_MODEL = 'keras_saved_model'
3941
KERAS_MODEL = 'keras'

tfjs-converter/python/tensorflowjs/converters/converter_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,10 @@ def testConvertKerasModelToTfGraphModel(self):
303303
model_json = json.load(f)
304304
self.assertTrue(model_json['modelTopology'])
305305
self.assertIsNot(model_json['modelTopology']['versions'], None)
306+
signature = model_json['signature']
307+
self.assertIsNot(signature, None)
308+
self.assertIsNot(signature['inputs'], None)
309+
self.assertIsNot(signature['outputs'], None)
306310
weights_manifest = model_json['weightsManifest']
307311
self.assertEqual(len(weights_manifest), 1)
308312
# Check meta-data in the artifact JSON.

tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,20 +185,23 @@ def optimize_graph(graph, signature_def, output_graph,
185185
', '.join(unsupported))
186186

187187
extract_weights(
188-
optimized_graph, output_graph, tf_version, quantization_dtype)
188+
optimized_graph, output_graph, tf_version,
189+
signature_def, quantization_dtype)
189190
return optimize_graph
190191

191192

192193
def extract_weights(graph_def,
193194
output_graph,
194195
tf_version,
196+
signature_def,
195197
quantization_dtype=None):
196198
"""Takes a Python GraphDef object and extract the weights.
197199
198200
Args:
199201
graph_def: tf.GraphDef TensorFlow GraphDef proto object, which represents
200202
the model topology.
201203
tf_version: Tensorflow version of the input graph.
204+
signature_def: the SignatureDef of the inference graph.
202205
quantization_dtype: An optional numpy dtype to quantize weights to for
203206
compression. Only np.uint8 and np.uint16 are supported.
204207
"""
@@ -233,13 +236,15 @@ def extract_weights(graph_def,
233236
const.attr["value"].tensor.ClearField(field_name)
234237

235238
write_artifacts(MessageToDict(graph_def), [const_manifest], output_graph,
236-
tf_version, quantization_dtype=quantization_dtype)
239+
tf_version, signature_def,
240+
quantization_dtype=quantization_dtype)
237241

238242

239243
def write_artifacts(topology,
240244
weights,
241245
output_graph,
242246
tf_version,
247+
signature_def,
243248
quantization_dtype=None):
244249
"""Writes weights and topology to the output_dir.
245250
@@ -251,6 +256,7 @@ def write_artifacts(topology,
251256
weights: an array of weight groups (as defined in tfjs write_weights).
252257
output_graph: the output file name to hold all the contents.
253258
tf_version: Tensorflow version of the input graph.
259+
signature_def: the SignatureDef of the inference graph.
254260
quantization_dtype: An optional numpy dtype to quantize weights to for
255261
compression. Only np.uint8 and np.uint16 are supported.
256262
"""
@@ -259,6 +265,7 @@ def write_artifacts(topology,
259265
# TODO(piyu): Add tensorflow version below by using `meta_info_def`.
260266
common.GENERATED_BY_KEY: tf_version,
261267
common.CONVERTED_BY_KEY: common.get_converted_by(),
268+
common.SIGNATURE_KEY: MessageToDict(signature_def)
262269
}
263270

264271
model_json[common.ARTIFACT_MODEL_TOPOLOGY_KEY] = topology or None
@@ -407,7 +414,7 @@ def convert_tf_saved_model(saved_model_dir,
407414
output_graph = os.path.join(
408415
output_dir, common.ARTIFACT_MODEL_JSON_FILE_NAME)
409416

410-
saved_model_tags = saved_model_tags.split(', ')
417+
saved_model_tags = saved_model_tags.split(',')
411418
model = load(saved_model_dir, saved_model_tags)
412419

413420
_check_signature_in_model(model, signature_def)

tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,10 @@ def test_convert_saved_model_v1(self):
297297
model_json = json.load(f)
298298
self.assertTrue(model_json['modelTopology'])
299299
self.assertIsNot(model_json['modelTopology']['versions'], None)
300+
signature = model_json['signature']
301+
self.assertIsNot(signature, None)
302+
self.assertIsNot(signature['inputs'], None)
303+
self.assertIsNot(signature['outputs'], None)
300304
weights_manifest = model_json['weightsManifest']
301305
self.assertCountEqual(weights_manifest[0]['paths'],
302306
['group1-shard1of1.bin'])
@@ -331,6 +335,12 @@ def test_convert_saved_model_v1_with_hashtable(self):
331335
model_json = json.load(f)
332336
self.assertTrue(model_json['modelTopology'])
333337
self.assertIsNot(model_json['modelTopology']['versions'], None)
338+
signature = model_json['signature']
339+
self.assertIsNot(signature, None)
340+
self.assertIsNot(signature['inputs'], None)
341+
self.assertIsNot(signature['outputs'], None)
342+
343+
334344
weights_manifest = model_json['weightsManifest']
335345
self.assertEqual(weights_manifest, expected_weights_manifest)
336346
# Check meta-data in the artifact JSON.
@@ -356,6 +366,10 @@ def test_convert_saved_model(self):
356366
model_json = json.load(f)
357367
self.assertTrue(model_json['modelTopology'])
358368
self.assertIsNot(model_json['modelTopology']['versions'], None)
369+
signature = model_json['signature']
370+
self.assertIsNot(signature, None)
371+
self.assertIsNot(signature['inputs'], None)
372+
self.assertIsNot(signature['outputs'], None)
359373
weights_manifest = model_json['weightsManifest']
360374
self.assertCountEqual(weights_manifest[0]['paths'],
361375
['group1-shard1of1.bin'])
@@ -374,6 +388,11 @@ def test_convert_saved_model_with_fused_conv2d(self):
374388
model_json = json.load(f)
375389
self.assertTrue(model_json['modelTopology'])
376390
self.assertIsNot(model_json['modelTopology']['versions'], None)
391+
signature = model_json['signature']
392+
self.assertIsNot(signature, None)
393+
self.assertIsNot(signature['inputs'], None)
394+
self.assertIsNot(signature['outputs'], None)
395+
377396
nodes = model_json['modelTopology']['node']
378397

379398
fusedOp = None
@@ -415,6 +434,11 @@ def test_convert_saved_model_with_prelu(self):
415434
model_json = json.load(f)
416435
self.assertTrue(model_json['modelTopology'])
417436
self.assertIsNot(model_json['modelTopology']['versions'], None)
437+
signature = model_json['signature']
438+
self.assertIsNot(signature, None)
439+
self.assertIsNot(signature['inputs'], None)
440+
self.assertIsNot(signature['outputs'], None)
441+
418442
nodes = model_json['modelTopology']['node']
419443

420444
prelu_op = None
@@ -455,6 +479,11 @@ def test_convert_saved_model_with_unfusable_prelu(self):
455479
model_json = json.load(f)
456480
self.assertTrue(model_json['modelTopology'])
457481
self.assertIsNot(model_json['modelTopology']['versions'], None)
482+
signature = model_json['signature']
483+
self.assertIsNot(signature, None)
484+
self.assertIsNot(signature['inputs'], None)
485+
self.assertIsNot(signature['outputs'], None)
486+
458487
nodes = model_json['modelTopology']['node']
459488

460489
prelu_op = None
@@ -490,6 +519,11 @@ def test_convert_saved_model_with_control_flow(self):
490519
model_json = json.load(f)
491520
self.assertTrue(model_json['modelTopology'])
492521
self.assertIsNot(model_json['modelTopology']['versions'], None)
522+
signature = model_json['signature']
523+
self.assertIsNot(signature, None)
524+
self.assertIsNot(signature['inputs'], None)
525+
self.assertIsNot(signature['outputs'], None)
526+
493527
weights_manifest = model_json['weightsManifest']
494528
self.assertCountEqual(weights_manifest[0]['paths'],
495529
['group1-shard1of1.bin'])
@@ -529,6 +563,11 @@ def test_convert_saved_model_skip_op_check(self):
529563
model_json = json.load(f)
530564
self.assertTrue(model_json['modelTopology'])
531565
self.assertIsNot(model_json['modelTopology']['versions'], None)
566+
signature = model_json['signature']
567+
self.assertIsNot(signature, None)
568+
self.assertIsNot(signature['inputs'], None)
569+
self.assertIsNot(signature['outputs'], None)
570+
532571
weights_manifest = model_json['weightsManifest']
533572
self.assertCountEqual(weights_manifest[0]['paths'],
534573
['group1-shard1of1.bin'])
@@ -554,6 +593,11 @@ def test_convert_saved_model_strip_debug_ops(self):
554593
model_json = json.load(f)
555594
self.assertTrue(model_json['modelTopology'])
556595
self.assertIsNot(model_json['modelTopology']['versions'], None)
596+
signature = model_json['signature']
597+
self.assertIsNot(signature, None)
598+
self.assertIsNot(signature['inputs'], None)
599+
self.assertIsNot(signature['outputs'], None)
600+
557601
weights_manifest = model_json['weightsManifest']
558602
self.assertCountEqual(weights_manifest[0]['paths'],
559603
['group1-shard1of1.bin'])
@@ -574,6 +618,11 @@ def test_convert_hub_module_v1(self):
574618
model_json = json.load(f)
575619
self.assertTrue(model_json['modelTopology'])
576620
self.assertIsNot(model_json['modelTopology']['versions'], None)
621+
signature = model_json['signature']
622+
self.assertIsNot(signature, None)
623+
self.assertIsNot(signature['inputs'], None)
624+
self.assertIsNot(signature['outputs'], None)
625+
577626
weights_manifest = model_json['weightsManifest']
578627
self.assertCountEqual(weights_manifest[0]['paths'],
579628
['group1-shard1of1.bin'])
@@ -596,6 +645,11 @@ def test_convert_hub_module_v2(self):
596645
model_json = json.load(f)
597646
self.assertTrue(model_json['modelTopology'])
598647
self.assertIsNot(model_json['modelTopology']['versions'], None)
648+
signature = model_json['signature']
649+
self.assertIsNot(signature, None)
650+
self.assertIsNot(signature['inputs'], None)
651+
self.assertIsNot(signature['outputs'], None)
652+
599653
weights_manifest = model_json['weightsManifest']
600654
self.assertCountEqual(weights_manifest[0]['paths'],
601655
['group1-shard1of1.bin'])
@@ -621,6 +675,11 @@ def test_convert_frozen_model(self):
621675
model_json = json.load(f)
622676
self.assertTrue(model_json['modelTopology'])
623677
self.assertIsNot(model_json['modelTopology']['versions'], None)
678+
signature = model_json['signature']
679+
self.assertIsNot(signature, None)
680+
# frozen model signature has no input nodes.
681+
self.assertIsNot(signature['outputs'], None)
682+
624683
weights_manifest = model_json['weightsManifest']
625684
self.assertCountEqual(weights_manifest[0]['paths'],
626685
['group1-shard1of1.bin'])

tfjs-converter/python/test_pip_package.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,11 @@ def testConvertTFFrozenModelWithCommandLineWorks(self):
561561
model_json = json.load(f)
562562
self.assertTrue(model_json['modelTopology'])
563563
self.assertIsNot(model_json['modelTopology']['versions'], None)
564+
signature = model_json['signature']
565+
self.assertIsNot(signature, None)
566+
# frozen model signature has no inputs
567+
self.assertIsNot(signature['outputs'], None)
568+
564569
weights_manifest = model_json['weightsManifest']
565570
weights_manifest = model_json['weightsManifest']
566571
self.assertCountEqual(weights_manifest[0]['paths'],

0 commit comments

Comments
 (0)