forked from astropy/astropy
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_convolve_nddata.py
More file actions
64 lines (48 loc) · 1.79 KB
/
test_convolve_nddata.py
File metadata and controls
64 lines (48 loc) · 1.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# Licensed under a 3-clause BSD style license - see LICENSE.rst
import numpy as np
import pytest
from astropy.convolution.convolve import convolve, convolve_fft
from astropy.convolution.kernels import Gaussian2DKernel
from astropy.nddata import NDData
def test_basic_nddata():
arr = np.zeros((11, 11))
arr[5, 5] = 1
ndd = NDData(arr)
test_kernel = Gaussian2DKernel(1)
result = convolve(ndd, test_kernel)
x, y = np.mgrid[:11, :11]
expected = result[5, 5] * np.exp(-0.5 * ((x - 5) ** 2 + (y - 5) ** 2))
np.testing.assert_allclose(result, expected, atol=1e-6)
resultf = convolve_fft(ndd, test_kernel)
np.testing.assert_allclose(resultf, expected, atol=1e-6)
@pytest.mark.parametrize(
"convfunc",
[
lambda *args: convolve(
*args, nan_treatment="interpolate", normalize_kernel=True
),
lambda *args: convolve_fft(
*args, nan_treatment="interpolate", normalize_kernel=True
),
],
)
def test_masked_nddata(convfunc):
arr = np.zeros((11, 11))
arr[4, 5] = arr[6, 5] = arr[5, 4] = arr[5, 6] = 0.2
arr[5, 5] = 1.5
ndd_base = NDData(arr)
mask = arr < 0 # this is all False
mask[5, 5] = True
ndd_mask = NDData(arr, mask=mask)
arrnan = arr.copy()
arrnan[5, 5] = np.nan
ndd_nan = NDData(arrnan)
test_kernel = Gaussian2DKernel(1)
result_base = convfunc(ndd_base, test_kernel)
result_nan = convfunc(ndd_nan, test_kernel)
result_mask = convfunc(ndd_mask, test_kernel)
assert np.allclose(result_nan, result_mask)
assert not np.allclose(result_base, result_mask)
assert not np.allclose(result_base, result_nan)
# check to make sure the mask run doesn't talk back to the initial array
assert np.sum(np.isnan(ndd_base.data)) != np.sum(np.isnan(ndd_nan.data))