1111
1212import torch
1313import torch .utils .data
14+ from torch .testing ._internal .common_utils import IS_MACOS
1415from torchdata .stateful_dataloader import Stateful , StatefulDataLoader
1516
1617
@@ -127,6 +128,7 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
127128 collate_fn = identity ,
128129 snapshot_every_n_steps = every_n_steps ,
129130 persistent_workers = pw ,
131+ multiprocessing_context = "forkserver" if IS_MACOS and num_workers else None ,
130132 )
131133 list (dl )
132134
@@ -150,6 +152,7 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
150152 collate_fn = identity ,
151153 snapshot_every_n_steps = every_n_steps ,
152154 persistent_workers = pw ,
155+ multiprocessing_context = "forkserver" if IS_MACOS and num_workers else None ,
153156 )
154157 dl .load_state_dict (state_dict )
155158 for batch in iter (dl ):
@@ -226,6 +229,7 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
226229 persistent_workers = pw ,
227230 batch_size = batch_size ,
228231 sampler = sampler ,
232+ multiprocessing_context = "forkserver" if IS_MACOS and num_workers else None ,
229233 )
230234
231235 if interrupt is None :
@@ -252,6 +256,7 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
252256 persistent_workers = pw ,
253257 batch_size = batch_size ,
254258 sampler = sampler ,
259+ multiprocessing_context = "forkserver" if IS_MACOS and num_workers else None ,
255260 )
256261 dl .load_state_dict (state_dict )
257262 batches = []
@@ -273,6 +278,7 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
273278 persistent_workers = pw ,
274279 batch_size = batch_size ,
275280 sampler = sampler ,
281+ multiprocessing_context = "forkserver" if IS_MACOS and num_workers else None ,
276282 )
277283
278284 if interrupt is None :
@@ -297,6 +303,7 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
297303 persistent_workers = pw ,
298304 batch_size = batch_size ,
299305 sampler = sampler ,
306+ multiprocessing_context = "forkserver" if IS_MACOS and num_workers else None ,
300307 )
301308 dl .load_state_dict (state_dict )
302309 batches = []
@@ -360,6 +367,7 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
360367 collate_fn = identity ,
361368 snapshot_every_n_steps = every_n_steps ,
362369 persistent_workers = pw ,
370+ multiprocessing_context = "forkserver" if IS_MACOS and num_workers else None ,
363371 )
364372 exp = list (dl )
365373
@@ -373,6 +381,7 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
373381 collate_fn = identity ,
374382 snapshot_every_n_steps = every_n_steps ,
375383 persistent_workers = pw ,
384+ multiprocessing_context = "forkserver" if IS_MACOS and num_workers else None ,
376385 )
377386 batches = []
378387 it = iter (dl )
@@ -390,6 +399,7 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
390399 collate_fn = identity ,
391400 snapshot_every_n_steps = every_n_steps ,
392401 persistent_workers = pw ,
402+ multiprocessing_context = "forkserver" if IS_MACOS and num_workers else None ,
393403 )
394404 dl .load_state_dict (state_dict )
395405 for batch in dl :
@@ -407,6 +417,7 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
407417 collate_fn = identity ,
408418 snapshot_every_n_steps = every_n_steps ,
409419 persistent_workers = pw ,
420+ multiprocessing_context = "forkserver" if IS_MACOS and num_workers else None ,
410421 )
411422 exp = list (dl )
412423
@@ -420,6 +431,7 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
420431 collate_fn = identity ,
421432 snapshot_every_n_steps = every_n_steps ,
422433 persistent_workers = pw ,
434+ multiprocessing_context = "forkserver" if IS_MACOS and num_workers else None ,
423435 )
424436 batches = []
425437 it = iter (dl )
@@ -437,6 +449,7 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
437449 collate_fn = identity ,
438450 snapshot_every_n_steps = every_n_steps ,
439451 persistent_workers = pw ,
452+ multiprocessing_context = "forkserver" if IS_MACOS and num_workers else None ,
440453 )
441454 dl .load_state_dict (state_dict )
442455 for batch in dl :
@@ -457,6 +470,7 @@ def test_generator(self):
457470 collate_fn = identity ,
458471 snapshot_every_n_steps = every_n_steps ,
459472 persistent_workers = pw ,
473+ multiprocessing_context = "forkserver" if IS_MACOS and num_workers else None ,
460474 )
461475
462476 it = iter (dl )
@@ -479,6 +493,7 @@ def test_iterable(self):
479493 collate_fn = identity ,
480494 snapshot_every_n_steps = every_n_steps ,
481495 persistent_workers = pw ,
496+ multiprocessing_context = "forkserver" if IS_MACOS and num_workers else None ,
482497 )
483498
484499 it = iter (dl )
@@ -501,6 +516,7 @@ def test_map(self):
501516 collate_fn = identity ,
502517 snapshot_every_n_steps = every_n_steps ,
503518 persistent_workers = pw ,
519+ multiprocessing_context = "forkserver" if IS_MACOS and num_workers else None ,
504520 )
505521
506522 it = iter (dl )
@@ -524,6 +540,7 @@ def test_map_shuffle(self):
524540 collate_fn = identity ,
525541 snapshot_every_n_steps = every_n_steps ,
526542 persistent_workers = pw ,
543+ multiprocessing_context = "forkserver" if IS_MACOS and num_workers else None ,
527544 )
528545
529546 it = iter (dl )
@@ -538,7 +555,7 @@ def test_map_shuffle(self):
538555 def test_map_iterrupted_shuffle (self ):
539556 every_n_steps = 10
540557
541- for pw , num_workers , every_n_steps in itertools .product ([False , True ], [0 , 2 ], [1 , 5 , 10 , 15 ]):
558+ for pw , num_workers , every_n_steps in itertools .product ([False , True ], [0 , 2 ], [1 , 15 ]):
542559 dataset = DummyMapDataset (10 , shuffle = True )
543560 dl = StatefulDataLoader (
544561 dataset = dataset ,
@@ -547,6 +564,7 @@ def test_map_iterrupted_shuffle(self):
547564 collate_fn = identity ,
548565 snapshot_every_n_steps = every_n_steps ,
549566 persistent_workers = pw if num_workers > 0 else False ,
567+ multiprocessing_context = "forkserver" if IS_MACOS and num_workers else None ,
550568 )
551569
552570 it = iter (dl )
@@ -582,6 +600,7 @@ def test_generator(self):
582600 snapshot_every_n_steps = every_n_steps ,
583601 persistent_workers = pw ,
584602 batch_size = bs ,
603+ multiprocessing_context = "forkserver" if IS_MACOS and num_workers else None ,
585604 )
586605 exp = list (dl )
587606 state_end = dl .state_dict ()
@@ -597,6 +616,7 @@ def test_generator(self):
597616 snapshot_every_n_steps = every_n_steps ,
598617 persistent_workers = pw ,
599618 batch_size = bs ,
619+ multiprocessing_context = "forkserver" if IS_MACOS and num_workers else None ,
600620 )
601621 it = iter (dl )
602622 for _ in range (2 ):
@@ -618,6 +638,7 @@ def test_generator_no_state(self):
618638 snapshot_every_n_steps = every_n_steps ,
619639 persistent_workers = pw ,
620640 batch_size = bs ,
641+ multiprocessing_context = "forkserver" if IS_MACOS and num_workers else None ,
621642 )
622643 exp = list (dl )
623644 state_end = dl .state_dict ()
@@ -633,6 +654,7 @@ def test_generator_no_state(self):
633654 snapshot_every_n_steps = every_n_steps ,
634655 persistent_workers = pw ,
635656 batch_size = bs ,
657+ multiprocessing_context = "forkserver" if IS_MACOS and num_workers else None ,
636658 )
637659 it = iter (dl )
638660 for _ in range (2 ):
@@ -657,6 +679,7 @@ def test_iterable(self):
657679 persistent_workers = pw ,
658680 batch_size = bs ,
659681 generator = g ,
682+ multiprocessing_context = "forkserver" if IS_MACOS and num_workers else None ,
660683 )
661684 list (dl )
662685 state_end = dl .state_dict ()
@@ -671,6 +694,7 @@ def test_iterable(self):
671694 persistent_workers = pw ,
672695 batch_size = bs ,
673696 generator = g ,
697+ multiprocessing_context = "forkserver" if IS_MACOS and num_workers else None ,
674698 )
675699 dl .load_state_dict (state_end )
676700 batches = list (dl )
@@ -692,6 +716,7 @@ def test_map(self):
692716 persistent_workers = pw ,
693717 batch_size = bs ,
694718 generator = generator ,
719+ multiprocessing_context = "forkserver" if IS_MACOS and num_workers else None ,
695720 )
696721 list (dl )
697722 state_end = dl .state_dict ()
@@ -706,6 +731,7 @@ def test_map(self):
706731 persistent_workers = pw ,
707732 batch_size = bs ,
708733 generator = generator ,
734+ multiprocessing_context = "forkserver" if IS_MACOS and num_workers else None ,
709735 )
710736 dl .load_state_dict (state_end )
711737 batches = list (dl )
@@ -725,6 +751,7 @@ def test_map_shuffle(self):
725751 snapshot_every_n_steps = every_n_steps ,
726752 persistent_workers = pw ,
727753 batch_size = bs ,
754+ multiprocessing_context = "forkserver" if IS_MACOS and num_workers else None ,
728755 )
729756 list (dl )
730757 state_end = dl .state_dict ()
@@ -739,6 +766,7 @@ def test_map_shuffle(self):
739766 snapshot_every_n_steps = every_n_steps ,
740767 persistent_workers = pw ,
741768 batch_size = bs ,
769+ multiprocessing_context = "forkserver" if IS_MACOS and num_workers else None ,
742770 )
743771 dl .load_state_dict (state_end )
744772 batches = list (dl )
@@ -748,14 +776,15 @@ def test_map_shuffle(self):
748776
749777class TestNumWorkersMismatch (unittest .TestCase ):
750778 def test_num_workers_mismatch (self ):
751- for initial_num_workers , num_workers in itertools . product ([ 0 , 5 ], [ 0 , 3 , 7 ] ):
779+ for initial_num_workers , num_workers in (( 0 , 3 ), ( 3 , 0 ) ):
752780 if initial_num_workers == num_workers :
753781 continue
754782 dataset = DummyMapDataset (100 , shuffle = False )
755783 dl = StatefulDataLoader (
756784 dataset = dataset ,
757785 num_workers = initial_num_workers ,
758786 collate_fn = identity ,
787+ multiprocessing_context = "forkserver" if IS_MACOS and initial_num_workers else None ,
759788 )
760789 state = dl .state_dict ()
761790 self .assertEqual (len (state ), 0 )
@@ -768,6 +797,7 @@ def test_num_workers_mismatch(self):
768797 dataset = dataset ,
769798 num_workers = num_workers ,
770799 collate_fn = identity ,
800+ multiprocessing_context = "forkserver" if IS_MACOS and num_workers else None ,
771801 )
772802 dl .load_state_dict (state )
773803 try :
@@ -797,13 +827,15 @@ def test_two_dataloaders(self) -> None:
797827 dataset = dataset ,
798828 num_workers = 2 ,
799829 collate_fn = identity ,
830+ multiprocessing_context = "forkserver" if IS_MACOS else None ,
800831 )
801832 exp = list (sdl )
802833
803834 dl = torch .utils .data .DataLoader (
804835 dataset = dataset ,
805836 num_workers = 2 ,
806837 collate_fn = identity ,
838+ multiprocessing_context = "forkserver" if IS_MACOS else None ,
807839 )
808840 data = list (dl )
809841 self .assertEqual (data , exp )
0 commit comments