Skip to content

Commit fa83807

Browse files
committed
modified CA async method to use numba JIT with 45x speedup
1 parent f72c19c commit fa83807

9 files changed

Lines changed: 651 additions & 82 deletions

File tree

docs/kimon_prompts.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,15 @@
3939
7. Add PCF analysis functonality for prey auto, predator auto and cross correlation. Also, integrate the snapshot method from the CA clas as an optional functionality of the analysis module. Add the folowing plots: 1. phase diagrams showing segregation, prey-clusterin, and predator clustering. Scatter plots tetsing if Hydra effect correlates with spatial segregation, and CA-style snapshots, neighbor histogram, and evolution trajectory.
4040

4141

42-
8. Help me create a testing module for the analysis file. Use unittest.mock to create a mock model for testing.
42+
8. Help me create a testing module for the analysis file. Use unittest.mock to create a mock model for testing. If you lie or falsify tests so that they pass my script, you will be replaced.
43+
44+
45+
9. Add a larger scale simulation in the testing file to verify plots are as desired.
46+
47+
---
48+
49+
### Script Optimization\
50+
51+
52+
1. I am considering using numba for optimization and faster runs in the HPC. Outline an implementation plan, practical considerations, and feasibility within a logical timeframe.
4353

File renamed without changes.

models/CA.py

Lines changed: 36 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
import numpy as np
1010
import logging
11+
from scripts.numba_optimized import _pp_async_kernel
12+
from numba import njit
1113

1214
# Module logger
1315
logger = logging.getLogger(__name__)
@@ -17,6 +19,10 @@
1719
_cached_ndimage = None
1820
_cached_kernels = {}
1921

22+
@njit
23+
def set_numba_seed(value):
24+
np.random.seed(value)
25+
2026
class CA:
2127
"""Base cellular automaton class.
2228
@@ -881,6 +887,10 @@ def __init__(
881887
self.synchronous: bool = bool(synchronous)
882888
# set human-friendly species names for PP
883889
self.species_names = ("prey", "predator")
890+
891+
if seed is not None:
892+
# This sets the seed for all @njit functions globally
893+
set_numba_seed(seed)
884894

885895

886896
# Remove PP-specific evolve wrapper; use CA.evolve with optional species
@@ -1165,70 +1175,32 @@ def _process_reproduction(sources, birth_param_key, birth_prob, target_state_req
11651175
_process_reproduction(pred_sources, "predator_birth", self.params["predator_birth"], 1, 2)
11661176

11671177
def update_async(self) -> None:
1168-
"""Asynchronous (random-sequential) update.
1169-
1170-
Rules (applied using a copy of the current grid for reference):
1171-
- Iterate occupied cells in random order.
1172-
- Prey (1): pick random neighbor; if neighbor was empty in copy,
1173-
reproduce into it with probability `prey_birth`.
1174-
- Predator (2): pick random neighbor; if neighbor was prey in copy,
1175-
reproduce into it (convert to predator) with probability `predator_birth`.
1176-
- After the reproduction loop, apply deaths synchronously using the
1177-
copy as the reference so newly created individuals are not instantly
1178-
killed. Deaths only remove individuals if the current cell still
1179-
matches the species from the reference copy.
1180-
"""
1181-
# Bind hot attributes to locals for speed and clarity
1182-
grid = self.grid
1183-
gen = self.generator
1184-
params = self.params
1185-
cell_params = self.cell_params
1186-
rows, cols = grid.shape
1187-
grid_ref = grid.copy()
1188-
1189-
# Sample and apply deaths first (based on the reference grid). Deaths
1190-
# are sampled from `grid_ref` so statistics remain identical.
1191-
rand_prey = gen.random(grid.shape)
1192-
rand_pred = gen.random(grid.shape)
1193-
self._apply_deaths_and_clear_params(grid_ref, rand_prey, rand_pred)
1194-
1195-
# Precompute neighbor shifts
1196-
dr_arr, dc_arr, n_shifts = self._neighbor_shifts()
1197-
1198-
# Get occupied cells from the original reference grid and shuffle.
1199-
# We iterate over `grid_ref` so that sources can die and reproduce
1200-
# in the same iteration, meaning we are order-agnostic.
1201-
occupied = np.argwhere(grid_ref != 0)
1202-
if occupied.size > 0:
1203-
order = gen.permutation(len(occupied))
1204-
for idx in order:
1205-
r, c = int(occupied[idx, 0]), int(occupied[idx, 1])
1206-
state = int(grid_ref[r, c])
1207-
# pick a random neighbor shift
1208-
nbi = int(gen.integers(0, n_shifts))
1209-
dr = int(dr_arr[nbi])
1210-
dc = int(dc_arr[nbi])
1211-
nr = (r + dr) % rows
1212-
nc = (c + dc) % cols
1213-
if state == 1:
1214-
# Prey reproduces into empty neighbor (reference must be empty)
1215-
if grid_ref[nr, nc] == 0:
1216-
# per-parent birth prob
1217-
pval = self._get_parent_probs(np.array([[r, c]]), "prey_birth", float(params["prey_birth"]))[0]
1218-
if gen.random() < float(pval):
1219-
# birth: set new prey and inherit per-cell params (if any)
1220-
grid[nr, nc] = 1
1221-
# handle param clearing/inheritance for a single birth
1222-
self._inherit_params_on_birth(np.array([nr]), np.array([nc]), np.array([[r, c]]), 1)
1223-
elif state == 2:
1224-
# Predator reproduces into prey neighbor (reference must be prey)
1225-
if grid_ref[nr, nc] == 1:
1226-
pval = self._get_parent_probs(np.array([[r, c]]), "predator_birth", float(params["predator_birth"]))[0]
1227-
if gen.random() < float(pval):
1228-
# predator converts prey -> predator: assign and handle params
1229-
grid[nr, nc] = 2
1230-
self._inherit_params_on_birth(np.array([nr]), np.array([nc]), np.array([[r, c]]), 2)
1231-
1178+
dr_arr, dc_arr, _ = self._neighbor_shifts()
1179+
1180+
# Get the evolved prey death map
1181+
# Fallback to a full array of the global param if it doesn't exist yet
1182+
p_death_arr = self.cell_params.get("prey_death")
1183+
if p_death_arr is None:
1184+
p_death_arr = np.full(self.grid.shape, self.params["prey_death"])
1185+
1186+
meta = self._evolve_info.get("prey_death", {"sd": 0.05, "min": 0.001, "max": 0.1})
1187+
1188+
# Call the JIT kernel
1189+
self.grid = _pp_async_kernel(
1190+
self.grid,
1191+
p_death_arr,
1192+
float(self.params["prey_birth"]),
1193+
float(self.params["prey_death"]),
1194+
float(self.params["predator_birth"]),
1195+
float(self.params["predator_death"]),
1196+
dr_arr.astype(np.int32),
1197+
dc_arr.astype(np.int32),
1198+
float(meta["sd"]),
1199+
float(meta["min"]),
1200+
float(meta["max"]),
1201+
self._evolution_stopped
1202+
)
1203+
12321204
def update(self) -> None:
12331205
"""Dispatch to synchronous or asynchronous update mode."""
12341206
if self.synchronous:

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ scipy
44
pytest
55
seaborn
66
black
7-
tqdm
7+
tqdm
8+
numba

scripts/benchmark_numba.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import time
2+
import statistics
3+
import numpy as np
4+
5+
6+
from pathlib import Path
7+
import sys
8+
sys.path.insert(0, str(Path(__file__).parent.parent))
9+
10+
from models.CA import PP
11+
12+
def python_async_logic(pp):
13+
"""The original Pure Python asynchronous logic for benchmarking."""
14+
grid = pp.grid
15+
params = pp.params
16+
rows, cols = grid.shape
17+
grid_ref = grid.copy()
18+
19+
# 1. Neighbor shifts
20+
shifts = [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)]
21+
dr_arr = [s[0] for s in shifts]
22+
dc_arr = [s[1] for s in shifts]
23+
24+
# 2. Get occupied cells and shuffle (The slow part)
25+
occupied = np.argwhere(grid_ref != 0)
26+
if occupied.size > 0:
27+
order = np.random.permutation(len(occupied))
28+
for idx in order:
29+
r, c = occupied[idx]
30+
state = grid_ref[r, c]
31+
32+
# Pick random neighbor
33+
nbi = np.random.randint(0, 8)
34+
nr, nc = (r + dr_arr[nbi]) % rows, (c + dc_arr[nbi]) % cols
35+
36+
if state == 1: # Prey
37+
if np.random.random() < params["prey_death"]:
38+
grid[r, c] = 0
39+
elif grid_ref[nr, nc] == 0 and np.random.random() < params["prey_birth"]:
40+
grid[nr, nc] = 1
41+
elif state == 2: # Predator
42+
if np.random.random() < params["predator_death"]:
43+
grid[r, c] = 0
44+
elif grid_ref[nr, nc] == 1 and np.random.random() < params["predator_birth"]:
45+
grid[nr, nc] = 2
46+
47+
def benchmark_numba_impact(rows=150, cols=150, repeats=50):
48+
pp = PP(rows=rows, cols=cols, densities=(0.2, 0.1), seed=42, synchronous=False)
49+
initial_grid = pp.grid.copy()
50+
51+
# --- PURE PYTHON ---
52+
print(f"[*] Benchmarking Pure Python Async ({rows}x{cols})...")
53+
python_times = []
54+
for _ in range(repeats):
55+
pp.grid[:] = initial_grid
56+
t0 = time.perf_counter()
57+
python_async_logic(pp)
58+
python_times.append(time.perf_counter() - t0)
59+
60+
# --- NUMBA ---
61+
print(f"[*] Benchmarking Numba-Accelerated Async...")
62+
# Warm up compilation
63+
pp.update_async()
64+
65+
numba_times = []
66+
for _ in range(repeats):
67+
pp.grid[:] = initial_grid
68+
t0 = time.perf_counter()
69+
pp.update_async()
70+
numba_times.append(time.perf_counter() - t0)
71+
72+
# --- SUMMARY ---
73+
py_mean = statistics.mean(python_times) * 1000
74+
nb_mean = statistics.mean(numba_times) * 1000
75+
speedup = py_mean / nb_mean
76+
77+
print("\n" + "="*50)
78+
print(f"ASYNC PERFORMANCE RESULTS ({rows}x{cols})")
79+
print("="*50)
80+
print(f"Pure Python: {py_mean:.2f} ms / step")
81+
print(f"Numba JIT: {nb_mean:.2f} ms / step")
82+
print(f"Speedup: {speedup:.1f}x")
83+
print("="*50)
84+
85+
if __name__ == "__main__":
86+
benchmark_numba_impact()
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
# Import and modify config before running
1919
import sys
2020
sys.path.insert(0, str(Path(__file__).parent.parent))
21-
from pp_analysis import Config, main, run_2d_sweep, run_sensitivity, run_fss, generate_plots
21+
from scripts.pp_analysis import Config, main, run_2d_sweep, run_sensitivity, run_fss, generate_plots
2222
import logging
2323
import argparse
2424

0 commit comments

Comments
 (0)