Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Demonstrate test.regrtest unittest sharding.
An incomplete implementation with details to be worked out, but it
works!  It makes our long tail tests take significantly less time. At
least when run on their own.

Example: ~25 seconds wall time to run test_multiprocessing_spawn and
test_concurrent_futures on a 12 thread machine for example.

`python -m test -r -j 20 test_multiprocessing_spawn test_concurrent_futures`

Known Issues to work out: result reporting and libregrtest accounting.
You see any sharded test "complete" multiple times and your total tests
run count goes higher than the total number of tests. 😂

Real caveat: This exposes ordering and concurrency weaknesses in some
tests like test_asyncio that'll need fixing.

Which tests get sharded is explicitly opt-in.  Currently not in a
maintainable spot.  How best to maintain that needs to be worked out,
but I expect we only ever have 10-20 test modules that we declare as
worth sharding.

This implementation is inspired by and with the unittest TestLoader bits
derived directly from the Apache 2.0 licensed
https://github.com/abseil/abseil-py/blob/v1.3.0/absl/testing/absltest.py#L2359

```
:~/oss/cpython (performance/test-sharding)$ ../b/python -m test -r -j 20
test_multiprocessing_spawn test_concurrent_futures
Using random seed 8555091
0:00:00 load avg: 0.98 Run tests in parallel using 20 child processes
0:00:08 load avg: 1.30 [1/2] test_multiprocessing_spawn passed
0:00:10 load avg: 1.68 [2/2] test_concurrent_futures passed
0:00:11 load avg: 1.68 [3/2] test_multiprocessing_spawn passed
0:00:12 load avg: 1.68 [4/2] test_multiprocessing_spawn passed
0:00:12 load avg: 1.68 [5/2] test_multiprocessing_spawn passed
0:00:14 load avg: 1.87 [6/2] test_multiprocessing_spawn passed
0:00:15 load avg: 1.87 [7/2] test_multiprocessing_spawn passed
0:00:16 load avg: 1.87 [8/2] test_concurrent_futures passed
0:00:16 load avg: 1.87 [9/2] test_multiprocessing_spawn passed
0:00:18 load avg: 1.87 [10/2] test_concurrent_futures passed
0:00:20 load avg: 1.72 [11/2] test_concurrent_futures passed
0:00:20 load avg: 1.72 [12/2] test_concurrent_futures passed
0:00:21 load avg: 1.72 [13/2] test_multiprocessing_spawn passed
0:00:21 load avg: 1.72 [14/2] test_concurrent_futures passed
0:00:22 load avg: 1.72 [15/2] test_concurrent_futures passed
0:00:25 load avg: 1.58 [16/2] test_concurrent_futures passed

== Tests result: SUCCESS ==

All 16 tests OK.

Total duration: 25.6 sec
Tests result: SUCCESS
```
  • Loading branch information
gpshead committed Nov 21, 2022
commit 2748b2dbc502b2c51a16b40fe375506acfbf3683
17 changes: 17 additions & 0 deletions Lib/test/libregrtest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,23 @@ def __init__(self):
# tests
self.tests = []
self.selected = []
self.tests_to_shard = set()
# TODO(gpshead): this list belongs elsewhere - it'd be nice to tag
# these within the test module/package itself but loading everything
# to detect those tags is complicated. As is a feedback mechanism
# from a shard file.
# Our slowest tests per a "-o" run:
self.tests_to_shard.add('test_concurrent_futures')
self.tests_to_shard.add('test_multiprocessing_spawn')
self.tests_to_shard.add('test_asyncio')
self.tests_to_shard.add('test_tools')
self.tests_to_shard.add('test_multiprocessing_forkserver')
self.tests_to_shard.add('test_multiprocessing_fork')
self.tests_to_shard.add('test_signal')
self.tests_to_shard.add('test_socket')
self.tests_to_shard.add('test_io')
self.tests_to_shard.add('test_imaplib')
self.tests_to_shard.add('test_subprocess')

# test results
self.good = []
Expand Down
66 changes: 54 additions & 12 deletions Lib/test/libregrtest/runtest_mp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import faulthandler
from dataclasses import dataclass
import json
import os.path
import queue
Expand All @@ -9,7 +10,7 @@
import threading
import time
import traceback
from typing import NamedTuple, NoReturn, Literal, Any, TextIO
from typing import Iterator, NamedTuple, NoReturn, Literal, Any, TextIO

from test import support
from test.support import os_helper
Expand Down Expand Up @@ -42,6 +43,13 @@
USE_PROCESS_GROUP = (hasattr(os, "setsid") and hasattr(os, "killpg"))


@dataclass
class ShardInfo:
number: int
total_shards: int
status_file: str = ""


def must_stop(result: TestResult, ns: Namespace) -> bool:
if isinstance(result, Interrupted):
return True
Expand All @@ -56,7 +64,7 @@ def parse_worker_args(worker_args) -> tuple[Namespace, str]:
return (ns, test_name)


def run_test_in_subprocess(testname: str, ns: Namespace, tmp_dir: str, stdout_fh: TextIO) -> subprocess.Popen:
def run_test_in_subprocess(testname: str, ns: Namespace, tmp_dir: str, stdout_fh: TextIO, shard: ShardInfo|None = None) -> subprocess.Popen:
ns_dict = vars(ns)
worker_args = (ns_dict, testname)
worker_args = json.dumps(worker_args)
Expand All @@ -75,6 +83,13 @@ def run_test_in_subprocess(testname: str, ns: Namespace, tmp_dir: str, stdout_fh
env['TEMP'] = tmp_dir
env['TMP'] = tmp_dir

if shard:
# This follows the "Bazel test sharding protocol"
shard.status_file = os.path.join(tmp_dir, 'sharded')
env['TEST_SHARD_STATUS_FILE'] = shard.status_file
env['TEST_SHARD_INDEX'] = str(shard.number)
env['TEST_TOTAL_SHARDS'] = str(shard.total_shards)

# Running the child from the same working directory as regrtest's original
# invocation ensures that TEMPDIR for the child is the same when
# sysconfig.is_python_build() is true. See issue 15300.
Expand Down Expand Up @@ -109,7 +124,7 @@ class MultiprocessIterator:

"""A thread-safe iterator over tests for multiprocess mode."""

def __init__(self, tests_iter):
def __init__(self, tests_iter: Iterator[tuple[str, ShardInfo|None]]):
self.lock = threading.Lock()
self.tests_iter = tests_iter

Expand Down Expand Up @@ -215,12 +230,17 @@ def mp_result_error(
test_result.duration_sec = time.monotonic() - self.start_time
return MultiprocessResult(test_result, stdout, err_msg)

def _run_process(self, test_name: str, tmp_dir: str, stdout_fh: TextIO) -> int:
def _run_process(self, test_name: str, tmp_dir: str, stdout_fh: TextIO,
shard: ShardInfo|None = None) -> int:
self.start_time = time.monotonic()

self.current_test_name = test_name
if shard:
self.current_test_name = f'{test_name}-shard-{shard.number:02}/{shard.total_shards-1:02}'
else:
self.current_test_name = test_name
try:
popen = run_test_in_subprocess(test_name, self.ns, tmp_dir, stdout_fh)
popen = run_test_in_subprocess(
test_name, self.ns, tmp_dir, stdout_fh, shard)

self._killed = False
self._popen = popen
Expand All @@ -240,6 +260,17 @@ def _run_process(self, test_name: str, tmp_dir: str, stdout_fh: TextIO) -> int:
# gh-94026: stdout+stderr are written to tempfile
retcode = popen.wait(timeout=self.timeout)
assert retcode is not None
if shard and shard.status_file:
if os.path.exists(shard.status_file):
try:
os.unlink(shard.status_file)
except IOError:
pass
else:
print_warning(
f"{self.current_test_name} process exited "
f"{retcode} without touching a shard status "
f"file. Does it really support sharding?")
return retcode
except subprocess.TimeoutExpired:
if self._stopped:
Expand Down Expand Up @@ -269,7 +300,7 @@ def _run_process(self, test_name: str, tmp_dir: str, stdout_fh: TextIO) -> int:
self._popen = None
self.current_test_name = None

def _runtest(self, test_name: str) -> MultiprocessResult:
def _runtest(self, test_name: str, shard: ShardInfo|None) -> MultiprocessResult:
if sys.platform == 'win32':
# gh-95027: When stdout is not a TTY, Python uses the ANSI code
# page for the sys.stdout encoding. If the main process runs in a
Expand All @@ -290,7 +321,7 @@ def _runtest(self, test_name: str) -> MultiprocessResult:
tmp_dir = tempfile.mkdtemp(prefix="test_python_")
tmp_dir = os.path.abspath(tmp_dir)
try:
retcode = self._run_process(test_name, tmp_dir, stdout_fh)
retcode = self._run_process(test_name, tmp_dir, stdout_fh, shard)
finally:
tmp_files = os.listdir(tmp_dir)
os_helper.rmtree(tmp_dir)
Expand Down Expand Up @@ -335,11 +366,11 @@ def run(self) -> None:
while not self._stopped:
try:
try:
test_name = next(self.pending)
test_name, shard_info = next(self.pending)
except StopIteration:
break

mp_result = self._runtest(test_name)
mp_result = self._runtest(test_name, shard_info)
self.output.put((False, mp_result))

if must_stop(mp_result.result, self.ns):
Expand Down Expand Up @@ -402,8 +433,19 @@ def __init__(self, regrtest: Regrtest) -> None:
self.regrtest = regrtest
self.log = self.regrtest.log
self.ns = regrtest.ns
self.num_procs: int = self.ns.use_mp
self.output: queue.Queue[QueueOutput] = queue.Queue()
self.pending = MultiprocessIterator(self.regrtest.tests)
tests_and_shards = []
for test in self.regrtest.tests:
if self.num_procs > 2 and test in self.regrtest.tests_to_shard:
# Split shardable tests across multiple processes to run
# distinct subsets of tests within a given test module.
shards = min(self.num_procs//2+1, 8) # avoid diminishing returns
for shard_no in range(shards):
tests_and_shards.append((test, ShardInfo(shard_no, shards)))
else:
tests_and_shards.append((test, None))
self.pending = MultiprocessIterator(iter(tests_and_shards))
if self.ns.timeout is not None:
# Rely on faulthandler to kill a worker process. This timouet is
# when faulthandler fails to kill a worker process. Give a maximum
Expand All @@ -416,7 +458,7 @@ def __init__(self, regrtest: Regrtest) -> None:

def start_workers(self) -> None:
self.workers = [TestWorkerProcess(index, self)
for index in range(1, self.ns.use_mp + 1)]
for index in range(1, self.num_procs + 1)]
msg = f"Run tests in parallel using {len(self.workers)} child processes"
if self.ns.timeout:
msg += (" (timeout: %s, worker timeout: %s)"
Expand Down
64 changes: 61 additions & 3 deletions Lib/unittest/loader.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Loading unittests."""

import itertools
import functools
import os
import re
import sys
import traceback
import types
import functools
import warnings

from fnmatch import fnmatch, fnmatchcase
Expand Down Expand Up @@ -63,7 +64,7 @@ def _jython_aware_splitext(path):
return os.path.splitext(path)[0]


class TestLoader(object):
class TestLoader:
"""
This class is responsible for loading tests according to various criteria
and returning them wrapped in a TestSuite
Expand All @@ -73,6 +74,43 @@ class TestLoader(object):
testNamePatterns = None
suiteClass = suite.TestSuite
_top_level_dir = None
_sharding_setup_complete = False
_shard_bucket_iterator = None
_shard_index = None

def __new__(cls, *args, **kwargs):
new_instance = super().__new__(cls, *args, **kwargs)
if cls._sharding_setup_complete:
return new_instance
# This assumes single threaded TestLoader construction.
cls._sharding_setup_complete = True

# It may be useful to write the shard file even if the other sharding
# environment variables are not set. Test runners may use this functionality
# to query whether a test binary implements the test sharding protocol.
if 'TEST_SHARD_STATUS_FILE' in os.environ:
status_name = os.environ['TEST_SHARD_STATUS_FILE']
try:
with open(status_name, 'w') as f:
f.write('')
except IOError as error:
raise RuntimeError(
f'Error opening TEST_SHARD_STATUS_FILE {status_name=}.')

if 'TEST_TOTAL_SHARDS' not in os.environ:
# Not using sharding? nothing more to do.
return new_instance

total_shards = int(os.environ['TEST_TOTAL_SHARDS'])
cls._shard_index = int(os.environ['TEST_SHARD_INDEX'])

if cls._shard_index < 0 or cls._shard_index >= total_shards:
raise RuntimeError(
'ERROR: Bad sharding values. '
f'index={cls._shard_index}, {total_shards=}')

cls._shard_bucket_iterator = itertools.cycle(range(total_shards))
return new_instance

def __init__(self):
super(TestLoader, self).__init__()
Expand Down Expand Up @@ -198,8 +236,28 @@ def loadTestsFromNames(self, names, module=None):
suites = [self.loadTestsFromName(name, module) for name in names]
return self.suiteClass(suites)

def _getShardedTestCaseNames(self, testCaseClass):
filtered_names = []
# We need to sort the list of tests in order to determine which tests this
# shard is responsible for; however, it's important to preserve the order
# returned by the base loader, e.g. in the case of randomized test ordering.
ordered_names = self._getTestCaseNames(testCaseClass)
for testcase in sorted(ordered_names):
bucket = next(self._shard_bucket_iterator)
if bucket == self._shard_index:
filtered_names.append(testcase)
return [x for x in ordered_names if x in filtered_names]

def getTestCaseNames(self, testCaseClass):
"""Return a sorted sequence of method names found within testCaseClass
"""Return a sorted sequence of method names found within testCaseClass.
Or a unique sharded subset thereof if sharding is enabled.
"""
if self._shard_bucket_iterator:
return self._getShardedTestCaseNames(testCaseClass)
return self._getTestCaseNames(testCaseClass)

def _getTestCaseNames(self, testCaseClass):
"""Return a sorted sequence of all method names found within testCaseClass.
"""
def shouldIncludeMethod(attrname):
if not attrname.startswith(self.testMethodPrefix):
Expand Down