diff --git a/lib/matplotlib/axis.py b/lib/matplotlib/axis.py index 349e728ba8ff..59ac405975b0 100644 --- a/lib/matplotlib/axis.py +++ b/lib/matplotlib/axis.py @@ -1546,6 +1546,46 @@ def get_tick_padding(self): values.append(self.minorTicks[0].get_tick_padding()) return max(values, default=0) + def _align_xtick_label_baselines(self, renderer): + + # Shift x-axis tick label1's down so their baselines align, + # compensating for differing ascents. + + if self.axis_name != 'x': + return + + entries = [] + for tick in self.get_major_ticks(): + label = tick.label1 + if not (label.get_visible() and label.get_text()): + continue + try: + _, info, _ = label._get_layout(renderer) + except Exception: + continue + if not info: + continue + ascent = info[0][1][1] + entries.append((tick, label, ascent)) + + if not entries: + return + + max_ascent = max(a for _, _, a in entries) + fig = self.get_figure(root=True) + dpi_scale_trans = fig.dpi_scale_trans + dpi = fig.dpi + + for tick, label, ascent in entries: + base_trans = tick._get_text1_transform()[0] + extra_px = max_ascent - ascent + if extra_px <= 0.5: + label.set_transform(base_trans) + continue + shift = mtransforms.ScaledTranslation( + 0, -extra_px / dpi, dpi_scale_trans) + label.set_transform(label.get_transform() + shift) + @martist.allow_rasterization def draw(self, renderer): # docstring inherited @@ -1557,6 +1597,8 @@ def draw(self, renderer): ticks_to_draw = self._update_ticks() tlb1, tlb2 = self._get_ticklabel_bboxes(ticks_to_draw, renderer) + self._align_xtick_label_baselines(renderer) + for tick in ticks_to_draw: tick.draw(renderer) diff --git a/lib/matplotlib/tests/test_axis.py b/lib/matplotlib/tests/test_axis.py index 3776b6f054b9..535b09e31d09 100644 --- a/lib/matplotlib/tests/test_axis.py +++ b/lib/matplotlib/tests/test_axis.py @@ -128,3 +128,27 @@ def test_set_ticks_emits_lim_changed(): ax2.callbacks.connect("ylim_changed", called_polar.append) ax2.set_rticks([1, 2, 3]) assert called_polar + + def test_xtick_label_baselines_alignment(): + + fig, ax = plt.subplots(figsize=(3, 3), dpi=300) + ax.set_xlim(-0.5, 2.5) + ax.set_ylim(-3, 3) + labels = [r"$w^{(2)}_1$", r"$w^{(2)}_2$", r"$b^{(2)}$"] + ax.set_xticks([0, 1, 2], labels) + ax.set_yticks([]) + ax.tick_params(which="both", length=0) + fig.tight_layout() + + fig.canvas.draw() + renderer = fig.canvas.get_renderer() + + baselines = [] + for tick in ax.xaxis.get_major_ticks(): + label = tick.label1 + bbox = label.get_window_extent(renderer) + _, info, _ = label._get_layout(renderer) + descent = info[0][1][2] + baselines.append(bbox.y0 + descent) + + assert max(baselines) - min(baselines) < 0.01