Skip to content

Commit f2fc1d7

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
chore: adding submit() method to CustomJob, similiar to PipelineJob.submit()
PiperOrigin-RevId: 495466621
1 parent d0017f9 commit f2fc1d7

2 files changed

Lines changed: 115 additions & 5 deletions

File tree

google/cloud/aiplatform/jobs.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1641,6 +1641,68 @@ def _run(
16411641
create_request_timeout (float):
16421642
Optional. The timeout for the create request in seconds.
16431643
"""
1644+
self.submit(
1645+
service_account=service_account,
1646+
network=network,
1647+
timeout=timeout,
1648+
restart_job_on_worker_restart=restart_job_on_worker_restart,
1649+
enable_web_access=enable_web_access,
1650+
tensorboard=tensorboard,
1651+
create_request_timeout=create_request_timeout,
1652+
)
1653+
1654+
self._block_until_complete()
1655+
1656+
def submit(
1657+
self,
1658+
*,
1659+
service_account: Optional[str] = None,
1660+
network: Optional[str] = None,
1661+
timeout: Optional[int] = None,
1662+
restart_job_on_worker_restart: bool = False,
1663+
enable_web_access: bool = False,
1664+
tensorboard: Optional[str] = None,
1665+
create_request_timeout: Optional[float] = None,
1666+
) -> None:
1667+
"""Submit the configured CustomJob.
1668+
1669+
Args:
1670+
service_account (str):
1671+
Optional. Specifies the service account for workload run-as account.
1672+
Users submitting jobs must have act-as permission on this run-as account.
1673+
network (str):
1674+
Optional. The full name of the Compute Engine network to which the job
1675+
should be peered. For example, projects/12345/global/networks/myVPC.
1676+
Private services access must already be configured for the network.
1677+
timeout (int):
1678+
The maximum job running time in seconds. The default is 7 days.
1679+
restart_job_on_worker_restart (bool):
1680+
Restarts the entire CustomJob if a worker
1681+
gets restarted. This feature can be used by
1682+
distributed training jobs that are not resilient
1683+
to workers leaving and joining a job.
1684+
enable_web_access (bool):
1685+
Whether you want Vertex AI to enable interactive shell access
1686+
to training containers.
1687+
https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell
1688+
tensorboard (str):
1689+
Optional. The name of a Vertex AI
1690+
[Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard]
1691+
resource to which this CustomJob will upload Tensorboard
1692+
logs. Format:
1693+
``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
1694+
1695+
The training script should write Tensorboard to following Vertex AI environment
1696+
variable:
1697+
1698+
AIP_TENSORBOARD_LOG_DIR
1699+
1700+
`service_account` is required with provided `tensorboard`.
1701+
For more information on configuring your service account please visit:
1702+
https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
1703+
create_request_timeout (float):
1704+
Optional. The timeout for the create request in seconds.
1705+
"""
16441706
if service_account:
16451707
self._gca_resource.job_spec.service_account = service_account
16461708

@@ -1682,8 +1744,6 @@ def _run(
16821744
)
16831745
)
16841746

1685-
self._block_until_complete()
1686-
16871747
@property
16881748
def job_spec(self):
16891749
return self._gca_resource.job_spec

tests/unit/aiplatform/test_custom_job.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,19 @@
2626
from google.rpc import status_pb2
2727

2828
import test_training_jobs
29-
from test_training_jobs import mock_python_package_to_gcs # noqa: F401
29+
from test_training_jobs import ( # noqa: F401
30+
mock_python_package_to_gcs,
31+
)
3032

3133
from google.cloud import aiplatform
3234
from google.cloud.aiplatform import base
33-
from google.cloud.aiplatform.compat.types import custom_job as gca_custom_job_compat
35+
from google.cloud.aiplatform.compat.types import (
36+
custom_job as gca_custom_job_compat,
37+
)
3438
from google.cloud.aiplatform.compat.types import io as gca_io_compat
35-
from google.cloud.aiplatform.compat.types import job_state as gca_job_state_compat
39+
from google.cloud.aiplatform.compat.types import (
40+
job_state as gca_job_state_compat,
41+
)
3642
from google.cloud.aiplatform.compat.types import (
3743
encryption_spec as gca_encryption_spec_compat,
3844
)
@@ -340,6 +346,50 @@ def test_create_custom_job(self, create_custom_job_mock, get_custom_job_mock, sy
340346
)
341347
assert job.network == _TEST_NETWORK
342348

349+
def test_submit_custom_job(self, create_custom_job_mock, get_custom_job_mock):
350+
351+
aiplatform.init(
352+
project=_TEST_PROJECT,
353+
location=_TEST_LOCATION,
354+
staging_bucket=_TEST_STAGING_BUCKET,
355+
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
356+
)
357+
358+
job = aiplatform.CustomJob(
359+
display_name=_TEST_DISPLAY_NAME,
360+
worker_pool_specs=_TEST_WORKER_POOL_SPEC,
361+
base_output_dir=_TEST_BASE_OUTPUT_DIR,
362+
labels=_TEST_LABELS,
363+
)
364+
365+
job.submit(
366+
service_account=_TEST_SERVICE_ACCOUNT,
367+
network=_TEST_NETWORK,
368+
timeout=_TEST_TIMEOUT,
369+
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
370+
create_request_timeout=None,
371+
)
372+
373+
job.wait_for_resource_creation()
374+
375+
assert job.resource_name == _TEST_CUSTOM_JOB_NAME
376+
377+
job.wait()
378+
379+
expected_custom_job = _get_custom_job_proto()
380+
381+
create_custom_job_mock.assert_called_once_with(
382+
parent=_TEST_PARENT,
383+
custom_job=expected_custom_job,
384+
timeout=None,
385+
)
386+
387+
assert job.job_spec == expected_custom_job.job_spec
388+
assert (
389+
job._gca_resource.state == gca_job_state_compat.JobState.JOB_STATE_PENDING
390+
)
391+
assert job.network == _TEST_NETWORK
392+
343393
@pytest.mark.parametrize("sync", [True, False])
344394
def test_create_custom_job_with_timeout(
345395
self, create_custom_job_mock, get_custom_job_mock, sync

0 commit comments

Comments
 (0)