Skip to content

Commit c78bbc7

Browse files
authored
Shard StatefulDataLoader tests
Differential Revision: D57227524 Pull Request resolved: meta-pytorch#1257
1 parent 11e16da commit c78bbc7

File tree

2 files changed

+28
-21
lines changed

2 files changed

+28
-21
lines changed

.github/workflows/stateful_dataloader_ci.yml

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,18 @@ jobs:
8181
pip3 install -r requirements.txt
8282
make doctest
8383
cd ..
84-
- name: Run StatefulDataLoader tests with pytest
84+
- name: Run StatefulDataLoader tests with pytest - dataloader
8585
if: ${{ ! contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }}
86-
run: pytest --durations=0 --no-header -v test/stateful_dataloader
86+
run: pytest --durations=0 --no-header -v test/stateful_dataloader/test_dataloader.py
87+
- name: Run StatefulDataLoader tests with pytest - state_dict 0
88+
if: ${{ ! contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }}
89+
run: pytest --durations=0 --no-header -v test/stateful_dataloader/test_state_dict.py -k _shard0
90+
- name: Run StatefulDataLoader tests with pytest - state_dict 1
91+
if: ${{ ! contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }}
92+
run: pytest --durations=0 --no-header -v test/stateful_dataloader/test_state_dict.py -k _shard1
93+
- name: Run StatefulDataLoader tests with pytest - state_dict 2
94+
if: ${{ ! contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }}
95+
run: pytest --durations=0 --no-header -v test/stateful_dataloader/test_state_dict.py -k _shard2
96+
- name: Run StatefulDataLoader tests with pytest - state_dict 3
97+
if: ${{ ! contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }}
98+
run: pytest --durations=0 --no-header -v test/stateful_dataloader/test_state_dict.py -k _shard3

test/stateful_dataloader/test_state_dict.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import torch
1313
import torch.utils.data
14-
from torch.testing._internal.common_utils import IS_MACOS
14+
from torch.testing._internal.common_utils import IS_MACOS, TestCase
1515
from torchdata.stateful_dataloader import Stateful, StatefulDataLoader
1616

1717

@@ -130,7 +130,7 @@ def identity(x):
130130
return x
131131

132132

133-
class TestStatefulDataLoaderIterable(unittest.TestCase):
133+
class TestStatefulDataLoaderIterable_shard0(TestCase):
134134
def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False):
135135
dataset = DummyIterableDataset([0, 100, 37], shuffle=shuffle)
136136
dl = StatefulDataLoader(
@@ -141,17 +141,12 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
141141
persistent_workers=pw,
142142
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
143143
)
144-
list(dl)
145-
146-
if interrupt is None:
147-
interrupt = len(exp)
148-
149-
exp = []
150144
it = iter(dl)
151145
for _ in range(interrupt):
152146
next(it)
153147

154148
state_dict = dl.state_dict()
149+
exp = []
155150
for data in it:
156151
exp.append(data)
157152

@@ -224,7 +219,7 @@ def test_random_state(self):
224219
)
225220

226221

227-
class TestStatefulDataLoaderMap(TestStatefulDataLoaderIterable):
222+
class TestStatefulDataLoaderMap_shard1(TestStatefulDataLoaderIterable_shard0):
228223
def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False):
229224
if num_workers == 0:
230225
return
@@ -277,7 +272,7 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
277272
self.assertEqual(batches, exp)
278273

279274

280-
class TestStatefulSampler(TestStatefulDataLoaderIterable):
275+
class TestStatefulSampler_shard1(TestStatefulDataLoaderIterable_shard0):
281276
def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False):
282277
dataset = DummyMapDataset(100, shuffle=shuffle)
283278
sampler = DummySampler(len(dataset))
@@ -369,7 +364,7 @@ def __iter__(self):
369364
yield from range(start, start + self.sizes_for_all_workers[worker_id])
370365

371366

372-
class TestStatefulDataLoaderGenerator(TestStatefulDataLoaderIterable):
367+
class TestStatefulDataLoaderGenerator_shard2(TestStatefulDataLoaderIterable_shard0):
373368
def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False):
374369
dataset = GeneratorIterable([0, 100, 37])
375370
dl = StatefulDataLoader(
@@ -419,7 +414,7 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
419414
self.assertEqual(batches, exp)
420415

421416

422-
class TestStatefulDataLoaderGeneratorNoState(TestStatefulDataLoaderIterable):
417+
class TestStatefulDataLoaderGeneratorNoState_shard2(TestStatefulDataLoaderIterable_shard0):
423418
def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False):
424419
dataset = GeneratorIterableNoState([0, 100, 37])
425420
dl = StatefulDataLoader(
@@ -469,7 +464,7 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
469464
self.assertEqual(batches, exp)
470465

471466

472-
class TestSnapshotZero(unittest.TestCase):
467+
class TestSnapshotZero_shard2(TestCase):
473468
def test_generator(self):
474469
num_workers = 3
475470
every_n_steps = 10
@@ -598,7 +593,7 @@ def test_map_iterrupted_shuffle(self):
598593
self.assertEqual(batches, exp)
599594

600595

601-
class TestSnapshotEnd(unittest.TestCase):
596+
class TestSnapshotEnd_shard2(TestCase):
602597
def test_generator(self):
603598
num_workers = 3
604599
every_n_steps = 10
@@ -785,7 +780,7 @@ def test_map_shuffle(self):
785780
self.assertEqual(batches, exp)
786781

787782

788-
class TestNumWorkersMismatch(unittest.TestCase):
783+
class TestNumWorkersMismatch_shard3(TestCase):
789784
def test_num_workers_mismatch(self):
790785
for initial_num_workers, num_workers in ((0, 3), (3, 0)):
791786
if initial_num_workers == num_workers:
@@ -819,7 +814,7 @@ def test_num_workers_mismatch(self):
819814
self.assertTrue(False, "Error should be of type AssertionError")
820815

821816

822-
class TestTorchDataLazyImport(unittest.TestCase):
817+
class TestTorchDataLazyImport_shard3(TestCase):
823818
def test_lazy_imports(self) -> None:
824819
import torchdata
825820

@@ -831,7 +826,7 @@ def test_lazy_imports(self) -> None:
831826
dp.iter.IterableWrapper([1, 2])
832827

833828

834-
class TestConcurrentDataLoaders(unittest.TestCase):
829+
class TestConcurrentDataLoaders_shard3(TestCase):
835830
def test_two_dataloaders(self) -> None:
836831
dataset = DummyMapDataset(100, shuffle=False)
837832
sdl = StatefulDataLoader(
@@ -852,7 +847,7 @@ def test_two_dataloaders(self) -> None:
852847
self.assertEqual(data, exp)
853848

854849

855-
class TestFastStateDictRequest(unittest.TestCase):
850+
class TestFastStateDictRequest_shard3(TestCase):
856851
def _run_test(self, snapshot_every_n_steps, interrupt):
857852
num_workers = 4
858853
dataset = DummyIterableDataset([25, 25, 25, 25], shuffle=True)
@@ -903,7 +898,7 @@ def test_fast_state_dict_request_skip_steps(self) -> None:
903898
self._run_test(17, 19)
904899

905900

906-
class TestJsonSerDe(unittest.TestCase):
901+
class TestJsonSerDe_shard3(TestCase):
907902
def _run_test_iterable(self, num_workers):
908903
interrupt = 4
909904
dataset = DummyIterableDataset([0, 100, 37], shuffle=False, include_generator=False)

0 commit comments

Comments
 (0)