2424import sys
2525import tempfile
2626import time
27- from typing import Callable , List , Optional , Sequence , Union
27+ from typing import Callable , Dict , List , Optional , NamedTuple , Sequence , Union
2828
2929
3030from google .auth import credentials as auth_credentials
@@ -339,6 +339,190 @@ def package_and_copy_to_gcs(
339339 return self .package_and_copy (copy_method = copy_method )
340340
341341
342+ class _MachineSpec (NamedTuple ):
343+ """Specification container for Machine specs used for distributed training.
344+
345+ Usage:
346+
347+ spec = _MachineSpec(
348+ replica_count=10,
349+ machine_type='n1-standard-4',
350+ accelerator_count=2,
351+ accelerator_type='NVIDIA_TESLA_K80')
352+
353+ Note that container and python package specs are not stored with this spec.
354+ """
355+
356+ replica_count : int = 0
357+ machine_type : str = "n1-standard-4"
358+ accelerator_count : int = 0
359+ accelerator_type : str = "ACCELERATOR_TYPE_UNSPECIFIED"
360+
361+ def _get_accelerator_type (self ) -> Optional [str ]:
362+ """Validates accelerator_type and returns the name of the accelerator.
363+
364+ Returns:
365+ None if no accelerator or valid accelerator name.
366+
367+ Raise:
368+ ValueError if accelerator type is invalid.
369+ """
370+
371+ # validate accelerator type
372+ if (
373+ self .accelerator_type
374+ not in gca_accelerator_type .AcceleratorType ._member_names_
375+ ):
376+ raise ValueError (
377+ f"accelerator_type `{ self .accelerator_type } ` invalid. "
378+ f"Choose one of { gca_accelerator_type .AcceleratorType ._member_names_ } "
379+ )
380+
381+ accelerator_enum = getattr (
382+ gca_accelerator_type .AcceleratorType , self .accelerator_type
383+ )
384+
385+ if (
386+ accelerator_enum
387+ != gca_accelerator_type .AcceleratorType .ACCELERATOR_TYPE_UNSPECIFIED
388+ ):
389+ return self .accelerator_type
390+
391+ @property
392+ def spec_dict (self ) -> Dict [str , Union [int , str , Dict [str , Union [int , str ]]]]:
393+ """Return specification as a Dict."""
394+ spec = {
395+ "machineSpec" : {"machineType" : self .machine_type },
396+ "replicaCount" : self .replica_count ,
397+ }
398+ accelerator_type = self ._get_accelerator_type ()
399+ if accelerator_type and self .accelerator_count :
400+ spec ["machineSpec" ]["acceleratorType" ] = accelerator_type
401+ spec ["machineSpec" ]["acceleratorCount" ] = self .accelerator_count
402+
403+ return spec
404+
405+ @property
406+ def is_empty (self ) -> bool :
407+ """Returns True is replica_count > 0 False otherwise."""
408+ return self .replica_count <= 0
409+
410+
411+ class _DistributedTrainingSpec (NamedTuple ):
412+ """Configuration for distributed training worker pool specs.
413+
414+ AI Platform Training expects configuration in this order:
415+ [
416+ chief spec, # can only have one replica
417+ worker spec,
418+ parameter server spec,
419+ evaluator spec
420+ ]
421+
422+ Usage:
423+
424+ dist_training_spec = _DistributedTrainingSpec(
425+ chief_spec = _MachineSpec(
426+ replica_count=1,
427+ machine_type='n1-standard-4',
428+ accelerator_count=2,
429+ accelerator_type='NVIDIA_TESLA_K80'
430+ ),
431+ worker_spec = _MachineSpec(
432+ replica_count=10,
433+ machine_type='n1-standard-4',
434+ accelerator_count=2,
435+ accelerator_type='NVIDIA_TESLA_K80'
436+ )
437+ )
438+
439+ """
440+
441+ chief_spec : _MachineSpec = _MachineSpec ()
442+ worker_spec : _MachineSpec = _MachineSpec ()
443+ parameter_server_spec : _MachineSpec = _MachineSpec ()
444+ evaluator_spec : _MachineSpec = _MachineSpec ()
445+
446+ @property
447+ def pool_specs (
448+ self ,
449+ ) -> List [Dict [str , Union [int , str , Dict [str , Union [int , str ]]]]]:
450+ """Return each pools spec in correct order for AI Platform as a list of dicts.
451+
452+ Also removes specs if they are empty but leaves specs in if there unusual
453+ specifications to not break the ordering in AI Platform Training.
454+ ie. 0 chief replica, 10 worker replica, 3 ps replica
455+
456+ Returns:
457+ Order list of worker pool specs suitable for AI Platform Training.
458+ """
459+ if self .chief_spec .replica_count > 1 :
460+ raise ValueError ("Chief spec replica count cannot be greater than 1." )
461+
462+ spec_order = [
463+ self .chief_spec ,
464+ self .worker_spec ,
465+ self .parameter_server_spec ,
466+ self .evaluator_spec ,
467+ ]
468+ specs = [s .spec_dict for s in spec_order ]
469+ for i in reversed (range (len (spec_order ))):
470+ if spec_order [i ].is_empty :
471+ specs .pop ()
472+ else :
473+ break
474+ return specs
475+
476+ @classmethod
477+ def chief_worker_pool (
478+ cls ,
479+ replica_count : int = 0 ,
480+ machine_type : str = "n1-standard-4" ,
481+ accelerator_count : int = 0 ,
482+ accelerator_type : str = "ACCELERATOR_TYPE_UNSPECIFIED" ,
483+ ) -> "_DistributedTrainingSpec" :
484+ """Parameterizes Config to support only chief with worker replicas.
485+
486+ For replica is assigned to chief and the remainder to workers. All spec have the
487+ same machine type, accelerator count, and accelerator type.
488+
489+ Args:
490+ replica_count (int):
491+ The number of worker replicas. Assigns 1 chief replica and
492+ replica_count - 1 worker replicas.
493+ machine_type (str):
494+ The type of machine to use for training.
495+ accelerator_type (str):
496+ Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED,
497+ NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4,
498+ NVIDIA_TESLA_T4, TPU_V2, TPU_V3
499+ accelerator_count (int):
500+ The number of accelerators to attach to a worker replica.
501+
502+ Returns:
503+ _DistributedTrainingSpec representing one chief and n workers all of same
504+ type. If replica_count <= 0 then an empty spec is returned.
505+ """
506+ if replica_count <= 0 :
507+ return cls ()
508+
509+ chief_spec = _MachineSpec (
510+ replica_count = 1 ,
511+ machine_type = machine_type ,
512+ accelerator_count = accelerator_count ,
513+ accelerator_type = accelerator_type ,
514+ )
515+
516+ worker_spec = _MachineSpec (
517+ replica_count = replica_count - 1 ,
518+ machine_type = machine_type ,
519+ accelerator_count = accelerator_count ,
520+ accelerator_type = accelerator_type ,
521+ )
522+
523+ return cls (chief_spec = chief_spec , worker_spec = worker_spec )
524+
525+
342526# TODO(b/172368325) add scheduling, custom_job.Scheduling
343527class CustomTrainingJob (base .AiPlatformResourceNoun ):
344528 """Class to launch a Custom Training Job in AI Platform using a script.
@@ -469,6 +653,12 @@ def run(
469653 ) -> Optional [models .Model ]:
470654 """Runs the custom training job.
471655
656+ Distributed Training Support:
657+ If replica count = 1 then one chief replica will be provisioned. If
658+ replica_count > 1 the remainder will be provisioned as a worker replica pool.
659+ ie: replica_count = 10 will result in 1 chief and 9 workers
660+ All replicas have same machine_type, accelerator_type, and accelerator_count
661+
472662 Data fraction splits:
473663 Any of ``training_fraction_split``, ``validation_fraction_split`` and
474664 ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If
@@ -498,7 +688,11 @@ def run(
498688 args (List[Unions[str, int, float]]):
499689 Command line arguments to be passed to the Python script.
500690 replica_count (int):
501- The number of worker replicas.
691+ The number of worker replicas. If replica count = 1 then one chief
692+ replica will be provisioned. If replica_count > 1 the remainder will be
693+ provisioned as a worker replica pool.
694+ machine_type (str):
695+ The type of machine to use for training.
502696 accelerator_type (str):
503697 Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED,
504698 NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4,
@@ -523,23 +717,10 @@ def run(
523717 RuntimeError if Training job has already been run, staging_bucket has not
524718 been set, or model_display_name was provided but required arguments
525719 were not provided in constructor.
526- NotImplementedError more then one replica.
527- ValueError if accelerator type is not valid.
528720 """
529721 if self ._has_run :
530722 raise RuntimeError ("Custom Training has already run." )
531723
532- # TODO(b/172369809) Add support for distributed training.
533- if replica_count > 1 :
534- raise NotImplementedError ("Distributed training not supported." )
535-
536- # validate accelerator type
537- if accelerator_type not in gca_accelerator_type .AcceleratorType ._member_names_ :
538- raise ValueError (
539- f"accelerator_type { accelerator_type } invalid. "
540- f"Choose one of { gca_accelerator_type .AcceleratorType ._member_names_ } "
541- )
542-
543724 staging_bucket = (
544725 self ._staging_bucket or initializer .global_config .staging_bucket
545726 )
@@ -550,9 +731,7 @@ def run(
550731 "set using aiplatform.init(staging_bucket='gs://my-bucket'"
551732 )
552733
553- # if args need for model is incomplete
554- # TODO (b/162273530) lift requirement for predict/health route when
555- # validation lifted and move these args down
734+ # if args needed for model is incomplete
556735 if model_display_name and not self ._model_serving_container_image_uri :
557736 raise RuntimeError (
558737 """model_display_name was provided but
@@ -561,6 +740,14 @@ def run(
561740 """
562741 )
563742
743+ # validates args and will raise
744+ worker_pool_specs = _DistributedTrainingSpec .chief_worker_pool (
745+ replica_count = replica_count ,
746+ machine_type = machine_type ,
747+ accelerator_count = accelerator_count ,
748+ accelerator_type = accelerator_type ,
749+ ).pool_specs
750+
564751 # make and copy package
565752 python_packager = _TrainingScriptPythonPackager (
566753 script_path = self ._script_path , requirements = self ._requirements
@@ -577,30 +764,15 @@ def run(
577764 staging_bucket , "aiplatform-custom-training"
578765 )
579766
580- # create worker pool spec
581- worker_pool_spec = {
582- "replicaCount" : replica_count ,
583- "machineSpec" : {"machineType" : machine_type },
584- "pythonPackageSpec" : {
767+ for spec in worker_pool_specs :
768+ spec ["pythonPackageSpec" ] = {
585769 "executorImageUri" : self ._container_uri ,
586770 "pythonModule" : python_packager .module_name ,
587771 "packageUris" : [package_gcs_uri ],
588- },
589- }
590-
591- accelerator_enum = getattr (
592- gca_accelerator_type .AcceleratorType , accelerator_type
593- )
594-
595- if (
596- accelerator_enum
597- != gca_accelerator_type .AcceleratorType .ACCELERATOR_TYPE_UNSPECIFIED
598- ):
599- worker_pool_spec ["machineSpec" ]["acceleratorType" ] = accelerator_type
600- worker_pool_spec ["machineSpec" ]["acceleratorCount" ] = accelerator_count
772+ }
601773
602- if args :
603- worker_pool_spec ["pythonPackageSpec" ]["args" ] = args
774+ if args :
775+ spec ["pythonPackageSpec" ]["args" ] = args
604776
605777 managed_model = None
606778 # create model payload
@@ -640,7 +812,7 @@ def run(
640812 training_task_definition = schema .training_job .definition .custom_task ,
641813 training_task_inputs = json_format .ParseDict (
642814 {
643- "workerPoolSpecs" : [ worker_pool_spec ] ,
815+ "workerPoolSpecs" : worker_pool_specs ,
644816 "baseOutputDirectory" : {"output_uri_prefix" : base_output_dir },
645817 },
646818 struct_pb2 .Value (),
0 commit comments