Skip to content

Commit 9f1bc1e

Browse files
committed
Fix bilinear interpolation for SegmentedBivarColormap
1 parent 702c669 commit 9f1bc1e

File tree

2 files changed

+68
-26
lines changed

2 files changed

+68
-26
lines changed

lib/matplotlib/colors.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555

5656
import matplotlib as mpl
5757
import numpy as np
58-
from matplotlib import _api, _cm, cbook, scale, _image
58+
from matplotlib import _api, _cm, cbook, scale
5959
from ._color_data import BASE_COLORS, TABLEAU_COLORS, CSS4_COLORS, XKCD_COLORS
6060

6161

@@ -2211,16 +2211,37 @@ def __init__(self, patch, N=256, shape='square', origin=(0, 0),
22112211
super().__init__(N, N, shape, origin, name=name)
22122212

22132213
def _init(self):
2214+
# Perform bilinear interpolation
2215+
22142216
s = self.patch.shape
2215-
_patch = np.empty((s[0], s[1], 4))
2216-
_patch[:, :, :3] = self.patch
2217-
_patch[:, :, 3] = 1
2218-
transform = mpl.transforms.Affine2D().translate(-0.5, -0.5)\
2219-
.scale(self.N / (s[1] - 1), self.N / (s[0] - 1))
2220-
self._lut = np.empty((self.N, self.N, 4))
2221-
2222-
_image.resample(_patch, self._lut, transform, _image.BILINEAR,
2223-
resample=False, alpha=1)
2217+
2218+
# Indices (whole and fraction) of the new grid points
2219+
row = np.linspace(0, s[0] - 1, self.N)[:, np.newaxis]
2220+
col = np.linspace(0, s[1] - 1, self.N)[np.newaxis, :]
2221+
left = np.floor(row).astype(int)
2222+
top = np.floor(col).astype(int)
2223+
row_frac = (row - left)[:, :, np.newaxis]
2224+
col_frac = (col - top)[:, :, np.newaxis]
2225+
2226+
# Indices of the next edges, clipping where needed
2227+
right = np.clip(left + 1, 0, s[0] - 1)
2228+
bottom = np.clip(top + 1, 0, s[1] - 1)
2229+
2230+
# Values at the corners
2231+
tl = self.patch[left, top, :]
2232+
tr = self.patch[right, top, :]
2233+
bl = self.patch[left, bottom, :]
2234+
br = self.patch[right, bottom, :]
2235+
2236+
# Interpolate between the corners
2237+
lut = (tl * (1 - row_frac) * (1 - col_frac) +
2238+
tr * row_frac * (1 - col_frac) +
2239+
bl * (1 - row_frac) * col_frac +
2240+
br * row_frac * col_frac)
2241+
2242+
# Add the alpha channel
2243+
self._lut = np.concatenate([lut, np.ones((self.N, self.N, 1))], axis=2)
2244+
22242245
self._isinit = True
22252246

22262247

lib/matplotlib/tests/test_multivariate_colormaps.py

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -212,9 +212,18 @@ def test_multivar_resample():
212212

213213
def test_bivar_cmap_call_tuple():
214214
cmap = mpl.bivar_colormaps['BiOrangeBlue']
215-
assert_allclose(cmap((1.0, 1.0)), (1, 1, 1, 1), atol=0.01)
216-
assert_allclose(cmap((0.0, 0.0)), (0, 0, 0, 1), atol=0.1)
217-
assert_allclose(cmap((0.0, 0.0), alpha=0.1), (0, 0, 0, 0.1), atol=0.1)
215+
assert_allclose(cmap((1.0, 1.0)), (1, 1, 1, 1))
216+
assert_allclose(cmap((0.0, 0.0)), (0, 0, 0, 1))
217+
assert_allclose(cmap((0.2, 0.8)), (0.2, 0.5, 0.8, 1))
218+
assert_allclose(cmap((0.0, 0.0), alpha=0.1), (0, 0, 0, 0.1))
219+
220+
221+
def test_bivar_cmap_lut_smooth():
222+
cmap = mpl.bivar_colormaps['BiOrangeBlue']
223+
assert_allclose(cmap.lut[:, 0, 0], np.linspace(0, 1, 256))
224+
assert_allclose(cmap.lut[:, 0, 1], np.linspace(0, 0.5, 256), atol=1e-3)
225+
assert_allclose(cmap.lut[0, :, 1], np.linspace(0, 0.5, 256), atol=1e-3)
226+
assert_allclose(cmap.lut[0, :, 2], np.linspace(0, 1, 256))
218227

219228

220229
def test_bivar_cmap_call():
@@ -312,17 +321,29 @@ def test_bivar_cmap_call():
312321
match="only implemented for use with with floats"):
313322
cs = cmap([(0, 5, 9, 0, 0, 9), (0, 0, 0, 5, 11, 11)])
314323

315-
# test origin
324+
325+
def test_bivar_cmap_1d_origin():
326+
"""
327+
Test getting 1D colormaps with different origins
328+
"""
329+
cmap = mpl.bivar_colormaps['BiOrangeBlue']
330+
assert_allclose(cmap[0](1.), (1., 0.5, 0., 1.))
331+
assert_allclose(cmap[1](1.), (0., 0.5, 1., 1.))
332+
333+
cmap = mpl.bivar_colormaps['BiOrangeBlue'].with_extremes(origin=(0, 1))
334+
assert_allclose(cmap[0](1.), (1., 1., 1., 1.))
335+
assert_allclose(cmap[1](1.), (0., 0.5, 1., 1.))
336+
316337
cmap = mpl.bivar_colormaps['BiOrangeBlue'].with_extremes(origin=(0.5, 0.5))
317338
assert_allclose(cmap[0](0.5),
318-
(0.50244140625, 0.5024222412109375, 0.50244140625, 1))
339+
(0.5019607843137255, 0.5019453440984237, 0.5019607843137255, 1))
319340
assert_allclose(cmap[1](0.5),
320-
(0.50244140625, 0.5024222412109375, 0.50244140625, 1))
341+
(0.5019607843137255, 0.5019453440984237, 0.5019607843137255, 1))
342+
321343
cmap = mpl.bivar_colormaps['BiOrangeBlue'].with_extremes(origin=(1, 1))
322-
assert_allclose(cmap[0](1.),
323-
(0.99853515625, 0.9985467529296875, 0.99853515625, 1.0))
324-
assert_allclose(cmap[1](1.),
325-
(0.99853515625, 0.9985467529296875, 0.99853515625, 1.0))
344+
assert_allclose(cmap[0](1.), (1., 1., 1., 1.))
345+
assert_allclose(cmap[1](1.), (1., 1., 1., 1.))
346+
326347
with pytest.raises(KeyError,
327348
match="only 0 or 1 are valid keys"):
328349
cs = cmap[2]
@@ -434,21 +455,21 @@ def test_bivar_cmap_from_image():
434455

435456
def test_bivar_resample():
436457
cmap = mpl.bivar_colormaps['BiOrangeBlue'].resampled((2, 2))
437-
assert_allclose(cmap((0.25, 0.25)), (0, 0, 0, 1), atol=1e-2)
458+
assert_allclose(cmap((0.25, 0.25)), (0, 0, 0, 1))
438459

439460
cmap = mpl.bivar_colormaps['BiOrangeBlue'].resampled((-2, 2))
440-
assert_allclose(cmap((0.25, 0.25)), (1., 0.5, 0., 1.), atol=1e-2)
461+
assert_allclose(cmap((0.25, 0.25)), (1., 0.5, 0., 1.))
441462

442463
cmap = mpl.bivar_colormaps['BiOrangeBlue'].resampled((2, -2))
443-
assert_allclose(cmap((0.25, 0.25)), (0., 0.5, 1., 1.), atol=1e-2)
464+
assert_allclose(cmap((0.25, 0.25)), (0., 0.5, 1., 1.))
444465

445466
cmap = mpl.bivar_colormaps['BiOrangeBlue'].resampled((-2, -2))
446-
assert_allclose(cmap((0.25, 0.25)), (1, 1, 1, 1), atol=1e-2)
467+
assert_allclose(cmap((0.25, 0.25)), (1, 1, 1, 1))
447468

448469
cmap = mpl.bivar_colormaps['BiOrangeBlue'].reversed()
449-
assert_allclose(cmap((0.25, 0.25)), (0.748535, 0.748547, 0.748535, 1.), atol=1e-2)
470+
assert_allclose(cmap((0.25, 0.25)), (0.74902, 0.74902, 0.74902, 1.), atol=1e-5)
450471
cmap = mpl.bivar_colormaps['BiOrangeBlue'].transposed()
451-
assert_allclose(cmap((0.25, 0.25)), (0.252441, 0.252422, 0.252441, 1.), atol=1e-2)
472+
assert_allclose(cmap((0.25, 0.25)), (0.25098, 0.25098, 0.25098, 1.), atol=1e-5)
452473

453474
with pytest.raises(ValueError, match="lutshape must be of length"):
454475
cmap = cmap.resampled(4)

0 commit comments

Comments
 (0)