Skip to content

Commit f426df5

Browse files
authored
Merge pull request #7 from codegithubka/storm
Hydra effect observed
2 parents 33b6500 + 0c610c8 commit f426df5

2 files changed

Lines changed: 120 additions & 118 deletions

File tree

models/CA.py

Lines changed: 119 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -50,26 +50,6 @@ def densities(self) -> Tuple[float, ...]:
5050
def n_species(self) -> int:
5151
return int(getattr(self, "_n_species"))
5252

53-
def validate(self) -> None:
54-
"""Validate core CA invariants.
55-
56-
Checks that `neighborhood` is valid, that `self.grid` has the
57-
texpected shape `(rows, cols)`, and that any numpy arrays in
58-
`self.cell_params` have matching shapes. Raises `ValueError` on
59-
validation failure.
60-
"""
61-
if self.neighborhood not in ("neumann", "moore"):
62-
raise ValueError("neighborhood must be 'neumann' or 'moore'")
63-
64-
expected_shape = (int(getattr(self, "_rows")), int(getattr(self, "_cols")))
65-
if self.grid.shape != expected_shape:
66-
raise ValueError(f"grid shape {self.grid.shape} does not match expected {expected_shape}")
67-
68-
# Ensure any array in cell_params matches grid shape
69-
for k, v in (self.cell_params or {}).items():
70-
if isinstance(v, np.ndarray) and v.shape != expected_shape:
71-
raise ValueError(f"cell_params['{k}'] must have shape equal to grid")
72-
7353
def __init__(
7454
self,
7555
rows: int,
@@ -113,6 +93,16 @@ def __init__(
11393
self._densities: Tuple[float, ...] = tuple(densities)
11494
self.params: Dict[str, object] = dict(params) if params is not None else {}
11595
self.cell_params: Dict[str, object] = dict(cell_params) if cell_params is not None else {}
96+
97+
# per-parameter evolve metadata and evolution state
98+
# maps parameter name -> dict with keys 'sd','min','max','species'
99+
self._evolve_info: Dict[str, Dict[str, float]] = {}
100+
# when True, inheritance uses deterministic copy from parent (no mutation)
101+
self._evolution_stopped: bool = False
102+
103+
# human-readable species names (useful for visualization). Default
104+
# generates generic names based on n_species; subclasses may override.
105+
self.species_names: Tuple[str, ...] = tuple(f"species{i+1}" for i in range(self._n_species))
116106
self.neighborhood: str = neighborhood
117107
self.generator: np.random.Generator = np.random.default_rng(seed)
118108

@@ -136,6 +126,67 @@ def __init__(
136126
c = chosen % cols
137127
self.grid[r, c] = i + 1
138128

129+
def validate(self) -> None:
130+
"""Validate core CA invariants.
131+
132+
Checks that `neighborhood` is valid, that `self.grid` has the
133+
texpected shape `(rows, cols)`, and that any numpy arrays in
134+
`self.cell_params` have matching shapes. Raises `ValueError` on
135+
validation failure.
136+
"""
137+
if self.neighborhood not in ("neumann", "moore"):
138+
raise ValueError("neighborhood must be 'neumann' or 'moore'")
139+
140+
expected_shape = (int(getattr(self, "_rows")), int(getattr(self, "_cols")))
141+
if self.grid.shape != expected_shape:
142+
raise ValueError(f"grid shape {self.grid.shape} does not match expected {expected_shape}")
143+
144+
# Ensure any array in cell_params matches grid shape
145+
for k, v in (self.cell_params or {}).items():
146+
if isinstance(v, np.ndarray) and v.shape != expected_shape:
147+
raise ValueError(f"cell_params['{k}'] must have shape equal to grid")
148+
149+
def _infer_species_from_param_name(self, param_name: str) -> Optional[int]:
150+
"""Infer species index (1-based) from a parameter name using `species_names`.
151+
152+
Returns the 1-based species index if a matching prefix is found,
153+
otherwise `None`.
154+
"""
155+
if not isinstance(param_name, str):
156+
return None
157+
for idx, name in enumerate(self.species_names or ()): # type: ignore
158+
if isinstance(name, str) and param_name.startswith(f"{name}_"):
159+
return idx + 1
160+
return None
161+
162+
def evolve(self, param: str, species: Optional[int] = None, sd: float = 0.05, min_val: Optional[float] = None, max_val: Optional[float] = None) -> None:
163+
"""Enable per-cell evolution for `param` on `species`.
164+
165+
If `species` is None, attempt to infer the species using
166+
`_infer_species_from_param_name(param)` which matches against
167+
`self.species_names`. This keeps `CA` free of domain-specific
168+
(predator/prey) logic while preserving backward compatibility when
169+
subclasses set `species_names` (e.g. `('prey','predator')`).
170+
"""
171+
if min_val is None:
172+
min_val = 0.01
173+
if max_val is None:
174+
max_val = 0.99
175+
if param not in self.params:
176+
raise ValueError(f"Unknown parameter '{param}'")
177+
if species is None:
178+
species = self._infer_species_from_param_name(param)
179+
if species is None:
180+
raise ValueError("species must be provided or inferable from param name and species_names")
181+
if not isinstance(species, int) or species <= 0 or species > self._n_species:
182+
raise ValueError("species must be an integer between 1 and n_species")
183+
184+
arr = np.full(self.grid.shape, np.nan, dtype=float)
185+
mask = (self.grid == int(species))
186+
arr[mask] = float(self.params[param])
187+
self.cell_params[param] = arr
188+
self._evolve_info[param] = {"sd": float(sd), "min": float(min_val), "max": float(max_val), "species": int(species)}
189+
139190
def count_neighbors(self) -> Tuple[np.ndarray, ...]:
140191
"""Count neighbors for each non-zero state.
141192
@@ -200,14 +251,6 @@ def run(self, steps: int, stop_evolution_at: Optional[int] = None, snapshot_iter
200251
"""
201252
assert isinstance(steps, int) and steps >= 0, "steps must be a non-negative integer"
202253

203-
# NOTE: validation of `cell_params` and evolved parameters has been
204-
# moved to the `validate()` method on the class. The run loop no
205-
# longer performs per-cell validation for performance; call
206-
# `validate()` explicitly when needed.
207-
208-
# normalize snapshot iteration list
209-
snapshot_set = set(snapshot_iters) if snapshot_iters is not None else set()
210-
211254
# normalize snapshot iteration list
212255
snapshot_set = set(snapshot_iters) if snapshot_iters is not None else set()
213256

@@ -241,32 +284,9 @@ def run(self, steps: int, stop_evolution_at: Optional[int] = None, snapshot_iter
241284

242285
# stop evolution at specified time-step (disable further evolution)
243286
if stop_evolution_at is not None and (i + 1) == int(stop_evolution_at):
244-
# disable further evolution
245-
self._evolve_info = {}
246-
247-
# create snapshots if requested at this iteration
248-
if (i + 1) in snapshot_set:
249-
try:
250-
# create snapshot folder if not present
251-
if not hasattr(self, "_viz_snapshot_dir") or self._viz_snapshot_dir is None:
252-
import os, time
253-
254-
base = "results"
255-
ts = int(time.time())
256-
run_folder = f"run-{ts}"
257-
full = os.path.join(base, run_folder)
258-
os.makedirs(full, exist_ok=True)
259-
self._viz_snapshot_dir = full
260-
self._viz_save_snapshot(i + 1)
261-
except Exception:
262-
pass
263-
264-
# stop evolution at specified time-step (disable further evolution)
265-
if stop_evolution_at is not None and (i + 1) == int(stop_evolution_at):
266-
try:
267-
self._evolve_info = {}
268-
except Exception:
269-
pass
287+
# mark evolution as stopped; do not erase evolve metadata so
288+
# deterministic inheritance can still use parent values
289+
self._evolution_stopped = True
270290

271291
def visualize(
272292
self,
@@ -859,51 +879,11 @@ def __init__(
859879
super().__init__(rows, cols, densities, neighborhood, merged_params, cell_params, seed)
860880

861881
self.synchronous: bool = bool(synchronous)
882+
# set human-friendly species names for PP
883+
self.species_names = ("prey", "predator")
862884

863-
# Information about which parameters are being evolved and their mutation specs
864-
# Maps parameter name -> dict with keys: 'sd', 'min', 'max'
865-
self._evolve_info: Dict[str, Dict[str, float]] = {}
866-
867-
868-
def evolve(self, param: str, sd: float = 0.05, min_val: Optional[float] = None, max_val: Optional[float] = None) -> None:
869-
"""Enable per-cell evolution for a given parameter.
870-
871-
Creates a per-cell array in `self.cell_params[param]` with the same
872-
shape as the grid. Cells currently occupied by the relevant species are
873-
initialized to the global value in `self.params[param]`; other cells are
874-
set to NaN. Mutation metadata (sd, min, max) are stored in
875-
`self._evolve_info[param]`.
876-
877-
Args:
878-
- param: one of the keys in `self.params` (e.g. 'prey_death')
879-
- sd: standard deviation for Gaussian mutations
880-
- min: minimum clipped value after mutation
881-
- max: maximum clipped value after mutation
882-
"""
883-
# Note: deprecated keyword names `min` and `max` have been removed.
884-
# Callers must use `min_val` and `max_val` explicitly.
885-
# Provide sensible defaults when not specified
886-
if min_val is None:
887-
min_val = 0.01
888-
if max_val is None:
889-
max_val = 0.99
890-
if param not in self.params:
891-
raise ValueError(f"Unknown parameter '{param}'")
892-
893-
# determine target species for this parameter
894-
if param.startswith("prey_"):
895-
species = 1
896-
elif param.startswith("predator_"):
897-
species = 2
898-
else:
899-
raise ValueError("Parameter must start with 'prey_' or 'predator_' to evolve")
900885

901-
# create per-cell float array with NaNs for non-relevant cells
902-
arr = np.full(self.grid.shape, np.nan, dtype=float)
903-
mask = (self.grid == species)
904-
arr[mask] = float(self.params[param])
905-
self.cell_params[param] = arr
906-
self._evolve_info[param] = {"sd": float(sd), "min": float(min_val), "max": float(max_val)}
886+
# Remove PP-specific evolve wrapper; use CA.evolve with optional species
907887

908888
def validate(self) -> None:
909889
"""Validate PP-specific invariants in addition to base CA checks.
@@ -932,8 +912,15 @@ def validate(self) -> None:
932912
# shape already checked in super().validate(), but be explicit
933913
if arr.shape != self.grid.shape:
934914
raise ValueError(f"cell_params['{pname}'] must match grid shape")
935-
# expected non-NaN positions correspond to species in grid
936-
species = 1 if pname.startswith("prey_") else 2
915+
# expected non-NaN positions correspond to species stored in metadata
916+
species = None
917+
if isinstance(meta, dict) and "species" in meta:
918+
species = int(meta.get("species"))
919+
else:
920+
# try to infer species from parameter name using species_names
921+
species = self._infer_species_from_param_name(pname)
922+
if species is None:
923+
raise ValueError(f"cell_params['{pname}'] missing species metadata and could not infer from name")
937924
nonnan = ~np.isnan(arr)
938925
expected = (self.grid == species)
939926
if not np.array_equal(nonnan, expected):
@@ -970,15 +957,23 @@ def _apply_deaths_and_clear_params(self, grid_ref: np.ndarray, rand_prey: np.nda
970957
self.grid[pred_death_mask] = 0
971958

972959
# Clear per-cell parameters for dead individuals
973-
for pname in self._evolve_info:
974-
if pname.startswith("prey_"):
975-
arr = self.cell_params.get(pname)
976-
if isinstance(arr, np.ndarray) and arr.shape == self.grid.shape:
977-
arr[prey_death_mask] = np.nan
978-
elif pname.startswith("predator_"):
979-
arr = self.cell_params.get(pname)
980-
if isinstance(arr, np.ndarray) and arr.shape == self.grid.shape:
981-
arr[pred_death_mask] = np.nan
960+
for pname, meta in self._evolve_info.items():
961+
# determine species from metadata or infer from name
962+
species = None
963+
if isinstance(meta, dict) and "species" in meta:
964+
species = int(meta.get("species"))
965+
else:
966+
species = self._infer_species_from_param_name(pname)
967+
if species is None:
968+
# cannot determine species; skip clearing for safety
969+
continue
970+
arr = self.cell_params.get(pname)
971+
if not (isinstance(arr, np.ndarray) and arr.shape == self.grid.shape):
972+
continue
973+
if species == 1:
974+
arr[prey_death_mask] = np.nan
975+
elif species == 2:
976+
arr[pred_death_mask] = np.nan
982977

983978
def _neighbor_shifts(self) -> Tuple[np.ndarray, np.ndarray, int]:
984979
"""Return neighbor shift arrays (dr_arr, dc_arr, n_shifts) for the configured neighborhood."""
@@ -1007,16 +1002,17 @@ def _inherit_params_on_birth(self, chosen_rs: np.ndarray, chosen_cs: np.ndarray,
10071002
an array of parent coordinates with same length.
10081003
"""
10091004
for pname, meta in self._evolve_info.items():
1010-
# determine species this parameter belongs to
1011-
if pname.startswith("prey_"):
1012-
target_species = 1
1013-
elif pname.startswith("predator_"):
1014-
target_species = 2
1005+
# determine species this parameter belongs to via metadata or inference
1006+
species = None
1007+
if isinstance(meta, dict) and "species" in meta:
1008+
species = int(meta.get("species"))
10151009
else:
1016-
continue
1010+
species = self._infer_species_from_param_name(pname)
1011+
if species is None:
1012+
raise ValueError(f"_evolve_info contains unexpected key '{pname}' without species metadata and could not infer")
10171013

10181014
# if new_state is not the species for this param, clear at targets
1019-
if new_state_val != target_species:
1015+
if new_state_val != species:
10201016
arr = self.cell_params.get(pname)
10211017
if isinstance(arr, np.ndarray) and arr.shape == self.grid.shape:
10221018
arr[chosen_rs, chosen_cs] = np.nan
@@ -1037,8 +1033,14 @@ def _inherit_params_on_birth(self, chosen_rs: np.ndarray, chosen_cs: np.ndarray,
10371033
sd = float(meta["sd"])
10381034
mn = float(meta["min"])
10391035
mx = float(meta["max"])
1040-
mut = parent_vals + self.generator.normal(0.0, sd, size=parent_vals.shape)
1041-
mut = np.clip(mut, mn, mx)
1036+
# If evolution has been stopped, inheritance is deterministic: copy
1037+
# parent values directly without Gaussian mutation so we can observe
1038+
# which parameter values survive.
1039+
if getattr(self, "_evolution_stopped", False):
1040+
mut = parent_vals.copy()
1041+
else:
1042+
mut = parent_vals + self.generator.normal(0.0, sd, size=parent_vals.shape)
1043+
mut = np.clip(mut, mn, mx)
10421044
# If an array exists but has wrong shape, raise an informative error
10431045
existing = self.cell_params.get(pname)
10441046
if isinstance(existing, np.ndarray) and existing.shape != self.grid.shape:

scripts/visualize_pp_evolution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def main():
3535

3636
# Run the simulation (ensure the plot stays open afterwards)
3737
try:
38-
pp.run(2500)
38+
pp.run(2500, stop_evolution_at=1000)
3939
finally:
4040
# Block and show the final figure so the user can inspect it.
4141
# Turn off interactive mode (visualize() enabled it) and show blocking.

0 commit comments

Comments
 (0)