diff --git a/lib/matplotlib/scale.py b/lib/matplotlib/scale.py index a4cce23562d3..d5a2b8dc0d5c 100644 --- a/lib/matplotlib/scale.py +++ b/lib/matplotlib/scale.py @@ -30,7 +30,6 @@ """ # noqa: E501 import inspect -import math import textwrap from functools import wraps @@ -119,17 +118,20 @@ def val_in_range(self, val): """ Return whether the value(s) are within the valid range for this scale. - This method is a generic implementation. Subclasses may implement more - efficient solutions for their domain. - """ - try: - if not math.isfinite(val): - return False + Accepts a scalar or array-like ``val``. For a scalar, returns a + Python ``bool``. For an array, returns a bool ndarray of the same + shape. This is a generic implementation, and subclasses may implement + more efficient solutions for their domain. + """ + arr = np.asarray(val) + with np.errstate(invalid='ignore'): + try: + vmin, vmax = self.limit_range_for_scale(arr, arr, minpos=1e-300) + except (TypeError, ValueError): + result = np.zeros(arr.shape, dtype=bool) else: - vmin, vmax = self.limit_range_for_scale(val, val, minpos=1e-300) - return vmin == val and vmax == val - except (TypeError, ValueError): - return False + result = np.isfinite(arr) & (vmin == arr) & (vmax == arr) + return bool(result) if arr.ndim == 0 else result def _make_axis_parameter_optional(init_func): @@ -219,11 +221,13 @@ def get_transform(self): def val_in_range(self, val): """ - Return whether the value is within the valid range for this scale. + Return whether the value(s) are within the valid range for this scale. This is True for all values, except +-inf and NaN. """ - return math.isfinite(val) + arr = np.asarray(val) + result = np.isfinite(arr) + return bool(result) if arr.ndim == 0 else result class FuncTransform(Transform): @@ -431,11 +435,14 @@ def limit_range_for_scale(self, vmin, vmax, minpos): def val_in_range(self, val): """ - Return whether the value is within the valid range for this scale. + Return whether the value(s) are within the valid range for this scale. This is True for value(s) > 0 except +inf and NaN. """ - return math.isfinite(val) and val > 0 + arr = np.asarray(val) + with np.errstate(invalid='ignore'): + result = np.isfinite(arr) & (arr > 0) + return bool(result) if arr.ndim == 0 else result class FuncScaleLog(LogScale): @@ -625,11 +632,13 @@ def get_transform(self): def val_in_range(self, val): """ - Return whether the value is within the valid range for this scale. + Return whether the value(s) are within the valid range for this scale. This is True for all values, except +-inf and NaN. """ - return math.isfinite(val) + arr = np.asarray(val) + result = np.isfinite(arr) + return bool(result) if arr.ndim == 0 else result class AsinhTransform(Transform): @@ -759,11 +768,13 @@ def set_default_locators_and_formatters(self, axis): def val_in_range(self, val): """ - Return whether the value is within the valid range for this scale. + Return whether the value(s) are within the valid range for this scale. This is True for all values, except +-inf and NaN. """ - return math.isfinite(val) + arr = np.asarray(val) + result = np.isfinite(arr) + return bool(result) if arr.ndim == 0 else result class LogitTransform(Transform): @@ -880,11 +891,14 @@ def limit_range_for_scale(self, vmin, vmax, minpos): def val_in_range(self, val): """ - Return whether the value is within the valid range for this scale. + Return whether the value(s) are within the valid range for this scale. This is True for value(s) which are between 0 and 1 (excluded). """ - return 0 < val < 1 + arr = np.asarray(val) + with np.errstate(invalid='ignore'): + result = (0 < arr) & (arr < 1) + return bool(result) if arr.ndim == 0 else result _scale_mapping = { diff --git a/lib/matplotlib/tests/test_scale.py b/lib/matplotlib/tests/test_scale.py index 104c87adab7b..ca00ee7ce054 100644 --- a/lib/matplotlib/tests/test_scale.py +++ b/lib/matplotlib/tests/test_scale.py @@ -477,3 +477,27 @@ def test_val_in_range_base_fallback(): assert s.val_in_range(np.nan) is False assert s.val_in_range(np.inf) is False assert s.val_in_range(-np.inf) is False + + +def test_val_in_range_array(): + # Vectorized: scalar in -> scalar bool, array in -> bool array. + arr = np.array([1.0, -1.0, 0.0, np.nan, np.inf, 5.0]) + cases = { + 'linear': [True, True, True, False, False, True], + 'log': [True, False, False, False, False, True], + 'symlog': [True, True, True, False, False, True], + 'asinh': [True, True, True, False, False, True], + } + for name, expected in cases.items(): + s = mscale._scale_mapping[name](axis=None) + np.testing.assert_array_equal(s.val_in_range(arr), expected) + + s = mscale._scale_mapping['logit'](axis=None) + np.testing.assert_array_equal( + s.val_in_range(np.array([0.1, 0.5, 0.0, 1.0, -0.1, 1.1])), + [True, True, False, False, False, False]) + + # 2D shape is preserved. + out = mscale._scale_mapping['log'](axis=None).val_in_range( + np.array([[1.0, -1.0], [0.5, np.nan]])) + np.testing.assert_array_equal(out, [[True, False], [True, False]]) diff --git a/lib/mpl_toolkits/mplot3d/art3d.py b/lib/mpl_toolkits/mplot3d/art3d.py index f664127dcb59..cc0c2a2b0c59 100644 --- a/lib/mpl_toolkits/mplot3d/art3d.py +++ b/lib/mpl_toolkits/mplot3d/art3d.py @@ -97,6 +97,30 @@ def _viewlim_mask(xs, ys, zs, axes): return mask +def _scale_invalid_mask(xs, ys, zs, axes): + """ + Return the mask of points whose coordinates are invalid for the axis + scale they live on (e.g. <=0 on a log axis). + + Parameters + ---------- + xs, ys, zs : array-like + The points to check, in data coordinates. + axes : Axes3D + The axes whose scales are queried. + + Returns + ------- + mask : np.ndarray + Boolean array, ``True`` where any of x/y/z is out of its scale's + valid domain. + """ + return np.logical_or.reduce(( + np.logical_not(axes.xaxis._scale.val_in_range(xs)), + np.logical_not(axes.yaxis._scale.val_in_range(ys)), + np.logical_not(axes.zaxis._scale.val_in_range(zs)))) + + class Text3D(mtext.Text): """ Text object with 3D position and direction. @@ -191,8 +215,10 @@ def set_3d_properties(self, z=0, zdir='z', axlim_clip=False): @artist.allow_rasterization def draw(self, renderer): + mask = _scale_invalid_mask(self._x, self._y, self._z, self.axes) if self._axlim_clip: - mask = _viewlim_mask(self._x, self._y, self._z, self.axes) + mask = mask | _viewlim_mask(self._x, self._y, self._z, self.axes) + if np.any(mask): pos3d = np.ma.array([self._x, self._y, self._z], mask=mask, dtype=float).filled(np.nan) else: @@ -328,9 +354,12 @@ def get_data_3d(self): @artist.allow_rasterization def draw(self, renderer): + scale_mask = _scale_invalid_mask(*self._verts3d, self.axes) if self._axlim_clip: + scale_mask = scale_mask | _viewlim_mask(*self._verts3d, self.axes) + if np.any(scale_mask): mask = np.broadcast_to( - _viewlim_mask(*self._verts3d, self.axes), + scale_mask, (len(self._verts3d), *self._verts3d[0].shape) ) xs3d, ys3d, zs3d = np.ma.array(self._verts3d, @@ -424,10 +453,13 @@ class Collection3D(Collection): def do_3d_projection(self): """Project the points according to renderer matrix.""" vs_list = [vs for vs, _ in self._3dverts_codes] + masks = [_scale_invalid_mask(*vs.T, self.axes) for vs in vs_list] if self._axlim_clip: - vs_list = [np.ma.array(vs, mask=np.broadcast_to( - _viewlim_mask(*vs.T, self.axes), vs.shape)) - for vs in vs_list] + masks = [m | _viewlim_mask(*vs.T, self.axes) + for m, vs in zip(masks, vs_list)] + vs_list = [np.ma.array(vs, mask=np.broadcast_to(m, vs.shape)) + if np.any(m) else vs + for vs, m in zip(vs_list, masks)] xyzs_list = [proj3d._scale_proj_transform( vs[:, 0], vs[:, 1], vs[:, 2], self.axes) for vs in vs_list] self._paths = [mpath.Path(np.ma.column_stack([xs, ys]), cs) @@ -508,6 +540,14 @@ def do_3d_projection(self): if np.ma.isMA(segments): mask = segments.mask + scale_mask = _scale_invalid_mask(segments[..., 0], + segments[..., 1], + segments[..., 2], + self.axes) + if np.any(scale_mask): + mask = mask | np.broadcast_to(scale_mask[..., np.newaxis], + (*scale_mask.shape, 3)) + if self._axlim_clip: viewlim_mask = _viewlim_mask(segments[..., 0], segments[..., 1], @@ -597,12 +637,15 @@ def get_path(self): def do_3d_projection(self): s = self._segment3d + xs0, ys0, zs0 = zip(*s) + mask = _scale_invalid_mask(xs0, ys0, zs0, self.axes) if self._axlim_clip: - mask = _viewlim_mask(*zip(*s), self.axes) + mask = mask | _viewlim_mask(xs0, ys0, zs0, self.axes) + if np.any(mask): xs, ys, zs = np.ma.array(zip(*s), dtype=float, mask=mask).filled(np.nan) else: - xs, ys, zs = zip(*s) + xs, ys, zs = xs0, ys0, zs0 vxs, vys, vzs, vis = proj3d._scale_proj_transform_clip(xs, ys, zs, self.axes) self._path2d = mpath.Path(np.ma.column_stack([vxs, vys])) return min(vzs) @@ -657,12 +700,15 @@ def set_3d_properties(self, path, zs=0, zdir='z', axlim_clip=False): def do_3d_projection(self): s = self._segment3d + xs0, ys0, zs0 = zip(*s) + mask = _scale_invalid_mask(xs0, ys0, zs0, self.axes) if self._axlim_clip: - mask = _viewlim_mask(*zip(*s), self.axes) + mask = mask | _viewlim_mask(xs0, ys0, zs0, self.axes) + if np.any(mask): xs, ys, zs = np.ma.array(zip(*s), dtype=float, mask=mask).filled(np.nan) else: - xs, ys, zs = zip(*s) + xs, ys, zs = xs0, ys0, zs0 vxs, vys, vzs, vis = proj3d._scale_proj_transform_clip(xs, ys, zs, self.axes) self._path2d = mpath.Path(np.ma.column_stack([vxs, vys]), self._code3d) return min(vzs) @@ -801,8 +847,10 @@ def set_3d_properties(self, zs, zdir, axlim_clip=False): self.stale = True def do_3d_projection(self): + mask = _scale_invalid_mask(*self._offsets3d, self.axes) if self._axlim_clip: - mask = _viewlim_mask(*self._offsets3d, self.axes) + mask = mask | _viewlim_mask(*self._offsets3d, self.axes) + if np.any(mask): xs, ys, zs = np.ma.array(self._offsets3d, mask=mask) else: xs, ys, zs = self._offsets3d @@ -1023,8 +1071,10 @@ def do_3d_projection(self): for xyz in self._offsets3d: if np.ma.isMA(xyz): mask = mask | xyz.mask + mask = mask | _scale_invalid_mask(*self._offsets3d, self.axes) if self._axlim_clip: mask = mask | _viewlim_mask(*self._offsets3d, self.axes) + if np.any(mask): mask = np.broadcast_to(mask, (len(self._offsets3d), *self._offsets3d[0].shape)) xyzs = np.ma.array(self._offsets3d, mask=mask) @@ -1362,9 +1412,11 @@ def do_3d_projection(self): if self._edge_is_mapped: self._edgecolor3d = self._edgecolors - needs_masking = np.any(self._invalid_vertices) num_faces = len(self._faces) - mask = self._invalid_vertices + mask = self._invalid_vertices | _scale_invalid_mask( + self._faces[..., 0], self._faces[..., 1], + self._faces[..., 2], self.axes) + needs_masking = np.any(mask) # Some faces might contain masked vertices, so we want to ignore any # errors that those might cause diff --git a/lib/mpl_toolkits/mplot3d/axes3d.py b/lib/mpl_toolkits/mplot3d/axes3d.py index 45f9319355e0..c2c11e6af137 100644 --- a/lib/mpl_toolkits/mplot3d/axes3d.py +++ b/lib/mpl_toolkits/mplot3d/axes3d.py @@ -58,6 +58,20 @@ from . import axis3d +def _mask_scale_invalid(data, axis): + """ + Return ``data`` with values invalid for ``axis``'s scale (e.g. ``<=0`` on a + log axis) replaced by NaN, so they don't pollute data limits. + """ + if data is None: + return data + data = np.asanyarray(data, dtype=float) + valid = axis._scale.val_in_range(data) + if np.all(valid): + return data + return np.where(valid, data, np.nan) + + @_docstring.interpd @_api.define_aliases({ "xlim": ["xlim3d"], "ylim": ["ylim3d"], "zlim": ["zlim3d"]}) @@ -640,6 +654,8 @@ def autoscale(self, enable=True, axis='both', tight=None): def auto_scale_xyz(self, X, Y, Z=None, had_data=None): # This updates the bounding boxes as to keep a record as to what the # minimum sized rectangular volume holds the data. + X = _mask_scale_invalid(X, self.xaxis) + Y = _mask_scale_invalid(Y, self.yaxis) if np.shape(X) == np.shape(Y): self.xy_dataLim.update_from_data_xy( np.column_stack([np.ravel(X), np.ravel(Y)]), not had_data) @@ -647,6 +663,7 @@ def auto_scale_xyz(self, X, Y, Z=None, had_data=None): self.xy_dataLim.update_from_data_x(X, not had_data) self.xy_dataLim.update_from_data_y(Y, not had_data) if Z is not None: + Z = _mask_scale_invalid(Z, self.zaxis) self.zz_dataLim.update_from_data_x(Z, not had_data) # Let autoscale_view figure out how to use this data. self.autoscale_view() diff --git a/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py b/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py index 2a5593a641c9..2a2c104372cf 100644 --- a/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py +++ b/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py @@ -3189,3 +3189,17 @@ def test_scale3d_calc_coord(): # Pane coordinate should match axis limit (y-pane at max) assert pane_idx == 1 assert point[pane_idx] == pytest.approx(ax.get_ylim()[1]) + + +def test_plot_surface_log_scale_invalid_values(): + """Ensure non-positive Z values on a log z-axis does not corrupt zlim.""" + fig = plt.figure() + ax = fig.add_subplot(projection='3d') + ax.set_zscale('log') + X, Y = np.meshgrid(np.linspace(1, 3, 4), np.linspace(1, 3, 4)) + Z = X * Y - 4 # half the entries are <= 0, invalid for a log scale + ax.plot_surface(X, Y, Z) + fig.canvas.draw() + + zmin, zmax = ax.get_zlim() + assert 1e-3 < zmin < zmax < 1e3, f"zlim corrupted: {(zmin, zmax)}"