4141 io as gca_io_compat ,
4242 job_state as gca_job_state ,
4343 hyperparameter_tuning_job as gca_hyperparameter_tuning_job_compat ,
44- machine_resources as gca_machine_resources_compat ,
45- manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters_compat ,
4644 study as gca_study_compat ,
4745 model_deployment_monitoring_job as gca_model_deployment_monitoring_job_compat ,
48- )
46+ job_state_v1beta1 as gca_job_state_v1beta1 ,
47+ model_monitoring_v1beta1 as gca_model_monitoring_v1beta1 ,
48+ ) # TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
4949
5050from google .cloud .aiplatform .constants import base as constants
5151from google .cloud .aiplatform import initializer
6363
6464_LOGGER = base .Logger (__name__ )
6565
66+ # TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
6667_JOB_COMPLETE_STATES = (
6768 gca_job_state .JobState .JOB_STATE_SUCCEEDED ,
6869 gca_job_state .JobState .JOB_STATE_FAILED ,
6970 gca_job_state .JobState .JOB_STATE_CANCELLED ,
7071 gca_job_state .JobState .JOB_STATE_PAUSED ,
72+ gca_job_state_v1beta1 .JobState .JOB_STATE_SUCCEEDED ,
73+ gca_job_state_v1beta1 .JobState .JOB_STATE_FAILED ,
74+ gca_job_state_v1beta1 .JobState .JOB_STATE_CANCELLED ,
75+ gca_job_state_v1beta1 .JobState .JOB_STATE_PAUSED ,
7176)
7277
7378_JOB_ERROR_STATES = (
7479 gca_job_state .JobState .JOB_STATE_FAILED ,
7580 gca_job_state .JobState .JOB_STATE_CANCELLED ,
81+ gca_job_state_v1beta1 .JobState .JOB_STATE_FAILED ,
82+ gca_job_state_v1beta1 .JobState .JOB_STATE_CANCELLED ,
7683)
7784
7885# _block_until_complete wait times
@@ -583,6 +590,23 @@ def create(
583590 (jobs.BatchPredictionJob):
584591 Instantiated representation of the created batch prediction job.
585592 """
593+ # TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
594+ if model_monitoring_objective_config :
595+ from google .cloud .aiplatform .compat .types import (
596+ batch_prediction_job_v1beta1 as gca_bp_job_compat ,
597+ io_v1beta1 as gca_io_compat ,
598+ explanation_v1beta1 as gca_explanation_v1beta1 ,
599+ machine_resources_v1beta1 as gca_machine_resources_compat ,
600+ manual_batch_tuning_parameters_v1beta1 as gca_manual_batch_tuning_parameters_compat ,
601+ )
602+ else :
603+ from google .cloud .aiplatform .compat .types import (
604+ batch_prediction_job as gca_bp_job_compat ,
605+ io as gca_io_compat ,
606+ explanation as gca_explanation_v1beta1 ,
607+ machine_resources as gca_machine_resources_compat ,
608+ manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters_compat ,
609+ )
586610 if not job_display_name :
587611 job_display_name = cls ._generate_display_name ()
588612
@@ -629,18 +653,7 @@ def create(
629653 f"{ predictions_format } is not an accepted prediction format "
630654 f"type. Please choose from: { constants .BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS } "
631655 )
632- # TODO(b/242108750): remove temporary re-import statements once model monitoring for batch prediction is GA
633- if model_monitoring_objective_config :
634- from google .cloud .aiplatform .compat .types import (
635- io_v1beta1 as gca_io_compat ,
636- batch_prediction_job_v1beta1 as gca_bp_job_compat ,
637- model_monitoring_v1beta1 as gca_model_monitoring_compat ,
638- )
639- else :
640- from google .cloud .aiplatform .compat .types import (
641- io as gca_io_compat ,
642- batch_prediction_job as gca_bp_job_compat ,
643- )
656+
644657 gapic_batch_prediction_job = gca_bp_job_compat .BatchPredictionJob ()
645658
646659 # Required Fields
@@ -721,40 +734,44 @@ def create(
721734 gapic_batch_prediction_job .generate_explanation = generate_explanation
722735
723736 if explanation_metadata or explanation_parameters :
724- gapic_batch_prediction_job .explanation_spec = (
725- gca_explanation_compat .ExplanationSpec (
726- metadata = explanation_metadata , parameters = explanation_parameters
727- )
737+ explanation_spec = gca_explanation_compat .ExplanationSpec (
738+ metadata = explanation_metadata , parameters = explanation_parameters
728739 )
740+ # TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
741+ if model_monitoring_objective_config :
729742
730- # Model Monitoring
731- if model_monitoring_objective_config :
732- if model_monitoring_objective_config .drift_detection_config :
733- _LOGGER .info (
734- "Drift detection config is currently not supported for monitoring models associated with batch prediction jobs."
735- )
736- if model_monitoring_objective_config .explanation_config :
737- _LOGGER .info (
738- "XAI config is currently not supported for monitoring models associated with batch prediction jobs."
743+ explanation_spec = gca_explanation_v1beta1 .ExplanationSpec .deserialize (
744+ gca_explanation_compat .ExplanationSpec .serialize (explanation_spec )
739745 )
740- gapic_batch_prediction_job .model_monitoring_config = (
741- gca_model_monitoring_compat .ModelMonitoringConfig (
742- objective_configs = [
743- model_monitoring_objective_config .as_proto (config_for_bp = True )
744- ],
745- alert_config = model_monitoring_alert_config .as_proto (
746- config_for_bp = True
747- ),
748- analysis_instance_schema_uri = analysis_instance_schema_uri ,
749- )
750- )
746+ gapic_batch_prediction_job .explanation_spec = explanation_spec
751747
752748 empty_batch_prediction_job = cls ._empty_constructor (
753749 project = project ,
754750 location = location ,
755751 credentials = credentials ,
756752 )
753+ if model_monitoring_objective_config :
754+ empty_batch_prediction_job .api_client = (
755+ empty_batch_prediction_job .api_client .select_version ("v1beta1" )
756+ )
757757
758+ # TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
759+ if model_monitoring_objective_config :
760+ model_monitoring_objective_config ._config_for_bp = True
761+ if model_monitoring_alert_config is not None :
762+ model_monitoring_alert_config ._config_for_bp = True
763+ gapic_mm_config = gca_model_monitoring_v1beta1 .ModelMonitoringConfig (
764+ objective_configs = [model_monitoring_objective_config .as_proto ()],
765+ alert_config = model_monitoring_alert_config .as_proto ()
766+ if model_monitoring_alert_config is not None
767+ else None ,
768+ analysis_instance_schema_uri = analysis_instance_schema_uri
769+ if analysis_instance_schema_uri is not None
770+ else None ,
771+ )
772+ gapic_batch_prediction_job .model_monitoring_config = gapic_mm_config
773+
774+ # TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
758775 return cls ._create (
759776 empty_batch_prediction_job = empty_batch_prediction_job ,
760777 model_or_model_name = model_name ,
@@ -763,11 +780,6 @@ def create(
763780 sync = sync ,
764781 create_request_timeout = create_request_timeout ,
765782 )
766- # TODO(b/242108750): remove temporary re-import statements once model monitoring for batch prediction is GA
767- from google .cloud .aiplatform .compat .types import (
768- io as gca_io_compat ,
769- batch_prediction_job as gca_bp_job_compat ,
770- )
771783
772784 @classmethod
773785 @base .optional_sync (return_input_arg = "empty_batch_prediction_job" )
0 commit comments