Skip to content

Commit 304e0bf

Browse files
andrewkhofacebook-github-bot
authored andcommitted
add timing info to ci, and update tests for mac to use forkserver (meta-pytorch#1248)
Summary: Fixes meta-pytorch#1247 ### Changes * Adjusts pytest CI command to show timing info * Set all multiprocessing_context = "forkserver" for macOS, as default spawn is too slow - - Pull Request resolved: meta-pytorch#1248 Test Plan: Green checks in Github Reviewed By: gokulavasan Differential Revision: D56774574 Pulled By: andrewkho fbshipit-source-id: 75e9ff063439f658fc26792a2c8fe288f19371cd
1 parent 905cdd4 commit 304e0bf

File tree

3 files changed

+36
-4
lines changed

3 files changed

+36
-4
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ jobs:
8686
- name: Run DataPipes tests with pytest
8787
if: ${{ ! contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }}
8888
run:
89-
pytest --no-header -v test --ignore=test/test_period.py --ignore=test/test_text_examples.py
89+
pytest --durations=0 --no-header -v test --ignore=test/test_period.py --ignore=test/test_text_examples.py
9090
--ignore=test/test_audio_examples.py --ignore=test/test_aistore.py
9191
--ignore=test/dataloader2/test_dataloader2.py --ignore=test/dataloader2/test_mprs.py
9292
--ignore=test/test_distributed.py --ignore=test/stateful_dataloader/test_dataloader.py

.github/workflows/stateful_dataloader_ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,4 +85,4 @@ jobs:
8585
cd ..
8686
- name: Run StatefulDataLoader tests with pytest
8787
if: ${{ ! contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }}
88-
run: pytest --no-header -v test/stateful_dataloader
88+
run: pytest --durations=0 --no-header -v test/stateful_dataloader

test/stateful_dataloader/test_state_dict.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import torch
1313
import torch.utils.data
14+
from torch.testing._internal.common_utils import IS_MACOS
1415
from torchdata.stateful_dataloader import Stateful, StatefulDataLoader
1516

1617

@@ -127,6 +128,7 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
127128
collate_fn=identity,
128129
snapshot_every_n_steps=every_n_steps,
129130
persistent_workers=pw,
131+
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
130132
)
131133
list(dl)
132134

@@ -150,6 +152,7 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
150152
collate_fn=identity,
151153
snapshot_every_n_steps=every_n_steps,
152154
persistent_workers=pw,
155+
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
153156
)
154157
dl.load_state_dict(state_dict)
155158
for batch in iter(dl):
@@ -226,6 +229,7 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
226229
persistent_workers=pw,
227230
batch_size=batch_size,
228231
sampler=sampler,
232+
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
229233
)
230234

231235
if interrupt is None:
@@ -252,6 +256,7 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
252256
persistent_workers=pw,
253257
batch_size=batch_size,
254258
sampler=sampler,
259+
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
255260
)
256261
dl.load_state_dict(state_dict)
257262
batches = []
@@ -273,6 +278,7 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
273278
persistent_workers=pw,
274279
batch_size=batch_size,
275280
sampler=sampler,
281+
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
276282
)
277283

278284
if interrupt is None:
@@ -297,6 +303,7 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
297303
persistent_workers=pw,
298304
batch_size=batch_size,
299305
sampler=sampler,
306+
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
300307
)
301308
dl.load_state_dict(state_dict)
302309
batches = []
@@ -360,6 +367,7 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
360367
collate_fn=identity,
361368
snapshot_every_n_steps=every_n_steps,
362369
persistent_workers=pw,
370+
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
363371
)
364372
exp = list(dl)
365373

@@ -373,6 +381,7 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
373381
collate_fn=identity,
374382
snapshot_every_n_steps=every_n_steps,
375383
persistent_workers=pw,
384+
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
376385
)
377386
batches = []
378387
it = iter(dl)
@@ -390,6 +399,7 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
390399
collate_fn=identity,
391400
snapshot_every_n_steps=every_n_steps,
392401
persistent_workers=pw,
402+
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
393403
)
394404
dl.load_state_dict(state_dict)
395405
for batch in dl:
@@ -407,6 +417,7 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
407417
collate_fn=identity,
408418
snapshot_every_n_steps=every_n_steps,
409419
persistent_workers=pw,
420+
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
410421
)
411422
exp = list(dl)
412423

@@ -420,6 +431,7 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
420431
collate_fn=identity,
421432
snapshot_every_n_steps=every_n_steps,
422433
persistent_workers=pw,
434+
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
423435
)
424436
batches = []
425437
it = iter(dl)
@@ -437,6 +449,7 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
437449
collate_fn=identity,
438450
snapshot_every_n_steps=every_n_steps,
439451
persistent_workers=pw,
452+
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
440453
)
441454
dl.load_state_dict(state_dict)
442455
for batch in dl:
@@ -457,6 +470,7 @@ def test_generator(self):
457470
collate_fn=identity,
458471
snapshot_every_n_steps=every_n_steps,
459472
persistent_workers=pw,
473+
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
460474
)
461475

462476
it = iter(dl)
@@ -479,6 +493,7 @@ def test_iterable(self):
479493
collate_fn=identity,
480494
snapshot_every_n_steps=every_n_steps,
481495
persistent_workers=pw,
496+
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
482497
)
483498

484499
it = iter(dl)
@@ -501,6 +516,7 @@ def test_map(self):
501516
collate_fn=identity,
502517
snapshot_every_n_steps=every_n_steps,
503518
persistent_workers=pw,
519+
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
504520
)
505521

506522
it = iter(dl)
@@ -524,6 +540,7 @@ def test_map_shuffle(self):
524540
collate_fn=identity,
525541
snapshot_every_n_steps=every_n_steps,
526542
persistent_workers=pw,
543+
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
527544
)
528545

529546
it = iter(dl)
@@ -538,7 +555,7 @@ def test_map_shuffle(self):
538555
def test_map_iterrupted_shuffle(self):
539556
every_n_steps = 10
540557

541-
for pw, num_workers, every_n_steps in itertools.product([False, True], [0, 2], [1, 5, 10, 15]):
558+
for pw, num_workers, every_n_steps in itertools.product([False, True], [0, 2], [1, 15]):
542559
dataset = DummyMapDataset(10, shuffle=True)
543560
dl = StatefulDataLoader(
544561
dataset=dataset,
@@ -547,6 +564,7 @@ def test_map_iterrupted_shuffle(self):
547564
collate_fn=identity,
548565
snapshot_every_n_steps=every_n_steps,
549566
persistent_workers=pw if num_workers > 0 else False,
567+
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
550568
)
551569

552570
it = iter(dl)
@@ -582,6 +600,7 @@ def test_generator(self):
582600
snapshot_every_n_steps=every_n_steps,
583601
persistent_workers=pw,
584602
batch_size=bs,
603+
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
585604
)
586605
exp = list(dl)
587606
state_end = dl.state_dict()
@@ -597,6 +616,7 @@ def test_generator(self):
597616
snapshot_every_n_steps=every_n_steps,
598617
persistent_workers=pw,
599618
batch_size=bs,
619+
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
600620
)
601621
it = iter(dl)
602622
for _ in range(2):
@@ -618,6 +638,7 @@ def test_generator_no_state(self):
618638
snapshot_every_n_steps=every_n_steps,
619639
persistent_workers=pw,
620640
batch_size=bs,
641+
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
621642
)
622643
exp = list(dl)
623644
state_end = dl.state_dict()
@@ -633,6 +654,7 @@ def test_generator_no_state(self):
633654
snapshot_every_n_steps=every_n_steps,
634655
persistent_workers=pw,
635656
batch_size=bs,
657+
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
636658
)
637659
it = iter(dl)
638660
for _ in range(2):
@@ -657,6 +679,7 @@ def test_iterable(self):
657679
persistent_workers=pw,
658680
batch_size=bs,
659681
generator=g,
682+
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
660683
)
661684
list(dl)
662685
state_end = dl.state_dict()
@@ -671,6 +694,7 @@ def test_iterable(self):
671694
persistent_workers=pw,
672695
batch_size=bs,
673696
generator=g,
697+
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
674698
)
675699
dl.load_state_dict(state_end)
676700
batches = list(dl)
@@ -692,6 +716,7 @@ def test_map(self):
692716
persistent_workers=pw,
693717
batch_size=bs,
694718
generator=generator,
719+
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
695720
)
696721
list(dl)
697722
state_end = dl.state_dict()
@@ -706,6 +731,7 @@ def test_map(self):
706731
persistent_workers=pw,
707732
batch_size=bs,
708733
generator=generator,
734+
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
709735
)
710736
dl.load_state_dict(state_end)
711737
batches = list(dl)
@@ -725,6 +751,7 @@ def test_map_shuffle(self):
725751
snapshot_every_n_steps=every_n_steps,
726752
persistent_workers=pw,
727753
batch_size=bs,
754+
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
728755
)
729756
list(dl)
730757
state_end = dl.state_dict()
@@ -739,6 +766,7 @@ def test_map_shuffle(self):
739766
snapshot_every_n_steps=every_n_steps,
740767
persistent_workers=pw,
741768
batch_size=bs,
769+
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
742770
)
743771
dl.load_state_dict(state_end)
744772
batches = list(dl)
@@ -748,14 +776,15 @@ def test_map_shuffle(self):
748776

749777
class TestNumWorkersMismatch(unittest.TestCase):
750778
def test_num_workers_mismatch(self):
751-
for initial_num_workers, num_workers in itertools.product([0, 5], [0, 3, 7]):
779+
for initial_num_workers, num_workers in ((0, 3), (3, 0)):
752780
if initial_num_workers == num_workers:
753781
continue
754782
dataset = DummyMapDataset(100, shuffle=False)
755783
dl = StatefulDataLoader(
756784
dataset=dataset,
757785
num_workers=initial_num_workers,
758786
collate_fn=identity,
787+
multiprocessing_context="forkserver" if IS_MACOS and initial_num_workers else None,
759788
)
760789
state = dl.state_dict()
761790
self.assertEqual(len(state), 0)
@@ -768,6 +797,7 @@ def test_num_workers_mismatch(self):
768797
dataset=dataset,
769798
num_workers=num_workers,
770799
collate_fn=identity,
800+
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
771801
)
772802
dl.load_state_dict(state)
773803
try:
@@ -797,13 +827,15 @@ def test_two_dataloaders(self) -> None:
797827
dataset=dataset,
798828
num_workers=2,
799829
collate_fn=identity,
830+
multiprocessing_context="forkserver" if IS_MACOS else None,
800831
)
801832
exp = list(sdl)
802833

803834
dl = torch.utils.data.DataLoader(
804835
dataset=dataset,
805836
num_workers=2,
806837
collate_fn=identity,
838+
multiprocessing_context="forkserver" if IS_MACOS else None,
807839
)
808840
data = list(dl)
809841
self.assertEqual(data, exp)

0 commit comments

Comments
 (0)