Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions pygad/pygad.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,151 @@ def to_csv(self, filename, delimiter=','):

self.logger.info(f"Run metrics saved to {filename}")

def plot_metrics(self,
title="PyGAD - Run Metrics",
font_size=12,
figsize=(12, 8),
plot_type="plot",
colors=["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728"],
labels=["Best Fitness", "Mean Fitness", "Time Elapsed (s)", "Diversity"],
save_dir=None,
show=True):
"""
Creates, shows, and returns a figure with 4 subplots showing the run metrics
recorded during the genetic algorithm evolution.

This method visualizes the metrics recorded by the run metrics recorder, including:
- Best fitness over generations
- Mean fitness over generations
- Time elapsed per generation
- Population diversity (gene variance) over generations

Parameters
----------
title : str, optional
Main title of the figure. Default is "PyGAD - Run Metrics".

font_size : int, optional
Font size for labels and titles. Default is 12.

figsize : tuple, optional
Figure size as (width, height) in inches. Default is (12, 8).

plot_type : str, optional
Type of plot. Can be "plot", "scatter", or "bar". Default is "plot".

colors : list, optional
List of 4 colors for the 4 subplots. Default is:
["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728"] (blue, orange, green, red).

labels : list, optional
List of 4 labels for the 4 subplots' y-axes. Default is:
["Best Fitness", "Mean Fitness", "Time Elapsed (s)", "Diversity"].

save_dir : str, optional
Directory path to save the figure. If None, the figure is not saved.

show : bool, optional
Whether to show the figure. Default is True.

Returns
-------
matplotlib.figure.Figure
The matplotlib figure object containing all subplots.

Notes
-----
This method can only be called after completing at least 1 generation.
If no generation is completed, a RuntimeError is raised.

For multi-objective optimization problems, only the first objective is plotted
for best fitness and mean fitness. Use separate plots for other objectives.

Examples
--------
>>> import pygad
>>> import numpy
>>>
>>> def fitness_func(ga_instance, solution, solution_idx):
... output = numpy.sum(solution * [4, -2, 3.5, 5, -11, -4.7])
... fitness = 1.0 / (numpy.abs(output - 44) + 0.000001)
... return fitness
>>>
>>> ga_instance = pygad.GA(num_generations=10,
... num_parents_mating=4,
... sol_per_pop=8,
... num_genes=6,
... fitness_func=fitness_func)
>>>
>>> ga_instance.run()
>>> fig = ga_instance.plot_metrics()
"""
from pygad.visualize.plot import get_matplotlib
import matplotlib

if self.run_metrics is None or len(self.run_metrics['generation']) == 0:
self.logger.error("The plot_metrics() method can only be called after completing at least 1 generation.")
raise RuntimeError("The plot_metrics() method can only be called after completing at least 1 generation.")

matplt = get_matplotlib()

generations = numpy.array(self.run_metrics['generation'])
best_fitness = numpy.array(self.run_metrics['best_fitness'])
mean_fitness = numpy.array(self.run_metrics['mean_fitness'])
time_elapsed = numpy.array(self.run_metrics['time_elapsed'])
diversity = numpy.array(self.run_metrics['diversity'])

is_multi_objective = False
if len(best_fitness) > 0:
if type(self.run_metrics['best_fitness'][0]) in [list, tuple, numpy.ndarray]:
is_multi_objective = True
best_fitness = numpy.array([bf[0] for bf in self.run_metrics['best_fitness']])
if type(self.run_metrics['mean_fitness'][0]) in [list, tuple, numpy.ndarray]:
mean_fitness = numpy.array([mf[0] for mf in self.run_metrics['mean_fitness']])

fig, axs = matplt.subplots(2, 2, figsize=figsize)
fig.suptitle(title, fontsize=font_size + 2, fontweight='bold')

axs = axs.flatten()

metrics_data = [
(best_fitness, labels[0], colors[0], "Best Fitness Over Generations"),
(mean_fitness, labels[1], colors[1], "Mean Fitness Over Generations"),
(time_elapsed, labels[2], colors[2], "Time Elapsed Per Generation"),
(diversity, labels[3], colors[3], "Population Diversity Over Generations")
]

for idx, (data, ylabel, color, subplot_title) in enumerate(metrics_data):
ax = axs[idx]

if plot_type == "plot":
ax.plot(generations, data, color=color, linewidth=2, marker='o', markersize=4)
elif plot_type == "scatter":
ax.scatter(generations, data, color=color, s=30, alpha=0.7)
elif plot_type == "bar":
ax.bar(generations, data, color=color, alpha=0.7)

ax.set_title(subplot_title, fontsize=font_size)
ax.set_xlabel("Generation", fontsize=font_size - 2)
ax.set_ylabel(ylabel, fontsize=font_size - 2)
ax.grid(True, alpha=0.3)
ax.tick_params(axis='both', labelsize=font_size - 3)

matplt.tight_layout(rect=[0, 0, 1, 0.96])

if is_multi_objective:
fig.text(0.5, 0.01,
"Note: Only the first objective is shown for fitness metrics (multi-objective problem).",
ha='center', fontsize=font_size - 2, style='italic')

if save_dir is not None:
fig.savefig(fname=save_dir, bbox_inches='tight', dpi=150)

if show:
matplt.show()

return fig

def load(filename):
"""
Reads a saved instance of the genetic algorithm:
Expand Down
Loading