Skip to content

Commit f16c888

Browse files
committed
add dask submodule
- fixes #278 - replaces/closes #279 - based on `tqdm.keras`
1 parent 4ff7753 commit f16c888

1 file changed

Lines changed: 45 additions & 0 deletions

File tree

tqdm/dask.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from __future__ import absolute_import
2+
from .auto import tqdm as tqdm_auto
3+
from functools import partial
4+
from dask.callbacks import Callback
5+
__author__ = {"github.com/": ["casperdcl"]}
6+
__all__ = ['TqdmCallback']
7+
8+
9+
class TqdmCallback(Callback):
10+
"""`dask` callback for task progress"""
11+
def __init__(self, start=None, start_state=None, pretask=None,
12+
posttask=None, finish=None, tqdm_class=tqdm_auto,
13+
**tqdm_kwargs):
14+
"""
15+
Parameters
16+
----------
17+
tqdm_class : optional
18+
`tqdm` class to use for bars [default: `tqdm.auto.tqdm`].
19+
tqdm_kwargs : optional
20+
Any other arguments used for all bars.
21+
"""
22+
super(TqdmCallback, self).__init__(
23+
start=start, start_state=start_state, pretask=pretask,
24+
posttask=posttask, finish=finish)
25+
if tqdm_kwargs:
26+
tqdm_class = partial(tqdm_class, **tqdm_kwargs)
27+
self.tqdm_class = tqdm_class
28+
29+
def _start_state(self, _, state):
30+
self.pbar = self.tqdm_class(total=sum(
31+
len(state[k]) for k in ['ready', 'waiting', 'running', 'finished']))
32+
33+
def _posttask(self, *args, **kwargs):
34+
self.pbar.update()
35+
36+
def _finish(self, *args, **kwargs):
37+
self.pbar.close()
38+
39+
def display(self):
40+
"""displays in the current cell in Notebooks"""
41+
container = getattr(self.bar, 'container', None)
42+
if container is None:
43+
return
44+
from .notebook import display
45+
display(container)

0 commit comments

Comments
 (0)