1111
1212import torch
1313import torch .utils .data
14- from torch .testing ._internal .common_utils import IS_MACOS
14+ from torch .testing ._internal .common_utils import IS_MACOS , TestCase
1515from torchdata .stateful_dataloader import Stateful , StatefulDataLoader
1616
1717
@@ -130,7 +130,7 @@ def identity(x):
130130 return x
131131
132132
133- class TestStatefulDataLoaderIterable ( unittest . TestCase ):
133+ class TestStatefulDataLoaderIterable_shard0 ( TestCase ):
134134 def _run_and_checkpoint (self , num_workers , batch_size , pw , interrupt , every_n_steps = 1 , shuffle = False ):
135135 dataset = DummyIterableDataset ([0 , 100 , 37 ], shuffle = shuffle )
136136 dl = StatefulDataLoader (
@@ -141,17 +141,12 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
141141 persistent_workers = pw ,
142142 multiprocessing_context = "forkserver" if IS_MACOS and num_workers else None ,
143143 )
144- list (dl )
145-
146- if interrupt is None :
147- interrupt = len (exp )
148-
149- exp = []
150144 it = iter (dl )
151145 for _ in range (interrupt ):
152146 next (it )
153147
154148 state_dict = dl .state_dict ()
149+ exp = []
155150 for data in it :
156151 exp .append (data )
157152
@@ -224,7 +219,7 @@ def test_random_state(self):
224219 )
225220
226221
227- class TestStatefulDataLoaderMap ( TestStatefulDataLoaderIterable ):
222+ class TestStatefulDataLoaderMap_shard1 ( TestStatefulDataLoaderIterable_shard0 ):
228223 def _run_and_checkpoint (self , num_workers , batch_size , pw , interrupt , every_n_steps = 1 , shuffle = False ):
229224 if num_workers == 0 :
230225 return
@@ -277,7 +272,7 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
277272 self .assertEqual (batches , exp )
278273
279274
280- class TestStatefulSampler ( TestStatefulDataLoaderIterable ):
275+ class TestStatefulSampler_shard1 ( TestStatefulDataLoaderIterable_shard0 ):
281276 def _run_and_checkpoint (self , num_workers , batch_size , pw , interrupt , every_n_steps = 1 , shuffle = False ):
282277 dataset = DummyMapDataset (100 , shuffle = shuffle )
283278 sampler = DummySampler (len (dataset ))
@@ -369,7 +364,7 @@ def __iter__(self):
369364 yield from range (start , start + self .sizes_for_all_workers [worker_id ])
370365
371366
372- class TestStatefulDataLoaderGenerator ( TestStatefulDataLoaderIterable ):
367+ class TestStatefulDataLoaderGenerator_shard2 ( TestStatefulDataLoaderIterable_shard0 ):
373368 def _run_and_checkpoint (self , num_workers , batch_size , pw , interrupt , every_n_steps = 1 , shuffle = False ):
374369 dataset = GeneratorIterable ([0 , 100 , 37 ])
375370 dl = StatefulDataLoader (
@@ -419,7 +414,7 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
419414 self .assertEqual (batches , exp )
420415
421416
422- class TestStatefulDataLoaderGeneratorNoState ( TestStatefulDataLoaderIterable ):
417+ class TestStatefulDataLoaderGeneratorNoState_shard2 ( TestStatefulDataLoaderIterable_shard0 ):
423418 def _run_and_checkpoint (self , num_workers , batch_size , pw , interrupt , every_n_steps = 1 , shuffle = False ):
424419 dataset = GeneratorIterableNoState ([0 , 100 , 37 ])
425420 dl = StatefulDataLoader (
@@ -469,7 +464,7 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
469464 self .assertEqual (batches , exp )
470465
471466
472- class TestSnapshotZero ( unittest . TestCase ):
467+ class TestSnapshotZero_shard2 ( TestCase ):
473468 def test_generator (self ):
474469 num_workers = 3
475470 every_n_steps = 10
@@ -598,7 +593,7 @@ def test_map_iterrupted_shuffle(self):
598593 self .assertEqual (batches , exp )
599594
600595
601- class TestSnapshotEnd ( unittest . TestCase ):
596+ class TestSnapshotEnd_shard2 ( TestCase ):
602597 def test_generator (self ):
603598 num_workers = 3
604599 every_n_steps = 10
@@ -785,7 +780,7 @@ def test_map_shuffle(self):
785780 self .assertEqual (batches , exp )
786781
787782
788- class TestNumWorkersMismatch ( unittest . TestCase ):
783+ class TestNumWorkersMismatch_shard3 ( TestCase ):
789784 def test_num_workers_mismatch (self ):
790785 for initial_num_workers , num_workers in ((0 , 3 ), (3 , 0 )):
791786 if initial_num_workers == num_workers :
@@ -819,7 +814,7 @@ def test_num_workers_mismatch(self):
819814 self .assertTrue (False , "Error should be of type AssertionError" )
820815
821816
822- class TestTorchDataLazyImport ( unittest . TestCase ):
817+ class TestTorchDataLazyImport_shard3 ( TestCase ):
823818 def test_lazy_imports (self ) -> None :
824819 import torchdata
825820
@@ -831,7 +826,7 @@ def test_lazy_imports(self) -> None:
831826 dp .iter .IterableWrapper ([1 , 2 ])
832827
833828
834- class TestConcurrentDataLoaders ( unittest . TestCase ):
829+ class TestConcurrentDataLoaders_shard3 ( TestCase ):
835830 def test_two_dataloaders (self ) -> None :
836831 dataset = DummyMapDataset (100 , shuffle = False )
837832 sdl = StatefulDataLoader (
@@ -852,7 +847,7 @@ def test_two_dataloaders(self) -> None:
852847 self .assertEqual (data , exp )
853848
854849
855- class TestFastStateDictRequest ( unittest . TestCase ):
850+ class TestFastStateDictRequest_shard3 ( TestCase ):
856851 def _run_test (self , snapshot_every_n_steps , interrupt ):
857852 num_workers = 4
858853 dataset = DummyIterableDataset ([25 , 25 , 25 , 25 ], shuffle = True )
@@ -903,7 +898,7 @@ def test_fast_state_dict_request_skip_steps(self) -> None:
903898 self ._run_test (17 , 19 )
904899
905900
906- class TestJsonSerDe ( unittest . TestCase ):
901+ class TestJsonSerDe_shard3 ( TestCase ):
907902 def _run_test_iterable (self , num_workers ):
908903 interrupt = 4
909904 dataset = DummyIterableDataset ([0 , 100 , 37 ], shuffle = False , include_generator = False )
0 commit comments