forked from astropy/astropy
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathstructured.py
More file actions
577 lines (474 loc) · 20.1 KB
/
structured.py
File metadata and controls
577 lines (474 loc) · 20.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
This module defines structured units and quantities.
"""
# Standard library
import operator
from functools import cached_property
import numpy as np
from astropy.utils.compat.numpycompat import NUMPY_LT_1_24
from .core import UNITY, Unit, UnitBase
__all__ = ["StructuredUnit"]
DTYPE_OBJECT = np.dtype("O")
def _names_from_dtype(dtype):
"""Recursively extract field names from a dtype."""
names = []
for name in dtype.names:
subdtype = dtype.fields[name][0].base
if subdtype.names:
names.append([name, _names_from_dtype(subdtype)])
else:
names.append(name)
return tuple(names)
def _normalize_names(names):
"""Recursively normalize, inferring upper level names for unadorned tuples.
Generally, we want the field names to be organized like dtypes, as in
``(['pv', ('p', 'v')], 't')``. But we automatically infer upper
field names if the list is absent from items like ``(('p', 'v'), 't')``,
by concatenating the names inside the tuple.
"""
result = []
for name in names:
if isinstance(name, str) and len(name) > 0:
result.append(name)
elif (
isinstance(name, list)
and len(name) == 2
and isinstance(name[0], str)
and len(name[0]) > 0
and isinstance(name[1], tuple)
and len(name[1]) > 0
):
result.append([name[0], _normalize_names(name[1])])
elif isinstance(name, tuple) and len(name) > 0:
new_tuple = _normalize_names(name)
name = "".join([(i[0] if isinstance(i, list) else i) for i in new_tuple])
result.append([name, new_tuple])
else:
raise ValueError(
f"invalid entry {name!r}. Should be a name, "
"tuple of names, or 2-element list of the "
"form [name, tuple of names]."
)
return tuple(result)
class StructuredUnit:
"""Container for units for a structured Quantity.
Parameters
----------
units : unit-like, tuple of unit-like, or `~astropy.units.StructuredUnit`
Tuples can be nested. If a `~astropy.units.StructuredUnit` is passed
in, it will be returned unchanged unless different names are requested.
names : tuple of str, tuple or list; `~numpy.dtype`; or `~astropy.units.StructuredUnit`, optional
Field names for the units, possibly nested. Can be inferred from a
structured `~numpy.dtype` or another `~astropy.units.StructuredUnit`.
For nested tuples, by default the name of the upper entry will be the
concatenation of the names of the lower levels. One can pass in a
list with the upper-level name and a tuple of lower-level names to
avoid this. For tuples, not all levels have to be given; for any level
not passed in, default field names of 'f0', 'f1', etc., will be used.
Notes
-----
It is recommended to initialize the class indirectly, using
`~astropy.units.Unit`. E.g., ``u.Unit('AU,AU/day')``.
When combined with a structured array to produce a structured
`~astropy.units.Quantity`, array field names will take precedence.
Generally, passing in ``names`` is needed only if the unit is used
unattached to a `~astropy.units.Quantity` and one needs to access its
fields.
Examples
--------
Various ways to initialize a `~astropy.units.StructuredUnit`::
>>> import astropy.units as u
>>> su = u.Unit('(AU,AU/day),yr')
>>> su
Unit("((AU, AU / d), yr)")
>>> su.field_names
(['f0', ('f0', 'f1')], 'f1')
>>> su['f1']
Unit("yr")
>>> su2 = u.StructuredUnit(((u.AU, u.AU/u.day), u.yr), names=(('p', 'v'), 't'))
>>> su2 == su
True
>>> su2.field_names
(['pv', ('p', 'v')], 't')
>>> su3 = u.StructuredUnit((su2['pv'], u.day), names=(['p_v', ('p', 'v')], 't'))
>>> su3.field_names
(['p_v', ('p', 'v')], 't')
>>> su3.keys()
('p_v', 't')
>>> su3.values()
(Unit("(AU, AU / d)"), Unit("d"))
Structured units share most methods with regular units::
>>> su.physical_type
astropy.units.structured.Structure((astropy.units.structured.Structure((PhysicalType('length'), PhysicalType({'speed', 'velocity'})), dtype=[('f0', 'O'), ('f1', 'O')]), PhysicalType('time')), dtype=[('f0', 'O'), ('f1', 'O')])
>>> su.si
Unit("((1.49598e+11 m, 1.73146e+06 m / s), 3.15576e+07 s)")
"""
def __new__(cls, units, names=None):
dtype = None
if names is not None:
if isinstance(names, StructuredUnit):
dtype = names._units.dtype
names = names.field_names
elif isinstance(names, np.dtype):
if not names.fields:
raise ValueError("dtype should be structured, with fields.")
dtype = np.dtype([(name, DTYPE_OBJECT) for name in names.names])
names = _names_from_dtype(names)
else:
if not isinstance(names, tuple):
names = (names,)
names = _normalize_names(names)
if not isinstance(units, tuple):
units = Unit(units)
if isinstance(units, StructuredUnit):
# Avoid constructing a new StructuredUnit if no field names
# are given, or if all field names are the same already anyway.
if names is None or units.field_names == names:
return units
# Otherwise, turn (the upper level) into a tuple, for renaming.
units = units.values()
else:
# Single regular unit: make a tuple for iteration below.
units = (units,)
if names is None:
names = tuple(f"f{i}" for i in range(len(units)))
elif len(units) != len(names):
raise ValueError("lengths of units and field names must match.")
converted = []
for unit, name in zip(units, names):
if isinstance(name, list):
# For list, the first item is the name of our level,
# and the second another tuple of names, i.e., we recurse.
unit = cls(unit, name[1])
name = name[0]
else:
# We are at the lowest level. Check unit.
unit = Unit(unit)
if dtype is not None and isinstance(unit, StructuredUnit):
raise ValueError(
"units do not match in depth with field "
"names from dtype or structured unit."
)
converted.append(unit)
self = super().__new__(cls)
if dtype is None:
dtype = np.dtype(
[
((name[0] if isinstance(name, list) else name), DTYPE_OBJECT)
for name in names
]
)
# Decay array to void so we can access by field name and number.
self._units = np.array(tuple(converted), dtype)[()]
return self
def __getnewargs__(self):
"""When de-serializing, e.g. pickle, start with a blank structure."""
return (), None
@property
def field_names(self):
"""Possibly nested tuple of the field names of the parts."""
return tuple(
([name, unit.field_names] if isinstance(unit, StructuredUnit) else name)
for name, unit in self.items()
)
# Allow StructuredUnit to be treated as an (ordered) mapping.
def __len__(self):
return len(self._units.dtype.names)
def __getitem__(self, item):
# Since we are based on np.void, indexing by field number works too.
return self._units[item]
def values(self):
return self._units.item()
def keys(self):
return self._units.dtype.names
def items(self):
return tuple(zip(self._units.dtype.names, self._units.item()))
def __iter__(self):
yield from self._units.dtype.names
# Helpers for methods below.
def _recursively_apply(self, func, cls=None):
"""Apply func recursively.
Parameters
----------
func : callable
Function to apply to all parts of the structured unit,
recursing as needed.
cls : type, optional
If given, should be a subclass of `~numpy.void`. By default,
will return a new `~astropy.units.StructuredUnit` instance.
"""
applied = tuple(func(part) for part in self.values())
if NUMPY_LT_1_24:
results = np.array(applied, self._units.dtype)[()]
else:
results = np.void(applied, self._units.dtype)
if cls is not None:
return results.view((cls, results.dtype))
# Short-cut; no need to interpret field names, etc.
result = super().__new__(self.__class__)
result._units = results
return result
def _recursively_get_dtype(self, value, enter_lists=True):
"""Get structured dtype according to value, using our field names.
This is useful since ``np.array(value)`` would treat tuples as lower
levels of the array, rather than as elements of a structured array.
The routine does presume that the type of the first tuple is
representative of the rest. Used in ``get_converter``.
For the special value of ``UNITY``, all fields are assumed to be 1.0,
and hence this will return an all-float dtype.
"""
if enter_lists:
while isinstance(value, list):
value = value[0]
if value is UNITY:
value = (UNITY,) * len(self)
elif not isinstance(value, tuple) or len(self) != len(value):
raise ValueError(f"cannot interpret value {value} for unit {self}.")
descr = []
for (name, unit), part in zip(self.items(), value):
if isinstance(unit, StructuredUnit):
descr.append(
(name, unit._recursively_get_dtype(part, enter_lists=False))
)
else:
# Got a part associated with a regular unit. Gets its dtype.
# Like for Quantity, we cast integers to float.
part = np.array(part)
part_dtype = part.dtype
if part_dtype.kind in "iu":
part_dtype = np.dtype(float)
descr.append((name, part_dtype, part.shape))
return np.dtype(descr)
@property
def si(self):
"""The `StructuredUnit` instance in SI units."""
return self._recursively_apply(operator.attrgetter("si"))
@property
def cgs(self):
"""The `StructuredUnit` instance in cgs units."""
return self._recursively_apply(operator.attrgetter("cgs"))
# Needed to pass through Unit initializer, so might as well use it.
@cached_property
def _physical_type_id(self):
return self._recursively_apply(
operator.attrgetter("_physical_type_id"), cls=Structure
)
@property
def physical_type(self):
"""Physical types of all the fields."""
return self._recursively_apply(
operator.attrgetter("physical_type"), cls=Structure
)
def decompose(self, bases=set()):
"""The `StructuredUnit` composed of only irreducible units.
Parameters
----------
bases : sequence of `~astropy.units.UnitBase`, optional
The bases to decompose into. When not provided,
decomposes down to any irreducible units. When provided,
the decomposed result will only contain the given units.
This will raises a `UnitsError` if it's not possible
to do so.
Returns
-------
`~astropy.units.StructuredUnit`
With the unit for each field containing only irreducible units.
"""
return self._recursively_apply(operator.methodcaller("decompose", bases=bases))
def is_equivalent(self, other, equivalencies=[]):
"""`True` if all fields are equivalent to the other's fields.
Parameters
----------
other : `~astropy.units.StructuredUnit`
The structured unit to compare with, or what can initialize one.
equivalencies : list of tuple, optional
A list of equivalence pairs to try if the units are not
directly convertible. See :ref:`unit_equivalencies`.
The list will be applied to all fields.
Returns
-------
bool
"""
try:
other = StructuredUnit(other)
except Exception:
return False
if len(self) != len(other):
return False
for self_part, other_part in zip(self.values(), other.values()):
if not self_part.is_equivalent(other_part, equivalencies=equivalencies):
return False
return True
def get_converter(self, other, equivalencies=[]):
if not isinstance(other, type(self)):
other = self.__class__(other, names=self)
converters = [
self_part.get_converter(other_part, equivalencies=equivalencies)
for (self_part, other_part) in zip(self.values(), other.values())
]
def converter(value):
if not hasattr(value, "dtype"):
value = np.array(value, self._recursively_get_dtype(value))
result = np.empty_like(value)
for name, converter_ in zip(result.dtype.names, converters):
result[name] = converter_(value[name])
# Index with empty tuple to decay array scalars to numpy void.
return result if result.shape else result[()]
return converter
get_converter.__doc__ = UnitBase.get_converter.__doc__
def to(self, other, value=np._NoValue, equivalencies=[]):
"""Return values converted to the specified unit.
Parameters
----------
other : `~astropy.units.StructuredUnit`
The unit to convert to. If necessary, will be converted to
a `~astropy.units.StructuredUnit` using the dtype of ``value``.
value : array-like, optional
Value(s) in the current unit to be converted to the
specified unit. If a sequence, the first element must have
entries of the correct type to represent all elements (i.e.,
not have, e.g., a ``float`` where other elements have ``complex``).
If not given, assumed to have 1. in all fields.
equivalencies : list of tuple, optional
A list of equivalence pairs to try if the units are not
directly convertible. See :ref:`unit_equivalencies`.
This list is in addition to possible global defaults set by, e.g.,
`set_enabled_equivalencies`.
Use `None` to turn off all equivalencies.
Returns
-------
values : scalar or array
Converted value(s).
Raises
------
UnitsError
If units are inconsistent
"""
if value is np._NoValue:
# We do not have UNITY as a default, since then the docstring
# would list 1.0 as default, yet one could not pass that in.
value = UNITY
return self.get_converter(other, equivalencies=equivalencies)(value)
def to_string(self, format="generic"):
"""Output the unit in the given format as a string.
Units are separated by commas.
Parameters
----------
format : `astropy.units.format.Base` subclass or str
The name of a format or a formatter class. If not
provided, defaults to the generic format.
Notes
-----
Structured units can be written to all formats, but can be
re-read only with 'generic'.
"""
parts = [part.to_string(format) for part in self.values()]
out_fmt = "({})" if len(self) > 1 else "({},)"
if format.startswith("latex"):
# Strip $ from parts and add them on the outside.
parts = [part[1:-1] for part in parts]
out_fmt = "$" + out_fmt + "$"
return out_fmt.format(", ".join(parts))
def _repr_latex_(self):
return self.to_string("latex")
__array_ufunc__ = None
def __mul__(self, other):
if isinstance(other, str):
try:
other = Unit(other, parse_strict="silent")
except Exception:
return NotImplemented
if isinstance(other, UnitBase):
new_units = tuple(part * other for part in self.values())
return self.__class__(new_units, names=self)
if isinstance(other, StructuredUnit):
return NotImplemented
# Anything not like a unit, try initialising as a structured quantity.
try:
from .quantity import Quantity
return Quantity(other, unit=self)
except Exception:
return NotImplemented
def __rmul__(self, other):
return self.__mul__(other)
def __truediv__(self, other):
if isinstance(other, str):
try:
other = Unit(other, parse_strict="silent")
except Exception:
return NotImplemented
if isinstance(other, UnitBase):
new_units = tuple(part / other for part in self.values())
return self.__class__(new_units, names=self)
return NotImplemented
def __rlshift__(self, m):
try:
from .quantity import Quantity
return Quantity(m, self, copy=False, subok=True)
except Exception:
return NotImplemented
def __str__(self):
return self.to_string()
def __repr__(self):
return f'Unit("{self.to_string()}")'
def __eq__(self, other):
try:
other = StructuredUnit(other)
except Exception:
return NotImplemented
return self.values() == other.values()
def __ne__(self, other):
if not isinstance(other, type(self)):
try:
other = StructuredUnit(other)
except Exception:
return NotImplemented
return self.values() != other.values()
class Structure(np.void):
"""Single element structure for physical type IDs, etc.
Behaves like a `~numpy.void` and thus mostly like a tuple which can also
be indexed with field names, but overrides ``__eq__`` and ``__ne__`` to
compare only the contents, not the field names. Furthermore, this way no
`FutureWarning` about comparisons is given.
"""
# Note that it is important for physical type IDs to not be stored in a
# tuple, since then the physical types would be treated as alternatives in
# :meth:`~astropy.units.UnitBase.is_equivalent`. (Of course, in that
# case, they could also not be indexed by name.)
def __eq__(self, other):
if isinstance(other, np.void):
other = other.item()
return self.item() == other
def __ne__(self, other):
if isinstance(other, np.void):
other = other.item()
return self.item() != other
def _structured_unit_like_dtype(
unit: UnitBase | StructuredUnit, dtype: np.dtype
) -> StructuredUnit:
"""Make a `StructuredUnit` of one unit, with the structure of a `numpy.dtype`.
Parameters
----------
unit : UnitBase
The unit that will be filled into the structure.
dtype : `numpy.dtype`
The structure for the StructuredUnit.
Returns
-------
StructuredUnit
"""
if isinstance(unit, StructuredUnit):
# If unit is structured, it should match the dtype. This function is
# only used in Quantity, which performs this check, so it's fine to
# return as is.
return unit
# Make a structured unit
units = []
for name in dtype.names:
subdtype = dtype.fields[name][0]
if subdtype.names is not None:
units.append(_structured_unit_like_dtype(unit, subdtype))
else:
units.append(unit)
return StructuredUnit(tuple(units), names=dtype.names)