|
1 | 1 | """Provides cvt_archive_3d_plot.""" |
2 | 2 |
|
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +from collections.abc import Iterable, Sequence |
| 6 | +from typing import Literal |
| 7 | + |
| 8 | +import matplotlib.colors |
3 | 9 | import matplotlib.pyplot as plt |
4 | 10 | import numpy as np |
| 11 | +from matplotlib.axes import Axes |
5 | 12 | from matplotlib.cm import ScalarMappable |
| 13 | +from matplotlib.typing import ColorType |
6 | 14 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection |
| 15 | +from mpl_toolkits.mplot3d.axes3d import Axes3D |
| 16 | +from pandas import DataFrame |
7 | 17 | from scipy.spatial import Voronoi |
8 | 18 |
|
| 19 | +from ribs.archives import ArchiveDataFrame, CVTArchive |
9 | 20 | from ribs.visualize._utils import ( |
10 | 21 | retrieve_cmap, |
11 | 22 | set_cbar, |
|
15 | 26 |
|
16 | 27 |
|
17 | 28 | def cvt_archive_3d_plot( |
18 | | - archive, |
19 | | - ax=None, |
| 29 | + archive: CVTArchive, |
| 30 | + ax: Axes3D | None = None, |
20 | 31 | *, |
21 | | - df=None, |
22 | | - measure_order=None, |
23 | | - cmap="magma", |
24 | | - lw=0.5, |
25 | | - ec=(0.0, 0.0, 0.0, 0.1), |
26 | | - cell_alpha=1.0, |
27 | | - vmin=None, |
28 | | - vmax=None, |
29 | | - cbar="auto", |
30 | | - cbar_kwargs=None, |
31 | | - plot_elites=False, |
32 | | - elite_ms=100, |
33 | | - elite_alpha=0.5, |
34 | | - plot_centroids=False, |
35 | | - plot_samples=False, |
36 | | - ms=1, |
37 | | -): |
| 32 | + df: DataFrame | ArchiveDataFrame | None = None, |
| 33 | + measure_order: Iterable[int] | None = None, |
| 34 | + cmap: str | Sequence[ColorType] | matplotlib.colors.Colormap = "magma", |
| 35 | + lw: float = 0.5, |
| 36 | + ec: ColorType = (0.0, 0.0, 0.0, 0.1), |
| 37 | + cell_alpha: float = 1.0, |
| 38 | + vmin: float | None = None, |
| 39 | + vmax: float | None = None, |
| 40 | + cbar: Literal["auto"] | None | Axes = "auto", |
| 41 | + cbar_kwargs: dict | None = None, |
| 42 | + plot_elites: bool = False, |
| 43 | + elite_ms: float = 100, |
| 44 | + elite_alpha: float = 0.5, |
| 45 | + plot_centroids: bool = False, |
| 46 | + plot_samples: bool = False, |
| 47 | + ms: float = 1, |
| 48 | +) -> None: |
38 | 49 | """Plots a :class:`~ribs.archives.CVTArchive` with 3D measure space. |
39 | 50 |
|
40 | 51 | This function relies on Matplotlib's `mplot3d |
@@ -152,52 +163,48 @@ def cvt_archive_3d_plot( |
152 | 163 | >>> plt.show() |
153 | 164 |
|
154 | 165 | Args: |
155 | | - archive (CVTArchive): A 3D :class:`~ribs.archives.CVTArchive`. |
156 | | - ax (matplotlib.axes.Axes): Axes on which to plot the heatmap. If ``None``, we |
157 | | - will create a new 3D axis. |
158 | | - df (ribs.archives.ArchiveDataFrame): If provided, we will plot data from this |
159 | | - argument instead of the data currently in the archive. This data can be |
160 | | - obtained by, for instance, calling :meth:`ribs.archives.ArchiveBase.data` |
161 | | - with ``return_type="pandas"`` and modifying the resulting |
162 | | - :class:`~ribs.archives.ArchiveDataFrame`. Note that, at a minimum, the data |
163 | | - must contain columns for index, objective, and measures. To display a custom |
164 | | - metric, replace the "objective" column. |
165 | | - measure_order (array-like of int): Specifies the axes order for plotting the |
166 | | - measures. By default, the first measure (measure 0) in the archive appears |
167 | | - on the x-axis, the second (measure 1) on y-axis, and third (measure 2) on |
168 | | - z-axis. This argument is an array of length 3 that specifies which measure |
169 | | - should appear on the x, y, and z axes. For instance, [2, 1, 0] will put |
170 | | - measure 2 on the x-axis, measure 1 on the y-axis, and measure 0 on the |
171 | | - z-axis. |
172 | | - cmap (str, list, matplotlib.colors.Colormap): The colormap to use when plotting |
173 | | - intensity. Either the name of a :class:`~matplotlib.colors.Colormap`, a list |
174 | | - of RGB or RGBA colors (i.e. an :math:`N \\times 3` or :math:`N \\times 4` |
175 | | - array), or a :class:`~matplotlib.colors.Colormap` object. |
176 | | - lw (float): Line width when plotting the Voronoi diagram. |
177 | | - ec (matplotlib color): Edge color of the cells in the Voronoi diagram. See |
| 166 | + archive: A 3D :class:`~ribs.archives.CVTArchive`. |
| 167 | + ax: Axes on which to plot the heatmap. If ``None``, we will create a new 3D |
| 168 | + axis. |
| 169 | + df: If provided, we will plot data from this argument instead of the data |
| 170 | + currently in the archive. This data can be obtained by, for instance, |
| 171 | + calling :meth:`ribs.archives.ArchiveBase.data` with ``return_type="pandas"`` |
| 172 | + and modifying the resulting :class:`~ribs.archives.ArchiveDataFrame`. Note |
| 173 | + that, at a minimum, the data must contain columns for index, objective, and |
| 174 | + measures. To display a custom metric, replace the "objective" column. |
| 175 | + measure_order: Specifies the axes order for plotting the measures. By default, |
| 176 | + the first measure (measure 0) in the archive appears on the x-axis, the |
| 177 | + second (measure 1) on y-axis, and third (measure 2) on z-axis. This argument |
| 178 | + is an array of length 3 that specifies which measure should appear on the x, |
| 179 | + y, and z axes. For instance, [2, 1, 0] will put measure 2 on the x-axis, |
| 180 | + measure 1 on the y-axis, and measure 0 on the z-axis. |
| 181 | + cmap: The colormap to use when plotting intensity. Either the name of a |
| 182 | + :class:`~matplotlib.colors.Colormap`, a list of Matplotlib color |
| 183 | + specifications (e.g., an :math:`N \\times 3` or :math:`N \\times 4` array -- |
| 184 | + see :class:`~matplotlib.colors.ListedColormap`), or a |
| 185 | + :class:`~matplotlib.colors.Colormap` object. |
| 186 | + lw: Line width when plotting the Voronoi diagram. |
| 187 | + ec: Edge color of the cells in the Voronoi diagram. See |
178 | 188 | `here <https://matplotlib.org/stable/tutorials/colors/colors.html>`_ for |
179 | 189 | more info on specifying colors in Matplotlib. |
180 | 190 | cell_alpha: Alpha value for the cell colors. Set to 1.0 for opaque cells; set to |
181 | 191 | 0.0 for fully transparent cells. |
182 | | - vmin (float): Minimum objective value to use in the plot. If ``None``, the |
183 | | - minimum objective value in the archive is used. |
184 | | - vmax (float): Maximum objective value to use in the plot. If ``None``, the |
185 | | - maximum objective value in the archive is used. |
186 | | - cbar ('auto', None, matplotlib.axes.Axes): By default, this is set to ``'auto'`` |
187 | | - which displays the colorbar on the archive's current |
188 | | - :class:`~matplotlib.axes.Axes`. If ``None``, then colorbar is not displayed. |
189 | | - If this is an :class:`~matplotlib.axes.Axes`, displays the colorbar on the |
190 | | - specified Axes. |
191 | | - cbar_kwargs (dict): Additional kwargs to pass to |
192 | | - :func:`~matplotlib.pyplot.colorbar`. |
193 | | - plot_elites (bool): If True, we will plot a scatter plot of the elites in the |
| 192 | + vmin: Minimum objective value to use in the plot. If ``None``, the minimum |
| 193 | + objective value in the archive is used. |
| 194 | + vmax: Maximum objective value to use in the plot. If ``None``, the maximum |
| 195 | + objective value in the archive is used. |
| 196 | + cbar: By default, this is set to ``'auto'`` which displays the colorbar on the |
| 197 | + archive's current :class:`~matplotlib.axes.Axes`. If ``None``, then colorbar |
| 198 | + is not displayed. If this is an :class:`~matplotlib.axes.Axes`, displays the |
| 199 | + colorbar on the specified Axes. |
| 200 | + cbar_kwargs: Additional kwargs to pass to :func:`~matplotlib.pyplot.colorbar`. |
| 201 | + plot_elites: If True, we will plot a scatter plot of the elites in the |
194 | 202 | archive. The elites will be colored according to their objective value. |
195 | | - elite_ms (float): Marker size for plotting elites. |
196 | | - elite_alpha (float): Alpha value for plotting elites. |
197 | | - plot_centroids (bool): Whether to plot the cluster centroids. |
198 | | - plot_samples (bool): Whether to plot the samples used when generating the |
199 | | - clusters. |
200 | | - ms (float): Marker size for both centroids and samples. |
| 203 | + elite_ms: Marker size for plotting elites. |
| 204 | + elite_alpha: Alpha value for plotting elites. |
| 205 | + plot_centroids: Whether to plot the cluster centroids. |
| 206 | + plot_samples: Whether to plot the samples used when generating the clusters. |
| 207 | + ms: Marker size for both centroids and samples. |
201 | 208 |
|
202 | 209 | Raises: |
203 | 210 | ValueError: The archive's measure dimension must be 1D or 2D. |
@@ -269,7 +276,7 @@ def cvt_archive_3d_plot( |
269 | 276 |
|
270 | 277 | # Default ax behavior. |
271 | 278 | if ax is None: |
272 | | - ax = plt.axes(projection="3d") |
| 279 | + ax: Axes3D = plt.axes(projection="3d") # ty: ignore[invalid-assignment] |
273 | 280 |
|
274 | 281 | ax.set_xlim(lower_bounds[0], upper_bounds[0]) |
275 | 282 | ax.set_ylim(lower_bounds[1], upper_bounds[1]) |
|
0 commit comments