|
8 | 8 |
|
9 | 9 | import numpy as np |
10 | 10 | import logging |
| 11 | +from scripts.numba_optimized import _pp_async_kernel |
| 12 | +from numba import njit |
11 | 13 |
|
12 | 14 | # Module logger |
13 | 15 | logger = logging.getLogger(__name__) |
|
17 | 19 | _cached_ndimage = None |
18 | 20 | _cached_kernels = {} |
19 | 21 |
|
| 22 | +@njit |
| 23 | +def set_numba_seed(value): |
| 24 | + np.random.seed(value) |
| 25 | + |
20 | 26 | class CA: |
21 | 27 | """Base cellular automaton class. |
22 | 28 |
|
@@ -881,6 +887,10 @@ def __init__( |
881 | 887 | self.synchronous: bool = bool(synchronous) |
882 | 888 | # set human-friendly species names for PP |
883 | 889 | self.species_names = ("prey", "predator") |
| 890 | + |
| 891 | + if seed is not None: |
| 892 | + # This sets the seed for all @njit functions globally |
| 893 | + set_numba_seed(seed) |
884 | 894 |
|
885 | 895 |
|
886 | 896 | # Remove PP-specific evolve wrapper; use CA.evolve with optional species |
@@ -1165,70 +1175,32 @@ def _process_reproduction(sources, birth_param_key, birth_prob, target_state_req |
1165 | 1175 | _process_reproduction(pred_sources, "predator_birth", self.params["predator_birth"], 1, 2) |
1166 | 1176 |
|
1167 | 1177 | def update_async(self) -> None: |
1168 | | - """Asynchronous (random-sequential) update. |
1169 | | -
|
1170 | | - Rules (applied using a copy of the current grid for reference): |
1171 | | - - Iterate occupied cells in random order. |
1172 | | - - Prey (1): pick random neighbor; if neighbor was empty in copy, |
1173 | | - reproduce into it with probability `prey_birth`. |
1174 | | - - Predator (2): pick random neighbor; if neighbor was prey in copy, |
1175 | | - reproduce into it (convert to predator) with probability `predator_birth`. |
1176 | | - - After the reproduction loop, apply deaths synchronously using the |
1177 | | - copy as the reference so newly created individuals are not instantly |
1178 | | - killed. Deaths only remove individuals if the current cell still |
1179 | | - matches the species from the reference copy. |
1180 | | - """ |
1181 | | - # Bind hot attributes to locals for speed and clarity |
1182 | | - grid = self.grid |
1183 | | - gen = self.generator |
1184 | | - params = self.params |
1185 | | - cell_params = self.cell_params |
1186 | | - rows, cols = grid.shape |
1187 | | - grid_ref = grid.copy() |
1188 | | - |
1189 | | - # Sample and apply deaths first (based on the reference grid). Deaths |
1190 | | - # are sampled from `grid_ref` so statistics remain identical. |
1191 | | - rand_prey = gen.random(grid.shape) |
1192 | | - rand_pred = gen.random(grid.shape) |
1193 | | - self._apply_deaths_and_clear_params(grid_ref, rand_prey, rand_pred) |
1194 | | - |
1195 | | - # Precompute neighbor shifts |
1196 | | - dr_arr, dc_arr, n_shifts = self._neighbor_shifts() |
1197 | | - |
1198 | | - # Get occupied cells from the original reference grid and shuffle. |
1199 | | - # We iterate over `grid_ref` so that sources can die and reproduce |
1200 | | - # in the same iteration, meaning we are order-agnostic. |
1201 | | - occupied = np.argwhere(grid_ref != 0) |
1202 | | - if occupied.size > 0: |
1203 | | - order = gen.permutation(len(occupied)) |
1204 | | - for idx in order: |
1205 | | - r, c = int(occupied[idx, 0]), int(occupied[idx, 1]) |
1206 | | - state = int(grid_ref[r, c]) |
1207 | | - # pick a random neighbor shift |
1208 | | - nbi = int(gen.integers(0, n_shifts)) |
1209 | | - dr = int(dr_arr[nbi]) |
1210 | | - dc = int(dc_arr[nbi]) |
1211 | | - nr = (r + dr) % rows |
1212 | | - nc = (c + dc) % cols |
1213 | | - if state == 1: |
1214 | | - # Prey reproduces into empty neighbor (reference must be empty) |
1215 | | - if grid_ref[nr, nc] == 0: |
1216 | | - # per-parent birth prob |
1217 | | - pval = self._get_parent_probs(np.array([[r, c]]), "prey_birth", float(params["prey_birth"]))[0] |
1218 | | - if gen.random() < float(pval): |
1219 | | - # birth: set new prey and inherit per-cell params (if any) |
1220 | | - grid[nr, nc] = 1 |
1221 | | - # handle param clearing/inheritance for a single birth |
1222 | | - self._inherit_params_on_birth(np.array([nr]), np.array([nc]), np.array([[r, c]]), 1) |
1223 | | - elif state == 2: |
1224 | | - # Predator reproduces into prey neighbor (reference must be prey) |
1225 | | - if grid_ref[nr, nc] == 1: |
1226 | | - pval = self._get_parent_probs(np.array([[r, c]]), "predator_birth", float(params["predator_birth"]))[0] |
1227 | | - if gen.random() < float(pval): |
1228 | | - # predator converts prey -> predator: assign and handle params |
1229 | | - grid[nr, nc] = 2 |
1230 | | - self._inherit_params_on_birth(np.array([nr]), np.array([nc]), np.array([[r, c]]), 2) |
1231 | | - |
| 1178 | + dr_arr, dc_arr, _ = self._neighbor_shifts() |
| 1179 | + |
| 1180 | + # Get the evolved prey death map |
| 1181 | + # Fallback to a full array of the global param if it doesn't exist yet |
| 1182 | + p_death_arr = self.cell_params.get("prey_death") |
| 1183 | + if p_death_arr is None: |
| 1184 | + p_death_arr = np.full(self.grid.shape, self.params["prey_death"]) |
| 1185 | + |
| 1186 | + meta = self._evolve_info.get("prey_death", {"sd": 0.05, "min": 0.001, "max": 0.1}) |
| 1187 | + |
| 1188 | + # Call the JIT kernel |
| 1189 | + self.grid = _pp_async_kernel( |
| 1190 | + self.grid, |
| 1191 | + p_death_arr, |
| 1192 | + float(self.params["prey_birth"]), |
| 1193 | + float(self.params["prey_death"]), |
| 1194 | + float(self.params["predator_birth"]), |
| 1195 | + float(self.params["predator_death"]), |
| 1196 | + dr_arr.astype(np.int32), |
| 1197 | + dc_arr.astype(np.int32), |
| 1198 | + float(meta["sd"]), |
| 1199 | + float(meta["min"]), |
| 1200 | + float(meta["max"]), |
| 1201 | + self._evolution_stopped |
| 1202 | + ) |
| 1203 | + |
1232 | 1204 | def update(self) -> None: |
1233 | 1205 | """Dispatch to synchronous or asynchronous update mode.""" |
1234 | 1206 | if self.synchronous: |
|
0 commit comments