Skip to content

Commit 6bafca5

Browse files
chore: refactor model monitoring system test to make it more efficient (googleapis#1624)
* chore: refactor model monitoring system test to make it more efficient * formatting Co-authored-by: sina chavoshi <sina.chavoshi@gmail.com>
1 parent b69a061 commit 6bafca5

1 file changed

Lines changed: 27 additions & 146 deletions

File tree

tests/system/aiplatform/test_model_monitoring.py

Lines changed: 27 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -31,31 +31,9 @@
3131

3232
# constants used for testing
3333
USER_EMAIL = ""
34-
MODEL_DISPLAYNAME_KEY = "churn"
35-
MODEL_DISPLAYNAME_KEY2 = "churn2"
36-
IMAGE = "us-docker.pkg.dev/cloud-aiplatform/prediction/tf2-cpu.2-5:latest"
37-
ENDPOINT = "us-central1-aiplatform.googleapis.com"
34+
PERMANENT_CHURN_ENDPOINT_ID = "8289570005524152320"
3835
CHURN_MODEL_PATH = "gs://mco-mm/churn"
39-
_DEFAULT_INPUT = {
40-
"cnt_ad_reward": 0,
41-
"cnt_challenge_a_friend": 0,
42-
"cnt_completed_5_levels": 1,
43-
"cnt_level_complete_quickplay": 3,
44-
"cnt_level_end_quickplay": 5,
45-
"cnt_level_reset_quickplay": 2,
46-
"cnt_level_start_quickplay": 6,
47-
"cnt_post_score": 34,
48-
"cnt_spend_virtual_currency": 0,
49-
"cnt_use_extra_steps": 0,
50-
"cnt_user_engagement": 120,
51-
"country": "Denmark",
52-
"dayofweek": 3,
53-
"julianday": 254,
54-
"language": "da-dk",
55-
"month": 9,
56-
"operating_system": "IOS",
57-
"user_pseudo_id": "104B0770BAE16E8B53DF330C95881893",
58-
}
36+
5937
JOB_NAME = "churn"
6038

6139
# Sampling rate (optional, default=.8)
@@ -72,35 +50,22 @@
7250

7351
# Skew and drift thresholds.
7452
DEFAULT_THRESHOLD_VALUE = 0.001
75-
SKEW_DEFAULT_THRESHOLDS = {
53+
SKEW_THRESHOLDS = {
7654
"country": DEFAULT_THRESHOLD_VALUE,
7755
"cnt_user_engagement": DEFAULT_THRESHOLD_VALUE,
7856
}
79-
SKEW_CUSTOM_THRESHOLDS = {"cnt_level_start_quickplay": 0.01}
80-
DRIFT_DEFAULT_THRESHOLDS = {
57+
DRIFT_THRESHOLDS = {
8158
"country": DEFAULT_THRESHOLD_VALUE,
8259
"cnt_user_engagement": DEFAULT_THRESHOLD_VALUE,
8360
}
84-
DRIFT_CUSTOM_THRESHOLDS = {"cnt_level_start_quickplay": 0.01}
85-
ATTRIB_SKEW_DEFAULT_THRESHOLDS = {
61+
ATTRIB_SKEW_THRESHOLDS = {
8662
"country": DEFAULT_THRESHOLD_VALUE,
8763
"cnt_user_engagement": DEFAULT_THRESHOLD_VALUE,
8864
}
89-
ATTRIB_SKEW_CUSTOM_THRESHOLDS = {"cnt_level_start_quickplay": 0.01}
90-
ATTRIB_DRIFT_DEFAULT_THRESHOLDS = {
65+
ATTRIB_DRIFT_THRESHOLDS = {
9166
"country": DEFAULT_THRESHOLD_VALUE,
9267
"cnt_user_engagement": DEFAULT_THRESHOLD_VALUE,
9368
}
94-
ATTRIB_DRIFT_CUSTOM_THRESHOLDS = {"cnt_level_start_quickplay": 0.01}
95-
96-
skew_thresholds = SKEW_DEFAULT_THRESHOLDS.copy()
97-
skew_thresholds.update(SKEW_CUSTOM_THRESHOLDS)
98-
drift_thresholds = DRIFT_DEFAULT_THRESHOLDS.copy()
99-
drift_thresholds.update(DRIFT_CUSTOM_THRESHOLDS)
100-
attrib_skew_thresholds = ATTRIB_SKEW_DEFAULT_THRESHOLDS.copy()
101-
attrib_skew_thresholds.update(ATTRIB_SKEW_CUSTOM_THRESHOLDS)
102-
attrib_drift_thresholds = ATTRIB_DRIFT_DEFAULT_THRESHOLDS.copy()
103-
attrib_drift_thresholds.update(ATTRIB_DRIFT_CUSTOM_THRESHOLDS)
10469

10570
# global test constants
10671
sampling_strategy = model_monitoring.RandomSampleConfig(sample_rate=LOG_SAMPLE_RATE)
@@ -113,89 +78,34 @@
11378

11479
skew_config = model_monitoring.SkewDetectionConfig(
11580
data_source=DATASET_BQ_URI,
116-
skew_thresholds=skew_thresholds,
117-
attribute_skew_thresholds=attrib_skew_thresholds,
81+
skew_thresholds=SKEW_THRESHOLDS,
82+
attribute_skew_thresholds=ATTRIB_SKEW_THRESHOLDS,
11883
target_field=TARGET,
11984
)
12085

12186
drift_config = model_monitoring.DriftDetectionConfig(
122-
drift_thresholds=drift_thresholds,
123-
attribute_drift_thresholds=attrib_drift_thresholds,
87+
drift_thresholds=DRIFT_THRESHOLDS,
88+
attribute_drift_thresholds=ATTRIB_DRIFT_THRESHOLDS,
12489
)
12590

12691
drift_config2 = model_monitoring.DriftDetectionConfig(
127-
drift_thresholds=drift_thresholds,
128-
attribute_drift_thresholds=ATTRIB_DRIFT_DEFAULT_THRESHOLDS,
92+
drift_thresholds=DRIFT_THRESHOLDS,
12993
)
13094

13195
objective_config = model_monitoring.ObjectiveConfig(skew_config, drift_config)
13296

13397
objective_config2 = model_monitoring.ObjectiveConfig(skew_config, drift_config2)
13498

13599

136-
@pytest.mark.usefixtures("tear_down_resources")
137100
class TestModelDeploymentMonitoring(e2e_base.TestEndToEnd):
138101
_temp_prefix = "temp_e2e_model_monitoring_test_"
102+
endpoint = aiplatform.Endpoint(PERMANENT_CHURN_ENDPOINT_ID)
139103

140-
def temp_endpoint(self, shared_state):
141-
aiplatform.init(
142-
project=e2e_base._PROJECT,
143-
location=e2e_base._LOCATION,
144-
)
145-
146-
model = aiplatform.Model.upload(
147-
display_name=self._make_display_name(key=MODEL_DISPLAYNAME_KEY),
148-
artifact_uri=CHURN_MODEL_PATH,
149-
serving_container_image_uri=IMAGE,
150-
)
151-
shared_state["resources"] = [model]
152-
endpoint = model.deploy(machine_type="n1-standard-2")
153-
predict_response = endpoint.predict(instances=[_DEFAULT_INPUT])
154-
assert len(predict_response.predictions) == 1
155-
shared_state["resources"].append(endpoint)
156-
return [endpoint, model]
157-
158-
def temp_endpoint_with_two_models(self, shared_state):
159-
aiplatform.init(
160-
project=e2e_base._PROJECT,
161-
location=e2e_base._LOCATION,
162-
)
163-
164-
model1 = aiplatform.Model.upload(
165-
display_name=self._make_display_name(key=MODEL_DISPLAYNAME_KEY),
166-
artifact_uri=CHURN_MODEL_PATH,
167-
serving_container_image_uri=IMAGE,
168-
)
169-
170-
model2 = aiplatform.Model.upload(
171-
display_name=self._make_display_name(key=MODEL_DISPLAYNAME_KEY2),
172-
artifact_uri=CHURN_MODEL_PATH,
173-
serving_container_image_uri=IMAGE,
174-
)
175-
shared_state["resources"] = [model1, model2]
176-
endpoint = aiplatform.Endpoint.create(
177-
display_name=self._make_display_name(key=MODEL_DISPLAYNAME_KEY)
178-
)
179-
endpoint.deploy(
180-
model=model1, machine_type="n1-standard-2", traffic_percentage=100
181-
)
182-
endpoint.deploy(
183-
model=model2, machine_type="n1-standard-2", traffic_percentage=30
184-
)
185-
predict_response = endpoint.predict(instances=[_DEFAULT_INPUT])
186-
assert len(predict_response.predictions) == 1
187-
shared_state["resources"].append(endpoint)
188-
return [endpoint, model1, model2]
189-
190-
def test_mdm_one_model_one_valid_config(self, shared_state):
104+
def test_mdm_two_models_one_valid_config(self):
191105
"""
192-
Upload pre-trained churn model from local file and deploy it for prediction.
106+
Enable model monitoring on two existing models deployed to the same endpoint.
193107
"""
194108
# test model monitoring configurations
195-
[temp_endpoint, model] = self.temp_endpoint(shared_state)
196-
197-
job = None
198-
199109
job = aiplatform.ModelDeploymentMonitoringJob.create(
200110
display_name=self._make_display_name(key=JOB_NAME),
201111
logging_sampling_strategy=sampling_strategy,
@@ -205,7 +115,7 @@ def test_mdm_one_model_one_valid_config(self, shared_state):
205115
create_request_timeout=3600,
206116
project=e2e_base._PROJECT,
207117
location=e2e_base._LOCATION,
208-
endpoint=temp_endpoint,
118+
endpoint=self.endpoint,
209119
predict_instance_schema_uri="",
210120
analysis_instance_schema_uri="",
211121
)
@@ -261,18 +171,10 @@ def test_mdm_one_model_one_valid_config(self, shared_state):
261171
job.delete()
262172
with pytest.raises(core_exceptions.NotFound):
263173
job.api_client.get_model_deployment_monitoring_job(name=job_resource)
264-
temp_endpoint.undeploy_all()
265-
temp_endpoint.delete()
266-
model.delete()
267-
268-
def test_mdm_two_models_two_valid_configs(self, shared_state):
269-
[
270-
temp_endpoint_with_two_models,
271-
model1,
272-
model2,
273-
] = self.temp_endpoint_with_two_models(shared_state)
174+
175+
def test_mdm_two_models_two_valid_configs(self):
274176
[deployed_model1, deployed_model2] = list(
275-
map(lambda x: x.id, temp_endpoint_with_two_models.list_models())
177+
map(lambda x: x.id, self.endpoint.list_models())
276178
)
277179
all_configs = {
278180
deployed_model1: objective_config,
@@ -288,7 +190,7 @@ def test_mdm_two_models_two_valid_configs(self, shared_state):
288190
create_request_timeout=3600,
289191
project=e2e_base._PROJECT,
290192
location=e2e_base._LOCATION,
291-
endpoint=temp_endpoint_with_two_models,
193+
endpoint=self.endpoint,
292194
predict_instance_schema_uri="",
293195
analysis_instance_schema_uri="",
294196
)
@@ -330,13 +232,8 @@ def test_mdm_two_models_two_valid_configs(self, shared_state):
330232
)
331233

332234
job.delete()
333-
temp_endpoint_with_two_models.undeploy_all()
334-
temp_endpoint_with_two_models.delete()
335-
model1.delete()
336-
model2.delete()
337235

338-
def test_mdm_invalid_config_incorrect_model_id(self, shared_state):
339-
[temp_endpoint, model] = self.temp_endpoint(shared_state)
236+
def test_mdm_invalid_config_incorrect_model_id(self):
340237
with pytest.raises(ValueError) as e:
341238
aiplatform.ModelDeploymentMonitoringJob.create(
342239
display_name=self._make_display_name(key=JOB_NAME),
@@ -347,18 +244,14 @@ def test_mdm_invalid_config_incorrect_model_id(self, shared_state):
347244
create_request_timeout=3600,
348245
project=e2e_base._PROJECT,
349246
location=e2e_base._LOCATION,
350-
endpoint=temp_endpoint,
247+
endpoint=self.endpoint,
351248
predict_instance_schema_uri="",
352249
analysis_instance_schema_uri="",
353250
deployed_model_ids=[""],
354251
)
355252
assert "Invalid model ID" in str(e.value)
356-
temp_endpoint.undeploy_all()
357-
temp_endpoint.delete()
358-
model.delete()
359253

360-
def test_mdm_invalid_config_xai(self, shared_state):
361-
[temp_endpoint, model] = self.temp_endpoint(shared_state)
254+
def test_mdm_invalid_config_xai(self):
362255
with pytest.raises(RuntimeError) as e:
363256
objective_config.explanation_config = model_monitoring.ExplanationConfig()
364257
aiplatform.ModelDeploymentMonitoringJob.create(
@@ -370,26 +263,18 @@ def test_mdm_invalid_config_xai(self, shared_state):
370263
create_request_timeout=3600,
371264
project=e2e_base._PROJECT,
372265
location=e2e_base._LOCATION,
373-
endpoint=temp_endpoint,
266+
endpoint=self.endpoint,
374267
predict_instance_schema_uri="",
375268
analysis_instance_schema_uri="",
376269
)
377270
assert (
378271
"`explanation_config` should only be enabled if the model has `explanation_spec populated"
379272
in str(e.value)
380273
)
381-
temp_endpoint.undeploy_all()
382-
temp_endpoint.delete()
383-
model.delete()
384-
385-
def test_mdm_two_models_invalid_configs_xai(self, shared_state):
386-
[
387-
temp_endpoint_with_two_models,
388-
model1,
389-
model2,
390-
] = self.temp_endpoint_with_two_models(shared_state)
274+
275+
def test_mdm_two_models_invalid_configs_xai(self):
391276
[deployed_model1, deployed_model2] = list(
392-
map(lambda x: x.id, temp_endpoint_with_two_models.list_models())
277+
map(lambda x: x.id, self.endpoint.list_models())
393278
)
394279
objective_config.explanation_config = model_monitoring.ExplanationConfig()
395280
all_configs = {
@@ -407,15 +292,11 @@ def test_mdm_two_models_invalid_configs_xai(self, shared_state):
407292
create_request_timeout=3600,
408293
project=e2e_base._PROJECT,
409294
location=e2e_base._LOCATION,
410-
endpoint=temp_endpoint_with_two_models,
295+
endpoint=self.endpoint,
411296
predict_instance_schema_uri="",
412297
analysis_instance_schema_uri="",
413298
)
414299
assert (
415300
"`explanation_config` should only be enabled if the model has `explanation_spec populated"
416301
in str(e.value)
417302
)
418-
temp_endpoint_with_two_models.undeploy_all()
419-
temp_endpoint_with_two_models.delete()
420-
model1.delete()
421-
model2.delete()

0 commit comments

Comments
 (0)