Skip to content

Commit 299789b

Browse files
committed
tests: test kwargs for keras, dask
1 parent 2a82405 commit 299789b

2 files changed

Lines changed: 57 additions & 67 deletions

File tree

tests/tests_dask.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ def test_dask(capsys):
1313
dask = importorskip('dask')
1414

1515
schedule = [dask.delayed(sleep)(i / 10) for i in range(5)]
16-
with ProgressBar():
16+
with ProgressBar(desc="computing"):
1717
dask.compute(schedule)
1818
_, err = capsys.readouterr()
19+
assert "computing: " in err
1920
assert '5/5' in err

tests/tests_keras.py

Lines changed: 55 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
from __future__ import division
22

3-
from tqdm import tqdm
4-
5-
from .tests_tqdm import StringIO, closing, importorskip, mark
3+
from .tests_tqdm import importorskip, mark
64

75
pytestmark = mark.slow
86

97

108
@mark.filterwarnings("ignore:.*:DeprecationWarning")
11-
def test_keras():
9+
def test_keras(capsys):
1210
"""Test tqdm.keras.TqdmCallback"""
1311
TqdmCallback = importorskip('tqdm.keras').TqdmCallback
1412
np = importorskip('numpy')
@@ -27,67 +25,58 @@ def test_keras():
2725
batches = len(x) / batch_size
2826
epochs = 5
2927

30-
with closing(StringIO()) as our_file:
31-
32-
class Tqdm(tqdm):
33-
"""redirected I/O class"""
34-
def __init__(self, *a, **k):
35-
k.setdefault("file", our_file)
36-
super(Tqdm, self).__init__(*a, **k)
37-
38-
# just epoch (no batch) progress
39-
model.fit(
40-
x,
41-
x,
42-
epochs=epochs,
43-
batch_size=batch_size,
44-
verbose=False,
45-
callbacks=[
46-
TqdmCallback(
47-
epochs,
48-
data_size=len(x),
49-
batch_size=batch_size,
50-
verbose=0,
51-
tqdm_class=Tqdm,
52-
)],
53-
)
54-
res = our_file.getvalue()
55-
assert "{epochs}/{epochs}".format(epochs=epochs) in res
56-
assert "{batches}/{batches}".format(batches=batches) not in res
28+
# just epoch (no batch) progress
29+
model.fit(
30+
x,
31+
x,
32+
epochs=epochs,
33+
batch_size=batch_size,
34+
verbose=False,
35+
callbacks=[
36+
TqdmCallback(
37+
epochs,
38+
desc="training",
39+
data_size=len(x),
40+
batch_size=batch_size,
41+
verbose=0,
42+
)],
43+
)
44+
_, res = capsys.readouterr()
45+
assert "training: " in res
46+
assert "{epochs}/{epochs}".format(epochs=epochs) in res
47+
assert "{batches}/{batches}".format(batches=batches) not in res
5748

58-
# full (epoch and batch) progress
59-
our_file.seek(0)
60-
our_file.truncate()
61-
model.fit(
62-
x,
63-
x,
64-
epochs=epochs,
65-
batch_size=batch_size,
66-
verbose=False,
67-
callbacks=[
68-
TqdmCallback(
69-
epochs,
70-
data_size=len(x),
71-
batch_size=batch_size,
72-
verbose=2,
73-
tqdm_class=Tqdm,
74-
)],
75-
)
76-
res = our_file.getvalue()
77-
assert "{epochs}/{epochs}".format(epochs=epochs) in res
78-
assert "{batches}/{batches}".format(batches=batches) in res
49+
# full (epoch and batch) progress
50+
model.fit(
51+
x,
52+
x,
53+
epochs=epochs,
54+
batch_size=batch_size,
55+
verbose=False,
56+
callbacks=[
57+
TqdmCallback(
58+
epochs,
59+
desc="training",
60+
data_size=len(x),
61+
batch_size=batch_size,
62+
verbose=2,
63+
)],
64+
)
65+
_, res = capsys.readouterr()
66+
assert "training: " in res
67+
assert "{epochs}/{epochs}".format(epochs=epochs) in res
68+
assert "{batches}/{batches}".format(batches=batches) in res
7969

80-
# auto-detect epochs and batches
81-
our_file.seek(0)
82-
our_file.truncate()
83-
model.fit(
84-
x,
85-
x,
86-
epochs=epochs,
87-
batch_size=batch_size,
88-
verbose=False,
89-
callbacks=[TqdmCallback(verbose=2, tqdm_class=Tqdm)],
90-
)
91-
res = our_file.getvalue()
92-
assert "{epochs}/{epochs}".format(epochs=epochs) in res
93-
assert "{batches}/{batches}".format(batches=batches) in res
70+
# auto-detect epochs and batches
71+
model.fit(
72+
x,
73+
x,
74+
epochs=epochs,
75+
batch_size=batch_size,
76+
verbose=False,
77+
callbacks=[TqdmCallback(desc="training", verbose=2)],
78+
)
79+
_, res = capsys.readouterr()
80+
assert "training: " in res
81+
assert "{epochs}/{epochs}".format(epochs=epochs) in res
82+
assert "{batches}/{batches}".format(batches=batches) in res

0 commit comments

Comments
 (0)