Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 74 additions & 38 deletions Lib/multiprocessing/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,9 @@ class Pool(object):
'''
_wrap_exception = True

def Process(self, *args, **kwds):
return self._ctx.Process(*args, **kwds)
@staticmethod
def Process(ctx, *args, **kwds):
return ctx.Process(*args, **kwds)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did this need to change? I feel like pool.Process existed solely as a convenience method, and this breaks any code that used it. Instead of changing this function, you could just use ctx.Process instead of self.Process below


def __init__(self, processes=None, initializer=None, initargs=(),
maxtasksperchild=None, context=None):
Expand Down Expand Up @@ -190,7 +191,10 @@ def __init__(self, processes=None, initializer=None, initargs=(),

self._worker_handler = threading.Thread(
target=Pool._handle_workers,
args=(self, )
args=(self._cache, self._taskqueue, self._ctx, self.Process,
self._processes, self._pool, self._inqueue, self._outqueue,
self._initializer, self._initargs, self._maxtasksperchild,
self._wrap_exception)
)
self._worker_handler.daemon = True
self._worker_handler._state = RUN
Expand Down Expand Up @@ -236,43 +240,61 @@ def __repr__(self):
f'state={self._state} '
f'pool_size={len(self._pool)}>')

def _join_exited_workers(self):
@staticmethod
def _join_exited_workers(pool):
"""Cleanup after any worker processes which have exited due to reaching
their specified lifetime. Returns True if any workers were cleaned up.
"""
cleaned = False
for i in reversed(range(len(self._pool))):
worker = self._pool[i]
for i in reversed(range(len(pool))):
worker = pool[i]
if worker.exitcode is not None:
# worker exited
util.debug('cleaning up worker %d' % i)
worker.join()
cleaned = True
del self._pool[i]
del pool[i]
return cleaned

def _repopulate_pool(self):
return self._repopulate_pool_static(self._ctx, self.Process,
self._processes,
self._pool, self._inqueue,
self._outqueue, self._initializer,
self._initargs,
self._maxtasksperchild,
self._wrap_exception)

@staticmethod
def _repopulate_pool_static(ctx, Process, processes, pool, inqueue,
outqueue, initializer, initargs,
maxtasksperchild, wrap_exception):
"""Bring the number of pool processes up to the specified number,
for use after reaping workers which have exited.
"""
for i in range(self._processes - len(self._pool)):
w = self.Process(target=worker,
args=(self._inqueue, self._outqueue,
self._initializer,
self._initargs, self._maxtasksperchild,
self._wrap_exception)
)
for i in range(processes - len(pool)):
w = Process(ctx, target=worker,
args=(inqueue, outqueue,
initializer,
initargs, maxtasksperchild,
wrap_exception))
w.name = w.name.replace('Process', 'PoolWorker')
w.daemon = True
w.start()
self._pool.append(w)
pool.append(w)
util.debug('added worker')

def _maintain_pool(self):
@staticmethod
def _maintain_pool(ctx, Process, processes, pool, inqueue, outqueue,
initializer, initargs, maxtasksperchild,
wrap_exception):
"""Clean up any exited workers and start replacements for them.
"""
if self._join_exited_workers():
self._repopulate_pool()
if Pool._join_exited_workers(pool):
Pool._repopulate_pool_static(ctx, Process, processes, pool,
inqueue, outqueue, initializer,
initargs, maxtasksperchild,
wrap_exception)

def _setup_queues(self):
self._inqueue = self._ctx.SimpleQueue()
Expand Down Expand Up @@ -331,7 +353,7 @@ def imap(self, func, iterable, chunksize=1):
'''
self._check_running()
if chunksize == 1:
result = IMapIterator(self._cache)
result = IMapIterator(self)
self._taskqueue.put(
(
self._guarded_task_generation(result._job, func, iterable),
Expand All @@ -344,7 +366,7 @@ def imap(self, func, iterable, chunksize=1):
"Chunksize must be 1+, not {0:n}".format(
chunksize))
task_batches = Pool._get_tasks(func, iterable, chunksize)
result = IMapIterator(self._cache)
result = IMapIterator(self)
self._taskqueue.put(
(
self._guarded_task_generation(result._job,
Expand All @@ -360,7 +382,7 @@ def imap_unordered(self, func, iterable, chunksize=1):
'''
self._check_running()
if chunksize == 1:
result = IMapUnorderedIterator(self._cache)
result = IMapUnorderedIterator(self)
self._taskqueue.put(
(
self._guarded_task_generation(result._job, func, iterable),
Expand All @@ -372,7 +394,7 @@ def imap_unordered(self, func, iterable, chunksize=1):
raise ValueError(
"Chunksize must be 1+, not {0!r}".format(chunksize))
task_batches = Pool._get_tasks(func, iterable, chunksize)
result = IMapUnorderedIterator(self._cache)
result = IMapUnorderedIterator(self)
self._taskqueue.put(
(
self._guarded_task_generation(result._job,
Expand All @@ -388,7 +410,7 @@ def apply_async(self, func, args=(), kwds={}, callback=None,
Asynchronous version of `apply()` method.
'''
self._check_running()
result = ApplyResult(self._cache, callback, error_callback)
result = ApplyResult(self, callback, error_callback)
self._taskqueue.put(([(result._job, 0, func, args, kwds)], None))
return result

Expand Down Expand Up @@ -417,7 +439,7 @@ def _map_async(self, func, iterable, mapper, chunksize=None, callback=None,
chunksize = 0

task_batches = Pool._get_tasks(func, iterable, chunksize)
result = MapResult(self._cache, chunksize, len(iterable), callback,
result = MapResult(self, chunksize, len(iterable), callback,
error_callback=error_callback)
self._taskqueue.put(
(
Expand All @@ -430,16 +452,20 @@ def _map_async(self, func, iterable, mapper, chunksize=None, callback=None,
return result

@staticmethod
def _handle_workers(pool):
def _handle_workers(cache, taskqueue, ctx, Process, processes, pool,
inqueue, outqueue, initializer, initargs,
maxtasksperchild, wrap_exception):
thread = threading.current_thread()

# Keep maintaining workers until the cache gets drained, unless the pool
# is terminated.
while thread._state == RUN or (pool._cache and thread._state != TERMINATE):
pool._maintain_pool()
while thread._state == RUN or (cache and thread._state != TERMINATE):
Pool._maintain_pool(ctx, Process, processes, pool, inqueue,
outqueue, initializer, initargs,
maxtasksperchild, wrap_exception)
time.sleep(0.1)
# send sentinel to stop workers
pool._taskqueue.put(None)
taskqueue.put(None)
util.debug('worker handler exiting')

@staticmethod
Expand Down Expand Up @@ -656,13 +682,14 @@ def __exit__(self, exc_type, exc_val, exc_tb):

class ApplyResult(object):

def __init__(self, cache, callback, error_callback):
def __init__(self, pool, callback, error_callback):
self._pool = pool
self._event = threading.Event()
self._job = next(job_counter)
self._cache = cache
self._cache = pool._cache
self._callback = callback
self._error_callback = error_callback
cache[self._job] = self
self._cache[self._job] = self

def ready(self):
return self._event.is_set()
Expand Down Expand Up @@ -692,6 +719,7 @@ def _set(self, i, obj):
self._error_callback(self._value)
self._event.set()
del self._cache[self._job]
self._pool = None

AsyncResult = ApplyResult # create alias -- see #17805

Expand All @@ -701,16 +729,16 @@ def _set(self, i, obj):

class MapResult(ApplyResult):

def __init__(self, cache, chunksize, length, callback, error_callback):
ApplyResult.__init__(self, cache, callback,
def __init__(self, pool, chunksize, length, callback, error_callback):
ApplyResult.__init__(self, pool, callback,
error_callback=error_callback)
self._success = True
self._value = [None] * length
self._chunksize = chunksize
if chunksize <= 0:
self._number_left = 0
self._event.set()
del cache[self._job]
del self._cache[self._job]
else:
self._number_left = length//chunksize + bool(length % chunksize)

Expand All @@ -724,6 +752,7 @@ def _set(self, i, success_result):
self._callback(self._value)
del self._cache[self._job]
self._event.set()
self._pool = None
else:
if not success and self._success:
# only store first exception
Expand All @@ -735,22 +764,24 @@ def _set(self, i, success_result):
self._error_callback(self._value)
del self._cache[self._job]
self._event.set()
self._pool = None

#
# Class whose instances are returned by `Pool.imap()`
#

class IMapIterator(object):

def __init__(self, cache):
def __init__(self, pool):
self._pool = pool
self._cond = threading.Condition(threading.Lock())
self._job = next(job_counter)
self._cache = cache
self._cache = pool._cache
self._items = collections.deque()
self._index = 0
self._length = None
self._unsorted = {}
cache[self._job] = self
self._cache[self._job] = self

def __iter__(self):
return self
Expand All @@ -761,12 +792,14 @@ def next(self, timeout=None):
item = self._items.popleft()
except IndexError:
if self._index == self._length:
self._pool = None
raise StopIteration from None
self._cond.wait(timeout)
try:
item = self._items.popleft()
except IndexError:
if self._index == self._length:
self._pool = None
raise StopIteration from None
raise TimeoutError from None

Expand All @@ -792,13 +825,15 @@ def _set(self, i, obj):

if self._index == self._length:
del self._cache[self._job]
self._pool = None

def _set_length(self, length):
with self._cond:
self._length = length
if self._index == self._length:
self._cond.notify()
del self._cache[self._job]
self._pool = None

#
# Class whose instances are returned by `Pool.imap_unordered()`
Expand All @@ -813,6 +848,7 @@ def _set(self, i, obj):
self._cond.notify()
if self._index == self._length:
del self._cache[self._job]
self._pool = None

#
#
Expand All @@ -822,7 +858,7 @@ class ThreadPool(Pool):
_wrap_exception = False

@staticmethod
def Process(*args, **kwds):
def Process(ctx, *args, **kwds):
from .dummy import Process
return Process(*args, **kwds)

Expand Down
1 change: 0 additions & 1 deletion Lib/test/_test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2593,7 +2593,6 @@ def test_resource_warning(self):
pool = None
support.gc_collect()


def raising():
raise KeyError("key")

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Fix a reference issue inside :class:`multiprocessing.Pool` that caused
the pool to remain alive if it was deleted without being closed or
terminated explicitly. A new strong reference is added to the pool
iterators to link the lifetime of the pool to the lifetime of its
iterators so the pool does not get destroyed if a pool iterator is
still alive.