Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 35 additions & 21 deletions lib/matplotlib/scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
""" # noqa: E501

import inspect
import math
import textwrap
from functools import wraps

Expand Down Expand Up @@ -119,17 +118,20 @@ def val_in_range(self, val):
"""
Return whether the value(s) are within the valid range for this scale.

This method is a generic implementation. Subclasses may implement more
efficient solutions for their domain.
"""
try:
if not math.isfinite(val):
return False
Accepts a scalar or array-like ``val``. For a scalar, returns a
Python ``bool``. For an array, returns a bool ndarray of the same
shape. This is a generic implementation, and subclasses may implement
more efficient solutions for their domain.
"""
arr = np.asarray(val)
with np.errstate(invalid='ignore'):
try:
vmin, vmax = self.limit_range_for_scale(arr, arr, minpos=1e-300)
except (TypeError, ValueError):
result = np.zeros(arr.shape, dtype=bool)
else:
vmin, vmax = self.limit_range_for_scale(val, val, minpos=1e-300)
return vmin == val and vmax == val
except (TypeError, ValueError):
return False
result = np.isfinite(arr) & (vmin == arr) & (vmax == arr)
return bool(result) if arr.ndim == 0 else result


def _make_axis_parameter_optional(init_func):
Expand Down Expand Up @@ -219,11 +221,13 @@ def get_transform(self):

def val_in_range(self, val):
"""
Return whether the value is within the valid range for this scale.
Return whether the value(s) are within the valid range for this scale.

This is True for all values, except +-inf and NaN.
"""
return math.isfinite(val)
arr = np.asarray(val)
result = np.isfinite(arr)
return bool(result) if arr.ndim == 0 else result


class FuncTransform(Transform):
Expand Down Expand Up @@ -431,11 +435,14 @@ def limit_range_for_scale(self, vmin, vmax, minpos):

def val_in_range(self, val):
"""
Return whether the value is within the valid range for this scale.
Return whether the value(s) are within the valid range for this scale.

This is True for value(s) > 0 except +inf and NaN.
"""
return math.isfinite(val) and val > 0
arr = np.asarray(val)
with np.errstate(invalid='ignore'):
result = np.isfinite(arr) & (arr > 0)
return bool(result) if arr.ndim == 0 else result


class FuncScaleLog(LogScale):
Expand Down Expand Up @@ -625,11 +632,13 @@ def get_transform(self):

def val_in_range(self, val):
"""
Return whether the value is within the valid range for this scale.
Return whether the value(s) are within the valid range for this scale.

This is True for all values, except +-inf and NaN.
"""
return math.isfinite(val)
arr = np.asarray(val)
result = np.isfinite(arr)
return bool(result) if arr.ndim == 0 else result


class AsinhTransform(Transform):
Expand Down Expand Up @@ -759,11 +768,13 @@ def set_default_locators_and_formatters(self, axis):

def val_in_range(self, val):
"""
Return whether the value is within the valid range for this scale.
Return whether the value(s) are within the valid range for this scale.

This is True for all values, except +-inf and NaN.
"""
return math.isfinite(val)
arr = np.asarray(val)
result = np.isfinite(arr)
return bool(result) if arr.ndim == 0 else result


class LogitTransform(Transform):
Expand Down Expand Up @@ -880,11 +891,14 @@ def limit_range_for_scale(self, vmin, vmax, minpos):

def val_in_range(self, val):
"""
Return whether the value is within the valid range for this scale.
Return whether the value(s) are within the valid range for this scale.

This is True for value(s) which are between 0 and 1 (excluded).
"""
return 0 < val < 1
arr = np.asarray(val)
with np.errstate(invalid='ignore'):
result = (0 < arr) & (arr < 1)
return bool(result) if arr.ndim == 0 else result


_scale_mapping = {
Expand Down
24 changes: 24 additions & 0 deletions lib/matplotlib/tests/test_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,3 +477,27 @@ def test_val_in_range_base_fallback():
assert s.val_in_range(np.nan) is False
assert s.val_in_range(np.inf) is False
assert s.val_in_range(-np.inf) is False


def test_val_in_range_array():
# Vectorized: scalar in -> scalar bool, array in -> bool array.
arr = np.array([1.0, -1.0, 0.0, np.nan, np.inf, 5.0])
cases = {
'linear': [True, True, True, False, False, True],
'log': [True, False, False, False, False, True],
'symlog': [True, True, True, False, False, True],
'asinh': [True, True, True, False, False, True],
}
for name, expected in cases.items():
s = mscale._scale_mapping[name](axis=None)
np.testing.assert_array_equal(s.val_in_range(arr), expected)

s = mscale._scale_mapping['logit'](axis=None)
np.testing.assert_array_equal(
s.val_in_range(np.array([0.1, 0.5, 0.0, 1.0, -0.1, 1.1])),
[True, True, False, False, False, False])

# 2D shape is preserved.
out = mscale._scale_mapping['log'](axis=None).val_in_range(
np.array([[1.0, -1.0], [0.5, np.nan]]))
np.testing.assert_array_equal(out, [[True, False], [True, False]])
76 changes: 64 additions & 12 deletions lib/mpl_toolkits/mplot3d/art3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,30 @@ def _viewlim_mask(xs, ys, zs, axes):
return mask


def _scale_invalid_mask(xs, ys, zs, axes):
"""
Return the mask of points whose coordinates are invalid for the axis
scale they live on (e.g. <=0 on a log axis).

Parameters
----------
xs, ys, zs : array-like
The points to check, in data coordinates.
axes : Axes3D
The axes whose scales are queried.

Returns
-------
mask : np.ndarray
Boolean array, ``True`` where any of x/y/z is out of its scale's
valid domain.
"""
return np.logical_or.reduce((
np.logical_not(axes.xaxis._scale.val_in_range(xs)),
np.logical_not(axes.yaxis._scale.val_in_range(ys)),
np.logical_not(axes.zaxis._scale.val_in_range(zs))))


class Text3D(mtext.Text):
"""
Text object with 3D position and direction.
Expand Down Expand Up @@ -191,8 +215,10 @@ def set_3d_properties(self, z=0, zdir='z', axlim_clip=False):

@artist.allow_rasterization
def draw(self, renderer):
mask = _scale_invalid_mask(self._x, self._y, self._z, self.axes)
if self._axlim_clip:
mask = _viewlim_mask(self._x, self._y, self._z, self.axes)
mask = mask | _viewlim_mask(self._x, self._y, self._z, self.axes)
if np.any(mask):
pos3d = np.ma.array([self._x, self._y, self._z],
mask=mask, dtype=float).filled(np.nan)
else:
Expand Down Expand Up @@ -328,9 +354,12 @@ def get_data_3d(self):

@artist.allow_rasterization
def draw(self, renderer):
scale_mask = _scale_invalid_mask(*self._verts3d, self.axes)
if self._axlim_clip:
scale_mask = scale_mask | _viewlim_mask(*self._verts3d, self.axes)
if np.any(scale_mask):
mask = np.broadcast_to(
_viewlim_mask(*self._verts3d, self.axes),
scale_mask,
(len(self._verts3d), *self._verts3d[0].shape)
)
xs3d, ys3d, zs3d = np.ma.array(self._verts3d,
Expand Down Expand Up @@ -424,10 +453,13 @@ class Collection3D(Collection):
def do_3d_projection(self):
"""Project the points according to renderer matrix."""
vs_list = [vs for vs, _ in self._3dverts_codes]
masks = [_scale_invalid_mask(*vs.T, self.axes) for vs in vs_list]
if self._axlim_clip:
vs_list = [np.ma.array(vs, mask=np.broadcast_to(
_viewlim_mask(*vs.T, self.axes), vs.shape))
for vs in vs_list]
masks = [m | _viewlim_mask(*vs.T, self.axes)
for m, vs in zip(masks, vs_list)]
vs_list = [np.ma.array(vs, mask=np.broadcast_to(m, vs.shape))
if np.any(m) else vs
for vs, m in zip(vs_list, masks)]
xyzs_list = [proj3d._scale_proj_transform(
vs[:, 0], vs[:, 1], vs[:, 2], self.axes) for vs in vs_list]
self._paths = [mpath.Path(np.ma.column_stack([xs, ys]), cs)
Expand Down Expand Up @@ -508,6 +540,14 @@ def do_3d_projection(self):
if np.ma.isMA(segments):
mask = segments.mask

scale_mask = _scale_invalid_mask(segments[..., 0],
segments[..., 1],
segments[..., 2],
self.axes)
if np.any(scale_mask):
mask = mask | np.broadcast_to(scale_mask[..., np.newaxis],
(*scale_mask.shape, 3))

if self._axlim_clip:
viewlim_mask = _viewlim_mask(segments[..., 0],
segments[..., 1],
Expand Down Expand Up @@ -597,12 +637,15 @@ def get_path(self):

def do_3d_projection(self):
s = self._segment3d
xs0, ys0, zs0 = zip(*s)
mask = _scale_invalid_mask(xs0, ys0, zs0, self.axes)
if self._axlim_clip:
mask = _viewlim_mask(*zip(*s), self.axes)
mask = mask | _viewlim_mask(xs0, ys0, zs0, self.axes)
if np.any(mask):
xs, ys, zs = np.ma.array(zip(*s),
dtype=float, mask=mask).filled(np.nan)
else:
xs, ys, zs = zip(*s)
xs, ys, zs = xs0, ys0, zs0
vxs, vys, vzs, vis = proj3d._scale_proj_transform_clip(xs, ys, zs, self.axes)
self._path2d = mpath.Path(np.ma.column_stack([vxs, vys]))
return min(vzs)
Expand Down Expand Up @@ -657,12 +700,15 @@ def set_3d_properties(self, path, zs=0, zdir='z', axlim_clip=False):

def do_3d_projection(self):
s = self._segment3d
xs0, ys0, zs0 = zip(*s)
mask = _scale_invalid_mask(xs0, ys0, zs0, self.axes)
if self._axlim_clip:
mask = _viewlim_mask(*zip(*s), self.axes)
mask = mask | _viewlim_mask(xs0, ys0, zs0, self.axes)
if np.any(mask):
xs, ys, zs = np.ma.array(zip(*s),
dtype=float, mask=mask).filled(np.nan)
else:
xs, ys, zs = zip(*s)
xs, ys, zs = xs0, ys0, zs0
vxs, vys, vzs, vis = proj3d._scale_proj_transform_clip(xs, ys, zs, self.axes)
self._path2d = mpath.Path(np.ma.column_stack([vxs, vys]), self._code3d)
return min(vzs)
Expand Down Expand Up @@ -801,8 +847,10 @@ def set_3d_properties(self, zs, zdir, axlim_clip=False):
self.stale = True

def do_3d_projection(self):
mask = _scale_invalid_mask(*self._offsets3d, self.axes)
if self._axlim_clip:
mask = _viewlim_mask(*self._offsets3d, self.axes)
mask = mask | _viewlim_mask(*self._offsets3d, self.axes)
if np.any(mask):
xs, ys, zs = np.ma.array(self._offsets3d, mask=mask)
else:
xs, ys, zs = self._offsets3d
Expand Down Expand Up @@ -1023,8 +1071,10 @@ def do_3d_projection(self):
for xyz in self._offsets3d:
if np.ma.isMA(xyz):
mask = mask | xyz.mask
mask = mask | _scale_invalid_mask(*self._offsets3d, self.axes)
if self._axlim_clip:
mask = mask | _viewlim_mask(*self._offsets3d, self.axes)
if np.any(mask):
mask = np.broadcast_to(mask,
(len(self._offsets3d), *self._offsets3d[0].shape))
xyzs = np.ma.array(self._offsets3d, mask=mask)
Expand Down Expand Up @@ -1362,9 +1412,11 @@ def do_3d_projection(self):
if self._edge_is_mapped:
self._edgecolor3d = self._edgecolors

needs_masking = np.any(self._invalid_vertices)
num_faces = len(self._faces)
mask = self._invalid_vertices
mask = self._invalid_vertices | _scale_invalid_mask(
self._faces[..., 0], self._faces[..., 1],
self._faces[..., 2], self.axes)
needs_masking = np.any(mask)

# Some faces might contain masked vertices, so we want to ignore any
# errors that those might cause
Expand Down
17 changes: 17 additions & 0 deletions lib/mpl_toolkits/mplot3d/axes3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,20 @@
from . import axis3d


def _mask_scale_invalid(data, axis):
"""
Return ``data`` with values invalid for ``axis``'s scale (e.g. ``<=0`` on a
log axis) replaced by NaN, so they don't pollute data limits.
"""
if data is None:
return data
data = np.asanyarray(data, dtype=float)
valid = axis._scale.val_in_range(data)
if np.all(valid):
return data
return np.where(valid, data, np.nan)


@_docstring.interpd
@_api.define_aliases({
"xlim": ["xlim3d"], "ylim": ["ylim3d"], "zlim": ["zlim3d"]})
Expand Down Expand Up @@ -640,13 +654,16 @@ def autoscale(self, enable=True, axis='both', tight=None):
def auto_scale_xyz(self, X, Y, Z=None, had_data=None):
# This updates the bounding boxes as to keep a record as to what the
# minimum sized rectangular volume holds the data.
X = _mask_scale_invalid(X, self.xaxis)
Y = _mask_scale_invalid(Y, self.yaxis)
if np.shape(X) == np.shape(Y):
self.xy_dataLim.update_from_data_xy(
np.column_stack([np.ravel(X), np.ravel(Y)]), not had_data)
else:
self.xy_dataLim.update_from_data_x(X, not had_data)
self.xy_dataLim.update_from_data_y(Y, not had_data)
if Z is not None:
Z = _mask_scale_invalid(Z, self.zaxis)
self.zz_dataLim.update_from_data_x(Z, not had_data)
# Let autoscale_view figure out how to use this data.
self.autoscale_view()
Expand Down
14 changes: 14 additions & 0 deletions lib/mpl_toolkits/mplot3d/tests/test_axes3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3189,3 +3189,17 @@ def test_scale3d_calc_coord():
# Pane coordinate should match axis limit (y-pane at max)
assert pane_idx == 1
assert point[pane_idx] == pytest.approx(ax.get_ylim()[1])


def test_plot_surface_log_scale_invalid_values():
"""Ensure non-positive Z values on a log z-axis does not corrupt zlim."""
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.set_zscale('log')
X, Y = np.meshgrid(np.linspace(1, 3, 4), np.linspace(1, 3, 4))
Z = X * Y - 4 # half the entries are <= 0, invalid for a log scale
ax.plot_surface(X, Y, Z)
fig.canvas.draw()

zmin, zmax = ax.get_zlim()
assert 1e-3 < zmin < zmax < 1e3, f"zlim corrupted: {(zmin, zmax)}"
Loading