Skip to content

Commit 15ebbff

Browse files
committed
fixed numba seeding (see updates in the prompts
1 parent 6a71451 commit 15ebbff

6 files changed

Lines changed: 1067 additions & 2 deletions

File tree

docs/kimon_prompts.md

Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,292 @@
6767
9. Help me write a benchmark script to test the updated version against a non-numba optimized version. Create a summary txt with benchmark results.
6868

6969
10. Create unit tests for pp_analysis and numba_optimized files. Test all methods and make sure the scripts are HPC ready. Provide me with an outline of tests.
70+
71+
72+
11. Explain to me how seeding works with numba and multiple processes. One of the tests fails which may lead to unreproducable results. There is a potential bug in the way runs are seeded.
73+
74+
75+
OUTPUT:
76+
77+
(This was really interesting and might be useful in the future of our project as well)
78+
79+
================================================================================
80+
NUMBA RNG SEEDING FIX - Implementation Guide
81+
================================================================================
82+
83+
PROBLEM:
84+
--------
85+
Numba's JIT-compiled functions use their own internal RNG (Xoshiro128++),
86+
which is completely separate from NumPy's RNG. Calling np.random.seed()
87+
from Python only seeds NumPy's RNG, not Numba's.
88+
89+
SOLUTION:
90+
---------
91+
Create a @njit function that seeds the RNG from within the JIT context.
92+
93+
================================================================================
94+
CHANGE 1: Add to numba_optimized.py (after line ~35, after the dummy decorators)
95+
================================================================================
96+
97+
# --- ADD THIS BLOCK AFTER THE IMPORTS AND DUMMY DECORATORS ---
98+
99+
@njit(cache=True)
100+
def set_numba_seed(seed: int) -> None:
101+
"""
102+
Seed Numba's internal RNG from within a JIT context.
103+
104+
IMPORTANT: This must be called to get reproducible results from
105+
Numba-accelerated functions. Calling np.random.seed() from Python
106+
only affects NumPy's RNG, not Numba's internal Xoshiro128++ RNG.
107+
108+
Args:
109+
seed: Integer seed value
110+
111+
Usage:
112+
from numba_optimized import set_numba_seed
113+
set_numba_seed(42) # Now Numba functions will be deterministic
114+
"""
115+
np.random.seed(seed)
116+
117+
# --- END OF ADDITION ---
118+
119+
120+
================================================================================
121+
CHANGE 2: Update the exports/imports
122+
================================================================================
123+
124+
In numba_optimized.py, update the module docstring to include set_numba_seed:
125+
126+
"""
127+
...
128+
Usage:
129+
from scripts.numba_optimized import (
130+
PPKernel,
131+
compute_all_pcfs_fast,
132+
measure_cluster_sizes_fast,
133+
set_numba_seed, # <-- ADD THIS
134+
NUMBA_AVAILABLE
135+
)
136+
137+
# Seed Numba's RNG for reproducibility
138+
set_numba_seed(42)
139+
140+
# Create kernel once, reuse for all updates
141+
kernel = PPKernel(rows, cols)
142+
...
143+
"""
144+
145+
146+
================================================================================
147+
CHANGE 3: Update pp_analysis.py - Import set_numba_seed
148+
================================================================================
149+
150+
Find the import block (around line 20-30) and add set_numba_seed:
151+
152+
# BEFORE:
153+
from scripts.numba_optimized import (
154+
PPKernel,
155+
compute_all_pcfs_fast,
156+
measure_cluster_sizes_fast,
157+
warmup_numba_kernels,
158+
NUMBA_AVAILABLE,
159+
)
160+
161+
# AFTER:
162+
from scripts.numba_optimized import (
163+
PPKernel,
164+
compute_all_pcfs_fast,
165+
measure_cluster_sizes_fast,
166+
warmup_numba_kernels,
167+
set_numba_seed, # <-- ADD THIS
168+
NUMBA_AVAILABLE,
169+
)
170+
171+
172+
================================================================================
173+
CHANGE 4: Update run_single_simulation() in pp_analysis.py
174+
================================================================================
175+
176+
Find the run_single_simulation function and add set_numba_seed call at the start:
177+
178+
def run_single_simulation(
179+
prey_birth: float,
180+
prey_death: float,
181+
grid_size: int,
182+
seed: int,
183+
with_evolution: bool,
184+
cfg: Config,
185+
) -> Dict[str, Any]:
186+
"""Run a single simulation and return results."""
187+
188+
# --- ADD THESE LINES AT THE VERY START OF THE FUNCTION ---
189+
# Seed both NumPy and Numba RNGs for full reproducibility
190+
np.random.seed(seed)
191+
if NUMBA_AVAILABLE:
192+
set_numba_seed(seed)
193+
# --- END OF ADDITION ---
194+
195+
# ... rest of the function remains unchanged ...
196+
197+
198+
================================================================================
199+
CHANGE 5: Update run_single_simulation_fss() in pp_analysis.py (if it exists)
200+
================================================================================
201+
202+
Same pattern - add seeding at the start:
203+
204+
def run_single_simulation_fss(...):
205+
"""Run FSS simulation."""
206+
207+
# Seed both RNGs
208+
np.random.seed(seed)
209+
if NUMBA_AVAILABLE:
210+
set_numba_seed(seed)
211+
212+
# ... rest unchanged ...
213+
214+
215+
================================================================================
216+
CHANGE 6: Update warmup_numba_kernels() in numba_optimized.py
217+
================================================================================
218+
219+
Add a deterministic seed during warmup to avoid variability:
220+
221+
def warmup_numba_kernels(grid_size: int = 100):
222+
"""
223+
Pre-compile all Numba kernels.
224+
"""
225+
if not NUMBA_AVAILABLE:
226+
return
227+
228+
# --- ADD THIS LINE ---
229+
set_numba_seed(0) # Deterministic warmup
230+
# --- END OF ADDITION ---
231+
232+
# Dummy data
233+
grid = np.zeros((grid_size, grid_size), dtype=np.int32)
234+
# ... rest unchanged ...
235+
236+
237+
================================================================================
238+
COMPLETE UPDATED numba_optimized.py (key sections only)
239+
================================================================================
240+
241+
Here's how the top of your file should look after changes:
242+
243+
```python
244+
#!/usr/bin/env python3
245+
"""
246+
Numba-optimized kernels for predator-prey cellular automaton.
247+
248+
...
249+
250+
Usage:
251+
from scripts.numba_optimized import (
252+
PPKernel,
253+
compute_all_pcfs_fast,
254+
measure_cluster_sizes_fast,
255+
set_numba_seed,
256+
NUMBA_AVAILABLE
257+
)
258+
259+
# Seed Numba's RNG for reproducibility
260+
set_numba_seed(42)
261+
262+
# Create kernel once, reuse for all updates
263+
kernel = PPKernel(rows, cols)
264+
for step in range(n_steps):
265+
kernel.update(grid, prey_death_arr, params...)
266+
"""
267+
268+
import numpy as np
269+
from typing import Tuple, Dict, Optional
270+
271+
try:
272+
from numba import njit, prange
273+
NUMBA_AVAILABLE = True
274+
except ImportError:
275+
NUMBA_AVAILABLE = False
276+
def njit(*args, **kwargs):
277+
def decorator(func):
278+
return func
279+
return decorator
280+
def prange(*args):
281+
return range(*args)
282+
283+
284+
# ============================================================================
285+
# RNG SEEDING
286+
# ============================================================================
287+
288+
@njit(cache=True)
289+
def set_numba_seed(seed: int) -> None:
290+
"""
291+
Seed Numba's internal RNG from within a JIT context.
292+
293+
IMPORTANT: This must be called to get reproducible results from
294+
Numba-accelerated functions. Calling np.random.seed() from Python
295+
only affects NumPy's RNG, not Numba's internal Xoshiro128++ RNG.
296+
297+
Args:
298+
seed: Integer seed value
299+
"""
300+
np.random.seed(seed)
301+
302+
303+
# ============================================================================
304+
# ASYNC KERNEL
305+
# ============================================================================
306+
307+
@njit(cache=True)
308+
def _pp_async_kernel(...):
309+
# ... unchanged ...
310+
```
311+
312+
313+
================================================================================
314+
TESTING THE FIX
315+
================================================================================
316+
317+
After making these changes, this test should pass:
318+
319+
```python
320+
def test_kernel_deterministic_with_seed():
321+
"""Same seed should produce same results."""
322+
from numba_optimized import PPKernel, set_numba_seed
323+
324+
results = []
325+
326+
for _ in range(2):
327+
# Seed BOTH RNGs
328+
np.random.seed(12345)
329+
set_numba_seed(12345) # <-- This is the key!
330+
331+
grid = np.random.choice([0, 1, 2], (30, 30), p=[0.5, 0.3, 0.2]).astype(np.int32)
332+
prey_death = np.full((30, 30), 0.05, dtype=np.float64)
333+
prey_death[grid != 1] = np.nan
334+
335+
kernel = PPKernel(30, 30, "moore")
336+
for _ in range(10):
337+
kernel.update(grid, prey_death, 0.2, 0.05, 0.2, 0.1)
338+
339+
results.append(grid.copy())
340+
341+
assert np.array_equal(results[0], results[1]), "Results should be deterministic"
342+
```
343+
344+
345+
================================================================================
346+
WHY THIS WORKS
347+
================================================================================
348+
349+
1. When you call `set_numba_seed(seed)` from Python, it invokes the @njit function
350+
2. Inside the JIT context, `np.random.seed(seed)` seeds Numba's internal RNG
351+
3. All subsequent calls to `np.random.random()`, `np.random.randint()`, etc.
352+
inside @njit functions will use this seeded state
353+
4. The RNG state persists across JIT function calls until re-seeded
354+
355+
Note: Each worker process in parallel execution needs its own seed call.
356+
For parallel simulations, use different seeds per worker (e.g., base_seed + worker_id).
357+
358+

models/mean_field.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#!/usr/bin/env python3
2+
13
import numpy as np
24
import matplotlib.pyplot as plt
35
from typing import Tuple, List, Dict, Optional

scripts/numba_optimized.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11

2+
#!/usr/bin/env python3
23
"""
34
Numba-optimized kernels for predator-prey cellular automaton.
45
@@ -38,6 +39,25 @@ def decorator(func):
3839
def prange(*args):
3940
return range(*args)
4041

42+
43+
# ============================================================================
44+
# RNG SEEDING
45+
# ============================================================================
46+
47+
@njit(cache=True)
48+
def set_numba_seed(seed: int) -> None:
49+
"""
50+
Seed Numba's internal RNG from within a JIT context.
51+
52+
IMPORTANT: This must be called to get reproducible results from
53+
Numba-accelerated functions. Calling np.random.seed() from Python
54+
only affects NumPy's RNG, not Numba's internal Xoshiro128++ RNG.
55+
56+
Args:
57+
seed: Integer seed value
58+
"""
59+
np.random.seed(seed)
60+
4161
@njit(cache=True)
4262
def _pp_async_kernel(
4363
grid: np.ndarray,
@@ -528,6 +548,8 @@ def warmup_numba_kernels(grid_size: int = 100):
528548
if not NUMBA_AVAILABLE:
529549
return
530550

551+
set_numba_seed(0)
552+
531553
# Dummy data
532554
grid = np.zeros((grid_size, grid_size), dtype=np.int32)
533555
grid[::3, ::3] = 1 # Sparse prey

scripts/pp_analysis.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#!/usr/bin/env python3
12
"""
23
Prey-predator evolutionary analysis - Snellius HPC Version (Optimized)
34
@@ -48,6 +49,7 @@
4849
compute_all_pcfs_fast,
4950
measure_cluster_sizes_fast,
5051
warmup_numba_kernels,
52+
set_numba_seed,
5153
NUMBA_AVAILABLE,
5254
)
5355
USE_NUMBA = NUMBA_AVAILABLE
@@ -298,6 +300,10 @@ def run_single_simulation(
298300
Dictionary with simulation results
299301
"""
300302
from models.CA import PP
303+
# Seed both RNGs
304+
np.random.seed(seed)
305+
if NUMBA_AVAILABLE:
306+
set_numba_seed(seed)
301307

302308
# Set evolution parameters
303309
if evolve_sd is None:
@@ -451,6 +457,10 @@ def run_single_simulation_fss(
451457
"""FSS-specific simulation with size-scaled equilibration time."""
452458
from models.CA import PP
453459

460+
np.random.seed(seed)
461+
if NUMBA_AVAILABLE:
462+
set_numba_seed(seed)
463+
454464
params = {
455465
"prey_birth": prey_birth,
456466
"prey_death": prey_death,

0 commit comments

Comments
 (0)