diff --git a/adk/ADK.py b/adk/ADK.py index 9107bff..51a73f0 100644 --- a/adk/ADK.py +++ b/adk/ADK.py @@ -38,7 +38,7 @@ def __init__(self, apply_func, load_func=None, client=None): self.is_local = not os.path.exists(self.FIFO_PATH) self.load_result = None self.loading_exception = None - self.manifest_path = "model_manifest.json.freeze" + self.manifest_path = "model_manifest.json" self.model_data = self.init_manifest(self.manifest_path) def init_manifest(self, path): diff --git a/adk/modeldata.py b/adk/modeldata.py index 39dadf9..529d8ff 100644 --- a/adk/modeldata.py +++ b/adk/modeldata.py @@ -6,8 +6,9 @@ class ModelData(object): def __init__(self, client, model_manifest_path): - self.manifest_freeze_path = model_manifest_path - self.manifest_data = get_manifest(self.manifest_freeze_path) + self.manifest_path = model_manifest_path + self.manifest_freeze_path = "{}.freeze".format(self.manifest_path) + self.manifest_data = get_manifest(self.manifest_freeze_path, self.manifest_path) self.client = client self.models = {} self.usr_key = "__user__" @@ -27,7 +28,6 @@ def data(self): output[without_usr_key] = __dict[key] return output - def available(self): if self.manifest_data: return True @@ -39,14 +39,16 @@ def initialize(self): raise Exception("Client was not defined, please define a Client when using Model Manifests.") for required_file in self.manifest_data['required_files']: name = required_file['name'] + source_uri = required_file['source_uri'] + fail_on_tamper = required_file.get('fail_on_tamper', False) + expected_hash = required_file.get('md5_checksum', None) if name in self.models: raise Exception("Duplicate 'name' detected. \n" + name + " was found to be used by more than one data file, please rename.") - expected_hash = required_file['md5_checksum'] - with self.client.file(required_file['source_uri']).getFile() as f: + with self.client.file(source_uri).getFile() as f: local_data_path = f.name real_hash = md5_for_file(local_data_path) - if real_hash != expected_hash and required_file['fail_on_tamper']: + if real_hash != expected_hash and fail_on_tamper: raise Exception("Model File Mismatch for " + name + "\nexpected hash: " + expected_hash + "\nreal hash: " + real_hash) else: @@ -70,32 +72,46 @@ def find_optional_model(self, file_name): raise Exception("file with name '" + file_name + "' not found in model manifest.") model_info = found_models[0] self.models[file_name] = {} - expected_hash = model_info['md5_checksum'] - with self.client.file(model_info['source_uri']).getFile() as f: + source_uri = model_info['source_uri'] + fail_on_tamper = model_info.get("fail_on_tamper", False) + expected_hash = model_info.get('md5_checksum', None) + with self.client.file(source_uri).getFile() as f: local_data_path = f.name real_hash = md5_for_file(local_data_path) - if real_hash != expected_hash and model_info['fail_on_tamper']: + if real_hash != expected_hash and fail_on_tamper: raise Exception("Model File Mismatch for " + file_name + "\nexpected hash: " + expected_hash + "\nreal hash: " + real_hash) else: self.models[file_name] = FileData(real_hash, local_data_path) -def get_manifest(manifest_path): - if os.path.exists(manifest_path): - with open(manifest_path) as f: +def get_manifest(manifest_frozen_path, manifest_reg_path): + if os.path.exists(manifest_frozen_path): + with open(manifest_frozen_path) as f: manifest_data = json.load(f) - expected_lock_checksum = manifest_data.get('lock_checksum') - del manifest_data['lock_checksum'] - detected_lock_checksum = md5_for_str(str(manifest_data)) - if expected_lock_checksum != detected_lock_checksum: + if check_lock(manifest_data): + return manifest_data + else: raise Exception("Manifest FreezeFile Tamper Detected; please use the CLI and 'algo freeze' to rebuild your " "algorithm's freeze file.") + elif os.path.exists(manifest_reg_path): + with open(manifest_reg_path) as f: + manifest_data = json.load(f) return manifest_data else: return None +def check_lock(manifest_data): + expected_lock_checksum = manifest_data.get('lock_checksum') + del manifest_data['lock_checksum'] + detected_lock_checksum = md5_for_str(str(manifest_data)) + if expected_lock_checksum != detected_lock_checksum: + return False + else: + return True + + def md5_for_file(fname): hash_md5 = hashlib.md5() with open(fname, "rb") as f: diff --git a/tests/test_adk_local.py b/tests/test_adk_local.py index 5592bfc..4b5585c 100644 --- a/tests/test_adk_local.py +++ b/tests/test_adk_local.py @@ -23,7 +23,7 @@ def execute_example(self, input, apply, load=None): algo.init(input, pprint=lambda x: output.append(x)) return output[0] - def execute_manifest_example(self, input, apply, load, manifest_path="manifests/good_model_manifest.json.freeze"): + def execute_manifest_example(self, input, apply, load, manifest_path): client = Algorithmia.client() algo = ADKTest(apply, load, manifest_path=manifest_path, client=client) output = [] @@ -131,7 +131,7 @@ def test_manifest_file_success(self): actual_output = json.loads(self.execute_manifest_example(input, apply_successful_manifest_parsing, loading_with_manifest, manifest_path="tests/manifests/good_model_manifest" - ".json.freeze")) + ".json")) self.assertEqual(expected_output, actual_output) def test_manifest_file_tampered(self): @@ -145,7 +145,7 @@ def test_manifest_file_tampered(self): actual_output = json.loads(self.execute_manifest_example(input, apply_successful_manifest_parsing, loading_with_manifest, manifest_path="tests/manifests/bad_model_manifest" - ".json.freeze")) + ".json")) self.assertEqual(expected_output, actual_output) diff --git a/tests/test_adk_remote.py b/tests/test_adk_remote.py index f0d69ba..5108dd7 100644 --- a/tests/test_adk_remote.py +++ b/tests/test_adk_remote.py @@ -173,7 +173,7 @@ def test_manifest_file_success(self): actual_output = self.execute_manifest_example(input, apply_successful_manifest_parsing, loading_with_manifest, manifest_path="tests/manifests/good_model_manifest" - ".json.freeze") + ".json") self.assertEqual(expected_output, actual_output)