8585 struct_pb2 .Value (),
8686)
8787
88+ _TEST_TRAINING_TASK_INPUTS_WITH_UPTRAIN_BASE_MODEL = json_format .ParseDict (
89+ {
90+ "modelType" : "CLOUD" ,
91+ "budgetMilliNodeHours" : _TEST_TRAINING_BUDGET_MILLI_NODE_HOURS ,
92+ "multiLabel" : False ,
93+ "disableEarlyStopping" : _TEST_TRAINING_DISABLE_EARLY_STOPPING ,
94+ "uptrainBaseModelId" : _TEST_MODEL_ID ,
95+ },
96+ struct_pb2 .Value (),
97+ )
98+
8899_TEST_FRACTION_SPLIT_TRAINING = 0.6
89100_TEST_FRACTION_SPLIT_VALIDATION = 0.2
90101_TEST_FRACTION_SPLIT_TEST = 0.2
@@ -213,6 +224,20 @@ def mock_model():
213224 yield model
214225
215226
227+ @pytest .fixture
228+ def mock_uptrain_base_model ():
229+ model = mock .MagicMock (models .Model )
230+ model .name = _TEST_MODEL_ID
231+ model ._latest_future = None
232+ model ._exception = None
233+ model ._gca_resource = gca_model .Model (
234+ display_name = _TEST_MODEL_DISPLAY_NAME ,
235+ description = "This is the mock uptrain base Model's description" ,
236+ name = _TEST_MODEL_NAME ,
237+ )
238+ yield model
239+
240+
216241@pytest .mark .usefixtures ("google_auth_mock" )
217242class TestAutoMLImageTrainingJob :
218243 def setup_method (self ):
@@ -223,7 +248,7 @@ def teardown_method(self):
223248 initializer .global_pool .shutdown (wait = True )
224249
225250 def test_init_all_parameters (self , mock_model ):
226- """Ensure all private members are set correctly at initialization"""
251+ """Ensure all private members are set correctly at initialization. """
227252
228253 aiplatform .init (project = _TEST_PROJECT )
229254
@@ -275,7 +300,7 @@ def test_run_call_pipeline_service_create(
275300 mock_pipeline_service_get ,
276301 mock_dataset_image ,
277302 mock_model_service_get ,
278- mock_model ,
303+ mock_uptrain_base_model ,
279304 sync ,
280305 ):
281306 """Create and run an AutoML ICN training job, verify calls and return value"""
@@ -287,7 +312,7 @@ def test_run_call_pipeline_service_create(
287312
288313 job = training_jobs .AutoMLImageTrainingJob (
289314 display_name = _TEST_DISPLAY_NAME ,
290- base_model = mock_model ,
315+ incremental_train_base_model = mock_uptrain_base_model ,
291316 labels = _TEST_LABELS ,
292317 )
293318
@@ -315,8 +340,7 @@ def test_run_call_pipeline_service_create(
315340
316341 true_managed_model = gca_model .Model (
317342 display_name = _TEST_MODEL_DISPLAY_NAME ,
318- labels = mock_model ._gca_resource .labels ,
319- description = mock_model ._gca_resource .description ,
343+ labels = _TEST_MODEL_LABELS ,
320344 encryption_spec = _TEST_DEFAULT_ENCRYPTION_SPEC ,
321345 version_aliases = ["default" ],
322346 )
@@ -330,7 +354,7 @@ def test_run_call_pipeline_service_create(
330354 display_name = _TEST_DISPLAY_NAME ,
331355 labels = _TEST_LABELS ,
332356 training_task_definition = schema .training_job .definition .automl_image_classification ,
333- training_task_inputs = _TEST_TRAINING_TASK_INPUTS_WITH_BASE_MODEL ,
357+ training_task_inputs = _TEST_TRAINING_TASK_INPUTS_WITH_UPTRAIN_BASE_MODEL ,
334358 model_to_upload = true_managed_model ,
335359 input_data_config = true_input_data_config ,
336360 encryption_spec = _TEST_DEFAULT_ENCRYPTION_SPEC ,
@@ -754,7 +778,7 @@ def test_splits_default(
754778 mock_pipeline_service_get ,
755779 mock_dataset_image ,
756780 mock_model_service_get ,
757- mock_model ,
781+ mock_uptrain_base_model ,
758782 sync ,
759783 ):
760784 """
@@ -768,7 +792,8 @@ def test_splits_default(
768792 )
769793
770794 job = training_jobs .AutoMLImageTrainingJob (
771- display_name = _TEST_DISPLAY_NAME , base_model = mock_model
795+ display_name = _TEST_DISPLAY_NAME ,
796+ incremental_train_base_model = mock_uptrain_base_model ,
772797 )
773798
774799 model_from_job = job .run (
@@ -785,7 +810,6 @@ def test_splits_default(
785810
786811 true_managed_model = gca_model .Model (
787812 display_name = _TEST_MODEL_DISPLAY_NAME ,
788- description = mock_model ._gca_resource .description ,
789813 encryption_spec = _TEST_DEFAULT_ENCRYPTION_SPEC ,
790814 version_aliases = ["default" ],
791815 )
@@ -797,7 +821,7 @@ def test_splits_default(
797821 true_training_pipeline = gca_training_pipeline .TrainingPipeline (
798822 display_name = _TEST_DISPLAY_NAME ,
799823 training_task_definition = schema .training_job .definition .automl_image_classification ,
800- training_task_inputs = _TEST_TRAINING_TASK_INPUTS_WITH_BASE_MODEL ,
824+ training_task_inputs = _TEST_TRAINING_TASK_INPUTS_WITH_UPTRAIN_BASE_MODEL ,
801825 model_to_upload = true_managed_model ,
802826 input_data_config = true_input_data_config ,
803827 encryption_spec = _TEST_DEFAULT_ENCRYPTION_SPEC ,
0 commit comments