diff --git a/README.md b/README.md index 562a095..2ff4709 100644 --- a/README.md +++ b/README.md @@ -106,7 +106,7 @@ def load(modelData): # during runtime. # Any variables returned here, will be passed as the secondary argument to your 'algorithm' function - modelData.user_data['payload'] = "Loading has been completed." + modelData['payload'] = "Loading has been completed." return modelData @@ -176,9 +176,9 @@ def infer_image(image_url, n, globals): def load(modelData): - modelData.user_data["SMID_ALGO"] = "algo://util/SmartImageDownloader/0.2.x" - modelData.user_data["model"] = load_model(modelData.get_model("squeezenet")) - modelData.user_data["labels"] = load_labels(modelData.get_model("labels")) + modelData["SMID_ALGO"] = "algo://util/SmartImageDownloader/0.2.x" + modelData["model"] = load_model(modelData.get_model("squeezenet")) + modelData["labels"] = load_labels(modelData.get_model("labels")) return modelData @@ -190,10 +190,10 @@ def apply(input, modelData): n = 3 if "data" in input: if isinstance(input["data"], str): - output = infer_image(input["data"], n, modelData.user_data) + output = infer_image(input["data"], n, modelData) elif isinstance(input["data"], list): for row in input["data"]: - row["predictions"] = infer_image(row["image_url"], n, modelData.user_data) + row["predictions"] = infer_image(row["image_url"], n, modelData) output = input["data"] else: raise Exception("\"data\" must be a image url or a list of image urls (with labels)") @@ -257,4 +257,4 @@ Verify that it works on pytest, then: ```commandline python -m twine upload -r pypi dist/* ``` -and you're done :) +and you're done :) \ No newline at end of file diff --git a/adk/modeldata.py b/adk/modeldata.py index 7256cc7..39dadf9 100644 --- a/adk/modeldata.py +++ b/adk/modeldata.py @@ -10,8 +10,23 @@ def __init__(self, client, model_manifest_path): self.manifest_data = get_manifest(self.manifest_freeze_path) self.client = client self.models = {} - self.user_data = {} - self.system_data = {} + self.usr_key = "__user__" + + def __getitem__(self, key): + return getattr(self, self.usr_key + key) + + def __setitem__(self, key, value): + setattr(self, self.usr_key + key, value) + + def data(self): + __dict = self.__dict__ + output = {} + for key in __dict.keys(): + if self.usr_key in key: + without_usr_key = key.split(self.usr_key)[1] + output[without_usr_key] = __dict[key] + return output + def available(self): if self.manifest_data: diff --git a/examples/loaded_state_hello_world/src/Algorithm.py b/examples/loaded_state_hello_world/src/Algorithm.py index a58fd71..27ab7b0 100644 --- a/examples/loaded_state_hello_world/src/Algorithm.py +++ b/examples/loaded_state_hello_world/src/Algorithm.py @@ -17,7 +17,7 @@ def load(modelData): # during runtime. # Any variables returned here, will be passed as the secondary argument to your 'algorithm' function - modelData.user_data['payload'] = "Loading has been completed." + modelData['payload'] = "Loading has been completed." return modelData diff --git a/examples/pytorch_image_classification/src/Algorithm.py b/examples/pytorch_image_classification/src/Algorithm.py index 85e8e55..c1ba558 100644 --- a/examples/pytorch_image_classification/src/Algorithm.py +++ b/examples/pytorch_image_classification/src/Algorithm.py @@ -53,9 +53,9 @@ def infer_image(image_url, n, globals): def load(modelData): - modelData.user_data["SMID_ALGO"] = "algo://util/SmartImageDownloader/0.2.x" - modelData.user_data["model"] = load_model(modelData.get_model("squeezenet")) - modelData.user_data["labels"] = load_labels(modelData.get_model("labels")) + modelData["SMID_ALGO"] = "algo://util/SmartImageDownloader/0.2.x" + modelData["model"] = load_model(modelData.get_model("squeezenet")) + modelData["labels"] = load_labels(modelData.get_model("labels")) return modelData @@ -67,10 +67,10 @@ def apply(input, modelData): n = 3 if "data" in input: if isinstance(input["data"], str): - output = infer_image(input["data"], n, modelData.user_data) + output = infer_image(input["data"], n, modelData) elif isinstance(input["data"], list): for row in input["data"]: - row["predictions"] = infer_image(row["image_url"], n, modelData.user_data) + row["predictions"] = infer_image(row["image_url"], n, modelData) output = input["data"] else: raise Exception("\"data\" must be a image url or a list of image urls (with labels)") diff --git a/tests/adk_algorithms.py b/tests/adk_algorithms.py index 7a41600..7e86a1c 100644 --- a/tests/adk_algorithms.py +++ b/tests/adk_algorithms.py @@ -16,7 +16,7 @@ def apply_binary(input): def apply_input_or_context(input, model_data=None): if model_data: - return model_data.user_data + return model_data.data() else: return "hello " + input @@ -30,7 +30,7 @@ def apply_successful_manifest_parsing(input, model_data): # -- Loading functions --- # def loading_text(modelData): - modelData.user_data['message'] = 'This message was loaded prior to runtime' + modelData['message'] = 'This message was loaded prior to runtime' return modelData @@ -39,14 +39,14 @@ def loading_exception(modelData): def loading_file_from_algorithmia(modelData): - modelData.user_data['data_url'] = 'data://demo/collection/somefile.json' - modelData.user_data['data'] = modelData.client.file(modelData.user_data['data_url']).getJson() + modelData['data_url'] = 'data://demo/collection/somefile.json' + modelData['data'] = modelData.client.file(modelData['data_url']).getJson() return modelData def loading_with_manifest(modelData): - modelData.user_data["squeezenet"] = modelData.get_model("squeezenet") - modelData.user_data['labels'] = modelData.get_model("labels") + modelData["squeezenet"] = modelData.get_model("squeezenet") + modelData['labels'] = modelData.get_model("labels") # optional model - modelData.user_data['mobilenet'] = modelData.get_model("mobilenet") + modelData['mobilenet'] = modelData.get_model("mobilenet") return modelData