7676_TEST_BQ_JOB_ID = "123459876"
7777_TEST_BQ_MAX_RESULTS = 100
7878_TEST_GCS_BUCKET_NAME = "my-bucket"
79+ _TEST_SERVICE_ACCOUNT = "vinnys@my-project.iam.gserviceaccount.com"
80+
7981
8082_TEST_BQ_PATH = f"bq://{ _TEST_BQ_PROJECT_ID } .{ _TEST_BQ_DATASET_ID } "
8183_TEST_GCS_BUCKET_PATH = f"gs://{ _TEST_GCS_BUCKET_NAME } "
@@ -719,6 +721,7 @@ def test_batch_predict_gcs_source_and_dest(
719721 gcs_destination_prefix = _TEST_BATCH_PREDICTION_GCS_DEST_PREFIX ,
720722 sync = sync ,
721723 create_request_timeout = None ,
724+ service_account = _TEST_SERVICE_ACCOUNT ,
722725 )
723726
724727 batch_prediction_job .wait_for_resource_creation ()
@@ -741,6 +744,7 @@ def test_batch_predict_gcs_source_and_dest(
741744 ),
742745 predictions_format = "jsonl" ,
743746 ),
747+ service_account = _TEST_SERVICE_ACCOUNT ,
744748 )
745749
746750 create_batch_prediction_job_mock .assert_called_once_with (
@@ -766,6 +770,7 @@ def test_batch_predict_gcs_source_and_dest_with_timeout(
766770 gcs_destination_prefix = _TEST_BATCH_PREDICTION_GCS_DEST_PREFIX ,
767771 sync = sync ,
768772 create_request_timeout = 180.0 ,
773+ service_account = _TEST_SERVICE_ACCOUNT ,
769774 )
770775
771776 batch_prediction_job .wait_for_resource_creation ()
@@ -788,6 +793,7 @@ def test_batch_predict_gcs_source_and_dest_with_timeout(
788793 ),
789794 predictions_format = "jsonl" ,
790795 ),
796+ service_account = _TEST_SERVICE_ACCOUNT ,
791797 )
792798
793799 create_batch_prediction_job_mock .assert_called_once_with (
@@ -812,6 +818,7 @@ def test_batch_predict_gcs_source_and_dest_with_timeout_not_explicitly_set(
812818 gcs_source = _TEST_BATCH_PREDICTION_GCS_SOURCE ,
813819 gcs_destination_prefix = _TEST_BATCH_PREDICTION_GCS_DEST_PREFIX ,
814820 sync = sync ,
821+ service_account = _TEST_SERVICE_ACCOUNT ,
815822 )
816823
817824 batch_prediction_job .wait_for_resource_creation ()
@@ -834,6 +841,7 @@ def test_batch_predict_gcs_source_and_dest_with_timeout_not_explicitly_set(
834841 ),
835842 predictions_format = "jsonl" ,
836843 ),
844+ service_account = _TEST_SERVICE_ACCOUNT ,
837845 )
838846
839847 create_batch_prediction_job_mock .assert_called_once_with (
@@ -855,6 +863,7 @@ def test_batch_predict_job_done_create(self, create_batch_prediction_job_mock):
855863 gcs_source = _TEST_BATCH_PREDICTION_GCS_SOURCE ,
856864 gcs_destination_prefix = _TEST_BATCH_PREDICTION_GCS_DEST_PREFIX ,
857865 sync = False ,
866+ service_account = _TEST_SERVICE_ACCOUNT ,
858867 )
859868
860869 batch_prediction_job .wait_for_resource_creation ()
@@ -881,6 +890,7 @@ def test_batch_predict_gcs_source_bq_dest(
881890 bigquery_destination_prefix = _TEST_BATCH_PREDICTION_BQ_PREFIX ,
882891 sync = sync ,
883892 create_request_timeout = None ,
893+ service_account = _TEST_SERVICE_ACCOUNT ,
884894 )
885895
886896 batch_prediction_job .wait_for_resource_creation ()
@@ -908,6 +918,7 @@ def test_batch_predict_gcs_source_bq_dest(
908918 ),
909919 predictions_format = "bigquery" ,
910920 ),
921+ service_account = _TEST_SERVICE_ACCOUNT ,
911922 )
912923
913924 create_batch_prediction_job_mock .assert_called_once_with (
@@ -946,6 +957,7 @@ def test_batch_predict_with_all_args(
946957 sync = sync ,
947958 create_request_timeout = None ,
948959 batch_size = _TEST_BATCH_SIZE ,
960+ service_account = _TEST_SERVICE_ACCOUNT ,
949961 )
950962
951963 batch_prediction_job .wait_for_resource_creation ()
@@ -986,6 +998,7 @@ def test_batch_predict_with_all_args(
986998 parameters = _TEST_EXPLANATION_PARAMETERS ,
987999 ),
9881000 labels = _TEST_LABEL ,
1001+ service_account = _TEST_SERVICE_ACCOUNT ,
9891002 )
9901003
9911004 create_batch_prediction_job_with_explanations_mock .assert_called_once_with (
@@ -1047,6 +1060,7 @@ def test_batch_predict_with_all_args_and_model_monitoring(
10471060 model_monitoring_objective_config = mm_obj_cfg ,
10481061 model_monitoring_alert_config = mm_alert_cfg ,
10491062 analysis_instance_schema_uri = "" ,
1063+ service_account = _TEST_SERVICE_ACCOUNT ,
10501064 )
10511065
10521066 batch_prediction_job .wait_for_resource_creation ()
@@ -1086,6 +1100,7 @@ def test_batch_predict_with_all_args_and_model_monitoring(
10861100 generate_explanation = True ,
10871101 model_monitoring_config = _TEST_MODEL_MONITORING_CFG ,
10881102 labels = _TEST_LABEL ,
1103+ service_account = _TEST_SERVICE_ACCOUNT ,
10891104 )
10901105 create_batch_prediction_job_v1beta1_mock .assert_called_once_with (
10911106 parent = f"projects/{ _TEST_PROJECT } /locations/{ _TEST_LOCATION } " ,
@@ -1103,6 +1118,7 @@ def test_batch_predict_create_fails(self):
11031118 gcs_source = _TEST_BATCH_PREDICTION_GCS_SOURCE ,
11041119 bigquery_destination_prefix = _TEST_BATCH_PREDICTION_BQ_PREFIX ,
11051120 sync = False ,
1121+ service_account = _TEST_SERVICE_ACCOUNT ,
11061122 )
11071123
11081124 with pytest .raises (RuntimeError ) as e :
@@ -1143,6 +1159,7 @@ def test_batch_predict_no_source(self, create_batch_prediction_job_mock):
11431159 model_name = _TEST_MODEL_NAME ,
11441160 job_display_name = _TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME ,
11451161 bigquery_destination_prefix = _TEST_BATCH_PREDICTION_BQ_PREFIX ,
1162+ service_account = _TEST_SERVICE_ACCOUNT ,
11461163 )
11471164
11481165 assert e .match (regexp = r"source" )
@@ -1159,6 +1176,7 @@ def test_batch_predict_two_sources(self, create_batch_prediction_job_mock):
11591176 gcs_source = _TEST_BATCH_PREDICTION_GCS_SOURCE ,
11601177 bigquery_source = _TEST_BATCH_PREDICTION_BQ_PREFIX ,
11611178 bigquery_destination_prefix = _TEST_BATCH_PREDICTION_BQ_PREFIX ,
1179+ service_account = _TEST_SERVICE_ACCOUNT ,
11621180 )
11631181
11641182 assert e .match (regexp = r"source" )
@@ -1173,6 +1191,7 @@ def test_batch_predict_no_destination(self):
11731191 model_name = _TEST_MODEL_NAME ,
11741192 job_display_name = _TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME ,
11751193 gcs_source = _TEST_BATCH_PREDICTION_GCS_SOURCE ,
1194+ service_account = _TEST_SERVICE_ACCOUNT ,
11761195 )
11771196
11781197 assert e .match (regexp = r"destination" )
@@ -1189,6 +1208,7 @@ def test_batch_predict_wrong_instance_format(self):
11891208 gcs_source = _TEST_BATCH_PREDICTION_GCS_SOURCE ,
11901209 instances_format = "wrong" ,
11911210 bigquery_destination_prefix = _TEST_BATCH_PREDICTION_BQ_PREFIX ,
1211+ service_account = _TEST_SERVICE_ACCOUNT ,
11921212 )
11931213
11941214 assert e .match (regexp = r"accepted instances format" )
@@ -1205,6 +1225,7 @@ def test_batch_predict_wrong_prediction_format(self):
12051225 gcs_source = _TEST_BATCH_PREDICTION_GCS_SOURCE ,
12061226 predictions_format = "wrong" ,
12071227 bigquery_destination_prefix = _TEST_BATCH_PREDICTION_BQ_PREFIX ,
1228+ service_account = _TEST_SERVICE_ACCOUNT ,
12081229 )
12091230
12101231 assert e .match (regexp = r"accepted prediction format" )
@@ -1222,6 +1243,7 @@ def test_batch_predict_job_with_versioned_model(
12221243 gcs_source = _TEST_BATCH_PREDICTION_GCS_SOURCE ,
12231244 gcs_destination_prefix = _TEST_BATCH_PREDICTION_GCS_DEST_PREFIX ,
12241245 sync = True ,
1246+ service_account = _TEST_SERVICE_ACCOUNT ,
12251247 )
12261248 assert (
12271249 create_batch_prediction_job_mock .call_args_list [0 ][1 ][
@@ -1237,6 +1259,7 @@ def test_batch_predict_job_with_versioned_model(
12371259 gcs_source = _TEST_BATCH_PREDICTION_GCS_SOURCE ,
12381260 gcs_destination_prefix = _TEST_BATCH_PREDICTION_GCS_DEST_PREFIX ,
12391261 sync = True ,
1262+ service_account = _TEST_SERVICE_ACCOUNT ,
12401263 )
12411264 assert (
12421265 create_batch_prediction_job_mock .call_args_list [0 ][1 ][
0 commit comments