Skip to content

Commit d729be4

Browse files
committed
Add tests to best_solution() method
1 parent 6fdd7aa commit d729be4

File tree

1 file changed

+374
-0
lines changed

1 file changed

+374
-0
lines changed

tests/test_best_solution.py

Lines changed: 374 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,374 @@
1+
import numpy
2+
import pygad
3+
import random
4+
5+
# Global constants for testing
6+
num_generations = 100
7+
num_parents_mating = 5
8+
sol_per_pop = 10
9+
num_genes = 3
10+
random_seed = 42
11+
12+
def fitness_func(ga_instance, solution, solution_idx):
13+
"""Single-objective fitness function."""
14+
return numpy.sum(solution**2)
15+
16+
def fitness_func_multi(ga_instance, solution, solution_idx):
17+
"""Multi-objective fitness function."""
18+
return [numpy.sum(solution**2), numpy.sum(solution)]
19+
20+
def test_best_solution_consistency_single_objective():
21+
"""
22+
Test best_solution() consistency for single-objective optimization.
23+
"""
24+
ga_instance = pygad.GA(num_generations=num_generations,
25+
num_parents_mating=num_parents_mating,
26+
fitness_func=fitness_func,
27+
sol_per_pop=sol_per_pop,
28+
num_genes=num_genes,
29+
random_seed=random_seed,
30+
suppress_warnings=True
31+
)
32+
ga_instance.run()
33+
34+
# Call with last_generation_fitness
35+
sol1, fitness1, idx1 = ga_instance.best_solution(pop_fitness=ga_instance.last_generation_fitness)
36+
37+
# Call without pop_fitness
38+
sol2, fitness2, idx2 = ga_instance.best_solution()
39+
40+
assert numpy.array_equal(sol1, sol2)
41+
assert fitness1 == fitness2
42+
assert idx1 == idx2
43+
print("test_best_solution_consistency_single_objective passed.")
44+
45+
def test_best_solution_consistency_multi_objective():
46+
"""
47+
Test best_solution() consistency for multi-objective optimization.
48+
"""
49+
ga_instance = pygad.GA(num_generations=num_generations,
50+
num_parents_mating=num_parents_mating,
51+
fitness_func=fitness_func_multi,
52+
sol_per_pop=sol_per_pop,
53+
num_genes=num_genes,
54+
random_seed=random_seed,
55+
parent_selection_type="nsga2",
56+
suppress_warnings=True
57+
)
58+
ga_instance.run()
59+
60+
# Call with last_generation_fitness
61+
sol1, fitness1, idx1 = ga_instance.best_solution(pop_fitness=ga_instance.last_generation_fitness)
62+
63+
# Call without pop_fitness
64+
sol2, fitness2, idx2 = ga_instance.best_solution()
65+
66+
assert numpy.array_equal(sol1, sol2)
67+
assert numpy.array_equal(fitness1, fitness2)
68+
assert idx1 == idx2
69+
print("test_best_solution_consistency_multi_objective passed.")
70+
71+
def test_best_solution_before_run():
72+
"""
73+
Test best_solution() consistency before run() is called.
74+
"""
75+
ga_instance = pygad.GA(num_generations=num_generations,
76+
num_parents_mating=num_parents_mating,
77+
fitness_func=fitness_func,
78+
sol_per_pop=sol_per_pop,
79+
num_genes=num_genes,
80+
random_seed=random_seed,
81+
suppress_warnings=True
82+
)
83+
84+
# Before run(), last_generation_fitness is None
85+
# We can still call best_solution(), it should call cal_pop_fitness()
86+
sol2, fitness2, idx2 = ga_instance.best_solution()
87+
88+
# Now cal_pop_fitness() should match ga_instance.best_solution() output if we pass it
89+
pop_fitness = ga_instance.cal_pop_fitness()
90+
sol1, fitness1, idx1 = ga_instance.best_solution(pop_fitness=pop_fitness)
91+
92+
assert numpy.array_equal(sol1, sol2)
93+
assert fitness1 == fitness2
94+
assert idx1 == idx2
95+
print("test_best_solution_before_run passed.")
96+
97+
def test_best_solution_with_save_solutions():
98+
"""
99+
Test best_solution() consistency when save_solutions=True.
100+
This tests the caching mechanism in cal_pop_fitness().
101+
"""
102+
ga_instance = pygad.GA(num_generations=num_generations,
103+
num_parents_mating=num_parents_mating,
104+
fitness_func=fitness_func,
105+
sol_per_pop=sol_per_pop,
106+
num_genes=num_genes,
107+
random_seed=random_seed,
108+
save_solutions=True,
109+
suppress_warnings=True
110+
)
111+
ga_instance.run()
112+
113+
# Call with last_generation_fitness
114+
sol1, fitness1, idx1 = ga_instance.best_solution(pop_fitness=ga_instance.last_generation_fitness)
115+
116+
# Call without pop_fitness (this will trigger cal_pop_fitness which uses saved solutions)
117+
sol2, fitness2, idx2 = ga_instance.best_solution()
118+
119+
assert numpy.array_equal(sol1, sol2)
120+
assert fitness1 == fitness2
121+
assert idx1 == idx2
122+
print("test_best_solution_with_save_solutions passed.")
123+
124+
def test_best_solution_with_save_best_solutions():
125+
"""
126+
Test best_solution() consistency when save_best_solutions=True.
127+
"""
128+
ga_instance = pygad.GA(num_generations=num_generations,
129+
num_parents_mating=num_parents_mating,
130+
fitness_func=fitness_func,
131+
sol_per_pop=sol_per_pop,
132+
num_genes=num_genes,
133+
random_seed=random_seed,
134+
save_best_solutions=True,
135+
suppress_warnings=True
136+
)
137+
ga_instance.run()
138+
139+
# Call with last_generation_fitness
140+
sol1, fitness1, idx1 = ga_instance.best_solution(pop_fitness=ga_instance.last_generation_fitness)
141+
142+
# Call without pop_fitness
143+
sol2, fitness2, idx2 = ga_instance.best_solution()
144+
145+
assert numpy.array_equal(sol1, sol2)
146+
assert fitness1 == fitness2
147+
assert idx1 == idx2
148+
print("test_best_solution_with_save_best_solutions passed.")
149+
150+
def test_best_solution_with_keep_elitism():
151+
"""
152+
Test best_solution() consistency when keep_elitism > 0.
153+
"""
154+
ga_instance = pygad.GA(num_generations=num_generations,
155+
num_parents_mating=num_parents_mating,
156+
fitness_func=fitness_func,
157+
sol_per_pop=sol_per_pop,
158+
num_genes=num_genes,
159+
random_seed=random_seed,
160+
keep_elitism=2,
161+
suppress_warnings=True
162+
)
163+
ga_instance.run()
164+
165+
# Call with last_generation_fitness
166+
sol1, fitness1, idx1 = ga_instance.best_solution(pop_fitness=ga_instance.last_generation_fitness)
167+
168+
# Call without pop_fitness
169+
sol2, fitness2, idx2 = ga_instance.best_solution()
170+
171+
assert numpy.array_equal(sol1, sol2)
172+
assert fitness1 == fitness2
173+
assert idx1 == idx2
174+
print("test_best_solution_with_keep_elitism passed.")
175+
176+
def test_best_solution_with_keep_parents():
177+
"""
178+
Test best_solution() consistency when keep_parents > 0.
179+
Note: keep_parents is ignored if keep_elitism > 0 (default is 1).
180+
So this tests the case where keep_parents is passed but effectively ignored by population update,
181+
yet we check if best_solution() still works consistently.
182+
"""
183+
ga_instance = pygad.GA(num_generations=num_generations,
184+
num_parents_mating=num_parents_mating,
185+
fitness_func=fitness_func,
186+
sol_per_pop=sol_per_pop,
187+
num_genes=num_genes,
188+
random_seed=random_seed,
189+
keep_parents=2,
190+
suppress_warnings=True
191+
)
192+
ga_instance.run()
193+
194+
# Call with last_generation_fitness
195+
sol1, fitness1, idx1 = ga_instance.best_solution(pop_fitness=ga_instance.last_generation_fitness)
196+
197+
# Call without pop_fitness
198+
sol2, fitness2, idx2 = ga_instance.best_solution()
199+
200+
assert numpy.array_equal(sol1, sol2)
201+
assert fitness1 == fitness2
202+
assert idx1 == idx2
203+
print("test_best_solution_with_keep_parents passed.")
204+
205+
def test_best_solution_with_keep_parents_elitism_0():
206+
"""
207+
Test best_solution() consistency when keep_parents > 0 and keep_elitism = 0.
208+
This ensures the 'keep_parents' logic in cal_pop_fitness is exercised.
209+
"""
210+
ga_instance = pygad.GA(num_generations=num_generations,
211+
num_parents_mating=num_parents_mating,
212+
fitness_func=fitness_func,
213+
sol_per_pop=sol_per_pop,
214+
num_genes=num_genes,
215+
random_seed=random_seed,
216+
keep_elitism=0,
217+
keep_parents=2,
218+
suppress_warnings=True
219+
)
220+
ga_instance.run()
221+
222+
# Call with last_generation_fitness
223+
sol1, fitness1, idx1 = ga_instance.best_solution(pop_fitness=ga_instance.last_generation_fitness)
224+
225+
# Call without pop_fitness
226+
sol2, fitness2, idx2 = ga_instance.best_solution()
227+
228+
assert numpy.array_equal(sol1, sol2)
229+
assert fitness1 == fitness2
230+
assert idx1 == idx2
231+
print("test_best_solution_with_keep_parents_elitism_0 passed.")
232+
233+
def test_best_solution_pop_fitness_validation():
234+
"""
235+
Test validation of the pop_fitness parameter in best_solution().
236+
237+
Note: num_generations=1 is used for speed as evolution is not needed.
238+
sol_per_pop=5 is used to provide a small population for testing invalid lengths.
239+
"""
240+
ga_instance = pygad.GA(num_generations=1,
241+
num_parents_mating=1,
242+
fitness_func=fitness_func,
243+
sol_per_pop=5,
244+
num_genes=3,
245+
suppress_warnings=True
246+
)
247+
248+
# Test invalid type
249+
try:
250+
ga_instance.best_solution(pop_fitness="invalid")
251+
except ValueError as e:
252+
assert "expected to be list, tuple, or numpy.ndarray" in str(e)
253+
print("Validation: Invalid type caught.")
254+
255+
# Test invalid length
256+
try:
257+
ga_instance.best_solution(pop_fitness=[1, 2, 3]) # Length 3, but sol_per_pop is 5
258+
except ValueError as e:
259+
assert "must match the length of the 'self.population' attribute" in str(e)
260+
print("Validation: Invalid length caught.")
261+
262+
def test_best_solution_single_objective_tie():
263+
"""
264+
Test best_solution() when there is a tie in fitness values.
265+
It should return the first solution with the maximum fitness.
266+
267+
Note: sol_per_pop=5 must match the length of the manual pop_fitness array below.
268+
num_generations=1 is sufficient for testing selection logic.
269+
"""
270+
ga_instance = pygad.GA(num_generations=1,
271+
num_parents_mating=1,
272+
fitness_func=fitness_func,
273+
sol_per_pop=5,
274+
num_genes=3,
275+
suppress_warnings=True
276+
)
277+
278+
# Mock fitness with a tie at index 1 and 3
279+
pop_fitness = numpy.array([10, 50, 20, 50, 5])
280+
281+
sol, fitness, idx = ga_instance.best_solution(pop_fitness=pop_fitness)
282+
283+
assert fitness == 50
284+
assert idx == 1 # First occurrence
285+
print("test_best_solution_single_objective_tie passed.")
286+
287+
def test_best_solution_with_parallel_processing():
288+
"""
289+
Test best_solution() with parallel_processing enabled.
290+
291+
Note: num_generations=5 is used to ensure the initial population and first generation
292+
trigger parallel fitness calculation.
293+
"""
294+
ga_instance = pygad.GA(num_generations=5,
295+
num_parents_mating=2,
296+
fitness_func=fitness_func,
297+
sol_per_pop=10,
298+
num_genes=3,
299+
random_seed=random_seed,
300+
parallel_processing=["thread", 2],
301+
suppress_warnings=True
302+
)
303+
# best_solution() should work and trigger cal_pop_fitness() internally
304+
sol, fitness, idx = ga_instance.best_solution()
305+
assert sol is not None
306+
assert fitness is not None
307+
print("test_best_solution_with_parallel_processing passed.")
308+
309+
def test_best_solution_with_fitness_batch_size():
310+
"""
311+
Test best_solution() with fitness_batch_size > 1.
312+
313+
Note: num_generations=5 and sol_per_pop=10 provide enough work for batch processing.
314+
"""
315+
def fitness_func_batch(ga_instance, solutions, indices):
316+
return [numpy.sum(s**2) for s in solutions]
317+
318+
ga_instance = pygad.GA(num_generations=5,
319+
num_parents_mating=2,
320+
fitness_func=fitness_func_batch,
321+
sol_per_pop=10,
322+
num_genes=3,
323+
random_seed=random_seed,
324+
fitness_batch_size=2,
325+
suppress_warnings=True
326+
)
327+
328+
sol, fitness, idx = ga_instance.best_solution()
329+
assert sol is not None
330+
assert fitness is not None
331+
print("test_best_solution_with_fitness_batch_size passed.")
332+
333+
def test_best_solution_pop_fitness_types():
334+
"""
335+
Test best_solution() with different types for the pop_fitness parameter.
336+
337+
Note: sol_per_pop=3 must match the length of fitness_vals below.
338+
num_generations=1 is sufficient for this type-check test.
339+
"""
340+
ga_instance = pygad.GA(num_generations=1,
341+
num_parents_mating=1,
342+
fitness_func=fitness_func,
343+
sol_per_pop=3,
344+
num_genes=3,
345+
suppress_warnings=True
346+
)
347+
348+
fitness_vals = [1.0, 5.0, 2.0]
349+
350+
# Test list
351+
_, _, idx_list = ga_instance.best_solution(pop_fitness=fitness_vals)
352+
# Test tuple
353+
_, _, idx_tuple = ga_instance.best_solution(pop_fitness=tuple(fitness_vals))
354+
# Test numpy array
355+
_, _, idx_ndarray = ga_instance.best_solution(pop_fitness=numpy.array(fitness_vals))
356+
357+
assert idx_list == idx_tuple == idx_ndarray == 1
358+
print("test_best_solution_pop_fitness_types passed.")
359+
360+
if __name__ == "__main__":
361+
test_best_solution_consistency_single_objective()
362+
test_best_solution_consistency_multi_objective()
363+
test_best_solution_before_run()
364+
test_best_solution_with_save_solutions()
365+
test_best_solution_with_save_best_solutions()
366+
test_best_solution_with_keep_elitism()
367+
test_best_solution_with_keep_parents()
368+
test_best_solution_with_keep_parents_elitism_0()
369+
test_best_solution_pop_fitness_validation()
370+
test_best_solution_single_objective_tie()
371+
test_best_solution_with_parallel_processing()
372+
test_best_solution_with_fitness_batch_size()
373+
test_best_solution_pop_fitness_types()
374+
print("\nAll tests passed!")

0 commit comments

Comments
 (0)