diff --git a/lib/mpl_toolkits/mplot3d/art3d.py b/lib/mpl_toolkits/mplot3d/art3d.py index f664127dcb59..114bb152ae80 100644 --- a/lib/mpl_toolkits/mplot3d/art3d.py +++ b/lib/mpl_toolkits/mplot3d/art3d.py @@ -497,7 +497,19 @@ def do_3d_projection(self): """ Project the points according to renderer matrix. """ - segments = np.asanyarray(self._segments3d) + segments = self._segments3d + + # Handle ragged inputs, but prefer a faster path for same-length segments + segment_lengths = [len(segment) for segment in segments] + ragged = len(set(segment_lengths)) > 1 + if ragged: + # Branch masked / non-masked for speed + if any(np.ma.isMA(segment) for segment in segments): + segments = np.ma.concatenate(segments) + else: + segments = np.concatenate(segments) + else: + segments = np.asanyarray(segments) # Handle empty segments if segments.size == 0: @@ -505,7 +517,7 @@ def do_3d_projection(self): return np.nan mask = False - if np.ma.isMA(segments): + if np.ma.isMA(segments) and segments.mask is not np.ma.nomask: mask = segments.mask if self._axlim_clip: @@ -519,9 +531,12 @@ def do_3d_projection(self): (*viewlim_mask.shape, 3)) mask = mask | viewlim_mask - xyzs = np.ma.array( - proj3d._scale_proj_transform_vectors(segments, self.axes), mask=mask) + xyzs = proj3d._scale_proj_transform_vectors(segments, self.axes) + if mask is not False: + xyzs = np.ma.array(xyzs, mask=mask) segments_2d = xyzs[..., 0:2] + if ragged: + segments_2d = np.split(segments_2d, np.cumsum(segment_lengths[:-1])) LineCollection.set_segments(self, segments_2d) # FIXME diff --git a/lib/mpl_toolkits/mplot3d/tests/baseline_images/test_axes3d/wireframe3dasymmetric.png b/lib/mpl_toolkits/mplot3d/tests/baseline_images/test_axes3d/wireframe3dasymmetric.png index 73507bf2f6c1..18a4312497b3 100644 Binary files a/lib/mpl_toolkits/mplot3d/tests/baseline_images/test_axes3d/wireframe3dasymmetric.png and b/lib/mpl_toolkits/mplot3d/tests/baseline_images/test_axes3d/wireframe3dasymmetric.png differ diff --git a/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py b/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py index 2a5593a641c9..14ad67661ad5 100644 --- a/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py +++ b/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py @@ -854,6 +854,7 @@ def test_wireframe3dasymmetric(): fig = plt.figure() ax = fig.add_subplot(projection='3d') X, Y, Z = axes3d.get_test_data(0.05) + X, Y, Z = X[:-1], Y[:-1], Z[:-1] # Drop a row so the grid is non-square ax.plot_wireframe(X, Y, Z, rcount=3, ccount=13)