Skip to content

Commit 0d52b54

Browse files
committed
Add checkpoint save/load system with JSON+NPZ format
1 parent 012cc95 commit 0d52b54

11 files changed

Lines changed: 796 additions & 5 deletions

File tree

src/pathsim/blocks/_block.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# IMPORTS ===============================================================================
1212

1313
import inspect
14+
from uuid import uuid4
1415
from functools import lru_cache
1516

1617
from ..utils.deprecation import deprecated
@@ -84,6 +85,9 @@ class definition for other blocks to be inherited.
8485

8586
def __init__(self):
8687

88+
#unique identifier for checkpointing and diagnostics
89+
self.id = uuid4().hex
90+
8791
#registers to hold input and output values
8892
self.inputs = Register(
8993
mapping=self.input_port_labels and self.input_port_labels.copy()
@@ -524,6 +528,93 @@ def state(self, val):
524528
self.engine.state = val
525529

526530

531+
# checkpoint methods ----------------------------------------------------------------
532+
533+
def to_checkpoint(self, recordings=False):
534+
"""Serialize block state for checkpointing.
535+
536+
Parameters
537+
----------
538+
recordings : bool
539+
include recording data (for Scope blocks)
540+
541+
Returns
542+
-------
543+
json_data : dict
544+
JSON-serializable metadata
545+
npz_data : dict
546+
numpy arrays keyed by path
547+
"""
548+
prefix = self.id
549+
550+
json_data = {
551+
"id": self.id,
552+
"type": self.__class__.__name__,
553+
"active": self._active,
554+
}
555+
556+
npz_data = {
557+
f"{prefix}/inputs": self.inputs.to_array(),
558+
f"{prefix}/outputs": self.outputs.to_array(),
559+
}
560+
561+
#solver state
562+
if self.engine:
563+
e_json, e_npz = self.engine.to_checkpoint(f"{prefix}/engine")
564+
json_data["engine"] = e_json
565+
npz_data.update(e_npz)
566+
567+
#internal events
568+
if self.events:
569+
evt_jsons = []
570+
for event in self.events:
571+
e_json, e_npz = event.to_checkpoint()
572+
evt_jsons.append(e_json)
573+
npz_data.update(e_npz)
574+
json_data["events"] = evt_jsons
575+
576+
return json_data, npz_data
577+
578+
579+
def load_checkpoint(self, json_data, npz):
580+
"""Restore block state from checkpoint.
581+
582+
Parameters
583+
----------
584+
json_data : dict
585+
block metadata from checkpoint JSON
586+
npz : dict-like
587+
numpy arrays from checkpoint NPZ
588+
"""
589+
prefix = json_data["id"]
590+
591+
#verify type
592+
if json_data["type"] != self.__class__.__name__:
593+
raise ValueError(
594+
f"Checkpoint type mismatch: expected '{self.__class__.__name__}', "
595+
f"got '{json_data['type']}'"
596+
)
597+
598+
self._active = json_data["active"]
599+
600+
#restore registers
601+
inp_key = f"{prefix}/inputs"
602+
out_key = f"{prefix}/outputs"
603+
if inp_key in npz:
604+
self.inputs.update_from_array(npz[inp_key])
605+
if out_key in npz:
606+
self.outputs.update_from_array(npz[out_key])
607+
608+
#restore solver state
609+
if self.engine and "engine" in json_data:
610+
self.engine.load_checkpoint(json_data["engine"], npz, f"{prefix}/engine")
611+
612+
#restore internal events
613+
if self.events and "events" in json_data:
614+
for event, evt_data in zip(self.events, json_data["events"]):
615+
event.load_checkpoint(evt_data, npz)
616+
617+
527618
# methods for block output and state updates ----------------------------------------
528619

529620
def update(self, t):

src/pathsim/blocks/delay.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,41 @@ def reset(self):
142142
self._ring.extend([0.0] * self._n)
143143

144144

145+
def to_checkpoint(self, recordings=False):
146+
"""Serialize Delay state including buffer data."""
147+
json_data, npz_data = super().to_checkpoint(recordings=recordings)
148+
prefix = self.id
149+
150+
json_data["sampling_period"] = self.sampling_period
151+
152+
if self.sampling_period is None:
153+
#continuous mode: adaptive buffer
154+
npz_data.update(self._buffer.to_checkpoint(f"{prefix}/buffer"))
155+
else:
156+
#discrete mode: ring buffer
157+
npz_data[f"{prefix}/ring"] = np.array(list(self._ring))
158+
json_data["_sample_next_timestep"] = self._sample_next_timestep
159+
160+
return json_data, npz_data
161+
162+
163+
def load_checkpoint(self, json_data, npz):
164+
"""Restore Delay state including buffer data."""
165+
super().load_checkpoint(json_data, npz)
166+
prefix = json_data["id"]
167+
168+
if self.sampling_period is None:
169+
#continuous mode
170+
self._buffer.load_checkpoint(npz, f"{prefix}/buffer")
171+
else:
172+
#discrete mode
173+
ring_key = f"{prefix}/ring"
174+
if ring_key in npz:
175+
self._ring.clear()
176+
self._ring.extend(npz[ring_key].tolist())
177+
self._sample_next_timestep = json_data.get("_sample_next_timestep", False)
178+
179+
145180
def update(self, t):
146181
"""Evaluation of the buffer at different times
147182
via interpolation (continuous) or ring buffer lookup (discrete).

src/pathsim/blocks/scope.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -448,13 +448,49 @@ def save(self, path="scope.csv"):
448448
wrt.writerow(sample)
449449

450450

451+
def to_checkpoint(self, recordings=False):
452+
"""Serialize Scope state including optional recording data."""
453+
json_data, npz_data = super().to_checkpoint(recordings=recordings)
454+
prefix = self.id
455+
456+
json_data["_incremental_idx"] = self._incremental_idx
457+
if hasattr(self, '_sample_next_timestep'):
458+
json_data["_sample_next_timestep"] = self._sample_next_timestep
459+
460+
if recordings and self.recording_time:
461+
npz_data[f"{prefix}/recording_time"] = np.array(self.recording_time)
462+
npz_data[f"{prefix}/recording_data"] = np.array(self.recording_data)
463+
464+
return json_data, npz_data
465+
466+
467+
def load_checkpoint(self, json_data, npz):
468+
"""Restore Scope state including optional recording data."""
469+
super().load_checkpoint(json_data, npz)
470+
prefix = json_data["id"]
471+
472+
self._incremental_idx = json_data.get("_incremental_idx", 0)
473+
if hasattr(self, '_sample_next_timestep'):
474+
self._sample_next_timestep = json_data.get("_sample_next_timestep", False)
475+
476+
#restore recordings if present
477+
rt_key = f"{prefix}/recording_time"
478+
rd_key = f"{prefix}/recording_data"
479+
if rt_key in npz and rd_key in npz:
480+
self.recording_time = npz[rt_key].tolist()
481+
self.recording_data = [row for row in npz[rd_key]]
482+
else:
483+
self.recording_time = []
484+
self.recording_data = []
485+
486+
451487
def update(self, t):
452-
"""update system equation for fixed point loop,
488+
"""update system equation for fixed point loop,
453489
here just setting the outputs
454-
490+
455491
Note
456492
----
457-
Scope has no passthrough, so the 'update' method
493+
Scope has no passthrough, so the 'update' method
458494
is optimized for this case (does nothing)
459495
460496
Parameters

src/pathsim/blocks/spectrum.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,24 @@ def step(self, t, dt):
283283
return True, 0.0, None
284284

285285

286+
def to_checkpoint(self, recordings=False):
287+
"""Serialize Spectrum state including integration time."""
288+
json_data, npz_data = super().to_checkpoint(recordings=recordings)
289+
290+
json_data["time"] = self.time
291+
json_data["t_sample"] = self.t_sample
292+
293+
return json_data, npz_data
294+
295+
296+
def load_checkpoint(self, json_data, npz):
297+
"""Restore Spectrum state including integration time."""
298+
super().load_checkpoint(json_data, npz)
299+
300+
self.time = json_data.get("time", 0.0)
301+
self.t_sample = json_data.get("t_sample", 0.0)
302+
303+
286304
def sample(self, t, dt):
287305
"""sample time of successfull timestep for waiting period
288306

src/pathsim/blocks/switch.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,15 @@ def select(self, switch_state=0):
8282
self.switch_state = switch_state
8383

8484

85+
def to_checkpoint(self, recordings=False):
86+
json_data, npz_data = super().to_checkpoint(recordings=recordings)
87+
json_data["switch_state"] = self.switch_state
88+
return json_data, npz_data
89+
90+
def load_checkpoint(self, json_data, npz):
91+
super().load_checkpoint(json_data, npz)
92+
self.switch_state = json_data.get("switch_state", None)
93+
8594
def update(self, t):
8695
"""Update switch output depending on inputs and switch state.
8796

src/pathsim/events/_event.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
import numpy as np
1313

14+
from uuid import uuid4
15+
1416
from .. _constants import EVT_TOLERANCE
1517

1618

@@ -64,6 +66,9 @@ def __init__(
6466
tolerance=EVT_TOLERANCE
6567
):
6668

69+
#unique identifier for checkpointing and diagnostics
70+
self.id = uuid4().hex
71+
6772
#event detection function
6873
self.func_evt = func_evt
6974

@@ -201,4 +206,60 @@ def resolve(self, t):
201206

202207
#action function for event resolution
203208
if self.func_act is not None:
204-
self.func_act(t)
209+
self.func_act(t)
210+
211+
212+
# checkpoint methods ----------------------------------------------------------------
213+
214+
def to_checkpoint(self):
215+
"""Serialize event state for checkpointing.
216+
217+
Returns
218+
-------
219+
json_data : dict
220+
JSON-serializable metadata
221+
npz_data : dict
222+
numpy arrays keyed by path
223+
"""
224+
prefix = self.id
225+
226+
#extract history eval value
227+
hist_eval, hist_time = self._history
228+
if hist_eval is not None and hasattr(hist_eval, 'item'):
229+
hist_eval = float(hist_eval)
230+
231+
json_data = {
232+
"id": self.id,
233+
"type": self.__class__.__name__,
234+
"active": self._active,
235+
"history_eval": hist_eval,
236+
"history_time": hist_time,
237+
}
238+
239+
npz_data = {}
240+
if self._times:
241+
npz_data[f"{prefix}/times"] = np.array(self._times)
242+
243+
return json_data, npz_data
244+
245+
246+
def load_checkpoint(self, json_data, npz):
247+
"""Restore event state from checkpoint.
248+
249+
Parameters
250+
----------
251+
json_data : dict
252+
event metadata from checkpoint JSON
253+
npz : dict-like
254+
numpy arrays from checkpoint NPZ
255+
"""
256+
prefix = json_data["id"]
257+
258+
self._active = json_data["active"]
259+
self._history = json_data["history_eval"], json_data["history_time"]
260+
261+
times_key = f"{prefix}/times"
262+
if times_key in npz:
263+
self._times = npz[times_key].tolist()
264+
else:
265+
self._times = []

0 commit comments

Comments
 (0)