Skip to content
105 changes: 58 additions & 47 deletions lib/mpl_toolkits/mplot3d/art3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,11 @@ def set_3d_properties(self, z=0, zdir='z', axlim_clip=False):

@artist.allow_rasterization
def draw(self, renderer):
pos3d = np.array([self._x, self._y, self._z], dtype=float)
if self._axlim_clip:
mask = _viewlim_mask(self._x, self._y, self._z, self.axes)
pos3d = np.ma.array([self._x, self._y, self._z],
mask=mask, dtype=float).filled(np.nan)
else:
pos3d = np.array([self._x, self._y, self._z], dtype=float)
if np.any(mask):
pos3d = np.where(mask, np.nan, pos3d)

dir_end = pos3d + self._dir_vec
points = np.asarray([pos3d, dir_end])
Expand Down Expand Up @@ -328,15 +327,13 @@ def get_data_3d(self):

@artist.allow_rasterization
def draw(self, renderer):
xs3d, ys3d, zs3d = self._verts3d
if self._axlim_clip:
mask = np.broadcast_to(
_viewlim_mask(*self._verts3d, self.axes),
(len(self._verts3d), *self._verts3d[0].shape)
)
xs3d, ys3d, zs3d = np.ma.array(self._verts3d,
dtype=float, mask=mask).filled(np.nan)
else:
xs3d, ys3d, zs3d = self._verts3d
mask = _viewlim_mask(xs3d, ys3d, zs3d, self.axes)
if np.any(mask):
xs3d = np.where(mask, np.nan, xs3d)
ys3d = np.where(mask, np.nan, ys3d)
zs3d = np.where(mask, np.nan, zs3d)
xs, ys, zs, tis = proj3d._scale_proj_transform_clip(xs3d, ys3d, zs3d, self.axes)
self.set_data(xs, ys)
super().draw(renderer)
Expand Down Expand Up @@ -425,15 +422,20 @@ def do_3d_projection(self):
"""Project the points according to renderer matrix."""
vs_list = [vs for vs, _ in self._3dverts_codes]
if self._axlim_clip:
vs_list = [np.ma.array(vs, mask=np.broadcast_to(
_viewlim_mask(*vs.T, self.axes), vs.shape))
for vs in vs_list]
vs_list_new = []
for vs in vs_list:
mask = _viewlim_mask(*vs.T, self.axes)
if np.any(mask):
vs = vs.astype(float, copy=True)
vs[mask] = np.nan
vs_list_new.append(vs)
vs_list = vs_list_new
xyzs_list = [proj3d._scale_proj_transform(
vs[:, 0], vs[:, 1], vs[:, 2], self.axes) for vs in vs_list]
self._paths = [mpath.Path(np.ma.column_stack([xs, ys]), cs)
self._paths = [mpath.Path(np.column_stack([xs, ys]), cs)
for (xs, ys, _), (_, cs) in zip(xyzs_list, self._3dverts_codes)]
zs = np.concatenate([zs for _, _, zs in xyzs_list])
return zs.min() if len(zs) else 1e9
return np.nanmin(zs) if len(zs) else 1e9


def collection_2d_to_3d(col, zs=0, zdir='z', axlim_clip=False):
Expand Down Expand Up @@ -499,9 +501,13 @@ def do_3d_projection(self):
"""
segments = self._segments3d

# Handle ragged inputs, but prefer a faster path for same-length segments
segment_lengths = [len(segment) for segment in segments]
ragged = len(set(segment_lengths)) > 1
if isinstance(segments, np.ndarray):
ragged = False
else:
# Handle ragged inputs, but prefer a faster path for same-length segments
segment_lengths = [len(segment) for segment in segments]
ragged = len(set(segment_lengths)) > 1

if ragged:
# Branch masked / non-masked for speed
if any(np.ma.isMA(segment) for segment in segments):
Expand Down Expand Up @@ -531,17 +537,17 @@ def do_3d_projection(self):
(*viewlim_mask.shape, 3))
mask = mask | viewlim_mask

xyzs = proj3d._scale_proj_transform_vectors(segments, self.axes)
if mask is not False:
xyzs = np.ma.array(xyzs, mask=mask)
xyzs = np.asarray(
proj3d._scale_proj_transform_vectors(segments, self.axes))
xyzs[mask] = np.nan
segments_2d = xyzs[..., 0:2]
if ragged:
segments_2d = np.split(segments_2d, np.cumsum(segment_lengths[:-1]))
LineCollection.set_segments(self, segments_2d)

# FIXME
if len(xyzs) > 0:
minz = min(xyzs[..., 2].min(), 1e9)
minz = min(np.nanmin(xyzs[..., 2]), 1e9)
else:
minz = np.nan
return minz
Expand Down Expand Up @@ -612,15 +618,17 @@ def get_path(self):

def do_3d_projection(self):
s = self._segment3d
mask = False
xs, ys, zs = zip(*s)
if self._axlim_clip:
mask = _viewlim_mask(*zip(*s), self.axes)
xs, ys, zs = np.ma.array(zip(*s),
dtype=float, mask=mask).filled(np.nan)
else:
xs, ys, zs = zip(*s)
mask = _viewlim_mask(xs, ys, zs, self.axes)
if np.any(mask):
xs = np.where(mask, np.nan, xs)
ys = np.where(mask, np.nan, ys)
zs = np.where(mask, np.nan, zs)
vxs, vys, vzs, vis = proj3d._scale_proj_transform_clip(xs, ys, zs, self.axes)
self._path2d = mpath.Path(np.ma.column_stack([vxs, vys]))
return min(vzs)
self._path2d = mpath.Path(np.column_stack([vxs, vys]))
return np.nanmin(vzs)


class PathPatch3D(Patch3D):
Expand Down Expand Up @@ -672,15 +680,18 @@ def set_3d_properties(self, path, zs=0, zdir='z', axlim_clip=False):

def do_3d_projection(self):
s = self._segment3d
mask = False
xs, ys, zs = zip(*s)
if self._axlim_clip:
mask = _viewlim_mask(*zip(*s), self.axes)
xs, ys, zs = np.ma.array(zip(*s),
dtype=float, mask=mask).filled(np.nan)
else:
xs, ys, zs = zip(*s)
mask = _viewlim_mask(xs, ys, zs, self.axes)
if np.any(mask):
xs = np.where(mask, np.nan, xs)
ys = np.where(mask, np.nan, ys)
zs = np.where(mask, np.nan, zs)
vxs, vys, vzs, vis = proj3d._scale_proj_transform_clip(xs, ys, zs, self.axes)
self._path2d = mpath.Path(np.ma.column_stack([vxs, vys]), self._code3d)
return min(vzs)
self._path2d = mpath.Path(np.column_stack([vxs, vys]), self._code3d)

return np.nanmin(vzs)


def _get_patch_verts(patch):
Expand Down Expand Up @@ -816,20 +827,20 @@ def set_3d_properties(self, zs, zdir, axlim_clip=False):
self.stale = True

def do_3d_projection(self):
xs, ys, zs = self._offsets3d
mask = False
if self._axlim_clip:
mask = _viewlim_mask(*self._offsets3d, self.axes)
xs, ys, zs = np.ma.array(self._offsets3d, mask=mask)
else:
xs, ys, zs = self._offsets3d
mask = _viewlim_mask(xs, ys, zs, self.axes)
if np.any(mask):
xs = np.where(mask, np.nan, xs)
ys = np.where(mask, np.nan, ys)
zs = np.where(mask, np.nan, zs)
vxs, vys, vzs, vis = proj3d._scale_proj_transform_clip(xs, ys, zs, self.axes)
self._vzs = vzs
if np.ma.isMA(vxs):
super().set_offsets(np.ma.column_stack([vxs, vys]))
else:
super().set_offsets(np.column_stack([vxs, vys]))
super().set_offsets(np.column_stack([vxs, vys]))

if vzs.size > 0:
return min(vzs)
return np.nanmin(vzs)
else:
return np.nan

Expand Down
94 changes: 89 additions & 5 deletions lib/mpl_toolkits/mplot3d/axes3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def __init__(
self.set_axis_on()
self.M = None
self.invM = None
self._draw_cache = {}

self._view_margin = 1/48 # default value to match mpl3.8
self.autoscale_view()
Expand Down Expand Up @@ -278,6 +279,20 @@ def _transformed_cube(self, vals):
(minx, maxy, maxz)]
return np.column_stack(proj3d._proj_trans_points(xyzs, self.M))

def _get_transformed_cube(self):
"""Get transformed cube from draw cache, or compute if outside draw cycle."""
if 'transformed_cube' in self._draw_cache:
return self._draw_cache['transformed_cube']
bounds = self._draw_cache.get('bounds',
(*self.get_xbound(), *self.get_ybound(), *self.get_zbound()))
return self._transformed_cube(bounds)

def _get_bounds(self):
"""Get data bounds from draw cache, or compute if outside draw cycle."""
if 'bounds' in self._draw_cache:
return self._draw_cache['bounds']
return (*self.get_xbound(), *self.get_ybound(), *self.get_zbound())

def _update_transScale(self):
"""
Override transScale to always use identity transforms.
Expand Down Expand Up @@ -474,6 +489,14 @@ def draw(self, renderer):
self.M = self.get_proj()
self.invM = np.linalg.inv(self.M)

# Cache values used multiple times during axis drawing
self._draw_cache = {}
bounds = (*self.get_xbound(), *self.get_ybound(), *self.get_zbound())
self._draw_cache['scaled_limits'] = self._calc_scaled_limits()
self._draw_cache['bounds'] = bounds
self._draw_cache['transformed_cube'] = self._transformed_cube(bounds)
self._draw_cache['coord_info'] = self._calc_coord_info()

collections_and_patches = (
artist for artist in self._children
if isinstance(artist, (mcoll.Collection, mpatches.Patch))
Expand All @@ -499,6 +522,9 @@ def draw(self, renderer):
artist.do_3d_projection()

if self._axis3don:
# Update ticks on all axes
for axis in self._axis_map.values():
axis._ticks_to_draw = axis._update_ticks()
# Draw panes first
for axis in self._axis_map.values():
axis.draw_pane(renderer)
Expand Down Expand Up @@ -1358,7 +1384,7 @@ def _roll_to_vertical(
else:
return np.roll(arr, (self._vertical_axis - 2))

def _get_scaled_limits(self):
def _calc_scaled_limits(self):
"""
Get axis limits transformed through their respective scale transforms.

Expand All @@ -1373,6 +1399,61 @@ def _get_scaled_limits(self):
zmin, zmax = self.zaxis.get_transform().transform(self.get_zlim3d())
return xmin, xmax, ymin, ymax, zmin, zmax

def _get_scaled_limits(self):
"""Get scaled limits from draw cache, or compute if outside draw cycle."""
if 'scaled_limits' in self._draw_cache:
return self._draw_cache['scaled_limits']
return self._calc_scaled_limits()

def _calc_coord_info(self):
"""
Compute coordinate info for axis drawing.

Returns
-------
mins : ndarray
Minimum values [xmin, ymin, zmin] in scaled coordinates.
maxs : ndarray
Maximum values [xmax, ymax, zmax] in scaled coordinates.
bounds_proj : ndarray
Projected cube corners.
highs : ndarray
Boolean array indicating which planes are higher up.
"""
xmin, xmax, ymin, ymax, zmin, zmax = self._get_scaled_limits()
mins = np.array([xmin, ymin, zmin])
maxs = np.array([xmax, ymax, zmax])
bounds_proj = self._get_transformed_cube()

# Determine which one of the parallel planes are higher up
means_z0 = np.zeros(3)
means_z1 = np.zeros(3)
for i in range(3):
means_z0[i] = np.mean(bounds_proj[axis3d.Axis._PLANES[2 * i], 2])
means_z1[i] = np.mean(bounds_proj[axis3d.Axis._PLANES[2 * i + 1], 2])
highs = means_z0 < means_z1

# Special handling for edge-on views
equals = np.abs(means_z0 - means_z1) <= np.finfo(float).eps
if np.sum(equals) == 2:
vertical = np.where(~equals)[0][0]
if vertical == 2: # looking at XY plane
highs = np.array([True, True, highs[2]])
elif vertical == 1: # looking at XZ plane
highs = np.array([True, highs[1], False])
elif vertical == 0: # looking at YZ plane
highs = np.array([highs[0], False, False])

return mins, maxs, bounds_proj, highs

def _get_coord_info(self):
"""Get coord info from draw cache, or compute if outside draw cycle."""
if 'coord_info' in self._draw_cache:
return self._draw_cache['coord_info']
coord_info = self._calc_coord_info()
self._draw_cache['coord_info'] = coord_info
return coord_info

def _untransform_point(self, x, y, z):
"""
Convert a point from transformed coordinates to data coordinates.
Expand Down Expand Up @@ -1431,8 +1512,11 @@ def get_proj(self):
# For non-linear scales, we use the scaled limits so the world
# transformation maps transformed coordinates (not data coordinates)
# to the unit cube
scaled_limits = self._get_scaled_limits()
worldM = proj3d.world_transformation(*scaled_limits, pb_aspect=box_aspect)
scaled_limits = self._calc_scaled_limits()
worldM = proj3d.world_transformation(
*scaled_limits,
pb_aspect=box_aspect,
)

# Look into the middle of the world coordinates:
R = 0.5 * box_aspect
Expand Down Expand Up @@ -1872,7 +1956,7 @@ def drag_pan(self, button, key, x, y):
duvw_projected = R.T @ np.array([du, dv, dw])

# Calculate pan distance in transformed coordinates for non-linear scales
minx, maxx, miny, maxy, minz, maxz = self._get_scaled_limits()
minx, maxx, miny, maxy, minz, maxz = self._calc_scaled_limits()
dx = (maxx - minx) * duvw_projected[0]
dy = (maxy - miny) * duvw_projected[1]
dz = (maxz - minz) * duvw_projected[2]
Expand Down Expand Up @@ -2026,7 +2110,7 @@ def _get_w_centers_ranges(self):
computed in transformed coordinates to ensure uniform zoom/pan behavior.
"""
# Get limits in transformed coordinates for non-linear scale zoom/pan
minx, maxx, miny, maxy, minz, maxz = self._get_scaled_limits()
minx, maxx, miny, maxy, minz, maxz = self._calc_scaled_limits()
cx = (maxx + minx)/2
cy = (maxy + miny)/2
cz = (maxz + minz)/2
Expand Down
Loading
Loading