diff --git a/doc/release/next_whats_new/get_gridlines_which.rst b/doc/release/next_whats_new/get_gridlines_which.rst new file mode 100644 index 000000000000..57fe7013cbe7 --- /dev/null +++ b/doc/release/next_whats_new/get_gridlines_which.rst @@ -0,0 +1,25 @@ +``Axis.get_gridlines`` can return minor gridlines +------------------------------------------------- +`~matplotlib.axis.Axis.get_gridlines` now accepts a *which* keyword argument +to select major, minor, or both groups of gridlines. The default value +``'major'`` preserves the previous behavior. + +.. plot:: + :include-source: true + :alt: Highlight every minor gridline of the x-axis in red. + + import matplotlib.pyplot as plt + + fig, ax = plt.subplots() + ax.plot(range(10)) + ax.minorticks_on() + ax.grid(which='both') + + for line in ax.xaxis.get_gridlines(which='minor'): + line.set_color('red') + + plt.show() + +Previously there was no public API to access minor gridlines, so downstream +libraries reached into the private ``Axis._minor_tick_kw`` mapping to detect +their state. diff --git a/lib/matplotlib/axes/_base.pyi b/lib/matplotlib/axes/_base.pyi index 4a70405346a5..42eb924007f0 100644 --- a/lib/matplotlib/axes/_base.pyi +++ b/lib/matplotlib/axes/_base.pyi @@ -403,9 +403,13 @@ class _AxesBase(martist.Artist): # itself with a method modified from the Axis methods for the x or y Axis. # As such, they are typed according to the resultant method rather than as that object. - def get_xgridlines(self) -> list[Line2D]: ... + def get_xgridlines( + self, which: Literal["major", "minor", "both"] = ... + ) -> list[Line2D]: ... def get_xticklines(self, minor: bool = ...) -> list[Line2D]: ... - def get_ygridlines(self) -> list[Line2D]: ... + def get_ygridlines( + self, which: Literal["major", "minor", "both"] = ... + ) -> list[Line2D]: ... def get_yticklines(self, minor: bool = ...) -> list[Line2D]: ... def _sci(self, im: ColorizingArtist) -> None: ... def get_autoscalex_on(self) -> bool: ... diff --git a/lib/matplotlib/axis.py b/lib/matplotlib/axis.py index c526b8a2aa6a..b0bbb288e7c2 100644 --- a/lib/matplotlib/axis.py +++ b/lib/matplotlib/axis.py @@ -1470,11 +1470,39 @@ def draw(self, renderer): renderer.close_group(__name__) self.stale = False - def get_gridlines(self): - r"""Return this Axis' grid lines as a list of `.Line2D`\s.""" - ticks = self.get_major_ticks() - return cbook.silent_list('Line2D gridline', - [tick.gridline for tick in ticks]) + def get_gridlines(self, which='major'): + r""" + Return this Axis' grid lines as a list of `.Line2D`\s. + + Parameters + ---------- + which : {'major', 'minor', 'both'}, default: 'major' + Which set of gridlines to return. + + .. versionchanged:: 3.11 + Added the *which* parameter; previously only major gridlines + were returned. + + Returns + ------- + list of `.Line2D` + The gridline `.Line2D` objects. For ``which='both'``, major + gridlines come before minor gridlines. + + Notes + ----- + The returned list contains every gridline managed by this Axis + regardless of its visibility. Use ``Line2D.get_visible()`` on each + returned object to check whether a particular gridline is currently + drawn. + """ + _api.check_in_list(['major', 'minor', 'both'], which=which) + lines = [] + if which in ('major', 'both'): + lines.extend(tick.gridline for tick in self.get_major_ticks()) + if which in ('minor', 'both'): + lines.extend(tick.gridline for tick in self.get_minor_ticks()) + return cbook.silent_list('Line2D gridline', lines) def set_label(self, s): """Assigning legend labels is not supported. Raises RuntimeError.""" diff --git a/lib/matplotlib/axis.pyi b/lib/matplotlib/axis.pyi index 4bcfb1e1cfb7..409c2d931384 100644 --- a/lib/matplotlib/axis.pyi +++ b/lib/matplotlib/axis.pyi @@ -176,7 +176,9 @@ class Axis(martist.Artist): self, renderer: RendererBase | None = ..., *, for_layout_only: bool = ... ) -> Bbox | None: ... def get_tick_padding(self) -> float: ... - def get_gridlines(self) -> list[Line2D]: ... + def get_gridlines( + self, which: Literal["major", "minor", "both"] = ... + ) -> list[Line2D]: ... def get_label(self) -> Text: ... def get_offset_text(self) -> Text: ... def get_pickradius(self) -> float: ... diff --git a/lib/matplotlib/tests/test_axes.py b/lib/matplotlib/tests/test_axes.py index 209593aee15e..4404a8e110f4 100644 --- a/lib/matplotlib/tests/test_axes.py +++ b/lib/matplotlib/tests/test_axes.py @@ -6300,6 +6300,54 @@ def test_grid(): assert not ax.xaxis.majorTicks[0].gridline.get_visible() +def test_get_gridlines_which(): + """`Axis.get_gridlines` selects major, minor, or both via *which*.""" + fig, ax = plt.subplots() + ax.minorticks_on() + + n_major = len(ax.xaxis.get_major_ticks()) + n_minor = len(ax.xaxis.get_minor_ticks()) + + major_lines = ax.xaxis.get_gridlines() + assert len(major_lines) == n_major + assert major_lines == ax.xaxis.get_gridlines(which='major') + + minor_lines = ax.xaxis.get_gridlines(which='minor') + assert len(minor_lines) == n_minor + assert all(line not in major_lines for line in minor_lines) + + both_lines = ax.xaxis.get_gridlines(which='both') + assert len(both_lines) == n_major + n_minor + assert both_lines[:n_major] == major_lines + assert both_lines[n_major:] == minor_lines + + with pytest.raises(ValueError, + match="'invalid' is not a valid value for which"): + ax.xaxis.get_gridlines(which='invalid') + + +def test_get_gridlines_visibility_reflects_grid_state(): + """Visibility of returned gridlines tracks the active grid state.""" + fig, ax = plt.subplots() + ax.minorticks_on() + + ax.grid(visible=False, which='both') + fig.canvas.draw() + assert not any(line.get_visible() + for line in ax.xaxis.get_gridlines(which='both')) + + ax.grid(visible=True, which='major') + fig.canvas.draw() + assert all(line.get_visible() for line in ax.xaxis.get_gridlines('major')) + assert not any(line.get_visible() + for line in ax.xaxis.get_gridlines('minor')) + + ax.grid(visible=True, which='minor') + fig.canvas.draw() + assert all(line.get_visible() + for line in ax.xaxis.get_gridlines(which='both')) + + def test_grid_color_with_alpha(): """Test that grid(color=(..., alpha)) respects the alpha value.""" fig, ax = plt.subplots()