Skip to content

Commit 886cff3

Browse files
sasha-gitgdizcology
authored andcommitted
feat: add support for distributed custom training (googleapis#66)
1 parent 2b6f407 commit 886cff3

2 files changed

Lines changed: 579 additions & 39 deletions

File tree

google/cloud/aiplatform/training_jobs.py

Lines changed: 211 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import sys
2525
import tempfile
2626
import time
27-
from typing import Callable, List, Optional, Sequence, Union
27+
from typing import Callable, Dict, List, Optional, NamedTuple, Sequence, Union
2828

2929

3030
from google.auth import credentials as auth_credentials
@@ -339,6 +339,190 @@ def package_and_copy_to_gcs(
339339
return self.package_and_copy(copy_method=copy_method)
340340

341341

342+
class _MachineSpec(NamedTuple):
343+
"""Specification container for Machine specs used for distributed training.
344+
345+
Usage:
346+
347+
spec = _MachineSpec(
348+
replica_count=10,
349+
machine_type='n1-standard-4',
350+
accelerator_count=2,
351+
accelerator_type='NVIDIA_TESLA_K80')
352+
353+
Note that container and python package specs are not stored with this spec.
354+
"""
355+
356+
replica_count: int = 0
357+
machine_type: str = "n1-standard-4"
358+
accelerator_count: int = 0
359+
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED"
360+
361+
def _get_accelerator_type(self) -> Optional[str]:
362+
"""Validates accelerator_type and returns the name of the accelerator.
363+
364+
Returns:
365+
None if no accelerator or valid accelerator name.
366+
367+
Raise:
368+
ValueError if accelerator type is invalid.
369+
"""
370+
371+
# validate accelerator type
372+
if (
373+
self.accelerator_type
374+
not in gca_accelerator_type.AcceleratorType._member_names_
375+
):
376+
raise ValueError(
377+
f"accelerator_type `{self.accelerator_type}` invalid. "
378+
f"Choose one of {gca_accelerator_type.AcceleratorType._member_names_}"
379+
)
380+
381+
accelerator_enum = getattr(
382+
gca_accelerator_type.AcceleratorType, self.accelerator_type
383+
)
384+
385+
if (
386+
accelerator_enum
387+
!= gca_accelerator_type.AcceleratorType.ACCELERATOR_TYPE_UNSPECIFIED
388+
):
389+
return self.accelerator_type
390+
391+
@property
392+
def spec_dict(self) -> Dict[str, Union[int, str, Dict[str, Union[int, str]]]]:
393+
"""Return specification as a Dict."""
394+
spec = {
395+
"machineSpec": {"machineType": self.machine_type},
396+
"replicaCount": self.replica_count,
397+
}
398+
accelerator_type = self._get_accelerator_type()
399+
if accelerator_type and self.accelerator_count:
400+
spec["machineSpec"]["acceleratorType"] = accelerator_type
401+
spec["machineSpec"]["acceleratorCount"] = self.accelerator_count
402+
403+
return spec
404+
405+
@property
406+
def is_empty(self) -> bool:
407+
"""Returns True is replica_count > 0 False otherwise."""
408+
return self.replica_count <= 0
409+
410+
411+
class _DistributedTrainingSpec(NamedTuple):
412+
"""Configuration for distributed training worker pool specs.
413+
414+
AI Platform Training expects configuration in this order:
415+
[
416+
chief spec, # can only have one replica
417+
worker spec,
418+
parameter server spec,
419+
evaluator spec
420+
]
421+
422+
Usage:
423+
424+
dist_training_spec = _DistributedTrainingSpec(
425+
chief_spec = _MachineSpec(
426+
replica_count=1,
427+
machine_type='n1-standard-4',
428+
accelerator_count=2,
429+
accelerator_type='NVIDIA_TESLA_K80'
430+
),
431+
worker_spec = _MachineSpec(
432+
replica_count=10,
433+
machine_type='n1-standard-4',
434+
accelerator_count=2,
435+
accelerator_type='NVIDIA_TESLA_K80'
436+
)
437+
)
438+
439+
"""
440+
441+
chief_spec: _MachineSpec = _MachineSpec()
442+
worker_spec: _MachineSpec = _MachineSpec()
443+
parameter_server_spec: _MachineSpec = _MachineSpec()
444+
evaluator_spec: _MachineSpec = _MachineSpec()
445+
446+
@property
447+
def pool_specs(
448+
self,
449+
) -> List[Dict[str, Union[int, str, Dict[str, Union[int, str]]]]]:
450+
"""Return each pools spec in correct order for AI Platform as a list of dicts.
451+
452+
Also removes specs if they are empty but leaves specs in if there unusual
453+
specifications to not break the ordering in AI Platform Training.
454+
ie. 0 chief replica, 10 worker replica, 3 ps replica
455+
456+
Returns:
457+
Order list of worker pool specs suitable for AI Platform Training.
458+
"""
459+
if self.chief_spec.replica_count > 1:
460+
raise ValueError("Chief spec replica count cannot be greater than 1.")
461+
462+
spec_order = [
463+
self.chief_spec,
464+
self.worker_spec,
465+
self.parameter_server_spec,
466+
self.evaluator_spec,
467+
]
468+
specs = [s.spec_dict for s in spec_order]
469+
for i in reversed(range(len(spec_order))):
470+
if spec_order[i].is_empty:
471+
specs.pop()
472+
else:
473+
break
474+
return specs
475+
476+
@classmethod
477+
def chief_worker_pool(
478+
cls,
479+
replica_count: int = 0,
480+
machine_type: str = "n1-standard-4",
481+
accelerator_count: int = 0,
482+
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
483+
) -> "_DistributedTrainingSpec":
484+
"""Parameterizes Config to support only chief with worker replicas.
485+
486+
For replica is assigned to chief and the remainder to workers. All spec have the
487+
same machine type, accelerator count, and accelerator type.
488+
489+
Args:
490+
replica_count (int):
491+
The number of worker replicas. Assigns 1 chief replica and
492+
replica_count - 1 worker replicas.
493+
machine_type (str):
494+
The type of machine to use for training.
495+
accelerator_type (str):
496+
Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED,
497+
NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4,
498+
NVIDIA_TESLA_T4, TPU_V2, TPU_V3
499+
accelerator_count (int):
500+
The number of accelerators to attach to a worker replica.
501+
502+
Returns:
503+
_DistributedTrainingSpec representing one chief and n workers all of same
504+
type. If replica_count <= 0 then an empty spec is returned.
505+
"""
506+
if replica_count <= 0:
507+
return cls()
508+
509+
chief_spec = _MachineSpec(
510+
replica_count=1,
511+
machine_type=machine_type,
512+
accelerator_count=accelerator_count,
513+
accelerator_type=accelerator_type,
514+
)
515+
516+
worker_spec = _MachineSpec(
517+
replica_count=replica_count - 1,
518+
machine_type=machine_type,
519+
accelerator_count=accelerator_count,
520+
accelerator_type=accelerator_type,
521+
)
522+
523+
return cls(chief_spec=chief_spec, worker_spec=worker_spec)
524+
525+
342526
# TODO(b/172368325) add scheduling, custom_job.Scheduling
343527
class CustomTrainingJob(base.AiPlatformResourceNoun):
344528
"""Class to launch a Custom Training Job in AI Platform using a script.
@@ -469,6 +653,12 @@ def run(
469653
) -> Optional[models.Model]:
470654
"""Runs the custom training job.
471655
656+
Distributed Training Support:
657+
If replica count = 1 then one chief replica will be provisioned. If
658+
replica_count > 1 the remainder will be provisioned as a worker replica pool.
659+
ie: replica_count = 10 will result in 1 chief and 9 workers
660+
All replicas have same machine_type, accelerator_type, and accelerator_count
661+
472662
Data fraction splits:
473663
Any of ``training_fraction_split``, ``validation_fraction_split`` and
474664
``test_fraction_split`` may optionally be provided, they must sum to up to 1. If
@@ -498,7 +688,11 @@ def run(
498688
args (List[Unions[str, int, float]]):
499689
Command line arguments to be passed to the Python script.
500690
replica_count (int):
501-
The number of worker replicas.
691+
The number of worker replicas. If replica count = 1 then one chief
692+
replica will be provisioned. If replica_count > 1 the remainder will be
693+
provisioned as a worker replica pool.
694+
machine_type (str):
695+
The type of machine to use for training.
502696
accelerator_type (str):
503697
Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED,
504698
NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4,
@@ -523,23 +717,10 @@ def run(
523717
RuntimeError if Training job has already been run, staging_bucket has not
524718
been set, or model_display_name was provided but required arguments
525719
were not provided in constructor.
526-
NotImplementedError more then one replica.
527-
ValueError if accelerator type is not valid.
528720
"""
529721
if self._has_run:
530722
raise RuntimeError("Custom Training has already run.")
531723

532-
# TODO(b/172369809) Add support for distributed training.
533-
if replica_count > 1:
534-
raise NotImplementedError("Distributed training not supported.")
535-
536-
# validate accelerator type
537-
if accelerator_type not in gca_accelerator_type.AcceleratorType._member_names_:
538-
raise ValueError(
539-
f"accelerator_type {accelerator_type} invalid. "
540-
f"Choose one of {gca_accelerator_type.AcceleratorType._member_names_}"
541-
)
542-
543724
staging_bucket = (
544725
self._staging_bucket or initializer.global_config.staging_bucket
545726
)
@@ -550,9 +731,7 @@ def run(
550731
"set using aiplatform.init(staging_bucket='gs://my-bucket'"
551732
)
552733

553-
# if args need for model is incomplete
554-
# TODO (b/162273530) lift requirement for predict/health route when
555-
# validation lifted and move these args down
734+
# if args needed for model is incomplete
556735
if model_display_name and not self._model_serving_container_image_uri:
557736
raise RuntimeError(
558737
"""model_display_name was provided but
@@ -561,6 +740,14 @@ def run(
561740
"""
562741
)
563742

743+
# validates args and will raise
744+
worker_pool_specs = _DistributedTrainingSpec.chief_worker_pool(
745+
replica_count=replica_count,
746+
machine_type=machine_type,
747+
accelerator_count=accelerator_count,
748+
accelerator_type=accelerator_type,
749+
).pool_specs
750+
564751
# make and copy package
565752
python_packager = _TrainingScriptPythonPackager(
566753
script_path=self._script_path, requirements=self._requirements
@@ -577,30 +764,15 @@ def run(
577764
staging_bucket, "aiplatform-custom-training"
578765
)
579766

580-
# create worker pool spec
581-
worker_pool_spec = {
582-
"replicaCount": replica_count,
583-
"machineSpec": {"machineType": machine_type},
584-
"pythonPackageSpec": {
767+
for spec in worker_pool_specs:
768+
spec["pythonPackageSpec"] = {
585769
"executorImageUri": self._container_uri,
586770
"pythonModule": python_packager.module_name,
587771
"packageUris": [package_gcs_uri],
588-
},
589-
}
590-
591-
accelerator_enum = getattr(
592-
gca_accelerator_type.AcceleratorType, accelerator_type
593-
)
594-
595-
if (
596-
accelerator_enum
597-
!= gca_accelerator_type.AcceleratorType.ACCELERATOR_TYPE_UNSPECIFIED
598-
):
599-
worker_pool_spec["machineSpec"]["acceleratorType"] = accelerator_type
600-
worker_pool_spec["machineSpec"]["acceleratorCount"] = accelerator_count
772+
}
601773

602-
if args:
603-
worker_pool_spec["pythonPackageSpec"]["args"] = args
774+
if args:
775+
spec["pythonPackageSpec"]["args"] = args
604776

605777
managed_model = None
606778
# create model payload
@@ -640,7 +812,7 @@ def run(
640812
training_task_definition=schema.training_job.definition.custom_task,
641813
training_task_inputs=json_format.ParseDict(
642814
{
643-
"workerPoolSpecs": [worker_pool_spec],
815+
"workerPoolSpecs": worker_pool_specs,
644816
"baseOutputDirectory": {"output_uri_prefix": base_output_dir},
645817
},
646818
struct_pb2.Value(),

0 commit comments

Comments
 (0)