|
11 | 11 | # IMPORTS =============================================================================== |
12 | 12 |
|
13 | 13 | import inspect |
| 14 | +from uuid import uuid4 |
14 | 15 | from functools import lru_cache |
15 | 16 |
|
16 | 17 | from ..utils.deprecation import deprecated |
@@ -84,6 +85,9 @@ class definition for other blocks to be inherited. |
84 | 85 |
|
85 | 86 | def __init__(self): |
86 | 87 |
|
| 88 | + #unique identifier for checkpointing and diagnostics |
| 89 | + self.id = uuid4().hex |
| 90 | + |
87 | 91 | #registers to hold input and output values |
88 | 92 | self.inputs = Register( |
89 | 93 | mapping=self.input_port_labels and self.input_port_labels.copy() |
@@ -524,6 +528,93 @@ def state(self, val): |
524 | 528 | self.engine.state = val |
525 | 529 |
|
526 | 530 |
|
| 531 | + # checkpoint methods ---------------------------------------------------------------- |
| 532 | + |
| 533 | + def to_checkpoint(self, recordings=False): |
| 534 | + """Serialize block state for checkpointing. |
| 535 | +
|
| 536 | + Parameters |
| 537 | + ---------- |
| 538 | + recordings : bool |
| 539 | + include recording data (for Scope blocks) |
| 540 | +
|
| 541 | + Returns |
| 542 | + ------- |
| 543 | + json_data : dict |
| 544 | + JSON-serializable metadata |
| 545 | + npz_data : dict |
| 546 | + numpy arrays keyed by path |
| 547 | + """ |
| 548 | + prefix = self.id |
| 549 | + |
| 550 | + json_data = { |
| 551 | + "id": self.id, |
| 552 | + "type": self.__class__.__name__, |
| 553 | + "active": self._active, |
| 554 | + } |
| 555 | + |
| 556 | + npz_data = { |
| 557 | + f"{prefix}/inputs": self.inputs.to_array(), |
| 558 | + f"{prefix}/outputs": self.outputs.to_array(), |
| 559 | + } |
| 560 | + |
| 561 | + #solver state |
| 562 | + if self.engine: |
| 563 | + e_json, e_npz = self.engine.to_checkpoint(f"{prefix}/engine") |
| 564 | + json_data["engine"] = e_json |
| 565 | + npz_data.update(e_npz) |
| 566 | + |
| 567 | + #internal events |
| 568 | + if self.events: |
| 569 | + evt_jsons = [] |
| 570 | + for event in self.events: |
| 571 | + e_json, e_npz = event.to_checkpoint() |
| 572 | + evt_jsons.append(e_json) |
| 573 | + npz_data.update(e_npz) |
| 574 | + json_data["events"] = evt_jsons |
| 575 | + |
| 576 | + return json_data, npz_data |
| 577 | + |
| 578 | + |
| 579 | + def load_checkpoint(self, json_data, npz): |
| 580 | + """Restore block state from checkpoint. |
| 581 | +
|
| 582 | + Parameters |
| 583 | + ---------- |
| 584 | + json_data : dict |
| 585 | + block metadata from checkpoint JSON |
| 586 | + npz : dict-like |
| 587 | + numpy arrays from checkpoint NPZ |
| 588 | + """ |
| 589 | + prefix = json_data["id"] |
| 590 | + |
| 591 | + #verify type |
| 592 | + if json_data["type"] != self.__class__.__name__: |
| 593 | + raise ValueError( |
| 594 | + f"Checkpoint type mismatch: expected '{self.__class__.__name__}', " |
| 595 | + f"got '{json_data['type']}'" |
| 596 | + ) |
| 597 | + |
| 598 | + self._active = json_data["active"] |
| 599 | + |
| 600 | + #restore registers |
| 601 | + inp_key = f"{prefix}/inputs" |
| 602 | + out_key = f"{prefix}/outputs" |
| 603 | + if inp_key in npz: |
| 604 | + self.inputs.update_from_array(npz[inp_key]) |
| 605 | + if out_key in npz: |
| 606 | + self.outputs.update_from_array(npz[out_key]) |
| 607 | + |
| 608 | + #restore solver state |
| 609 | + if self.engine and "engine" in json_data: |
| 610 | + self.engine.load_checkpoint(json_data["engine"], npz, f"{prefix}/engine") |
| 611 | + |
| 612 | + #restore internal events |
| 613 | + if self.events and "events" in json_data: |
| 614 | + for event, evt_data in zip(self.events, json_data["events"]): |
| 615 | + event.load_checkpoint(evt_data, npz) |
| 616 | + |
| 617 | + |
527 | 618 | # methods for block output and state updates ---------------------------------------- |
528 | 619 |
|
529 | 620 | def update(self, t): |
|
0 commit comments