Skip to content

Commit f87fef0

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: updated proto message formatting logic for batch predict model monitoring
PiperOrigin-RevId: 499377219
1 parent 65300c4 commit f87fef0

File tree

5 files changed

+422
-101
lines changed

5 files changed

+422
-101
lines changed

google/cloud/aiplatform/jobs.py

Lines changed: 56 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,11 @@
4141
io as gca_io_compat,
4242
job_state as gca_job_state,
4343
hyperparameter_tuning_job as gca_hyperparameter_tuning_job_compat,
44-
machine_resources as gca_machine_resources_compat,
45-
manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters_compat,
4644
study as gca_study_compat,
4745
model_deployment_monitoring_job as gca_model_deployment_monitoring_job_compat,
48-
)
46+
job_state_v1beta1 as gca_job_state_v1beta1,
47+
model_monitoring_v1beta1 as gca_model_monitoring_v1beta1,
48+
) # TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
4949

5050
from google.cloud.aiplatform.constants import base as constants
5151
from google.cloud.aiplatform import initializer
@@ -63,16 +63,23 @@
6363

6464
_LOGGER = base.Logger(__name__)
6565

66+
# TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
6667
_JOB_COMPLETE_STATES = (
6768
gca_job_state.JobState.JOB_STATE_SUCCEEDED,
6869
gca_job_state.JobState.JOB_STATE_FAILED,
6970
gca_job_state.JobState.JOB_STATE_CANCELLED,
7071
gca_job_state.JobState.JOB_STATE_PAUSED,
72+
gca_job_state_v1beta1.JobState.JOB_STATE_SUCCEEDED,
73+
gca_job_state_v1beta1.JobState.JOB_STATE_FAILED,
74+
gca_job_state_v1beta1.JobState.JOB_STATE_CANCELLED,
75+
gca_job_state_v1beta1.JobState.JOB_STATE_PAUSED,
7176
)
7277

7378
_JOB_ERROR_STATES = (
7479
gca_job_state.JobState.JOB_STATE_FAILED,
7580
gca_job_state.JobState.JOB_STATE_CANCELLED,
81+
gca_job_state_v1beta1.JobState.JOB_STATE_FAILED,
82+
gca_job_state_v1beta1.JobState.JOB_STATE_CANCELLED,
7683
)
7784

7885
# _block_until_complete wait times
@@ -583,6 +590,23 @@ def create(
583590
(jobs.BatchPredictionJob):
584591
Instantiated representation of the created batch prediction job.
585592
"""
593+
# TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
594+
if model_monitoring_objective_config:
595+
from google.cloud.aiplatform.compat.types import (
596+
batch_prediction_job_v1beta1 as gca_bp_job_compat,
597+
io_v1beta1 as gca_io_compat,
598+
explanation_v1beta1 as gca_explanation_v1beta1,
599+
machine_resources_v1beta1 as gca_machine_resources_compat,
600+
manual_batch_tuning_parameters_v1beta1 as gca_manual_batch_tuning_parameters_compat,
601+
)
602+
else:
603+
from google.cloud.aiplatform.compat.types import (
604+
batch_prediction_job as gca_bp_job_compat,
605+
io as gca_io_compat,
606+
explanation as gca_explanation_v1beta1,
607+
machine_resources as gca_machine_resources_compat,
608+
manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters_compat,
609+
)
586610
if not job_display_name:
587611
job_display_name = cls._generate_display_name()
588612

@@ -629,18 +653,7 @@ def create(
629653
f"{predictions_format} is not an accepted prediction format "
630654
f"type. Please choose from: {constants.BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS}"
631655
)
632-
# TODO(b/242108750): remove temporary re-import statements once model monitoring for batch prediction is GA
633-
if model_monitoring_objective_config:
634-
from google.cloud.aiplatform.compat.types import (
635-
io_v1beta1 as gca_io_compat,
636-
batch_prediction_job_v1beta1 as gca_bp_job_compat,
637-
model_monitoring_v1beta1 as gca_model_monitoring_compat,
638-
)
639-
else:
640-
from google.cloud.aiplatform.compat.types import (
641-
io as gca_io_compat,
642-
batch_prediction_job as gca_bp_job_compat,
643-
)
656+
644657
gapic_batch_prediction_job = gca_bp_job_compat.BatchPredictionJob()
645658

646659
# Required Fields
@@ -721,40 +734,44 @@ def create(
721734
gapic_batch_prediction_job.generate_explanation = generate_explanation
722735

723736
if explanation_metadata or explanation_parameters:
724-
gapic_batch_prediction_job.explanation_spec = (
725-
gca_explanation_compat.ExplanationSpec(
726-
metadata=explanation_metadata, parameters=explanation_parameters
727-
)
737+
explanation_spec = gca_explanation_compat.ExplanationSpec(
738+
metadata=explanation_metadata, parameters=explanation_parameters
728739
)
740+
# TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
741+
if model_monitoring_objective_config:
729742

730-
# Model Monitoring
731-
if model_monitoring_objective_config:
732-
if model_monitoring_objective_config.drift_detection_config:
733-
_LOGGER.info(
734-
"Drift detection config is currently not supported for monitoring models associated with batch prediction jobs."
735-
)
736-
if model_monitoring_objective_config.explanation_config:
737-
_LOGGER.info(
738-
"XAI config is currently not supported for monitoring models associated with batch prediction jobs."
743+
explanation_spec = gca_explanation_v1beta1.ExplanationSpec.deserialize(
744+
gca_explanation_compat.ExplanationSpec.serialize(explanation_spec)
739745
)
740-
gapic_batch_prediction_job.model_monitoring_config = (
741-
gca_model_monitoring_compat.ModelMonitoringConfig(
742-
objective_configs=[
743-
model_monitoring_objective_config.as_proto(config_for_bp=True)
744-
],
745-
alert_config=model_monitoring_alert_config.as_proto(
746-
config_for_bp=True
747-
),
748-
analysis_instance_schema_uri=analysis_instance_schema_uri,
749-
)
750-
)
746+
gapic_batch_prediction_job.explanation_spec = explanation_spec
751747

752748
empty_batch_prediction_job = cls._empty_constructor(
753749
project=project,
754750
location=location,
755751
credentials=credentials,
756752
)
753+
if model_monitoring_objective_config:
754+
empty_batch_prediction_job.api_client = (
755+
empty_batch_prediction_job.api_client.select_version("v1beta1")
756+
)
757757

758+
# TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
759+
if model_monitoring_objective_config:
760+
model_monitoring_objective_config._config_for_bp = True
761+
if model_monitoring_alert_config is not None:
762+
model_monitoring_alert_config._config_for_bp = True
763+
gapic_mm_config = gca_model_monitoring_v1beta1.ModelMonitoringConfig(
764+
objective_configs=[model_monitoring_objective_config.as_proto()],
765+
alert_config=model_monitoring_alert_config.as_proto()
766+
if model_monitoring_alert_config is not None
767+
else None,
768+
analysis_instance_schema_uri=analysis_instance_schema_uri
769+
if analysis_instance_schema_uri is not None
770+
else None,
771+
)
772+
gapic_batch_prediction_job.model_monitoring_config = gapic_mm_config
773+
774+
# TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
758775
return cls._create(
759776
empty_batch_prediction_job=empty_batch_prediction_job,
760777
model_or_model_name=model_name,
@@ -763,11 +780,6 @@ def create(
763780
sync=sync,
764781
create_request_timeout=create_request_timeout,
765782
)
766-
# TODO(b/242108750): remove temporary re-import statements once model monitoring for batch prediction is GA
767-
from google.cloud.aiplatform.compat.types import (
768-
io as gca_io_compat,
769-
batch_prediction_job as gca_bp_job_compat,
770-
)
771783

772784
@classmethod
773785
@base.optional_sync(return_input_arg="empty_batch_prediction_job")

google/cloud/aiplatform/model_monitoring/alert.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
model_monitoring as gca_model_monitoring_v1,
2121
)
2222

23-
# TODO(b/242108750): remove temporary re-import statements once model monitoring for batch prediction is GA
23+
# TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
2424
from google.cloud.aiplatform_v1beta1.types import (
2525
model_monitoring as gca_model_monitoring_v1beta1,
2626
)
@@ -46,17 +46,16 @@ def __init__(
4646
"""
4747
self.enable_logging = enable_logging
4848
self.user_emails = user_emails
49+
self._config_for_bp = False
4950

50-
# TODO(b/242108750): remove temporary re-import statements once model monitoring for batch prediction is GA
51-
def as_proto(self, config_for_bp: bool = False):
52-
"""Returns EmailAlertConfig as a proto message.
51+
# TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
52+
def as_proto(self) -> gca_model_monitoring.ModelMonitoringAlertConfig:
53+
"""Converts EmailAlertConfig to a proto message.
5354
54-
Args:
55-
config_for_bp (bool):
56-
Optional. Set this parameter to True if the config object
57-
is used for model monitoring on a batch prediction job.
55+
Returns:
56+
The GAPIC representation of the email alert config.
5857
"""
59-
if config_for_bp:
58+
if self._config_for_bp:
6059
gca_model_monitoring = gca_model_monitoring_v1beta1
6160
else:
6261
gca_model_monitoring = gca_model_monitoring_v1

google/cloud/aiplatform/model_monitoring/objective.py

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,16 @@
1818
from typing import Optional, Dict, Union
1919

2020
from google.cloud.aiplatform_v1.types import (
21-
io as gca_io_v1,
21+
io as gca_io,
2222
model_monitoring as gca_model_monitoring_v1,
2323
)
2424

25-
# TODO(b/242108750): remove temporary re-import statements once model monitoring for batch prediction is GA
25+
# TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
2626
from google.cloud.aiplatform_v1beta1.types import (
27-
io as gca_io_v1beta1,
2827
model_monitoring as gca_model_monitoring_v1beta1,
2928
)
3029

3130
gca_model_monitoring = gca_model_monitoring_v1
32-
gca_io = gca_io_v1
3331

3432
TF_RECORD = "tf-record"
3533
CSV = "csv"
@@ -92,8 +90,14 @@ def __init__(
9290
self.data_format = data_format
9391
self.target_field = target_field
9492

95-
def as_proto(self):
96-
"""Returns _SkewDetectionConfig as a proto message."""
93+
def as_proto(
94+
self,
95+
) -> gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingPredictionSkewDetectionConfig:
96+
"""Converts _SkewDetectionConfig to a proto message.
97+
98+
Returns:
99+
The GAPIC representation of the skew detection config.
100+
"""
97101
skew_thresholds_mapping = {}
98102
attribution_score_skew_thresholds_mapping = {}
99103
default_skew_threshold = None
@@ -147,8 +151,14 @@ def __init__(
147151
self.drift_thresholds = drift_thresholds
148152
self.attribute_drift_thresholds = attribute_drift_thresholds
149153

150-
def as_proto(self):
151-
"""Returns drift detection config as a proto message."""
154+
def as_proto(
155+
self,
156+
) -> gca_model_monitoring.ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig:
157+
"""Converts _DriftDetectionConfig to a proto message.
158+
159+
Returns:
160+
The GAPIC representation of the drift detection config.
161+
"""
152162
drift_thresholds_mapping = {}
153163
attribution_score_drift_thresholds_mapping = {}
154164
if self.drift_thresholds is not None:
@@ -178,8 +188,14 @@ def __init__(self):
178188
"""Base class for ExplanationConfig."""
179189
self.enable_feature_attributes = False
180190

181-
def as_proto(self):
182-
"""Returns _ExplanationConfig as a proto message."""
191+
def as_proto(
192+
self,
193+
) -> gca_model_monitoring.ModelMonitoringObjectiveConfig.ExplanationConfig:
194+
"""Converts _ExplanationConfig to a proto message.
195+
196+
Returns:
197+
The GAPIC representation of the explanation config.
198+
"""
183199
return gca_model_monitoring.ModelMonitoringObjectiveConfig.ExplanationConfig(
184200
enable_feature_attributes=self.enable_feature_attributes
185201
)
@@ -208,22 +224,15 @@ def __init__(
208224
self.skew_detection_config = skew_detection_config
209225
self.drift_detection_config = drift_detection_config
210226
self.explanation_config = explanation_config
227+
# TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
228+
self._config_for_bp = False
211229

212-
# TODO(b/242108750): remove temporary re-import statements once model monitoring for batch prediction is GA
213-
def as_proto(self, config_for_bp: bool = False):
214-
"""Returns _SkewDetectionConfig as a proto message.
230+
def as_proto(self) -> gca_model_monitoring.ModelMonitoringObjectiveConfig:
231+
"""Converts _ObjectiveConfig to a proto message.
215232
216-
Args:
217-
config_for_bp (bool):
218-
Optional. Set this parameter to True if the config object
219-
is used for model monitoring on a batch prediction job.
233+
Returns:
234+
The GAPIC representation of the objective config.
220235
"""
221-
if config_for_bp:
222-
gca_io = gca_io_v1beta1
223-
gca_model_monitoring = gca_model_monitoring_v1beta1
224-
else:
225-
gca_io = gca_io_v1
226-
gca_model_monitoring = gca_model_monitoring_v1
227236
training_dataset = None
228237
if self.skew_detection_config is not None:
229238
training_dataset = (
@@ -252,7 +261,8 @@ def as_proto(self, config_for_bp: bool = False):
252261
else:
253262
training_dataset.dataset = self.skew_detection_config.data_source
254263

255-
return gca_model_monitoring.ModelMonitoringObjectiveConfig(
264+
# TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
265+
gapic_config = gca_model_monitoring.ModelMonitoringObjectiveConfig(
256266
training_dataset=training_dataset,
257267
training_prediction_skew_detection_config=self.skew_detection_config.as_proto()
258268
if self.skew_detection_config is not None
@@ -264,6 +274,15 @@ def as_proto(self, config_for_bp: bool = False):
264274
if self.explanation_config is not None
265275
else None,
266276
)
277+
if self._config_for_bp:
278+
return (
279+
gca_model_monitoring_v1beta1.ModelMonitoringObjectiveConfig.deserialize(
280+
gca_model_monitoring.ModelMonitoringObjectiveConfig.serialize(
281+
gapic_config
282+
)
283+
)
284+
)
285+
return gapic_config
267286

268287

269288
class SkewDetectionConfig(_SkewDetectionConfig):

0 commit comments

Comments
 (0)