diff --git a/tests/snippets/set.py b/tests/snippets/set.py index 84754bca81..a7a20b0efb 100644 --- a/tests/snippets/set.py +++ b/tests/snippets/set.py @@ -1,4 +1,4 @@ -from testutils import assert_raises +from testutils import assert_raises, assertRaises assert set([1,2]) == set([1,2]) assert not set([1,2,3]) == set([1,2]) @@ -109,12 +109,8 @@ def __hash__(self): a = set([1,2,3]) a |= set([3,4,5]) assert a == set([1,2,3,4,5]) -try: +with assertRaises(TypeError): a |= 1 -except TypeError: - pass -else: - assert False, "TypeError not raised" a = set([1,2,3]) a.intersection_update([2,3,4,5]) @@ -124,12 +120,8 @@ def __hash__(self): a = set([1,2,3]) a &= set([2,3,4,5]) assert a == set([2,3]) -try: +with assertRaises(TypeError): a &= 1 -except TypeError: - pass -else: - assert False, "TypeError not raised" a = set([1,2,3]) a.difference_update([3,4,5]) @@ -139,12 +131,8 @@ def __hash__(self): a = set([1,2,3]) a -= set([3,4,5]) assert a == set([1,2]) -try: +with assertRaises(TypeError): a -= 1 -except TypeError: - pass -else: - assert False, "TypeError not raised" a = set([1,2,3]) a.symmetric_difference_update([3,4,5]) @@ -154,9 +142,5 @@ def __hash__(self): a = set([1,2,3]) a ^= set([3,4,5]) assert a == set([1,2,4,5]) -try: +with assertRaises(TypeError): a ^= 1 -except TypeError: - pass -else: - assert False, "TypeError not raised" diff --git a/tests/snippets/testutils.py b/tests/snippets/testutils.py index 3c039f4f1e..8237ceb621 100644 --- a/tests/snippets/testutils.py +++ b/tests/snippets/testutils.py @@ -18,3 +18,19 @@ def assert_raises(exc_type, expr, msg=None): if msg is not None: failmsg += ': {!s}'.format(msg) assert False, failmsg + + +class assertRaises: + def __init__(self, expected): + self.expected = expected + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is None: + failmsg = '{!s} was not raised'.format(self.expected.__name_) + assert False, failmsg + if not issubclass(exc_type, self.expected): + return False + return True