From b8d2e1d439e19c615dbfb829185d527d32fae0d1 Mon Sep 17 00:00:00 2001 From: Eric Schmidt Date: Wed, 12 May 2021 13:34:28 -0700 Subject: [PATCH 1/5] samples: adds batch prediction, training job for text SDK use cases --- samples/model-builder/conftest.py | 13 ++++ .../create_batch_prediction_job_sample.py | 9 ++- ...ing_pipeline_text_classification_sample.py | 59 +++++++++++++++++++ ...ipeline_text_classification_sample_test.py | 58 ++++++++++++++++++ ..._pipeline_text_entity_extraction_sample.py | 59 +++++++++++++++++++ ...line_text_entity_extraction_sample_test.py | 58 ++++++++++++++++++ ...pipeline_text_sentiment_analysis_sample.py | 59 +++++++++++++++++++ ...ine_text_sentiment_analysis_sample_test.py | 58 ++++++++++++++++++ 8 files changed, 371 insertions(+), 2 deletions(-) create mode 100644 samples/model-builder/create_training_pipeline_text_classification_sample.py create mode 100644 samples/model-builder/create_training_pipeline_text_classification_sample_test.py create mode 100644 samples/model-builder/create_training_pipeline_text_entity_extraction_sample.py create mode 100644 samples/model-builder/create_training_pipeline_text_entity_extraction_sample_test.py create mode 100644 samples/model-builder/create_training_pipeline_text_sentiment_analysis_sample.py create mode 100644 samples/model-builder/create_training_pipeline_text_sentiment_analysis_sample_test.py diff --git a/samples/model-builder/conftest.py b/samples/model-builder/conftest.py index 01756f668b..c7ac736f84 100644 --- a/samples/model-builder/conftest.py +++ b/samples/model-builder/conftest.py @@ -209,6 +209,19 @@ def mock_run_automl_image_training_job(mock_image_training_job): yield mock +@pytest.fixture +def mock_get_automl_text_training_job(mock_text_training_job): + with patch.object(aiplatform, "AutoMLTextTrainingJob") as mock: + mock.return_value = mock_text_training_job + yield mock + + +@pytest.fixture +def mock_run_automl_text_training_job(mock_text_training_job): + with patch.object(mock_text_training_job, "run") as mock: + yield mock + + @pytest.fixture def mock_get_custom_training_job(mock_custom_training_job): with patch.object(aiplatform, "CustomTrainingJob") as mock: diff --git a/samples/model-builder/create_batch_prediction_job_sample.py b/samples/model-builder/create_batch_prediction_job_sample.py index 9bd5c697a5..9dd9ef016d 100644 --- a/samples/model-builder/create_batch_prediction_job_sample.py +++ b/samples/model-builder/create_batch_prediction_job_sample.py @@ -16,7 +16,9 @@ from google.cloud import aiplatform - +# [START aiplatform_sdk_create_batch_prediction_job_text_classification_sample] +# [START aiplatform_sdk_create_batch_prediction_job_text_entity_extraction_sample] +# [START aiplatform_sdk_create_batch_prediction_job_text_sentiment_analysis_sample] # [START aiplatform_sdk_create_batch_prediction_job_sample] def create_batch_prediction_job_sample( project: str, @@ -46,4 +48,7 @@ def create_batch_prediction_job_sample( return batch_prediction_job -# [END aiplatform_sdk_create_batch_prediction_job_sample] +# [END aiplatform_sdk_create_batch_prediction_job_text_sentiment_analysis_sample] +# [END aiplatform_sdk_create_batch_prediction_job_text_entity_extraction_sample] +# [END aiplatform_sdk_create_batch_prediction_job_text_classification_sample] +# [END aiplatform_sdk_create_batch_prediction_job_sample] \ No newline at end of file diff --git a/samples/model-builder/create_training_pipeline_text_classification_sample.py b/samples/model-builder/create_training_pipeline_text_classification_sample.py new file mode 100644 index 0000000000..5834db4727 --- /dev/null +++ b/samples/model-builder/create_training_pipeline_text_classification_sample.py @@ -0,0 +1,59 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_create_training_pipeline_text_classification_sample] +def create_training_pipeline_text_classification_sample( + project: str, + location: str, + display_name: str, + dataset_id: int, + model_display_name: Optional[str] = None, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + budget_milli_node_hours: int = 8000, + disable_early_stopping: bool = False, + sync: bool = True, +): + aiplatform.init(project=project, location=location) + + job = aiplatform.AutoMLTextTrainingJob(display_name=display_name) + + text_dataset = aiplatform.TextDataset(dataset_id) + + model = job.run( + dataset=text_dataset, + model_display_name=model_display_name, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + budget_milli_node_hours=budget_milli_node_hours, + disable_early_stopping=disable_early_stopping, + sync=sync, + ) + + model.wait() + + print(model.display_name) + print(model.resource_name) + print(model.uri) + return model + + +# [END aiplatform_sdk_create_training_pipeline_text_classification_sample] diff --git a/samples/model-builder/create_training_pipeline_text_classification_sample_test.py b/samples/model-builder/create_training_pipeline_text_classification_sample_test.py new file mode 100644 index 0000000000..cb50cbef19 --- /dev/null +++ b/samples/model-builder/create_training_pipeline_text_classification_sample_test.py @@ -0,0 +1,58 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import create_training_pipeline_text_classification_sample +import test_constants as constants + + +def test_create_training_pipeline_text_classification_sample( + mock_sdk_init, + mock_text_dataset, + mock_get_automl_text_training_job, + mock_run_automl_text_training_job, + mock_get_text_dataset, +): + + create_training_pipeline_text_classification_sample.create_training_pipeline_text_classification_sample( + project=constants.PROJECT, + location=constants.LOCATION, + display_name=constants.DISPLAY_NAME, + dataset_id=constants.RESOURCE_ID, + model_display_name=constants.DISPLAY_NAME_2, + training_fraction_split=constants.TRAINING_FRACTION_SPLIT, + validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT, + test_fraction_split=constants.TEST_FRACTION_SPLIT, + budget_milli_node_hours=constants.BUDGET_MILLI_NODE_HOURS_8000, + disable_early_stopping=False, + ) + + mock_get_text_dataset.assert_called_once_with(constants.RESOURCE_ID) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + mock_get_automl_text_training_job.assert_called_once_with( + display_name=constants.DISPLAY_NAME + ) + mock_run_automl_text_training_job.assert_called_once_with( + dataset=mock_text_dataset, + model_display_name=constants.DISPLAY_NAME_2, + training_fraction_split=constants.TRAINING_FRACTION_SPLIT, + validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT, + test_fraction_split=constants.TEST_FRACTION_SPLIT, + budget_milli_node_hours=constants.BUDGET_MILLI_NODE_HOURS_8000, + disable_early_stopping=False, + sync=True, + ) diff --git a/samples/model-builder/create_training_pipeline_text_entity_extraction_sample.py b/samples/model-builder/create_training_pipeline_text_entity_extraction_sample.py new file mode 100644 index 0000000000..8f1c78a948 --- /dev/null +++ b/samples/model-builder/create_training_pipeline_text_entity_extraction_sample.py @@ -0,0 +1,59 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_create_training_pipeline_text_entity_extraction_sample] +def create_training_pipeline_text_entity_extraction_sample( + project: str, + location: str, + display_name: str, + dataset_id: int, + model_display_name: Optional[str] = None, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + budget_milli_node_hours: int = 8000, + disable_early_stopping: bool = False, + sync: bool = True, +): + aiplatform.init(project=project, location=location) + + job = aiplatform.AutoMLTextTrainingJob(display_name=display_name) + + text_dataset = aiplatform.TextDataset(dataset_id) + + model = job.run( + dataset=text_dataset, + model_display_name=model_display_name, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + budget_milli_node_hours=budget_milli_node_hours, + disable_early_stopping=disable_early_stopping, + sync=sync, + ) + + model.wait() + + print(model.display_name) + print(model.resource_name) + print(model.uri) + return model + + +# [END aiplatform_sdk_create_training_pipeline_text_entity_extraction_sample] diff --git a/samples/model-builder/create_training_pipeline_text_entity_extraction_sample_test.py b/samples/model-builder/create_training_pipeline_text_entity_extraction_sample_test.py new file mode 100644 index 0000000000..b8dd8f75bc --- /dev/null +++ b/samples/model-builder/create_training_pipeline_text_entity_extraction_sample_test.py @@ -0,0 +1,58 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import create_training_pipeline_text_entity_extraction_sample +import test_constants as constants + + +def test_create_training_pipeline_text_clentity_extraction_sample( + mock_sdk_init, + mock_text_dataset, + mock_get_automl_text_training_job, + mock_run_automl_text_training_job, + mock_get_text_dataset, +): + + create_training_pipeline_text_entity_extraction_sample.create_training_pipeline_text_entity_extraction_sample( + project=constants.PROJECT, + location=constants.LOCATION, + display_name=constants.DISPLAY_NAME, + dataset_id=constants.RESOURCE_ID, + model_display_name=constants.DISPLAY_NAME_2, + training_fraction_split=constants.TRAINING_FRACTION_SPLIT, + validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT, + test_fraction_split=constants.TEST_FRACTION_SPLIT, + budget_milli_node_hours=constants.BUDGET_MILLI_NODE_HOURS_8000, + disable_early_stopping=False, + ) + + mock_get_text_dataset.assert_called_once_with(constants.RESOURCE_ID) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + mock_get_automl_text_training_job.assert_called_once_with( + display_name=constants.DISPLAY_NAME + ) + mock_run_automl_text_training_job.assert_called_once_with( + dataset=mock_text_dataset, + model_display_name=constants.DISPLAY_NAME_2, + training_fraction_split=constants.TRAINING_FRACTION_SPLIT, + validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT, + test_fraction_split=constants.TEST_FRACTION_SPLIT, + budget_milli_node_hours=constants.BUDGET_MILLI_NODE_HOURS_8000, + disable_early_stopping=False, + sync=True, + ) diff --git a/samples/model-builder/create_training_pipeline_text_sentiment_analysis_sample.py b/samples/model-builder/create_training_pipeline_text_sentiment_analysis_sample.py new file mode 100644 index 0000000000..7f41cf680c --- /dev/null +++ b/samples/model-builder/create_training_pipeline_text_sentiment_analysis_sample.py @@ -0,0 +1,59 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_create_training_pipeline_text_sentiment_analysis_sample] +def create_training_pipeline_text_sentiment_analysis_sample( + project: str, + location: str, + display_name: str, + dataset_id: int, + model_display_name: Optional[str] = None, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + budget_milli_node_hours: int = 8000, + disable_early_stopping: bool = False, + sync: bool = True, +): + aiplatform.init(project=project, location=location) + + job = aiplatform.AutoMLTextTrainingJob(display_name=display_name) + + text_dataset = aiplatform.TextDataset(dataset_id) + + model = job.run( + dataset=text_dataset, + model_display_name=model_display_name, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + budget_milli_node_hours=budget_milli_node_hours, + disable_early_stopping=disable_early_stopping, + sync=sync, + ) + + model.wait() + + print(model.display_name) + print(model.resource_name) + print(model.uri) + return model + + +# [END aiplatform_sdk_create_training_pipeline_text_sentiment_analysis_sample] diff --git a/samples/model-builder/create_training_pipeline_text_sentiment_analysis_sample_test.py b/samples/model-builder/create_training_pipeline_text_sentiment_analysis_sample_test.py new file mode 100644 index 0000000000..2081e589d7 --- /dev/null +++ b/samples/model-builder/create_training_pipeline_text_sentiment_analysis_sample_test.py @@ -0,0 +1,58 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import create_training_pipeline_text_sentiment_analysis_sample +import test_constants as constants + + +def test_create_training_pipeline_text_clentity_extraction_sample( + mock_sdk_init, + mock_text_dataset, + mock_get_automl_text_training_job, + mock_run_automl_text_training_job, + mock_get_text_dataset, +): + + create_training_pipeline_text_sentiment_analysis_sample.create_training_pipeline_text_sentiment_analysis_sample( + project=constants.PROJECT, + location=constants.LOCATION, + display_name=constants.DISPLAY_NAME, + dataset_id=constants.RESOURCE_ID, + model_display_name=constants.DISPLAY_NAME_2, + training_fraction_split=constants.TRAINING_FRACTION_SPLIT, + validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT, + test_fraction_split=constants.TEST_FRACTION_SPLIT, + budget_milli_node_hours=constants.BUDGET_MILLI_NODE_HOURS_8000, + disable_early_stopping=False, + ) + + mock_get_text_dataset.assert_called_once_with(constants.RESOURCE_ID) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + mock_get_automl_text_training_job.assert_called_once_with( + display_name=constants.DISPLAY_NAME + ) + mock_run_automl_text_training_job.assert_called_once_with( + dataset=mock_text_dataset, + model_display_name=constants.DISPLAY_NAME_2, + training_fraction_split=constants.TRAINING_FRACTION_SPLIT, + validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT, + test_fraction_split=constants.TEST_FRACTION_SPLIT, + budget_milli_node_hours=constants.BUDGET_MILLI_NODE_HOURS_8000, + disable_early_stopping=False, + sync=True, + ) From b9369742d28801eaa63c6f5499b2b8f527c08e9f Mon Sep 17 00:00:00 2001 From: Eric Schmidt Date: Wed, 12 May 2021 13:37:22 -0700 Subject: [PATCH 2/5] fix: lint --- samples/model-builder/create_batch_prediction_job_sample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/samples/model-builder/create_batch_prediction_job_sample.py b/samples/model-builder/create_batch_prediction_job_sample.py index 9dd9ef016d..08530df5e0 100644 --- a/samples/model-builder/create_batch_prediction_job_sample.py +++ b/samples/model-builder/create_batch_prediction_job_sample.py @@ -51,4 +51,4 @@ def create_batch_prediction_job_sample( # [END aiplatform_sdk_create_batch_prediction_job_text_sentiment_analysis_sample] # [END aiplatform_sdk_create_batch_prediction_job_text_entity_extraction_sample] # [END aiplatform_sdk_create_batch_prediction_job_text_classification_sample] -# [END aiplatform_sdk_create_batch_prediction_job_sample] \ No newline at end of file +# [END aiplatform_sdk_create_batch_prediction_job_sample] From 68cdf2d218364d1f97f8094b66f4fe6163b3f2e9 Mon Sep 17 00:00:00 2001 From: Eric Schmidt Date: Wed, 12 May 2021 14:43:31 -0700 Subject: [PATCH 3/5] fix: lint --- samples/model-builder/create_batch_prediction_job_sample.py | 1 + 1 file changed, 1 insertion(+) diff --git a/samples/model-builder/create_batch_prediction_job_sample.py b/samples/model-builder/create_batch_prediction_job_sample.py index 08530df5e0..cb5a5d3ad8 100644 --- a/samples/model-builder/create_batch_prediction_job_sample.py +++ b/samples/model-builder/create_batch_prediction_job_sample.py @@ -16,6 +16,7 @@ from google.cloud import aiplatform + # [START aiplatform_sdk_create_batch_prediction_job_text_classification_sample] # [START aiplatform_sdk_create_batch_prediction_job_text_entity_extraction_sample] # [START aiplatform_sdk_create_batch_prediction_job_text_sentiment_analysis_sample] From a644b60e644beb96f5b19ef37bdcee85af38a387 Mon Sep 17 00:00:00 2001 From: Eric Schmidt Date: Thu, 13 May 2021 08:43:12 -0700 Subject: [PATCH 4/5] fix: per reviewer --- .../create_training_pipeline_text_classification_sample.py | 7 ++++++- ...eate_training_pipeline_text_entity_extraction_sample.py | 4 +++- ...ate_training_pipeline_text_sentiment_analysis_sample.py | 7 ++++++- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/samples/model-builder/create_training_pipeline_text_classification_sample.py b/samples/model-builder/create_training_pipeline_text_classification_sample.py index 5834db4727..9306a82084 100644 --- a/samples/model-builder/create_training_pipeline_text_classification_sample.py +++ b/samples/model-builder/create_training_pipeline_text_classification_sample.py @@ -24,6 +24,7 @@ def create_training_pipeline_text_classification_sample( display_name: str, dataset_id: int, model_display_name: Optional[str] = None, + multi_label: bool = False, training_fraction_split: float = 0.8, validation_fraction_split: float = 0.1, test_fraction_split: float = 0.1, @@ -33,7 +34,11 @@ def create_training_pipeline_text_classification_sample( ): aiplatform.init(project=project, location=location) - job = aiplatform.AutoMLTextTrainingJob(display_name=display_name) + job = aiplatform.AutoMLTextTrainingJob( + display_name=display_name, + prediction_type="classification", + multi_label=multi_label, + ) text_dataset = aiplatform.TextDataset(dataset_id) diff --git a/samples/model-builder/create_training_pipeline_text_entity_extraction_sample.py b/samples/model-builder/create_training_pipeline_text_entity_extraction_sample.py index 8f1c78a948..2d53cb2d63 100644 --- a/samples/model-builder/create_training_pipeline_text_entity_extraction_sample.py +++ b/samples/model-builder/create_training_pipeline_text_entity_extraction_sample.py @@ -33,7 +33,9 @@ def create_training_pipeline_text_entity_extraction_sample( ): aiplatform.init(project=project, location=location) - job = aiplatform.AutoMLTextTrainingJob(display_name=display_name) + job = aiplatform.AutoMLTextTrainingJob( + display_name=display_name, prediction_type="extraction" + ) text_dataset = aiplatform.TextDataset(dataset_id) diff --git a/samples/model-builder/create_training_pipeline_text_sentiment_analysis_sample.py b/samples/model-builder/create_training_pipeline_text_sentiment_analysis_sample.py index 7f41cf680c..685bed6feb 100644 --- a/samples/model-builder/create_training_pipeline_text_sentiment_analysis_sample.py +++ b/samples/model-builder/create_training_pipeline_text_sentiment_analysis_sample.py @@ -24,6 +24,7 @@ def create_training_pipeline_text_sentiment_analysis_sample( display_name: str, dataset_id: int, model_display_name: Optional[str] = None, + sentiment_max: int = 10, training_fraction_split: float = 0.8, validation_fraction_split: float = 0.1, test_fraction_split: float = 0.1, @@ -33,7 +34,11 @@ def create_training_pipeline_text_sentiment_analysis_sample( ): aiplatform.init(project=project, location=location) - job = aiplatform.AutoMLTextTrainingJob(display_name=display_name) + job = aiplatform.AutoMLTextTrainingJob( + display_name=display_name, + prediction_type="sentiment", + sentiment_max=sentiment_max, + ) text_dataset = aiplatform.TextDataset(dataset_id) From 49a80f208d131700f20fbc7d380457c8e3a4d890 Mon Sep 17 00:00:00 2001 From: Eric Schmidt Date: Fri, 14 May 2021 14:25:03 -0700 Subject: [PATCH 5/5] fix: tests --- ...ate_training_pipeline_text_classification_sample_test.py | 4 +++- ..._training_pipeline_text_entity_extraction_sample_test.py | 2 +- ...training_pipeline_text_sentiment_analysis_sample_test.py | 6 ++++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/samples/model-builder/create_training_pipeline_text_classification_sample_test.py b/samples/model-builder/create_training_pipeline_text_classification_sample_test.py index cb50cbef19..6f54218e45 100644 --- a/samples/model-builder/create_training_pipeline_text_classification_sample_test.py +++ b/samples/model-builder/create_training_pipeline_text_classification_sample_test.py @@ -44,7 +44,9 @@ def test_create_training_pipeline_text_classification_sample( project=constants.PROJECT, location=constants.LOCATION ) mock_get_automl_text_training_job.assert_called_once_with( - display_name=constants.DISPLAY_NAME + display_name=constants.DISPLAY_NAME, + multi_label=False, + prediction_type="classification", ) mock_run_automl_text_training_job.assert_called_once_with( dataset=mock_text_dataset, diff --git a/samples/model-builder/create_training_pipeline_text_entity_extraction_sample_test.py b/samples/model-builder/create_training_pipeline_text_entity_extraction_sample_test.py index b8dd8f75bc..215b123942 100644 --- a/samples/model-builder/create_training_pipeline_text_entity_extraction_sample_test.py +++ b/samples/model-builder/create_training_pipeline_text_entity_extraction_sample_test.py @@ -44,7 +44,7 @@ def test_create_training_pipeline_text_clentity_extraction_sample( project=constants.PROJECT, location=constants.LOCATION ) mock_get_automl_text_training_job.assert_called_once_with( - display_name=constants.DISPLAY_NAME + display_name=constants.DISPLAY_NAME, prediction_type="extraction" ) mock_run_automl_text_training_job.assert_called_once_with( dataset=mock_text_dataset, diff --git a/samples/model-builder/create_training_pipeline_text_sentiment_analysis_sample_test.py b/samples/model-builder/create_training_pipeline_text_sentiment_analysis_sample_test.py index 2081e589d7..6ae5f414bd 100644 --- a/samples/model-builder/create_training_pipeline_text_sentiment_analysis_sample_test.py +++ b/samples/model-builder/create_training_pipeline_text_sentiment_analysis_sample_test.py @@ -17,7 +17,7 @@ import test_constants as constants -def test_create_training_pipeline_text_clentity_extraction_sample( +def test_create_training_pipeline_text_sentiment_analysis_sample( mock_sdk_init, mock_text_dataset, mock_get_automl_text_training_job, @@ -44,7 +44,9 @@ def test_create_training_pipeline_text_clentity_extraction_sample( project=constants.PROJECT, location=constants.LOCATION ) mock_get_automl_text_training_job.assert_called_once_with( - display_name=constants.DISPLAY_NAME + display_name=constants.DISPLAY_NAME, + prediction_type="sentiment", + sentiment_max=10, ) mock_run_automl_text_training_job.assert_called_once_with( dataset=mock_text_dataset,