diff --git a/lib/mpl_toolkits/mplot3d/art3d.py b/lib/mpl_toolkits/mplot3d/art3d.py index 114bb152ae80..e463bf18e671 100644 --- a/lib/mpl_toolkits/mplot3d/art3d.py +++ b/lib/mpl_toolkits/mplot3d/art3d.py @@ -191,12 +191,11 @@ def set_3d_properties(self, z=0, zdir='z', axlim_clip=False): @artist.allow_rasterization def draw(self, renderer): + pos3d = np.array([self._x, self._y, self._z], dtype=float) if self._axlim_clip: mask = _viewlim_mask(self._x, self._y, self._z, self.axes) - pos3d = np.ma.array([self._x, self._y, self._z], - mask=mask, dtype=float).filled(np.nan) - else: - pos3d = np.array([self._x, self._y, self._z], dtype=float) + if np.any(mask): + pos3d = np.where(mask, np.nan, pos3d) dir_end = pos3d + self._dir_vec points = np.asarray([pos3d, dir_end]) @@ -328,15 +327,13 @@ def get_data_3d(self): @artist.allow_rasterization def draw(self, renderer): + xs3d, ys3d, zs3d = self._verts3d if self._axlim_clip: - mask = np.broadcast_to( - _viewlim_mask(*self._verts3d, self.axes), - (len(self._verts3d), *self._verts3d[0].shape) - ) - xs3d, ys3d, zs3d = np.ma.array(self._verts3d, - dtype=float, mask=mask).filled(np.nan) - else: - xs3d, ys3d, zs3d = self._verts3d + mask = _viewlim_mask(xs3d, ys3d, zs3d, self.axes) + if np.any(mask): + xs3d = np.where(mask, np.nan, xs3d) + ys3d = np.where(mask, np.nan, ys3d) + zs3d = np.where(mask, np.nan, zs3d) xs, ys, zs, tis = proj3d._scale_proj_transform_clip(xs3d, ys3d, zs3d, self.axes) self.set_data(xs, ys) super().draw(renderer) @@ -425,15 +422,20 @@ def do_3d_projection(self): """Project the points according to renderer matrix.""" vs_list = [vs for vs, _ in self._3dverts_codes] if self._axlim_clip: - vs_list = [np.ma.array(vs, mask=np.broadcast_to( - _viewlim_mask(*vs.T, self.axes), vs.shape)) - for vs in vs_list] + vs_list_new = [] + for vs in vs_list: + mask = _viewlim_mask(*vs.T, self.axes) + if np.any(mask): + vs = vs.astype(float, copy=True) + vs[mask] = np.nan + vs_list_new.append(vs) + vs_list = vs_list_new xyzs_list = [proj3d._scale_proj_transform( vs[:, 0], vs[:, 1], vs[:, 2], self.axes) for vs in vs_list] - self._paths = [mpath.Path(np.ma.column_stack([xs, ys]), cs) + self._paths = [mpath.Path(np.column_stack([xs, ys]), cs) for (xs, ys, _), (_, cs) in zip(xyzs_list, self._3dverts_codes)] zs = np.concatenate([zs for _, _, zs in xyzs_list]) - return zs.min() if len(zs) else 1e9 + return np.nanmin(zs) if len(zs) else 1e9 def collection_2d_to_3d(col, zs=0, zdir='z', axlim_clip=False): @@ -499,9 +501,13 @@ def do_3d_projection(self): """ segments = self._segments3d - # Handle ragged inputs, but prefer a faster path for same-length segments - segment_lengths = [len(segment) for segment in segments] - ragged = len(set(segment_lengths)) > 1 + if isinstance(segments, np.ndarray): + ragged = False + else: + # Handle ragged inputs, but prefer a faster path for same-length segments + segment_lengths = [len(segment) for segment in segments] + ragged = len(set(segment_lengths)) > 1 + if ragged: # Branch masked / non-masked for speed if any(np.ma.isMA(segment) for segment in segments): @@ -531,9 +537,9 @@ def do_3d_projection(self): (*viewlim_mask.shape, 3)) mask = mask | viewlim_mask - xyzs = proj3d._scale_proj_transform_vectors(segments, self.axes) - if mask is not False: - xyzs = np.ma.array(xyzs, mask=mask) + xyzs = np.asarray( + proj3d._scale_proj_transform_vectors(segments, self.axes)) + xyzs[mask] = np.nan segments_2d = xyzs[..., 0:2] if ragged: segments_2d = np.split(segments_2d, np.cumsum(segment_lengths[:-1])) @@ -541,7 +547,7 @@ def do_3d_projection(self): # FIXME if len(xyzs) > 0: - minz = min(xyzs[..., 2].min(), 1e9) + minz = min(np.nanmin(xyzs[..., 2]), 1e9) else: minz = np.nan return minz @@ -612,15 +618,17 @@ def get_path(self): def do_3d_projection(self): s = self._segment3d + mask = False + xs, ys, zs = zip(*s) if self._axlim_clip: - mask = _viewlim_mask(*zip(*s), self.axes) - xs, ys, zs = np.ma.array(zip(*s), - dtype=float, mask=mask).filled(np.nan) - else: - xs, ys, zs = zip(*s) + mask = _viewlim_mask(xs, ys, zs, self.axes) + if np.any(mask): + xs = np.where(mask, np.nan, xs) + ys = np.where(mask, np.nan, ys) + zs = np.where(mask, np.nan, zs) vxs, vys, vzs, vis = proj3d._scale_proj_transform_clip(xs, ys, zs, self.axes) - self._path2d = mpath.Path(np.ma.column_stack([vxs, vys])) - return min(vzs) + self._path2d = mpath.Path(np.column_stack([vxs, vys])) + return np.nanmin(vzs) class PathPatch3D(Patch3D): @@ -672,15 +680,18 @@ def set_3d_properties(self, path, zs=0, zdir='z', axlim_clip=False): def do_3d_projection(self): s = self._segment3d + mask = False + xs, ys, zs = zip(*s) if self._axlim_clip: - mask = _viewlim_mask(*zip(*s), self.axes) - xs, ys, zs = np.ma.array(zip(*s), - dtype=float, mask=mask).filled(np.nan) - else: - xs, ys, zs = zip(*s) + mask = _viewlim_mask(xs, ys, zs, self.axes) + if np.any(mask): + xs = np.where(mask, np.nan, xs) + ys = np.where(mask, np.nan, ys) + zs = np.where(mask, np.nan, zs) vxs, vys, vzs, vis = proj3d._scale_proj_transform_clip(xs, ys, zs, self.axes) - self._path2d = mpath.Path(np.ma.column_stack([vxs, vys]), self._code3d) - return min(vzs) + self._path2d = mpath.Path(np.column_stack([vxs, vys]), self._code3d) + + return np.nanmin(vzs) def _get_patch_verts(patch): @@ -816,20 +827,20 @@ def set_3d_properties(self, zs, zdir, axlim_clip=False): self.stale = True def do_3d_projection(self): + xs, ys, zs = self._offsets3d + mask = False if self._axlim_clip: - mask = _viewlim_mask(*self._offsets3d, self.axes) - xs, ys, zs = np.ma.array(self._offsets3d, mask=mask) - else: - xs, ys, zs = self._offsets3d + mask = _viewlim_mask(xs, ys, zs, self.axes) + if np.any(mask): + xs = np.where(mask, np.nan, xs) + ys = np.where(mask, np.nan, ys) + zs = np.where(mask, np.nan, zs) vxs, vys, vzs, vis = proj3d._scale_proj_transform_clip(xs, ys, zs, self.axes) self._vzs = vzs - if np.ma.isMA(vxs): - super().set_offsets(np.ma.column_stack([vxs, vys])) - else: - super().set_offsets(np.column_stack([vxs, vys])) + super().set_offsets(np.column_stack([vxs, vys])) if vzs.size > 0: - return min(vzs) + return np.nanmin(vzs) else: return np.nan diff --git a/lib/mpl_toolkits/mplot3d/axes3d.py b/lib/mpl_toolkits/mplot3d/axes3d.py index 45f9319355e0..50e3f9f8b860 100644 --- a/lib/mpl_toolkits/mplot3d/axes3d.py +++ b/lib/mpl_toolkits/mplot3d/axes3d.py @@ -191,6 +191,7 @@ def __init__( self.set_axis_on() self.M = None self.invM = None + self._draw_cache = {} self._view_margin = 1/48 # default value to match mpl3.8 self.autoscale_view() @@ -278,6 +279,20 @@ def _transformed_cube(self, vals): (minx, maxy, maxz)] return np.column_stack(proj3d._proj_trans_points(xyzs, self.M)) + def _get_transformed_cube(self): + """Get transformed cube from draw cache, or compute if outside draw cycle.""" + if 'transformed_cube' in self._draw_cache: + return self._draw_cache['transformed_cube'] + bounds = self._draw_cache.get('bounds', + (*self.get_xbound(), *self.get_ybound(), *self.get_zbound())) + return self._transformed_cube(bounds) + + def _get_bounds(self): + """Get data bounds from draw cache, or compute if outside draw cycle.""" + if 'bounds' in self._draw_cache: + return self._draw_cache['bounds'] + return (*self.get_xbound(), *self.get_ybound(), *self.get_zbound()) + def _update_transScale(self): """ Override transScale to always use identity transforms. @@ -474,6 +489,14 @@ def draw(self, renderer): self.M = self.get_proj() self.invM = np.linalg.inv(self.M) + # Cache values used multiple times during axis drawing + self._draw_cache = {} + bounds = (*self.get_xbound(), *self.get_ybound(), *self.get_zbound()) + self._draw_cache['scaled_limits'] = self._calc_scaled_limits() + self._draw_cache['bounds'] = bounds + self._draw_cache['transformed_cube'] = self._transformed_cube(bounds) + self._draw_cache['coord_info'] = self._calc_coord_info() + collections_and_patches = ( artist for artist in self._children if isinstance(artist, (mcoll.Collection, mpatches.Patch)) @@ -499,6 +522,9 @@ def draw(self, renderer): artist.do_3d_projection() if self._axis3don: + # Update ticks on all axes + for axis in self._axis_map.values(): + axis._ticks_to_draw = axis._update_ticks() # Draw panes first for axis in self._axis_map.values(): axis.draw_pane(renderer) @@ -1358,7 +1384,7 @@ def _roll_to_vertical( else: return np.roll(arr, (self._vertical_axis - 2)) - def _get_scaled_limits(self): + def _calc_scaled_limits(self): """ Get axis limits transformed through their respective scale transforms. @@ -1373,6 +1399,61 @@ def _get_scaled_limits(self): zmin, zmax = self.zaxis.get_transform().transform(self.get_zlim3d()) return xmin, xmax, ymin, ymax, zmin, zmax + def _get_scaled_limits(self): + """Get scaled limits from draw cache, or compute if outside draw cycle.""" + if 'scaled_limits' in self._draw_cache: + return self._draw_cache['scaled_limits'] + return self._calc_scaled_limits() + + def _calc_coord_info(self): + """ + Compute coordinate info for axis drawing. + + Returns + ------- + mins : ndarray + Minimum values [xmin, ymin, zmin] in scaled coordinates. + maxs : ndarray + Maximum values [xmax, ymax, zmax] in scaled coordinates. + bounds_proj : ndarray + Projected cube corners. + highs : ndarray + Boolean array indicating which planes are higher up. + """ + xmin, xmax, ymin, ymax, zmin, zmax = self._get_scaled_limits() + mins = np.array([xmin, ymin, zmin]) + maxs = np.array([xmax, ymax, zmax]) + bounds_proj = self._get_transformed_cube() + + # Determine which one of the parallel planes are higher up + means_z0 = np.zeros(3) + means_z1 = np.zeros(3) + for i in range(3): + means_z0[i] = np.mean(bounds_proj[axis3d.Axis._PLANES[2 * i], 2]) + means_z1[i] = np.mean(bounds_proj[axis3d.Axis._PLANES[2 * i + 1], 2]) + highs = means_z0 < means_z1 + + # Special handling for edge-on views + equals = np.abs(means_z0 - means_z1) <= np.finfo(float).eps + if np.sum(equals) == 2: + vertical = np.where(~equals)[0][0] + if vertical == 2: # looking at XY plane + highs = np.array([True, True, highs[2]]) + elif vertical == 1: # looking at XZ plane + highs = np.array([True, highs[1], False]) + elif vertical == 0: # looking at YZ plane + highs = np.array([highs[0], False, False]) + + return mins, maxs, bounds_proj, highs + + def _get_coord_info(self): + """Get coord info from draw cache, or compute if outside draw cycle.""" + if 'coord_info' in self._draw_cache: + return self._draw_cache['coord_info'] + coord_info = self._calc_coord_info() + self._draw_cache['coord_info'] = coord_info + return coord_info + def _untransform_point(self, x, y, z): """ Convert a point from transformed coordinates to data coordinates. @@ -1431,8 +1512,11 @@ def get_proj(self): # For non-linear scales, we use the scaled limits so the world # transformation maps transformed coordinates (not data coordinates) # to the unit cube - scaled_limits = self._get_scaled_limits() - worldM = proj3d.world_transformation(*scaled_limits, pb_aspect=box_aspect) + scaled_limits = self._calc_scaled_limits() + worldM = proj3d.world_transformation( + *scaled_limits, + pb_aspect=box_aspect, + ) # Look into the middle of the world coordinates: R = 0.5 * box_aspect @@ -1872,7 +1956,7 @@ def drag_pan(self, button, key, x, y): duvw_projected = R.T @ np.array([du, dv, dw]) # Calculate pan distance in transformed coordinates for non-linear scales - minx, maxx, miny, maxy, minz, maxz = self._get_scaled_limits() + minx, maxx, miny, maxy, minz, maxz = self._calc_scaled_limits() dx = (maxx - minx) * duvw_projected[0] dy = (maxy - miny) * duvw_projected[1] dz = (maxz - minz) * duvw_projected[2] @@ -2026,7 +2110,7 @@ def _get_w_centers_ranges(self): computed in transformed coordinates to ensure uniform zoom/pan behavior. """ # Get limits in transformed coordinates for non-linear scale zoom/pan - minx, maxx, miny, maxy, minz, maxz = self._get_scaled_limits() + minx, maxx, miny, maxy, minz, maxz = self._calc_scaled_limits() cx = (maxx + minx)/2 cy = (maxy + miny)/2 cz = (maxz + minz)/2 diff --git a/lib/mpl_toolkits/mplot3d/axis3d.py b/lib/mpl_toolkits/mplot3d/axis3d.py index 0ac2e50b1a1a..64c06be5007e 100644 --- a/lib/mpl_toolkits/mplot3d/axis3d.py +++ b/lib/mpl_toolkits/mplot3d/axis3d.py @@ -158,6 +158,7 @@ def _init3d(self): self.axes._set_artist_props(self.line) self.axes._set_artist_props(self.pane) + self._ticks_to_draw = [] self.gridlines = art3d.Line3DCollection([]) self.axes._set_artist_props(self.gridlines) self.axes._set_artist_props(self.label) @@ -266,39 +267,6 @@ def get_rotate_label(self, text): else: return len(text) > 4 - def _get_coord_info(self): - # Get scaled limits directly from the axes helper - xmin, xmax, ymin, ymax, zmin, zmax = self.axes._get_scaled_limits() - mins = np.array([xmin, ymin, zmin]) - maxs = np.array([xmax, ymax, zmax]) - - # Get data-space bounds for _transformed_cube - bounds = (*self.axes.get_xbound(), - *self.axes.get_ybound(), - *self.axes.get_zbound()) - bounds_proj = self.axes._transformed_cube(bounds) - - # Determine which one of the parallel planes are higher up: - means_z0 = np.zeros(3) - means_z1 = np.zeros(3) - for i in range(3): - means_z0[i] = np.mean(bounds_proj[self._PLANES[2 * i], 2]) - means_z1[i] = np.mean(bounds_proj[self._PLANES[2 * i + 1], 2]) - highs = means_z0 < means_z1 - - # Special handling for edge-on views - equals = np.abs(means_z0 - means_z1) <= np.finfo(float).eps - if np.sum(equals) == 2: - vertical = np.where(~equals)[0][0] - if vertical == 2: # looking at XY plane - highs = np.array([True, True, highs[2]]) - elif vertical == 1: # looking at XZ plane - highs = np.array([True, highs[1], False]) - elif vertical == 0: # looking at YZ plane - highs = np.array([highs[0], False, False]) - - return mins, maxs, bounds_proj, highs - def _calc_centers_deltas(self, maxs, mins): centers = 0.5 * (maxs + mins) # In mpl3.8, the scale factor was 1/12. mpl3.9 changes this to @@ -403,7 +371,7 @@ def _get_tickdir(self, position): return tickdir def active_pane(self): - mins, maxs, tc, highs = self._get_coord_info() + mins, maxs, tc, highs = self.axes._get_coord_info() info = self._axinfo index = info['i'] if not highs[index]: @@ -436,12 +404,16 @@ def _axmask(self): def _draw_ticks(self, renderer, edgep1, centers, deltas, highs, deltas_per_point, pos): - ticks = self._update_ticks() + ticks = self._ticks_to_draw # Set with _update_ticks() in axes3d.draw() + n_ticks = len(ticks) + if n_ticks == 0: + return + info = self._axinfo index = info["i"] juggled = info["juggled"] - mins, maxs, tc, highs = self._get_coord_info() + mins, maxs, tc, highs = self.axes._get_coord_info() centers, deltas = self._calc_centers_deltas(maxs, mins) # Get the scale transform for this axis to transform tick locations @@ -462,23 +434,39 @@ def _draw_ticks(self, renderer, edgep1, centers, deltas, highs, default_label_offset = 8. # A rough estimate points = deltas_per_point * deltas - # All coordinates below are in transformed coordinates for proper projection - for tick in ticks: - # Get tick line positions - pos = edgep1.copy() - pos[index] = axis_trans.transform([tick.get_loc()])[0] - pos[tickdir] = out_tickdir - x1, y1, z1 = proj3d.proj_transform(*pos, self.axes.M) - pos[tickdir] = in_tickdir - x2, y2, z2 = proj3d.proj_transform(*pos, self.axes.M) - - # Get position of label - labeldeltas = (tick.get_pad() + default_label_offset) * points - pos[tickdir] = edgep1_tickdir - pos = _move_from_center(pos, centers, labeldeltas, self._axmask()) - lx, ly, lz = proj3d.proj_transform(*pos, self.axes.M) - - _tick_update_position(tick, (x1, x2), (y1, y2), (lx, ly)) + + # Collect tick data and batch transform tick locations + tick_locs = np.array([tick.get_loc() for tick in ticks]) + tick_pads = np.array([tick.get_pad() for tick in ticks]) + transformed_locs = axis_trans.transform(tick_locs) + + # Build position arrays for tick line endpoints (shape: n_ticks x 3) + pos1 = np.tile(edgep1, (n_ticks, 1)) + pos1[:, index] = transformed_locs + pos1[:, tickdir] = out_tickdir + + pos2 = pos1.copy() + pos2[:, tickdir] = in_tickdir + + # Batch proj_transform for tick lines + x1, y1, _ = proj3d.proj_transform(pos1[:, 0], pos1[:, 1], pos1[:, 2], + self.axes.M) + x2, y2, _ = proj3d.proj_transform(pos2[:, 0], pos2[:, 1], pos2[:, 2], + self.axes.M) + + # Build label positions + labeldeltas = (tick_pads + default_label_offset)[:, np.newaxis] * points + pos_label = pos1.copy() + pos_label[:, tickdir] = edgep1_tickdir + axmask = self._axmask() + pos_label = _move_from_center(pos_label, centers, labeldeltas, axmask) + lx, ly, _ = proj3d.proj_transform(pos_label[:, 0], pos_label[:, 1], + pos_label[:, 2], self.axes.M) + + # Update and draw each tick + for i, tick in enumerate(ticks): + _tick_update_position(tick, (x1[i], x2[i]), (y1[i], y2[i]), + (lx[i], ly[i])) tick.tick1line.set_linewidth(tick_lw[tick._major]) tick.draw(renderer) @@ -576,7 +564,7 @@ def draw(self, renderer): renderer.open_group("axis3d", gid=self.get_gid()) # Get general axis information: - mins, maxs, tc, highs = self._get_coord_info() + mins, maxs, tc, highs = self.axes._get_coord_info() centers, deltas = self._calc_centers_deltas(maxs, mins) # Calculate offset distances @@ -639,42 +627,43 @@ def draw_grid(self, renderer): if not self.axes._draw_grid: return + ticks = self._ticks_to_draw # Set with _update_ticks() in axes3d.draw() + if len(ticks) == 0: + return + renderer.open_group("grid3d", gid=self.get_gid()) - ticks = self._update_ticks() - if len(ticks): - # Get general axis information: - info = self._axinfo - index = info["i"] - - # Grid lines use data-space bounds (Line3DCollection applies transforms) - mins, maxs, tc, highs = self._get_coord_info() - xlim, ylim, zlim = (self.axes.get_xbound(), - self.axes.get_ybound(), - self.axes.get_zbound()) - data_mins = np.array([xlim[0], ylim[0], zlim[0]]) - data_maxs = np.array([xlim[1], ylim[1], zlim[1]]) - minmax = np.where(highs, data_maxs, data_mins) - maxmin = np.where(~highs, data_maxs, data_mins) - - # Grid points where the planes meet - xyz0 = np.tile(minmax, (len(ticks), 1)) - xyz0[:, index] = [tick.get_loc() for tick in ticks] - - # Grid lines go from the end of one plane through the plane - # intersection (at xyz0) to the end of the other plane. The first - # point (0) differs along dimension index-2 and the last (2) along - # dimension index-1. - lines = np.stack([xyz0, xyz0, xyz0], axis=1) - lines[:, 0, index - 2] = maxmin[index - 2] - lines[:, 2, index - 1] = maxmin[index - 1] - self.gridlines.set_segments(lines) - gridinfo = info['grid'] - self.gridlines.set_color(gridinfo['color']) - self.gridlines.set_linewidth(gridinfo['linewidth']) - self.gridlines.set_linestyle(gridinfo['linestyle']) - self.gridlines.do_3d_projection() - self.gridlines.draw(renderer) + # Get general axis information: + info = self._axinfo + index = info["i"] + + # Grid lines use data-space bounds (Line3DCollection applies transforms) + mins, maxs, tc, highs = self.axes._get_coord_info() + bounds = self.axes._get_bounds() + xlim, ylim, zlim = bounds[0:2], bounds[2:4], bounds[4:6] + data_mins = np.array([xlim[0], ylim[0], zlim[0]]) + data_maxs = np.array([xlim[1], ylim[1], zlim[1]]) + minmax = np.where(highs, data_maxs, data_mins) + maxmin = np.where(~highs, data_maxs, data_mins) + + # Grid points where the planes meet + xyz0 = np.tile(minmax, (len(ticks), 1)) + xyz0[:, index] = [tick.get_loc() for tick in ticks] + + # Grid lines go from the end of one plane through the plane + # intersection (at xyz0) to the end of the other plane. The first + # point (0) differs along dimension index-2 and the last (2) along + # dimension index-1. + lines = np.stack([xyz0, xyz0, xyz0], axis=1) + lines[:, 0, index - 2] = maxmin[index - 2] + lines[:, 2, index - 1] = maxmin[index - 1] + self.gridlines.set_segments(lines) + gridinfo = info['grid'] + self.gridlines.set_color(gridinfo['color']) + self.gridlines.set_linewidth(gridinfo['linewidth']) + self.gridlines.set_linestyle(gridinfo['linestyle']) + self.gridlines.do_3d_projection() + self.gridlines.draw(renderer) renderer.close_group('grid3d') diff --git a/lib/mpl_toolkits/mplot3d/proj3d.py b/lib/mpl_toolkits/mplot3d/proj3d.py index 81a5aacbdded..3cab750fb3fe 100644 --- a/lib/mpl_toolkits/mplot3d/proj3d.py +++ b/lib/mpl_toolkits/mplot3d/proj3d.py @@ -221,10 +221,17 @@ def inv_transform(xs, ys, zs, invM): def _vec_pad_ones(xs, ys, zs): + # Allocate and then fill for speed + shape = (4,) + np.shape(xs) if np.ma.isMA(xs) or np.ma.isMA(ys) or np.ma.isMA(zs): - return np.ma.array([xs, ys, zs, np.ones_like(xs)]) + result = np.ma.empty(shape) else: - return np.array([xs, ys, zs, np.ones_like(xs)]) + result = np.empty(shape) + result[0] = xs + result[1] = ys + result[2] = zs + result[3] = 1 + return result def proj_transform(xs, ys, zs, M):