Skip to content

Commit 4e6904a

Browse files
authored
fix the directory error in wizard bin (tensorflow#2139)
BUG * fix the directory error in wizard bin * fix the saved model test * fixed lint error
1 parent 5aa8828 commit 4e6904a

2 files changed

Lines changed: 11 additions & 61 deletions

File tree

tfjs-converter/python/setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def _get_requirements(file):
2727

2828
CONSOLE_SCRIPTS = [
2929
'tensorflowjs_converter = tensorflowjs.converters.converter:pip_main',
30-
'tensorflowjs_wizard = tensorflowjs.wizard:pip_main',
30+
'tensorflowjs_wizard = tensorflowjs.converters.wizard:pip_main',
3131
]
3232

3333
setuptools.setup(
@@ -59,7 +59,6 @@ def _get_requirements(file):
5959
'tensorflowjs.quantization',
6060
'tensorflowjs.read_weights',
6161
'tensorflowjs.resource_loader',
62-
'tensorflowjs.wizard',
6362
'tensorflowjs.write_weights',
6463
'tensorflowjs.converters',
6564
'tensorflowjs.converters.common',
@@ -69,6 +68,7 @@ def _get_requirements(file):
6968
'tensorflowjs.converters.keras_h5_conversion',
7069
'tensorflowjs.converters.keras_tfjs_loader',
7170
'tensorflowjs.converters.tf_saved_model_conversion_v2',
71+
'tensorflowjs.converters.wizard',
7272
],
7373
include_package_data=True,
7474
packages=['tensorflowjs/op_list'],

tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py

Lines changed: 9 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -248,20 +248,16 @@ def test_convert_saved_model_v1(self):
248248
output_dir
249249
)
250250

251-
weights = [{
252-
'paths': ['group1-shard1of1.bin'],
253-
'weights': [{'dtype': 'float32', 'name': 'w', 'shape': [2, 2]}]}]
254-
255251
tfjs_path = os.path.join(self._tmp_dir, SAVED_MODEL_DIR, 'js')
256252
# Check model.json and weights manifest.
257253
with open(os.path.join(tfjs_path, 'model.json'), 'rt') as f:
258254
model_json = json.load(f)
259255
self.assertTrue(model_json['modelTopology'])
260256
weights_manifest = model_json['weightsManifest']
261257
self.assertCountEqual(weights_manifest[0]['paths'],
262-
weights[0]['paths'])
263-
self.assertCountEqual(weights_manifest[0]['weights'],
264-
weights[0]['weights'])
258+
['group1-shard1of1.bin'])
259+
self.assertEqual(len(weights_manifest[0]['weights']), 1)
260+
265261
# Check meta-data in the artifact JSON.
266262
self.assertEqual(model_json['format'], 'graph-model')
267263
self.assertEqual(
@@ -309,21 +305,15 @@ def test_convert_saved_model(self):
309305
os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
310306
)
311307

312-
weights = [{'dtype': 'float32',
313-
'name': 'StatefulPartitionedCall/mul',
314-
'shape': []}]
315-
316308
tfjs_path = os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
317309
# Check model.json and weights manifest.
318310
with open(os.path.join(tfjs_path, 'model.json'), 'rt') as f:
319311
model_json = json.load(f)
320312
self.assertTrue(model_json['modelTopology'])
321313
weights_manifest = model_json['weightsManifest']
322-
self.assertEqual(len(weights_manifest), len(weights))
323314
self.assertCountEqual(weights_manifest[0]['paths'],
324315
['group1-shard1of1.bin'])
325-
self.assertCountEqual(weights_manifest[0]['weights'],
326-
weights)
316+
self.assertEqual(len(weights_manifest[0]['weights']), 1)
327317

328318
def test_convert_saved_model_with_fused_conv2d(self):
329319
self._create_saved_model_with_fusable_conv2d()
@@ -445,21 +435,6 @@ def test_convert_saved_model_with_control_flow(self):
445435
os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
446436
)
447437

448-
weights = [
449-
{'dtype': 'int32', 'shape': [],
450-
'name': 'StatefulPartitionedCall/while/loop_counter'},
451-
{'dtype': 'int32', 'shape': [],
452-
'name': 'StatefulPartitionedCall/while/maximum_iterations'
453-
},
454-
{'dtype': 'int32', 'shape': [],
455-
'name': 'StatefulPartitionedCall/while/cond/_3/mod/y'},
456-
{'dtype': 'int32', 'shape': [],
457-
'name': 'StatefulPartitionedCall/while/cond/_3/Equal/y'},
458-
{'dtype': 'int32', 'shape': [],
459-
'name': 'StatefulPartitionedCall/while/body/_4/add_1/y'},
460-
{'name': 'StatefulPartitionedCall/add/y',
461-
'dtype': 'int32', 'shape': []}]
462-
463438
tfjs_path = os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
464439
# Check model.json and weights manifest.
465440
with open(os.path.join(tfjs_path, 'model.json'), 'rt') as f:
@@ -468,8 +443,7 @@ def test_convert_saved_model_with_control_flow(self):
468443
weights_manifest = model_json['weightsManifest']
469444
self.assertCountEqual(weights_manifest[0]['paths'],
470445
['group1-shard1of1.bin'])
471-
self.assertCountEqual(weights_manifest[0]['weights'],
472-
weights)
446+
self.assertEqual(len(weights_manifest[0]['weights']), 6)
473447

474448
# Check meta-data in the artifact JSON.
475449
self.assertEqual(model_json['format'], 'graph-model')
@@ -499,9 +473,6 @@ def test_convert_saved_model_skip_op_check(self):
499473
os.path.join(self._tmp_dir, SAVED_MODEL_DIR), skip_op_check=True
500474
)
501475

502-
weights = [{'dtype': 'float32',
503-
'name': 'StatefulPartitionedCall/MatrixDiag',
504-
'shape': [2, 2, 2]}]
505476
tfjs_path = os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
506477
# Check model.json and weights manifest.
507478
with open(os.path.join(tfjs_path, 'model.json'), 'rt') as f:
@@ -510,8 +481,7 @@ def test_convert_saved_model_skip_op_check(self):
510481
weights_manifest = model_json['weightsManifest']
511482
self.assertCountEqual(weights_manifest[0]['paths'],
512483
['group1-shard1of1.bin'])
513-
self.assertCountEqual(weights_manifest[0]['weights'],
514-
weights)
484+
self.assertEqual(len(weights_manifest[0]['weights']), 1)
515485
self.assertTrue(
516486
glob.glob(
517487
os.path.join(self._tmp_dir, SAVED_MODEL_DIR, 'group*-*')))
@@ -527,11 +497,6 @@ def test_convert_saved_model_strip_debug_ops(self):
527497
os.path.join(self._tmp_dir, SAVED_MODEL_DIR),
528498
strip_debug_ops=True)
529499

530-
weights = [{
531-
'dtype': 'float32',
532-
'name': 'add',
533-
'shape': [2, 2]
534-
}]
535500
tfjs_path = os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
536501
# Check model.json and weights manifest.
537502
with open(os.path.join(tfjs_path, 'model.json'), 'rt') as f:
@@ -540,8 +505,7 @@ def test_convert_saved_model_strip_debug_ops(self):
540505
weights_manifest = model_json['weightsManifest']
541506
self.assertCountEqual(weights_manifest[0]['paths'],
542507
['group1-shard1of1.bin'])
543-
self.assertCountEqual(weights_manifest[0]['weights'],
544-
weights)
508+
self.assertEqual(len(weights_manifest[0]['weights']), 1)
545509
self.assertTrue(
546510
glob.glob(
547511
os.path.join(self._tmp_dir, SAVED_MODEL_DIR, 'group*-*')))
@@ -553,12 +517,6 @@ def test_convert_hub_module_v1(self):
553517

554518
tf_saved_model_conversion_v2.convert_tf_hub_module(module_path, tfjs_path)
555519

556-
weights = [{
557-
'shape': [2],
558-
'name': 'module/Variable',
559-
'dtype': 'float32'
560-
}]
561-
562520
# Check model.json and weights manifest.
563521
with open(os.path.join(tfjs_path, 'model.json'), 'rt') as f:
564522
model_json = json.load(f)
@@ -567,8 +525,7 @@ def test_convert_hub_module_v1(self):
567525
weights_manifest = model_json['weightsManifest']
568526
self.assertCountEqual(weights_manifest[0]['paths'],
569527
['group1-shard1of1.bin'])
570-
self.assertCountEqual(weights_manifest[0]['weights'],
571-
weights)
528+
self.assertEqual(len(weights_manifest[0]['weights']), 1)
572529

573530
self.assertTrue(
574531
glob.glob(
@@ -582,21 +539,14 @@ def test_convert_hub_module_v2(self):
582539
tf_saved_model_conversion_v2.convert_tf_hub_module(
583540
module_path, tfjs_path, "serving_default", "serve")
584541

585-
weights = [{
586-
'shape': [],
587-
'name': 'StatefulPartitionedCall/mul',
588-
'dtype': 'float32'
589-
}]
590-
591542
# Check model.json and weights manifest.
592543
with open(os.path.join(tfjs_path, 'model.json'), 'rt') as f:
593544
model_json = json.load(f)
594545
self.assertTrue(model_json['modelTopology'])
595546
weights_manifest = model_json['weightsManifest']
596547
self.assertCountEqual(weights_manifest[0]['paths'],
597548
['group1-shard1of1.bin'])
598-
self.assertCountEqual(weights_manifest[0]['weights'],
599-
weights)
549+
self.assertEqual(len(weights_manifest[0]['weights']), 1)
600550

601551
self.assertTrue(
602552
glob.glob(

0 commit comments

Comments
 (0)