|
22 | 22 | from unittest import mock |
23 | 23 | from importlib import reload |
24 | 24 | from unittest.mock import patch |
| 25 | +from urllib import request |
25 | 26 | from datetime import datetime |
26 | 27 |
|
27 | 28 | from google.auth import credentials as auth_credentials |
|
50 | 51 | _TEST_SERVICE_ACCOUNT = "abcde@my-project.iam.gserviceaccount.com" |
51 | 52 |
|
52 | 53 | _TEST_TEMPLATE_PATH = f"gs://{_TEST_GCS_BUCKET_NAME}/job_spec.json" |
| 54 | +_TEST_AR_TEMPLATE_PATH = "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest" |
53 | 55 | _TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" |
54 | 56 | _TEST_NETWORK = f"projects/{_TEST_PROJECT}/global/networks/{_TEST_PIPELINE_JOB_ID}" |
55 | 57 |
|
@@ -289,6 +291,17 @@ def mock_load_yaml_and_json(job_spec): |
289 | 291 | yield mock_load_yaml_and_json |
290 | 292 |
|
291 | 293 |
|
| 294 | +@pytest.fixture |
| 295 | +def mock_request_urlopen(job_spec): |
| 296 | + with patch.object(request, "urlopen") as mock_urlopen: |
| 297 | + mock_read_response = mock.MagicMock() |
| 298 | + mock_decode_response = mock.MagicMock() |
| 299 | + mock_decode_response.return_value = job_spec.encode() |
| 300 | + mock_read_response.return_value.decode = mock_decode_response |
| 301 | + mock_urlopen.return_value.read = mock_read_response |
| 302 | + yield mock_urlopen |
| 303 | + |
| 304 | + |
292 | 305 | @pytest.mark.usefixtures("google_auth_mock") |
293 | 306 | class TestPipelineJob: |
294 | 307 | def setup_method(self): |
@@ -376,6 +389,85 @@ def test_run_call_pipeline_service_create( |
376 | 389 | gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED |
377 | 390 | ) |
378 | 391 |
|
| 392 | + @pytest.mark.parametrize( |
| 393 | + "job_spec", |
| 394 | + [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB], |
| 395 | + ) |
| 396 | + @pytest.mark.parametrize("sync", [True, False]) |
| 397 | + def test_run_call_pipeline_service_create_artifact_registry( |
| 398 | + self, |
| 399 | + mock_pipeline_service_create, |
| 400 | + mock_pipeline_service_get, |
| 401 | + mock_request_urlopen, |
| 402 | + job_spec, |
| 403 | + mock_load_yaml_and_json, |
| 404 | + sync, |
| 405 | + ): |
| 406 | + aiplatform.init( |
| 407 | + project=_TEST_PROJECT, |
| 408 | + staging_bucket=_TEST_GCS_BUCKET_NAME, |
| 409 | + location=_TEST_LOCATION, |
| 410 | + credentials=_TEST_CREDENTIALS, |
| 411 | + ) |
| 412 | + |
| 413 | + job = pipeline_jobs.PipelineJob( |
| 414 | + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, |
| 415 | + template_path=_TEST_AR_TEMPLATE_PATH, |
| 416 | + job_id=_TEST_PIPELINE_JOB_ID, |
| 417 | + parameter_values=_TEST_PIPELINE_PARAMETER_VALUES, |
| 418 | + enable_caching=True, |
| 419 | + ) |
| 420 | + |
| 421 | + job.run( |
| 422 | + service_account=_TEST_SERVICE_ACCOUNT, |
| 423 | + network=_TEST_NETWORK, |
| 424 | + sync=sync, |
| 425 | + create_request_timeout=None, |
| 426 | + ) |
| 427 | + |
| 428 | + if not sync: |
| 429 | + job.wait() |
| 430 | + |
| 431 | + expected_runtime_config_dict = { |
| 432 | + "gcsOutputDirectory": _TEST_GCS_BUCKET_NAME, |
| 433 | + "parameterValues": _TEST_PIPELINE_PARAMETER_VALUES, |
| 434 | + } |
| 435 | + runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb |
| 436 | + json_format.ParseDict(expected_runtime_config_dict, runtime_config) |
| 437 | + |
| 438 | + job_spec = yaml.safe_load(job_spec) |
| 439 | + pipeline_spec = job_spec.get("pipelineSpec") or job_spec |
| 440 | + |
| 441 | + # Construct expected request |
| 442 | + expected_gapic_pipeline_job = gca_pipeline_job.PipelineJob( |
| 443 | + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, |
| 444 | + pipeline_spec={ |
| 445 | + "components": {}, |
| 446 | + "pipelineInfo": pipeline_spec["pipelineInfo"], |
| 447 | + "root": pipeline_spec["root"], |
| 448 | + "schemaVersion": "2.1.0", |
| 449 | + }, |
| 450 | + runtime_config=runtime_config, |
| 451 | + service_account=_TEST_SERVICE_ACCOUNT, |
| 452 | + network=_TEST_NETWORK, |
| 453 | + template_uri=_TEST_AR_TEMPLATE_PATH, |
| 454 | + ) |
| 455 | + |
| 456 | + mock_pipeline_service_create.assert_called_once_with( |
| 457 | + parent=_TEST_PARENT, |
| 458 | + pipeline_job=expected_gapic_pipeline_job, |
| 459 | + pipeline_job_id=_TEST_PIPELINE_JOB_ID, |
| 460 | + timeout=None, |
| 461 | + ) |
| 462 | + |
| 463 | + mock_pipeline_service_get.assert_called_with( |
| 464 | + name=_TEST_PIPELINE_JOB_NAME, retry=base._DEFAULT_RETRY |
| 465 | + ) |
| 466 | + |
| 467 | + assert job._gca_resource == make_pipeline_job( |
| 468 | + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED |
| 469 | + ) |
| 470 | + |
379 | 471 | @pytest.mark.parametrize( |
380 | 472 | "job_spec", |
381 | 473 | [ |
|
0 commit comments