@@ -50,26 +50,6 @@ def densities(self) -> Tuple[float, ...]:
5050 def n_species (self ) -> int :
5151 return int (getattr (self , "_n_species" ))
5252
53- def validate (self ) -> None :
54- """Validate core CA invariants.
55-
56- Checks that `neighborhood` is valid, that `self.grid` has the
57- texpected shape `(rows, cols)`, and that any numpy arrays in
58- `self.cell_params` have matching shapes. Raises `ValueError` on
59- validation failure.
60- """
61- if self .neighborhood not in ("neumann" , "moore" ):
62- raise ValueError ("neighborhood must be 'neumann' or 'moore'" )
63-
64- expected_shape = (int (getattr (self , "_rows" )), int (getattr (self , "_cols" )))
65- if self .grid .shape != expected_shape :
66- raise ValueError (f"grid shape { self .grid .shape } does not match expected { expected_shape } " )
67-
68- # Ensure any array in cell_params matches grid shape
69- for k , v in (self .cell_params or {}).items ():
70- if isinstance (v , np .ndarray ) and v .shape != expected_shape :
71- raise ValueError (f"cell_params['{ k } '] must have shape equal to grid" )
72-
7353 def __init__ (
7454 self ,
7555 rows : int ,
@@ -113,6 +93,16 @@ def __init__(
11393 self ._densities : Tuple [float , ...] = tuple (densities )
11494 self .params : Dict [str , object ] = dict (params ) if params is not None else {}
11595 self .cell_params : Dict [str , object ] = dict (cell_params ) if cell_params is not None else {}
96+
97+ # per-parameter evolve metadata and evolution state
98+ # maps parameter name -> dict with keys 'sd','min','max','species'
99+ self ._evolve_info : Dict [str , Dict [str , float ]] = {}
100+ # when True, inheritance uses deterministic copy from parent (no mutation)
101+ self ._evolution_stopped : bool = False
102+
103+ # human-readable species names (useful for visualization). Default
104+ # generates generic names based on n_species; subclasses may override.
105+ self .species_names : Tuple [str , ...] = tuple (f"species{ i + 1 } " for i in range (self ._n_species ))
116106 self .neighborhood : str = neighborhood
117107 self .generator : np .random .Generator = np .random .default_rng (seed )
118108
@@ -136,6 +126,67 @@ def __init__(
136126 c = chosen % cols
137127 self .grid [r , c ] = i + 1
138128
129+ def validate (self ) -> None :
130+ """Validate core CA invariants.
131+
132+ Checks that `neighborhood` is valid, that `self.grid` has the
133+ texpected shape `(rows, cols)`, and that any numpy arrays in
134+ `self.cell_params` have matching shapes. Raises `ValueError` on
135+ validation failure.
136+ """
137+ if self .neighborhood not in ("neumann" , "moore" ):
138+ raise ValueError ("neighborhood must be 'neumann' or 'moore'" )
139+
140+ expected_shape = (int (getattr (self , "_rows" )), int (getattr (self , "_cols" )))
141+ if self .grid .shape != expected_shape :
142+ raise ValueError (f"grid shape { self .grid .shape } does not match expected { expected_shape } " )
143+
144+ # Ensure any array in cell_params matches grid shape
145+ for k , v in (self .cell_params or {}).items ():
146+ if isinstance (v , np .ndarray ) and v .shape != expected_shape :
147+ raise ValueError (f"cell_params['{ k } '] must have shape equal to grid" )
148+
149+ def _infer_species_from_param_name (self , param_name : str ) -> Optional [int ]:
150+ """Infer species index (1-based) from a parameter name using `species_names`.
151+
152+ Returns the 1-based species index if a matching prefix is found,
153+ otherwise `None`.
154+ """
155+ if not isinstance (param_name , str ):
156+ return None
157+ for idx , name in enumerate (self .species_names or ()): # type: ignore
158+ if isinstance (name , str ) and param_name .startswith (f"{ name } _" ):
159+ return idx + 1
160+ return None
161+
162+ def evolve (self , param : str , species : Optional [int ] = None , sd : float = 0.05 , min_val : Optional [float ] = None , max_val : Optional [float ] = None ) -> None :
163+ """Enable per-cell evolution for `param` on `species`.
164+
165+ If `species` is None, attempt to infer the species using
166+ `_infer_species_from_param_name(param)` which matches against
167+ `self.species_names`. This keeps `CA` free of domain-specific
168+ (predator/prey) logic while preserving backward compatibility when
169+ subclasses set `species_names` (e.g. `('prey','predator')`).
170+ """
171+ if min_val is None :
172+ min_val = 0.01
173+ if max_val is None :
174+ max_val = 0.99
175+ if param not in self .params :
176+ raise ValueError (f"Unknown parameter '{ param } '" )
177+ if species is None :
178+ species = self ._infer_species_from_param_name (param )
179+ if species is None :
180+ raise ValueError ("species must be provided or inferable from param name and species_names" )
181+ if not isinstance (species , int ) or species <= 0 or species > self ._n_species :
182+ raise ValueError ("species must be an integer between 1 and n_species" )
183+
184+ arr = np .full (self .grid .shape , np .nan , dtype = float )
185+ mask = (self .grid == int (species ))
186+ arr [mask ] = float (self .params [param ])
187+ self .cell_params [param ] = arr
188+ self ._evolve_info [param ] = {"sd" : float (sd ), "min" : float (min_val ), "max" : float (max_val ), "species" : int (species )}
189+
139190 def count_neighbors (self ) -> Tuple [np .ndarray , ...]:
140191 """Count neighbors for each non-zero state.
141192
@@ -200,14 +251,6 @@ def run(self, steps: int, stop_evolution_at: Optional[int] = None, snapshot_iter
200251 """
201252 assert isinstance (steps , int ) and steps >= 0 , "steps must be a non-negative integer"
202253
203- # NOTE: validation of `cell_params` and evolved parameters has been
204- # moved to the `validate()` method on the class. The run loop no
205- # longer performs per-cell validation for performance; call
206- # `validate()` explicitly when needed.
207-
208- # normalize snapshot iteration list
209- snapshot_set = set (snapshot_iters ) if snapshot_iters is not None else set ()
210-
211254 # normalize snapshot iteration list
212255 snapshot_set = set (snapshot_iters ) if snapshot_iters is not None else set ()
213256
@@ -241,32 +284,9 @@ def run(self, steps: int, stop_evolution_at: Optional[int] = None, snapshot_iter
241284
242285 # stop evolution at specified time-step (disable further evolution)
243286 if stop_evolution_at is not None and (i + 1 ) == int (stop_evolution_at ):
244- # disable further evolution
245- self ._evolve_info = {}
246-
247- # create snapshots if requested at this iteration
248- if (i + 1 ) in snapshot_set :
249- try :
250- # create snapshot folder if not present
251- if not hasattr (self , "_viz_snapshot_dir" ) or self ._viz_snapshot_dir is None :
252- import os , time
253-
254- base = "results"
255- ts = int (time .time ())
256- run_folder = f"run-{ ts } "
257- full = os .path .join (base , run_folder )
258- os .makedirs (full , exist_ok = True )
259- self ._viz_snapshot_dir = full
260- self ._viz_save_snapshot (i + 1 )
261- except Exception :
262- pass
263-
264- # stop evolution at specified time-step (disable further evolution)
265- if stop_evolution_at is not None and (i + 1 ) == int (stop_evolution_at ):
266- try :
267- self ._evolve_info = {}
268- except Exception :
269- pass
287+ # mark evolution as stopped; do not erase evolve metadata so
288+ # deterministic inheritance can still use parent values
289+ self ._evolution_stopped = True
270290
271291 def visualize (
272292 self ,
@@ -859,51 +879,11 @@ def __init__(
859879 super ().__init__ (rows , cols , densities , neighborhood , merged_params , cell_params , seed )
860880
861881 self .synchronous : bool = bool (synchronous )
882+ # set human-friendly species names for PP
883+ self .species_names = ("prey" , "predator" )
862884
863- # Information about which parameters are being evolved and their mutation specs
864- # Maps parameter name -> dict with keys: 'sd', 'min', 'max'
865- self ._evolve_info : Dict [str , Dict [str , float ]] = {}
866-
867-
868- def evolve (self , param : str , sd : float = 0.05 , min_val : Optional [float ] = None , max_val : Optional [float ] = None ) -> None :
869- """Enable per-cell evolution for a given parameter.
870-
871- Creates a per-cell array in `self.cell_params[param]` with the same
872- shape as the grid. Cells currently occupied by the relevant species are
873- initialized to the global value in `self.params[param]`; other cells are
874- set to NaN. Mutation metadata (sd, min, max) are stored in
875- `self._evolve_info[param]`.
876-
877- Args:
878- - param: one of the keys in `self.params` (e.g. 'prey_death')
879- - sd: standard deviation for Gaussian mutations
880- - min: minimum clipped value after mutation
881- - max: maximum clipped value after mutation
882- """
883- # Note: deprecated keyword names `min` and `max` have been removed.
884- # Callers must use `min_val` and `max_val` explicitly.
885- # Provide sensible defaults when not specified
886- if min_val is None :
887- min_val = 0.01
888- if max_val is None :
889- max_val = 0.99
890- if param not in self .params :
891- raise ValueError (f"Unknown parameter '{ param } '" )
892-
893- # determine target species for this parameter
894- if param .startswith ("prey_" ):
895- species = 1
896- elif param .startswith ("predator_" ):
897- species = 2
898- else :
899- raise ValueError ("Parameter must start with 'prey_' or 'predator_' to evolve" )
900885
901- # create per-cell float array with NaNs for non-relevant cells
902- arr = np .full (self .grid .shape , np .nan , dtype = float )
903- mask = (self .grid == species )
904- arr [mask ] = float (self .params [param ])
905- self .cell_params [param ] = arr
906- self ._evolve_info [param ] = {"sd" : float (sd ), "min" : float (min_val ), "max" : float (max_val )}
886+ # Remove PP-specific evolve wrapper; use CA.evolve with optional species
907887
908888 def validate (self ) -> None :
909889 """Validate PP-specific invariants in addition to base CA checks.
@@ -932,8 +912,15 @@ def validate(self) -> None:
932912 # shape already checked in super().validate(), but be explicit
933913 if arr .shape != self .grid .shape :
934914 raise ValueError (f"cell_params['{ pname } '] must match grid shape" )
935- # expected non-NaN positions correspond to species in grid
936- species = 1 if pname .startswith ("prey_" ) else 2
915+ # expected non-NaN positions correspond to species stored in metadata
916+ species = None
917+ if isinstance (meta , dict ) and "species" in meta :
918+ species = int (meta .get ("species" ))
919+ else :
920+ # try to infer species from parameter name using species_names
921+ species = self ._infer_species_from_param_name (pname )
922+ if species is None :
923+ raise ValueError (f"cell_params['{ pname } '] missing species metadata and could not infer from name" )
937924 nonnan = ~ np .isnan (arr )
938925 expected = (self .grid == species )
939926 if not np .array_equal (nonnan , expected ):
@@ -970,15 +957,23 @@ def _apply_deaths_and_clear_params(self, grid_ref: np.ndarray, rand_prey: np.nda
970957 self .grid [pred_death_mask ] = 0
971958
972959 # Clear per-cell parameters for dead individuals
973- for pname in self ._evolve_info :
974- if pname .startswith ("prey_" ):
975- arr = self .cell_params .get (pname )
976- if isinstance (arr , np .ndarray ) and arr .shape == self .grid .shape :
977- arr [prey_death_mask ] = np .nan
978- elif pname .startswith ("predator_" ):
979- arr = self .cell_params .get (pname )
980- if isinstance (arr , np .ndarray ) and arr .shape == self .grid .shape :
981- arr [pred_death_mask ] = np .nan
960+ for pname , meta in self ._evolve_info .items ():
961+ # determine species from metadata or infer from name
962+ species = None
963+ if isinstance (meta , dict ) and "species" in meta :
964+ species = int (meta .get ("species" ))
965+ else :
966+ species = self ._infer_species_from_param_name (pname )
967+ if species is None :
968+ # cannot determine species; skip clearing for safety
969+ continue
970+ arr = self .cell_params .get (pname )
971+ if not (isinstance (arr , np .ndarray ) and arr .shape == self .grid .shape ):
972+ continue
973+ if species == 1 :
974+ arr [prey_death_mask ] = np .nan
975+ elif species == 2 :
976+ arr [pred_death_mask ] = np .nan
982977
983978 def _neighbor_shifts (self ) -> Tuple [np .ndarray , np .ndarray , int ]:
984979 """Return neighbor shift arrays (dr_arr, dc_arr, n_shifts) for the configured neighborhood."""
@@ -1007,16 +1002,17 @@ def _inherit_params_on_birth(self, chosen_rs: np.ndarray, chosen_cs: np.ndarray,
10071002 an array of parent coordinates with same length.
10081003 """
10091004 for pname , meta in self ._evolve_info .items ():
1010- # determine species this parameter belongs to
1011- if pname .startswith ("prey_" ):
1012- target_species = 1
1013- elif pname .startswith ("predator_" ):
1014- target_species = 2
1005+ # determine species this parameter belongs to via metadata or inference
1006+ species = None
1007+ if isinstance (meta , dict ) and "species" in meta :
1008+ species = int (meta .get ("species" ))
10151009 else :
1016- continue
1010+ species = self ._infer_species_from_param_name (pname )
1011+ if species is None :
1012+ raise ValueError (f"_evolve_info contains unexpected key '{ pname } ' without species metadata and could not infer" )
10171013
10181014 # if new_state is not the species for this param, clear at targets
1019- if new_state_val != target_species :
1015+ if new_state_val != species :
10201016 arr = self .cell_params .get (pname )
10211017 if isinstance (arr , np .ndarray ) and arr .shape == self .grid .shape :
10221018 arr [chosen_rs , chosen_cs ] = np .nan
@@ -1037,8 +1033,14 @@ def _inherit_params_on_birth(self, chosen_rs: np.ndarray, chosen_cs: np.ndarray,
10371033 sd = float (meta ["sd" ])
10381034 mn = float (meta ["min" ])
10391035 mx = float (meta ["max" ])
1040- mut = parent_vals + self .generator .normal (0.0 , sd , size = parent_vals .shape )
1041- mut = np .clip (mut , mn , mx )
1036+ # If evolution has been stopped, inheritance is deterministic: copy
1037+ # parent values directly without Gaussian mutation so we can observe
1038+ # which parameter values survive.
1039+ if getattr (self , "_evolution_stopped" , False ):
1040+ mut = parent_vals .copy ()
1041+ else :
1042+ mut = parent_vals + self .generator .normal (0.0 , sd , size = parent_vals .shape )
1043+ mut = np .clip (mut , mn , mx )
10421044 # If an array exists but has wrong shape, raise an informative error
10431045 existing = self .cell_params .get (pname )
10441046 if isinstance (existing , np .ndarray ) and existing .shape != self .grid .shape :
0 commit comments