Skip to content

Commit b918626

Browse files
authored
async NDProcessor (#1026)
* async NDProcessor established, NOT TESTED * ASYNC NDPROC IS WORKING :D :D CELEBRATE * comments * type annot * fix * fix * polish async integration, cuda also integrated * no longer using xarray, allow simpler ArrayProtocol * comments * comments * docs
1 parent bbd03f6 commit b918626

File tree

16 files changed

+515
-249
lines changed

16 files changed

+515
-249
lines changed

docs/source/user_guide/guide.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ With jupyterlab support.
2121
2222
pip install -U "fastplotlib[notebook,imgui]"
2323
24+
.. note:: ``imgui-bundle`` is required for the ``NDWidget``
25+
2426
Without imgui
2527
^^^^^^^^^^^^^
2628

fastplotlib/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@
44

55
# this must be the first import for auto-canvas detection
66
from .utils import loop # noqa
7+
from .utils import (
8+
config,
9+
enumerate_adapters,
10+
select_adapter,
11+
print_wgpu_report,
12+
protocols,
13+
)
714
from .graphics import *
815
from .graphics.features import GraphicFeatureEvent
916
from .graphics.selectors import *
@@ -20,7 +27,6 @@
2027
from .layouts import Figure
2128

2229
from .widgets import NDWidget, ImageWidget
23-
from .utils import config, enumerate_adapters, select_adapter, print_wgpu_report
2430

2531

2632
if len(enumerate_adapters()) < 1:

fastplotlib/graphics/features/_image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def _fix_data(self, data):
106106
)
107107

108108
if data.itemsize == 8:
109-
warn(f"casting {array.dtype} array to float32")
109+
warn(f"casting {data.dtype} array to float32")
110110
return data.astype(np.float32)
111111

112112
return data

fastplotlib/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from .gpu import enumerate_adapters, select_adapter, print_wgpu_report
77
from ._plot_helpers import *
88
from .enums import *
9-
from ._protocols import ArrayProtocol, ARRAY_LIKE_ATTRS
9+
from .protocols import ARRAY_LIKE_ATTRS, ArrayProtocol, FutureProtocol, CudaArrayProtocol
1010

1111

1212
@dataclass

fastplotlib/utils/_protocols.py

Lines changed: 0 additions & 33 deletions
This file was deleted.

fastplotlib/utils/functions.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from pygfx import Texture, Color
88

9+
from .protocols import CudaArrayProtocol
10+
911

1012
cmap_catalog = cmap_lib.Catalog()
1113

@@ -405,9 +407,22 @@ def parse_cmap_values(
405407
return colors
406408

407409

410+
def cuda_to_numpy(arr: CudaArrayProtocol) -> np.ndarray:
411+
try:
412+
import cupy
413+
except ImportError:
414+
raise ImportError(
415+
"`cupy` is required to work with GPU arrays\npip install cupy"
416+
)
417+
418+
return cupy.asnumpy(arr)
419+
420+
408421
def subsample_array(
409-
arr: np.ndarray, max_size: int = 1e6, ignore_dims: Sequence[int] | None = None
410-
):
422+
arr: CudaArrayProtocol,
423+
max_size: int = 1e6,
424+
ignore_dims: Sequence[int] | None = None,
425+
) -> np.ndarray:
411426
"""
412427
Subsamples an input array while preserving its relative dimensional proportions.
413428
@@ -476,7 +491,12 @@ def subsample_array(
476491

477492
slices = tuple(slices)
478493

479-
return np.asarray(arr[slices])
494+
arr_sliced = arr[slices]
495+
496+
if isinstance(arr_sliced, CudaArrayProtocol):
497+
return cuda_to_numpy(arr_sliced)
498+
499+
return arr_sliced
480500

481501

482502
def heatmap_to_positions(heatmap: np.ndarray, xvals: np.ndarray) -> np.ndarray:

fastplotlib/utils/protocols.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Callable
4+
from typing import Any, Protocol, runtime_checkable
5+
6+
7+
ARRAY_LIKE_ATTRS = [
8+
"__array__",
9+
"__array_ufunc__",
10+
"dtype",
11+
"shape",
12+
"ndim",
13+
"__getitem__",
14+
]
15+
16+
17+
@runtime_checkable
18+
class ArrayProtocol(Protocol):
19+
"""an object that is sufficiently array-like"""
20+
21+
def __array__(self) -> ArrayProtocol: ...
22+
23+
@property
24+
def dtype(self) -> Any: ...
25+
26+
@property
27+
def ndim(self) -> int: ...
28+
29+
@property
30+
def shape(self) -> tuple[int, ...]: ...
31+
32+
def __getitem__(self, key) -> ArrayProtocol: ...
33+
34+
35+
@runtime_checkable
36+
class CudaArrayProtocol(Protocol):
37+
"""an object that can be converted to a cupy array"""
38+
39+
def __cuda_array_interface__(self) -> CudaArrayProtocol: ...
40+
41+
42+
@runtime_checkable
43+
class FutureProtocol(Protocol):
44+
"""An object that is sufficiently Future-like"""
45+
46+
def cancel(self): ...
47+
48+
def cancelled(self): ...
49+
50+
def running(self): ...
51+
52+
def done(self): ...
53+
54+
def add_done_callback(self, fn: Callable): ...
55+
56+
def result(self, timeout: float | None): ...
57+
58+
def exception(self, timeout: float | None): ...
59+
60+
def set_result(self, array: ArrayProtocol): ...
61+
62+
def set_exception(self, exception): ...
Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,7 @@
11
from ...layouts import IMGUI
22

3-
try:
4-
import imgui_bundle
5-
except ImportError:
6-
HAS_XARRAY = False
7-
else:
8-
HAS_XARRAY = True
9-
103

11-
if IMGUI and HAS_XARRAY:
4+
if IMGUI:
125
from ._base import NDProcessor, NDGraphic
136
from ._nd_positions import NDPositions, NDPositionsProcessor, ndp_extras
147
from ._nd_image import NDImageProcessor, NDImage
@@ -19,6 +12,6 @@
1912
class NDWidget:
2013
def __init__(self, *args, **kwargs):
2114
raise ModuleNotFoundError(
22-
"NDWidget requires `imgui-bundle` and `xarray` to be installed.\n"
15+
"NDWidget requires `imgui-bundle` to be installed.\n"
2316
"pip install imgui-bundle"
2417
)
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from collections.abc import Generator
2+
from concurrent.futures import Future
3+
4+
from ...utils import ArrayProtocol, FutureProtocol, CudaArrayProtocol, cuda_to_numpy
5+
6+
7+
class FutureArray(Future):
8+
def __init__(self, shape, dtype, timeout: float = 1.0):
9+
self._shape = shape
10+
self._dtype = dtype
11+
self._timeout = timeout
12+
13+
super().__init__()
14+
15+
@property
16+
def shape(self) -> tuple[int, ...]:
17+
return self._shape
18+
19+
@property
20+
def ndim(self) -> int:
21+
return len(self.shape)
22+
23+
@property
24+
def dtype(self) -> str:
25+
return self._dtype
26+
27+
def __getitem__(self, item) -> ArrayProtocol:
28+
return self.result(self._timeout)[item]
29+
30+
def __array__(self) -> ArrayProtocol:
31+
return self.result(self._timeout)
32+
33+
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
34+
raise NotImplementedError
35+
36+
def __array_function__(self, func, types, *args, **kwargs):
37+
raise NotImplementedError
38+
39+
40+
# inspired by https://www.dabeaz.com/coroutines/
41+
def start_coroutine(func):
42+
"""
43+
Starts coroutines for async arrays wrapped by NDProcessor.
44+
Used by all NDGraphic.set_indices and NDGraphic._create_graphic.
45+
46+
It also immediately starts coroutines unless block=False is provided. It handles all the triage of possible
47+
sync vs. async (Future-like) objects.
48+
49+
The only time when block=False is when ReferenceIndex._render_indices uses it to loop through setting all
50+
indices, and then collect and send the results back down to NDProcessor.get().
51+
"""
52+
53+
def start(
54+
self, *args, **kwargs
55+
) -> tuple[Generator, ArrayProtocol | CudaArrayProtocol | FutureProtocol] | None:
56+
cr = func(self, *args, **kwargs)
57+
try:
58+
# begin coroutine
59+
to_resolve: FutureProtocol | ArrayProtocol | CudaArrayProtocol = cr.send(
60+
None
61+
)
62+
except StopIteration:
63+
# NDProcessor.get() has no `yield` expression, not async, nothing to return
64+
return None
65+
66+
block = kwargs.get("block", True)
67+
timeout = kwargs.get("timeout", 1.0)
68+
69+
if block: # resolve Future immediately
70+
try:
71+
if isinstance(to_resolve, FutureProtocol):
72+
# array is async, resolve future and send
73+
cr.send(to_resolve.result(timeout=timeout))
74+
elif isinstance(to_resolve, CudaArrayProtocol):
75+
# array is on GPU, it is technically and on GPU, convert to numpy array on CPU
76+
cr.send(cuda_to_numpy(to_resolve))
77+
else:
78+
# not async, just send the array
79+
cr.send(to_resolve)
80+
except StopIteration:
81+
pass
82+
83+
else: # no block, probably resolving multiple futures simultaneously
84+
if isinstance(to_resolve, FutureProtocol):
85+
# data is async, return coroutine generator and future
86+
# ReferenceIndex._render_indices() will manage them and wait to gather all futures
87+
return cr, to_resolve
88+
elif isinstance(to_resolve, CudaArrayProtocol):
89+
# it is async technically, but it's a GPU array, ReferenceIndex._render_indices will manage it
90+
return cr, to_resolve
91+
else:
92+
# not async, just send the array
93+
try:
94+
cr.send(to_resolve)
95+
except (
96+
StopIteration
97+
): # has to be here because of the yield expression, i.e. it's a generator
98+
pass
99+
100+
return start

0 commit comments

Comments
 (0)