forked from astropy/astropy
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_shape_manipulation.py
More file actions
337 lines (314 loc) · 17.5 KB
/
test_shape_manipulation.py
File metadata and controls
337 lines (314 loc) · 17.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
# Licensed under a 3-clause BSD style license - see LICENSE.rst
import numpy as np
import pytest
from numpy.testing import assert_array_equal
from astropy import units as u
from astropy.coordinates import EarthLocation, Latitude, Longitude, SkyCoord
# test on frame with most complicated frame attributes.
from astropy.coordinates.builtin_frames import GCRS, ICRS, AltAz
from astropy.time import Time
from astropy.units.quantity_helper.function_helpers import ARRAY_FUNCTION_ENABLED
@pytest.fixture(params=[True, False] if ARRAY_FUNCTION_ENABLED else [True])
def method(request):
return request.param
needs_array_function = pytest.mark.xfail(
not ARRAY_FUNCTION_ENABLED, reason="Needs __array_function__ support"
)
class TestManipulation:
"""Manipulation of Frame shapes.
Checking that attributes are manipulated correctly.
Even more exhaustive tests are done in time.tests.test_methods
"""
def setup_class(cls):
# For these tests, we set up frames and coordinates using copy=False,
# so we can check that broadcasting is handled correctly.
lon = Longitude(np.arange(0, 24, 4), u.hourangle)
lat = Latitude(np.arange(-90, 91, 30), u.deg)
# With same-sized arrays, no attributes.
cls.s0 = ICRS(
lon[:, np.newaxis] * np.ones(lat.shape),
lat * np.ones(lon.shape)[:, np.newaxis],
copy=False,
)
# Make an AltAz frame since that has many types of attributes.
# Match one axis with times.
cls.obstime = Time("2012-01-01") + np.arange(len(lon))[:, np.newaxis] * u.s
# And another with location.
cls.location = EarthLocation(20.0 * u.deg, lat, 100 * u.m)
# Ensure we have a quantity scalar.
cls.pressure = 1000 * u.hPa
# As well as an array.
cls.temperature = (
np.random.uniform(0.0, 20.0, size=(lon.size, lat.size)) * u.deg_C
)
cls.s1 = AltAz(
az=lon[:, np.newaxis],
alt=lat,
obstime=cls.obstime,
location=cls.location,
pressure=cls.pressure,
temperature=cls.temperature,
copy=False,
)
# For some tests, also try a GCRS, since that has representation
# attributes. We match the second dimension (via the location)
cls.obsgeoloc, cls.obsgeovel = cls.location.get_gcrs_posvel(cls.obstime[0, 0])
cls.s2 = GCRS(
ra=lon[:, np.newaxis],
dec=lat,
obstime=cls.obstime,
obsgeoloc=cls.obsgeoloc,
obsgeovel=cls.obsgeovel,
copy=False,
)
# For completeness, also some tests on an empty frame.
cls.s3 = GCRS(
obstime=cls.obstime,
obsgeoloc=cls.obsgeoloc,
obsgeovel=cls.obsgeovel,
copy=False,
)
# And make a SkyCoord
cls.sc = SkyCoord(ra=lon[:, np.newaxis], dec=lat, frame=cls.s3, copy=False)
def test_getitem0101(self):
# We on purpose take a slice with only one element, as for the
# general tests it doesn't matter, but it allows us to check
# for a few cases that shapes correctly become scalar if we
# index our size-1 array down to a scalar. See gh-10113.
item = (slice(0, 1), slice(0, 1))
s0_0101 = self.s0[item]
assert s0_0101.shape == (1, 1)
assert_array_equal(s0_0101.data.lon, self.s0.data.lon[item])
assert np.may_share_memory(s0_0101.data.lon, self.s0.data.lon)
assert np.may_share_memory(s0_0101.data.lat, self.s0.data.lat)
s0_0101_00 = s0_0101[0, 0]
assert s0_0101_00.shape == ()
assert s0_0101_00.data.lon.shape == ()
assert_array_equal(s0_0101_00.data.lon, self.s0.data.lon[0, 0])
s1_0101 = self.s1[item]
assert s1_0101.shape == (1, 1)
assert_array_equal(s1_0101.data.lon, self.s1.data.lon[item])
assert np.may_share_memory(s1_0101.data.lat, self.s1.data.lat)
assert np.all(s1_0101.obstime == self.s1.obstime[item])
assert np.may_share_memory(s1_0101.obstime.jd1, self.s1.obstime.jd1)
assert_array_equal(s1_0101.location, self.s1.location[0, 0])
assert np.may_share_memory(s1_0101.location, self.s1.location)
assert_array_equal(s1_0101.temperature, self.s1.temperature[item])
assert np.may_share_memory(s1_0101.temperature, self.s1.temperature)
# scalar should just be transferred.
assert s1_0101.pressure is self.s1.pressure
s1_0101_00 = s1_0101[0, 0]
assert s1_0101_00.shape == ()
assert s1_0101_00.obstime.shape == ()
assert s1_0101_00.obstime == self.s1.obstime[0, 0]
s2_0101 = self.s2[item]
assert s2_0101.shape == (1, 1)
assert np.all(s2_0101.data.lon == self.s2.data.lon[item])
assert np.may_share_memory(s2_0101.data.lat, self.s2.data.lat)
assert np.all(s2_0101.obstime == self.s2.obstime[item])
assert np.may_share_memory(s2_0101.obstime.jd1, self.s2.obstime.jd1)
assert_array_equal(s2_0101.obsgeoloc.xyz, self.s2.obsgeoloc[item].xyz)
s3_0101 = self.s3[item]
assert s3_0101.shape == (1, 1)
assert s3_0101.obstime.shape == (1, 1)
assert np.all(s3_0101.obstime == self.s3.obstime[item])
assert np.may_share_memory(s3_0101.obstime.jd1, self.s3.obstime.jd1)
assert_array_equal(s3_0101.obsgeoloc.xyz, self.s3.obsgeoloc[item].xyz)
sc_0101 = self.sc[item]
assert sc_0101.shape == (1, 1)
assert_array_equal(sc_0101.data.lon, self.sc.data.lon[item])
assert np.may_share_memory(sc_0101.data.lat, self.sc.data.lat)
assert np.all(sc_0101.obstime == self.sc.obstime[item])
assert np.may_share_memory(sc_0101.obstime.jd1, self.sc.obstime.jd1)
assert_array_equal(sc_0101.obsgeoloc.xyz, self.sc.obsgeoloc[item].xyz)
def test_ravel(self):
s0_ravel = self.s0.ravel()
assert s0_ravel.shape == (self.s0.size,)
assert np.all(s0_ravel.data.lon == self.s0.data.lon.ravel())
assert np.may_share_memory(s0_ravel.data.lon, self.s0.data.lon)
assert np.may_share_memory(s0_ravel.data.lat, self.s0.data.lat)
# Since s1 lon, lat were broadcast, ravel needs to make a copy.
s1_ravel = self.s1.ravel()
assert s1_ravel.shape == (self.s1.size,)
assert np.all(s1_ravel.data.lon == self.s1.data.lon.ravel())
assert not np.may_share_memory(s1_ravel.data.lat, self.s1.data.lat)
assert np.all(s1_ravel.obstime == self.s1.obstime.ravel())
assert not np.may_share_memory(s1_ravel.obstime.jd1, self.s1.obstime.jd1)
assert np.all(s1_ravel.location == self.s1.location.ravel())
assert not np.may_share_memory(s1_ravel.location, self.s1.location)
assert np.all(s1_ravel.temperature == self.s1.temperature.ravel())
assert np.may_share_memory(s1_ravel.temperature, self.s1.temperature)
assert s1_ravel.pressure == self.s1.pressure
s2_ravel = self.s2.ravel()
assert s2_ravel.shape == (self.s2.size,)
assert np.all(s2_ravel.data.lon == self.s2.data.lon.ravel())
assert not np.may_share_memory(s2_ravel.data.lat, self.s2.data.lat)
assert np.all(s2_ravel.obstime == self.s2.obstime.ravel())
assert not np.may_share_memory(s2_ravel.obstime.jd1, self.s2.obstime.jd1)
# CartesianRepresentation do not allow direct comparisons, as this is
# too tricky to get right in the face of rounding issues. Here, though,
# it cannot be an issue, so we compare the xyz quantities.
assert np.all(s2_ravel.obsgeoloc.xyz == self.s2.obsgeoloc.ravel().xyz)
assert not np.may_share_memory(s2_ravel.obsgeoloc.x, self.s2.obsgeoloc.x)
s3_ravel = self.s3.ravel()
assert s3_ravel.shape == (42,) # cannot use .size on frame w/o data.
assert np.all(s3_ravel.obstime == self.s3.obstime.ravel())
assert not np.may_share_memory(s3_ravel.obstime.jd1, self.s3.obstime.jd1)
assert np.all(s3_ravel.obsgeoloc.xyz == self.s3.obsgeoloc.ravel().xyz)
assert not np.may_share_memory(s3_ravel.obsgeoloc.x, self.s3.obsgeoloc.x)
sc_ravel = self.sc.ravel()
assert sc_ravel.shape == (self.sc.size,)
assert np.all(sc_ravel.data.lon == self.sc.data.lon.ravel())
assert not np.may_share_memory(sc_ravel.data.lat, self.sc.data.lat)
assert np.all(sc_ravel.obstime == self.sc.obstime.ravel())
assert not np.may_share_memory(sc_ravel.obstime.jd1, self.sc.obstime.jd1)
assert np.all(sc_ravel.obsgeoloc.xyz == self.sc.obsgeoloc.ravel().xyz)
assert not np.may_share_memory(sc_ravel.obsgeoloc.x, self.sc.obsgeoloc.x)
def test_flatten(self):
s0_flatten = self.s0.flatten()
assert s0_flatten.shape == (self.s0.size,)
assert np.all(s0_flatten.data.lon == self.s0.data.lon.flatten())
# Flatten always copies.
assert not np.may_share_memory(s0_flatten.data.lat, self.s0.data.lat)
s1_flatten = self.s1.flatten()
assert s1_flatten.shape == (self.s1.size,)
assert np.all(s1_flatten.data.lat == self.s1.data.lat.flatten())
assert not np.may_share_memory(s1_flatten.data.lon, self.s1.data.lat)
assert np.all(s1_flatten.obstime == self.s1.obstime.flatten())
assert not np.may_share_memory(s1_flatten.obstime.jd1, self.s1.obstime.jd1)
assert np.all(s1_flatten.location == self.s1.location.flatten())
assert not np.may_share_memory(s1_flatten.location, self.s1.location)
assert np.all(s1_flatten.temperature == self.s1.temperature.flatten())
assert not np.may_share_memory(s1_flatten.temperature, self.s1.temperature)
assert s1_flatten.pressure == self.s1.pressure
def test_transpose(self):
s0_transpose = self.s0.transpose()
assert s0_transpose.shape == (7, 6)
assert np.all(s0_transpose.data.lon == self.s0.data.lon.transpose())
assert np.may_share_memory(s0_transpose.data.lat, self.s0.data.lat)
s1_transpose = self.s1.transpose()
assert s1_transpose.shape == (7, 6)
assert np.all(s1_transpose.data.lat == self.s1.data.lat.transpose())
assert np.may_share_memory(s1_transpose.data.lon, self.s1.data.lon)
assert np.all(s1_transpose.obstime == self.s1.obstime.transpose())
assert np.may_share_memory(s1_transpose.obstime.jd1, self.s1.obstime.jd1)
assert np.all(s1_transpose.location == self.s1.location.transpose())
assert np.may_share_memory(s1_transpose.location, self.s1.location)
assert np.all(s1_transpose.temperature == self.s1.temperature.transpose())
assert np.may_share_memory(s1_transpose.temperature, self.s1.temperature)
assert s1_transpose.pressure == self.s1.pressure
# Only one check on T, since it just calls transpose anyway.
s1_T = self.s1.T
assert s1_T.shape == (7, 6)
assert np.all(s1_T.temperature == self.s1.temperature.T)
assert np.may_share_memory(s1_T.location, self.s1.location)
def test_diagonal(self):
s0_diagonal = self.s0.diagonal()
assert s0_diagonal.shape == (6,)
assert np.all(s0_diagonal.data.lat == self.s0.data.lat.diagonal())
assert np.may_share_memory(s0_diagonal.data.lat, self.s0.data.lat)
def test_swapaxes(self):
s1_swapaxes = self.s1.swapaxes(0, 1)
assert s1_swapaxes.shape == (7, 6)
assert np.all(s1_swapaxes.data.lat == self.s1.data.lat.swapaxes(0, 1))
assert np.may_share_memory(s1_swapaxes.data.lat, self.s1.data.lat)
assert np.all(s1_swapaxes.obstime == self.s1.obstime.swapaxes(0, 1))
assert np.may_share_memory(s1_swapaxes.obstime.jd1, self.s1.obstime.jd1)
assert np.all(s1_swapaxes.location == self.s1.location.swapaxes(0, 1))
assert s1_swapaxes.location.shape == (7, 6)
assert np.may_share_memory(s1_swapaxes.location, self.s1.location)
assert np.all(s1_swapaxes.temperature == self.s1.temperature.swapaxes(0, 1))
assert np.may_share_memory(s1_swapaxes.temperature, self.s1.temperature)
assert s1_swapaxes.pressure == self.s1.pressure
def test_reshape(self):
s0_reshape = self.s0.reshape(2, 3, 7)
assert s0_reshape.shape == (2, 3, 7)
assert np.all(s0_reshape.data.lon == self.s0.data.lon.reshape(2, 3, 7))
assert np.all(s0_reshape.data.lat == self.s0.data.lat.reshape(2, 3, 7))
assert np.may_share_memory(s0_reshape.data.lon, self.s0.data.lon)
assert np.may_share_memory(s0_reshape.data.lat, self.s0.data.lat)
s1_reshape = self.s1.reshape(3, 2, 7)
assert s1_reshape.shape == (3, 2, 7)
assert np.all(s1_reshape.data.lat == self.s1.data.lat.reshape(3, 2, 7))
assert np.may_share_memory(s1_reshape.data.lat, self.s1.data.lat)
assert np.all(s1_reshape.obstime == self.s1.obstime.reshape(3, 2, 7))
assert np.may_share_memory(s1_reshape.obstime.jd1, self.s1.obstime.jd1)
assert np.all(s1_reshape.location == self.s1.location.reshape(3, 2, 7))
assert np.may_share_memory(s1_reshape.location, self.s1.location)
assert np.all(s1_reshape.temperature == self.s1.temperature.reshape(3, 2, 7))
assert np.may_share_memory(s1_reshape.temperature, self.s1.temperature)
assert s1_reshape.pressure == self.s1.pressure
# For reshape(3, 14), copying is necessary for lon, lat, location, time
s1_reshape2 = self.s1.reshape(3, 14)
assert s1_reshape2.shape == (3, 14)
assert np.all(s1_reshape2.data.lon == self.s1.data.lon.reshape(3, 14))
assert not np.may_share_memory(s1_reshape2.data.lon, self.s1.data.lon)
assert np.all(s1_reshape2.obstime == self.s1.obstime.reshape(3, 14))
assert not np.may_share_memory(s1_reshape2.obstime.jd1, self.s1.obstime.jd1)
assert np.all(s1_reshape2.location == self.s1.location.reshape(3, 14))
assert not np.may_share_memory(s1_reshape2.location, self.s1.location)
assert np.all(s1_reshape2.temperature == self.s1.temperature.reshape(3, 14))
assert np.may_share_memory(s1_reshape2.temperature, self.s1.temperature)
assert s1_reshape2.pressure == self.s1.pressure
s2_reshape = self.s2.reshape(3, 2, 7)
assert s2_reshape.shape == (3, 2, 7)
assert np.all(s2_reshape.data.lon == self.s2.data.lon.reshape(3, 2, 7))
assert np.may_share_memory(s2_reshape.data.lat, self.s2.data.lat)
assert np.all(s2_reshape.obstime == self.s2.obstime.reshape(3, 2, 7))
assert np.may_share_memory(s2_reshape.obstime.jd1, self.s2.obstime.jd1)
assert np.all(
s2_reshape.obsgeoloc.xyz == self.s2.obsgeoloc.reshape(3, 2, 7).xyz
)
assert np.may_share_memory(s2_reshape.obsgeoloc.x, self.s2.obsgeoloc.x)
s3_reshape = self.s3.reshape(3, 2, 7)
assert s3_reshape.shape == (3, 2, 7)
assert np.all(s3_reshape.obstime == self.s3.obstime.reshape(3, 2, 7))
assert np.may_share_memory(s3_reshape.obstime.jd1, self.s3.obstime.jd1)
assert np.all(
s3_reshape.obsgeoloc.xyz == self.s3.obsgeoloc.reshape(3, 2, 7).xyz
)
assert np.may_share_memory(s3_reshape.obsgeoloc.x, self.s3.obsgeoloc.x)
sc_reshape = self.sc.reshape(3, 2, 7)
assert sc_reshape.shape == (3, 2, 7)
assert np.all(sc_reshape.data.lon == self.sc.data.lon.reshape(3, 2, 7))
assert np.may_share_memory(sc_reshape.data.lat, self.sc.data.lat)
assert np.all(sc_reshape.obstime == self.sc.obstime.reshape(3, 2, 7))
assert np.may_share_memory(sc_reshape.obstime.jd1, self.sc.obstime.jd1)
assert np.all(
sc_reshape.obsgeoloc.xyz == self.sc.obsgeoloc.reshape(3, 2, 7).xyz
)
assert np.may_share_memory(sc_reshape.obsgeoloc.x, self.sc.obsgeoloc.x)
# For reshape(3, 14), the arrays all need to be copied.
sc_reshape2 = self.sc.reshape(3, 14)
assert sc_reshape2.shape == (3, 14)
assert np.all(sc_reshape2.data.lon == self.sc.data.lon.reshape(3, 14))
assert not np.may_share_memory(sc_reshape2.data.lat, self.sc.data.lat)
assert np.all(sc_reshape2.obstime == self.sc.obstime.reshape(3, 14))
assert not np.may_share_memory(sc_reshape2.obstime.jd1, self.sc.obstime.jd1)
assert np.all(sc_reshape2.obsgeoloc.xyz == self.sc.obsgeoloc.reshape(3, 14).xyz)
assert not np.may_share_memory(sc_reshape2.obsgeoloc.x, self.sc.obsgeoloc.x)
def test_squeeze(self):
s0_squeeze = self.s0.reshape(3, 1, 2, 1, 7).squeeze()
assert s0_squeeze.shape == (3, 2, 7)
assert np.all(s0_squeeze.data.lat == self.s0.data.lat.reshape(3, 2, 7))
assert np.may_share_memory(s0_squeeze.data.lat, self.s0.data.lat)
def test_add_dimension(self, method):
if method:
s0_adddim = self.s0[:, np.newaxis, :]
else:
s0_adddim = np.expand_dims(self.s0, 1)
assert s0_adddim.shape == (6, 1, 7)
assert np.all(s0_adddim.data.lon == self.s0.data.lon[:, np.newaxis, :])
assert np.may_share_memory(s0_adddim.data.lat, self.s0.data.lat)
def test_take(self):
s0_take = self.s0.take((5, 2))
assert s0_take.shape == (2,)
assert np.all(s0_take.data.lon == self.s0.data.lon.take((5, 2)))
# Much more detailed tests of shape manipulation via numpy functions done
# in test_representation_methods.
@needs_array_function
def test_broadcast_to(self):
s1_broadcast = np.broadcast_to(self.s1, (20, 6, 7))
assert s1_broadcast.shape == (20, 6, 7)
assert np.all(s1_broadcast.data.lon == self.s1.data.lon[np.newaxis])
assert np.may_share_memory(s1_broadcast.data.lat, self.s1.data.lat)