|
| 1 | +from __future__ import absolute_import |
| 2 | +from .auto import tqdm as tqdm_auto |
| 3 | +from functools import partial |
| 4 | +from dask.callbacks import Callback |
| 5 | +__author__ = {"github.com/": ["casperdcl"]} |
| 6 | +__all__ = ['TqdmCallback'] |
| 7 | + |
| 8 | + |
| 9 | +class TqdmCallback(Callback): |
| 10 | + """`dask` callback for task progress""" |
| 11 | + def __init__(self, start=None, start_state=None, pretask=None, |
| 12 | + posttask=None, finish=None, tqdm_class=tqdm_auto, |
| 13 | + **tqdm_kwargs): |
| 14 | + """ |
| 15 | + Parameters |
| 16 | + ---------- |
| 17 | + tqdm_class : optional |
| 18 | + `tqdm` class to use for bars [default: `tqdm.auto.tqdm`]. |
| 19 | + tqdm_kwargs : optional |
| 20 | + Any other arguments used for all bars. |
| 21 | + """ |
| 22 | + super(TqdmCallback, self).__init__( |
| 23 | + start=start, start_state=start_state, pretask=pretask, |
| 24 | + posttask=posttask, finish=finish) |
| 25 | + if tqdm_kwargs: |
| 26 | + tqdm_class = partial(tqdm_class, **tqdm_kwargs) |
| 27 | + self.tqdm_class = tqdm_class |
| 28 | + |
| 29 | + def _start_state(self, _, state): |
| 30 | + self.pbar = self.tqdm_class(total=sum( |
| 31 | + len(state[k]) for k in ['ready', 'waiting', 'running', 'finished'])) |
| 32 | + |
| 33 | + def _posttask(self, *args, **kwargs): |
| 34 | + self.pbar.update() |
| 35 | + |
| 36 | + def _finish(self, *args, **kwargs): |
| 37 | + self.pbar.close() |
| 38 | + |
| 39 | + def display(self): |
| 40 | + """displays in the current cell in Notebooks""" |
| 41 | + container = getattr(self.bar, 'container', None) |
| 42 | + if container is None: |
| 43 | + return |
| 44 | + from .notebook import display |
| 45 | + display(container) |
0 commit comments