-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathbase.py
More file actions
38 lines (29 loc) · 984 Bytes
/
base.py
File metadata and controls
38 lines (29 loc) · 984 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import abc
from abc import ABC
from types import ModuleType
import numpy.testing as npt
from xarray.namedarray._typing import duckarray
class DuckArrayTestMixin(ABC):
@property
@abc.abstractmethod
def xp() -> ModuleType:
pass
@staticmethod
def array_type(op: str) -> type[duckarray]:
pass
@staticmethod
@abc.abstractmethod
def array_strategy_fn(*, shape, dtype):
raise NotImplementedError("has to be overridden")
@staticmethod
def assert_equal(a, b):
npt.assert_equal(a, b)
def assert_dimension_indexers_equal(self, a, b):
assert type(a) is type(b), f"types don't match: {type(a)} vs {type(b)}"
if isinstance(a, dict):
assert a.keys() == b.keys(), f"Different dimensions: {list(a)} vs {list(b)}"
assert all(
self.xp.all(self.xp.equal(a[k], b[k])) for k in a
), "Differing indexers"
else:
npt.assert_equal(a, b)