Skip to content

Commit 55e6672

Browse files
author
Dex
committed
Altered to allow function_inputs and desired_output to be passed as parameters on initiation.
1 parent 76bb230 commit 55e6672

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

pygad.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,11 @@ def __init__(self,
4747
save_best_solutions=False,
4848
save_solutions=False,
4949
suppress_warnings=False,
50-
stop_criteria=None):
50+
stop_criteria=None,
51+
# Initialization for data parameters on fitness function
52+
desired_output=None,
53+
function_inputs=None
54+
):
5155

5256
"""
5357
The constructor of the GA class accepts all parameters required to create an instance of the GA class. It validates such parameters.
@@ -628,14 +632,26 @@ def __init__(self,
628632
elif (self.keep_parents > 0): # Keep the specified number of parents in the next population.
629633
self.num_offspring = self.sol_per_pop - self.keep_parents
630634

635+
# Initialization for data parameters for fitness function
636+
self.fitness_function_extra_data_set = False
637+
if not desired_output.empty and not function_inputs.empty:
638+
self.fitness_function_extra_data_set = True
639+
self.desired_output = desired_output
640+
self.function_inputs = function_inputs
641+
631642
# Check if the fitness_func is a function.
632643
if callable(fitness_func):
633644
# Check if the fitness function accepts 2 paramaters.
634645
if (fitness_func.__code__.co_argcount == 2):
635646
self.fitness_func = fitness_func
647+
elif self.fitness_function_extra_data_set and (fitness_func.__code__.co_argcount == 4):
648+
self.fitness_func = fitness_func
636649
else:
637650
self.valid_parameters = False
638-
raise ValueError("The fitness function must accept 2 parameters:\n1) A solution to calculate its fitness value.\n2) The solution's index within the population.\n\nThe passed fitness function named '{funcname}' accepts {argcount} parameter(s).".format(funcname=fitness_func.__code__.co_name, argcount=fitness_func.__code__.co_argcount))
651+
if not self.fitness_function_extra_data_set:
652+
raise ValueError("The fitness function must accept 2 parameters:\n1) A solution to calculate its fitness value.\n2) The solution's index within the population.\n\nThe passed fitness function named '{funcname}' accepts {argcount} parameter(s).".format(funcname=fitness_func.__code__.co_name, argcount=fitness_func.__code__.co_argcount))
653+
else:
654+
raise ValueError("The fitness function with extra_data must accept 4 parameters:\n1) A solution to calculate its fitness value.\n2) The solution's index within the population\n3) The desired output.\n4) The function input\n\nThe passed fitness function named '{funcname}' accepts {argcount} parameter(s).".format(funcname=fitness_func.__code__.co_name, argcount=fitness_func.__code__.co_argcount))
639655
else:
640656
self.valid_parameters = False
641657
raise ValueError("The value assigned to the fitness_func parameter is expected to be of type function but ({fitness_func_type}) found.".format(fitness_func_type=type(fitness_func)))
@@ -1156,7 +1172,10 @@ def cal_pop_fitness(self):
11561172
# Use the parent's index to return its pre-calculated fitness value.
11571173
fitness = self.previous_generation_fitness[parent_idx]
11581174
else:
1159-
fitness = self.fitness_func(sol, sol_idx)
1175+
if not self.fitness_function_extra_data_set:
1176+
fitness = self.fitness_func(sol, sol_idx)
1177+
else:
1178+
fitness = self.fitness_func(sol, sol_idx, self.desired_output, self.function_inputs)
11601179
if type(fitness) in GA.supported_int_float_types:
11611180
pass
11621181
else:

0 commit comments

Comments
 (0)