@@ -50,26 +50,6 @@ def densities(self) -> Tuple[float, ...]:
5050 def n_species (self ) -> int :
5151 return int (getattr (self , "_n_species" ))
5252
53- def validate (self ) -> None :
54- """Validate core CA invariants.
55-
56- Checks that `neighborhood` is valid, that `self.grid` has the
57- texpected shape `(rows, cols)`, and that any numpy arrays in
58- `self.cell_params` have matching shapes. Raises `ValueError` on
59- validation failure.
60- """
61- if self .neighborhood not in ("neumann" , "moore" ):
62- raise ValueError ("neighborhood must be 'neumann' or 'moore'" )
63-
64- expected_shape = (int (getattr (self , "_rows" )), int (getattr (self , "_cols" )))
65- if self .grid .shape != expected_shape :
66- raise ValueError (f"grid shape { self .grid .shape } does not match expected { expected_shape } " )
67-
68- # Ensure any array in cell_params matches grid shape
69- for k , v in (self .cell_params or {}).items ():
70- if isinstance (v , np .ndarray ) and v .shape != expected_shape :
71- raise ValueError (f"cell_params['{ k } '] must have shape equal to grid" )
72-
7353 def __init__ (
7454 self ,
7555 rows : int ,
@@ -113,6 +93,16 @@ def __init__(
11393 self ._densities : Tuple [float , ...] = tuple (densities )
11494 self .params : Dict [str , object ] = dict (params ) if params is not None else {}
11595 self .cell_params : Dict [str , object ] = dict (cell_params ) if cell_params is not None else {}
96+
97+ # per-parameter evolve metadata and evolution state
98+ # maps parameter name -> dict with keys 'sd','min','max','species'
99+ self ._evolve_info : Dict [str , Dict [str , float ]] = {}
100+ # when True, inheritance uses deterministic copy from parent (no mutation)
101+ self ._evolution_stopped : bool = False
102+
103+ # human-readable species names (useful for visualization). Default
104+ # generates generic names based on n_species; subclasses may override.
105+ self .species_names : Tuple [str , ...] = tuple (f"species{ i + 1 } " for i in range (self ._n_species ))
116106 self .neighborhood : str = neighborhood
117107 self .generator : np .random .Generator = np .random .default_rng (seed )
118108
@@ -136,6 +126,67 @@ def __init__(
136126 c = chosen % cols
137127 self .grid [r , c ] = i + 1
138128
129+ def validate (self ) -> None :
130+ """Validate core CA invariants.
131+
132+ Checks that `neighborhood` is valid, that `self.grid` has the
133+ texpected shape `(rows, cols)`, and that any numpy arrays in
134+ `self.cell_params` have matching shapes. Raises `ValueError` on
135+ validation failure.
136+ """
137+ if self .neighborhood not in ("neumann" , "moore" ):
138+ raise ValueError ("neighborhood must be 'neumann' or 'moore'" )
139+
140+ expected_shape = (int (getattr (self , "_rows" )), int (getattr (self , "_cols" )))
141+ if self .grid .shape != expected_shape :
142+ raise ValueError (f"grid shape { self .grid .shape } does not match expected { expected_shape } " )
143+
144+ # Ensure any array in cell_params matches grid shape
145+ for k , v in (self .cell_params or {}).items ():
146+ if isinstance (v , np .ndarray ) and v .shape != expected_shape :
147+ raise ValueError (f"cell_params['{ k } '] must have shape equal to grid" )
148+
149+ def _infer_species_from_param_name (self , param_name : str ) -> Optional [int ]:
150+ """Infer species index (1-based) from a parameter name using `species_names`.
151+
152+ Returns the 1-based species index if a matching prefix is found,
153+ otherwise `None`.
154+ """
155+ if not isinstance (param_name , str ):
156+ return None
157+ for idx , name in enumerate (self .species_names or ()): # type: ignore
158+ if isinstance (name , str ) and param_name .startswith (f"{ name } _" ):
159+ return idx + 1
160+ return None
161+
162+ def evolve (self , param : str , species : Optional [int ] = None , sd : float = 0.05 , min_val : Optional [float ] = None , max_val : Optional [float ] = None ) -> None :
163+ """Enable per-cell evolution for `param` on `species`.
164+
165+ If `species` is None, attempt to infer the species using
166+ `_infer_species_from_param_name(param)` which matches against
167+ `self.species_names`. This keeps `CA` free of domain-specific
168+ (predator/prey) logic while preserving backward compatibility when
169+ subclasses set `species_names` (e.g. `('prey','predator')`).
170+ """
171+ if min_val is None :
172+ min_val = 0.01
173+ if max_val is None :
174+ max_val = 0.99
175+ if param not in self .params :
176+ raise ValueError (f"Unknown parameter '{ param } '" )
177+ if species is None :
178+ species = self ._infer_species_from_param_name (param )
179+ if species is None :
180+ raise ValueError ("species must be provided or inferable from param name and species_names" )
181+ if not isinstance (species , int ) or species <= 0 or species > self ._n_species :
182+ raise ValueError ("species must be an integer between 1 and n_species" )
183+
184+ arr = np .full (self .grid .shape , np .nan , dtype = float )
185+ mask = (self .grid == int (species ))
186+ arr [mask ] = float (self .params [param ])
187+ self .cell_params [param ] = arr
188+ self ._evolve_info [param ] = {"sd" : float (sd ), "min" : float (min_val ), "max" : float (max_val ), "species" : int (species )}
189+
139190 def count_neighbors (self ) -> Tuple [np .ndarray , ...]:
140191 """Count neighbors for each non-zero state.
141192
@@ -200,14 +251,6 @@ def run(self, steps: int, stop_evolution_at: Optional[int] = None, snapshot_iter
200251 """
201252 assert isinstance (steps , int ) and steps >= 0 , "steps must be a non-negative integer"
202253
203- # NOTE: validation of `cell_params` and evolved parameters has been
204- # moved to the `validate()` method on the class. The run loop no
205- # longer performs per-cell validation for performance; call
206- # `validate()` explicitly when needed.
207-
208- # normalize snapshot iteration list
209- snapshot_set = set (snapshot_iters ) if snapshot_iters is not None else set ()
210-
211254 # normalize snapshot iteration list
212255 snapshot_set = set (snapshot_iters ) if snapshot_iters is not None else set ()
213256
@@ -241,8 +284,9 @@ def run(self, steps: int, stop_evolution_at: Optional[int] = None, snapshot_iter
241284
242285 # stop evolution at specified time-step (disable further evolution)
243286 if stop_evolution_at is not None and (i + 1 ) == int (stop_evolution_at ):
244- # disable further evolution
245- self ._evolve_info = {}
287+ # mark evolution as stopped; do not erase evolve metadata so
288+ # deterministic inheritance can still use parent values
289+ self ._evolution_stopped = True
246290
247291 def visualize (
248292 self ,
@@ -775,6 +819,7 @@ def _viz_update(self, iteration: int) -> None:
775819
776820class PP (CA ):
777821 """Predator-prey CA.
822+
778823 States: 0 = empty, 1 = prey, 2 = predator
779824
780825 Parameters (in `params` dict). Allowed keys and defaults:
@@ -834,51 +879,11 @@ def __init__(
834879 super ().__init__ (rows , cols , densities , neighborhood , merged_params , cell_params , seed )
835880
836881 self .synchronous : bool = bool (synchronous )
837-
838- # Information about which parameters are being evolved and their mutation specs
839- # Maps parameter name -> dict with keys: 'sd', 'min', 'max'
840- self ._evolve_info : Dict [str , Dict [str , float ]] = {}
882+ # set human-friendly species names for PP
883+ self .species_names = ("prey" , "predator" )
841884
842885
843- def evolve (self , param : str , sd : float = 0.05 , min_val : Optional [float ] = None , max_val : Optional [float ] = None ) -> None :
844- """Enable per-cell evolution for a given parameter.
845-
846- Creates a per-cell array in `self.cell_params[param]` with the same
847- shape as the grid. Cells currently occupied by the relevant species are
848- initialized to the global value in `self.params[param]`; other cells are
849- set to NaN. Mutation metadata (sd, min, max) are stored in
850- `self._evolve_info[param]`.
851-
852- Args:
853- - param: one of the keys in `self.params` (e.g. 'prey_death')
854- - sd: standard deviation for Gaussian mutations
855- - min: minimum clipped value after mutation
856- - max: maximum clipped value after mutation
857- """
858- # Note: deprecated keyword names `min` and `max` have been removed.
859- # Callers must use `min_val` and `max_val` explicitly.
860- # Provide sensible defaults when not specified
861- if min_val is None :
862- min_val = 0.01
863- if max_val is None :
864- max_val = 0.99
865- if param not in self .params :
866- raise ValueError (f"Unknown parameter '{ param } '" )
867-
868- # determine target species for this parameter
869- if param .startswith ("prey_" ):
870- species = 1
871- elif param .startswith ("predator_" ):
872- species = 2
873- else :
874- raise ValueError ("Parameter must start with 'prey_' or 'predator_' to evolve" )
875-
876- # create per-cell float array with NaNs for non-relevant cells
877- arr = np .full (self .grid .shape , np .nan , dtype = float )
878- mask = (self .grid == species )
879- arr [mask ] = float (self .params [param ])
880- self .cell_params [param ] = arr
881- self ._evolve_info [param ] = {"sd" : float (sd ), "min" : float (min_val ), "max" : float (max_val )}
886+ # Remove PP-specific evolve wrapper; use CA.evolve with optional species
882887
883888 def validate (self ) -> None :
884889 """Validate PP-specific invariants in addition to base CA checks.
@@ -907,8 +912,15 @@ def validate(self) -> None:
907912 # shape already checked in super().validate(), but be explicit
908913 if arr .shape != self .grid .shape :
909914 raise ValueError (f"cell_params['{ pname } '] must match grid shape" )
910- # expected non-NaN positions correspond to species in grid
911- species = 1 if pname .startswith ("prey_" ) else 2
915+ # expected non-NaN positions correspond to species stored in metadata
916+ species = None
917+ if isinstance (meta , dict ) and "species" in meta :
918+ species = int (meta .get ("species" ))
919+ else :
920+ # try to infer species from parameter name using species_names
921+ species = self ._infer_species_from_param_name (pname )
922+ if species is None :
923+ raise ValueError (f"cell_params['{ pname } '] missing species metadata and could not infer from name" )
912924 nonnan = ~ np .isnan (arr )
913925 expected = (self .grid == species )
914926 if not np .array_equal (nonnan , expected ):
@@ -945,15 +957,23 @@ def _apply_deaths_and_clear_params(self, grid_ref: np.ndarray, rand_prey: np.nda
945957 self .grid [pred_death_mask ] = 0
946958
947959 # Clear per-cell parameters for dead individuals
948- for pname in self ._evolve_info :
949- if pname .startswith ("prey_" ):
950- arr = self .cell_params .get (pname )
951- if isinstance (arr , np .ndarray ) and arr .shape == self .grid .shape :
952- arr [prey_death_mask ] = np .nan
953- elif pname .startswith ("predator_" ):
954- arr = self .cell_params .get (pname )
955- if isinstance (arr , np .ndarray ) and arr .shape == self .grid .shape :
956- arr [pred_death_mask ] = np .nan
960+ for pname , meta in self ._evolve_info .items ():
961+ # determine species from metadata or infer from name
962+ species = None
963+ if isinstance (meta , dict ) and "species" in meta :
964+ species = int (meta .get ("species" ))
965+ else :
966+ species = self ._infer_species_from_param_name (pname )
967+ if species is None :
968+ # cannot determine species; skip clearing for safety
969+ continue
970+ arr = self .cell_params .get (pname )
971+ if not (isinstance (arr , np .ndarray ) and arr .shape == self .grid .shape ):
972+ continue
973+ if species == 1 :
974+ arr [prey_death_mask ] = np .nan
975+ elif species == 2 :
976+ arr [pred_death_mask ] = np .nan
957977
958978 def _neighbor_shifts (self ) -> Tuple [np .ndarray , np .ndarray , int ]:
959979 """Return neighbor shift arrays (dr_arr, dc_arr, n_shifts) for the configured neighborhood."""
@@ -982,16 +1002,17 @@ def _inherit_params_on_birth(self, chosen_rs: np.ndarray, chosen_cs: np.ndarray,
9821002 an array of parent coordinates with same length.
9831003 """
9841004 for pname , meta in self ._evolve_info .items ():
985- # determine species this parameter belongs to
986- if pname .startswith ("prey_" ):
987- target_species = 1
988- elif pname .startswith ("predator_" ):
989- target_species = 2
1005+ # determine species this parameter belongs to via metadata or inference
1006+ species = None
1007+ if isinstance (meta , dict ) and "species" in meta :
1008+ species = int (meta .get ("species" ))
9901009 else :
991- continue
1010+ species = self ._infer_species_from_param_name (pname )
1011+ if species is None :
1012+ raise ValueError (f"_evolve_info contains unexpected key '{ pname } ' without species metadata and could not infer" )
9921013
9931014 # if new_state is not the species for this param, clear at targets
994- if new_state_val != target_species :
1015+ if new_state_val != species :
9951016 arr = self .cell_params .get (pname )
9961017 if isinstance (arr , np .ndarray ) and arr .shape == self .grid .shape :
9971018 arr [chosen_rs , chosen_cs ] = np .nan
@@ -1012,8 +1033,14 @@ def _inherit_params_on_birth(self, chosen_rs: np.ndarray, chosen_cs: np.ndarray,
10121033 sd = float (meta ["sd" ])
10131034 mn = float (meta ["min" ])
10141035 mx = float (meta ["max" ])
1015- mut = parent_vals + self .generator .normal (0.0 , sd , size = parent_vals .shape )
1016- mut = np .clip (mut , mn , mx )
1036+ # If evolution has been stopped, inheritance is deterministic: copy
1037+ # parent values directly without Gaussian mutation so we can observe
1038+ # which parameter values survive.
1039+ if getattr (self , "_evolution_stopped" , False ):
1040+ mut = parent_vals .copy ()
1041+ else :
1042+ mut = parent_vals + self .generator .normal (0.0 , sd , size = parent_vals .shape )
1043+ mut = np .clip (mut , mn , mx )
10171044 # If an array exists but has wrong shape, raise an informative error
10181045 existing = self .cell_params .get (pname )
10191046 if isinstance (existing , np .ndarray ) and existing .shape != self .grid .shape :
0 commit comments