")
self._size = value
+ self._set_rect()
@property
def location(self) -> str:
@@ -153,6 +163,7 @@ def height(self) -> int:
def _set_rect(self, *args):
self._x, self._y, self._width, self._height = self.get_rect()
+ self._figure._fpl_reset_layout()
def get_rect(self) -> tuple[int, int, int, int]:
"""
@@ -192,25 +203,203 @@ def get_rect(self) -> tuple[int, int, int, int]:
return x_pos, y_pos, width, height
+ def _draw_resize_handle(self):
+ if self._location == "bottom":
+ imgui.set_cursor_pos((0, 0))
+ imgui.invisible_button("##resize_handle", imgui.ImVec2(imgui.get_window_width(), self._separator_thickness))
+
+ hovered = imgui.is_item_hovered()
+ active = imgui.is_item_active()
+
+ # Get the actual screen rect of the button after it's been laid out
+ rect_min = imgui.get_item_rect_min()
+ rect_max = imgui.get_item_rect_max()
+
+ elif self._location == "right":
+ imgui.set_cursor_pos((0, 0))
+ screen_pos = imgui.get_cursor_screen_pos()
+ win_height = imgui.get_window_height()
+ mouse_pos = imgui.get_mouse_pos()
+
+ rect_min = imgui.ImVec2(screen_pos.x, screen_pos.y)
+ rect_max = imgui.ImVec2(screen_pos.x + self._separator_thickness, screen_pos.y + win_height)
+
+ hovered = (
+ rect_min.x <= mouse_pos.x <= rect_max.x
+ and rect_min.y <= mouse_pos.y <= rect_max.y
+ )
+
+ if hovered and imgui.is_mouse_clicked(0):
+ self._right_gui_resizing = True
+
+ if not imgui.is_mouse_down(0):
+ self._right_gui_resizing = False
+
+ active = self._right_gui_resizing
+
+ imgui.set_cursor_pos((self._separator_thickness, 0))
+
+ if hovered and imgui.is_mouse_double_clicked(0):
+ if not self._collapsed:
+ self._old_size = self.size
+ if self._location == "bottom":
+ self.size = int(self._separator_thickness)
+ elif self._location == "right":
+ self.size = int(self._separator_thickness)
+ self._collapsed = True
+ else:
+ self.size = self._old_size
+ self._collapsed = False
+
+ if hovered or active:
+ if not self._resize_cursor_set:
+ if self._location == "bottom":
+ self._figure.canvas.set_cursor("ns_resize")
+
+ elif self._location == "right":
+ self._figure.canvas.set_cursor("ew_resize")
+
+ self._resize_cursor_set = True
+ imgui.set_tooltip("Drag to resize, double click to expand/collapse")
+
+ elif self._resize_cursor_set:
+ self._figure.canvas.set_cursor("default")
+ self._resize_cursor_set = False
+
+ if active and imgui.is_mouse_dragging(0):
+ if self._location == "bottom":
+ delta = imgui.get_mouse_drag_delta(0).y
+
+ elif self._location == "right":
+ delta = imgui.get_mouse_drag_delta(0).x
+
+ imgui.reset_mouse_drag_delta(0)
+ px, py, pw, ph = self._figure.get_pygfx_render_area()
+
+ if self._location == "bottom":
+ new_render_size = ph + delta
+ elif self._location == "right":
+ new_render_size = pw + delta
+
+ # check if the new size would make the pygfx render area too small
+ if (delta < 0) and (new_render_size < 150):
+ print("not enough render area")
+ self._resize_blocked = True
+
+ if self._resize_blocked:
+ # check if cursor has returned
+ if self._location == "bottom":
+ _min, pos, _max = rect_min.y, imgui.get_mouse_pos().y, rect_max.y
+
+ elif self._location == "right":
+ _min, pos, _max = rect_min.x, imgui.get_mouse_pos().x, rect_max.x
+
+ if ((_min - 5) <= pos <= (_max + 5)) and delta > 0:
+ # if the mouse cursor is back on the bar and the delta > 0, i.e. render area increasing
+ self._resize_blocked = False
+
+ if not self._resize_blocked:
+ self.size = max(30, round(self.size - delta))
+ self._collapsed = False
+
+ draw_list = imgui.get_window_draw_list()
+
+ line_color = (
+ imgui.get_color_u32(imgui.ImVec4(0.9, 0.9, 0.9, 1.0))
+ if (hovered or active)
+ else imgui.get_color_u32(imgui.ImVec4(0.5, 0.5, 0.5, 0.8))
+ )
+ bg_color = (
+ imgui.get_color_u32(imgui.ImVec4(0.2, 0.2, 0.2, 0.8))
+ if (hovered or active)
+ else imgui.get_color_u32(imgui.ImVec4(0.15, 0.15, 0.15, 0.6))
+ )
+
+ # Background bar
+ draw_list.add_rect_filled(
+ imgui.ImVec2(rect_min.x, rect_min.y),
+ imgui.ImVec2(rect_max.x, rect_max.y),
+ bg_color,
+ )
+
+ # Three grip dots centered on the line
+ dot_spacing = 7.0
+ dot_radius = 2
+ if self._location == "bottom":
+ mid_y = (rect_min.y + rect_max.y) * 0.5
+ center_x = (rect_min.x + rect_max.x) * 0.5
+ for i in (-1, 0, 1):
+ cx = center_x + i * dot_spacing
+ draw_list.add_circle_filled(imgui.ImVec2(cx, mid_y), dot_radius, line_color)
+
+ imgui.set_cursor_pos((0, imgui.get_cursor_pos_y() - imgui.get_style().item_spacing.y))
+
+ elif self._location == "right":
+ mid_x = (rect_min.x + rect_max.x) * 0.5
+ center_y = (rect_min.y + rect_max.y) * 0.5
+ for i in (-1, 0, 1):
+ cy = center_y + i * dot_spacing
+ draw_list.add_circle_filled(
+ imgui.ImVec2(mid_x, cy), dot_radius, line_color
+ )
+
+ def _draw_title(self, title: str):
+ padding = imgui.ImVec2(10, 4)
+ text_size = imgui.calc_text_size(title)
+ win_width = imgui.get_window_width()
+ box_size = imgui.ImVec2(win_width, text_size.y + padding.y * 2)
+
+ box_screen_pos = imgui.get_cursor_screen_pos()
+
+ draw_list = imgui.get_window_draw_list()
+
+ # Background — use imgui's default title bar color
+ draw_list.add_rect_filled(
+ imgui.ImVec2(box_screen_pos.x, box_screen_pos.y),
+ imgui.ImVec2(box_screen_pos.x + box_size.x, box_screen_pos.y + box_size.y),
+ imgui.get_color_u32(imgui.Col_.title_bg_active),
+ )
+
+ # Centered text
+ text_pos = imgui.ImVec2(
+ box_screen_pos.x + (win_width - text_size.x) * 0.5,
+ box_screen_pos.y + padding.y,
+ )
+ draw_list.add_text(
+ text_pos, imgui.get_color_u32(imgui.ImVec4(1, 1, 1, 1)), title
+ )
+
+ imgui.dummy(imgui.ImVec2(win_width, box_size.y))
+
def draw_window(self):
"""helps simplify using imgui by managing window creation & position, and pushing/popping the ID"""
# window position & size
x, y, w, h = self.get_rect()
imgui.set_next_window_size((self.width, self.height))
imgui.set_next_window_pos((self.x, self.y))
- # imgui.set_next_window_pos((x, y))
- # imgui.set_next_window_size((w, h))
flags = self._window_flags
# begin window
imgui.begin(self._title, p_open=None, flags=flags)
+ self._draw_resize_handle()
+
# push ID to prevent conflict between multiple figs with same UI
imgui.push_id(self._id_counter)
+ # collapse the UI if the separator state is collapsed
+ # otherwise the UI renders partially on the separator for "right" guis and it looks weird
+ main_height = 1.0 if self._collapsed else 0.0
+ imgui.begin_child("##main_ui", imgui.ImVec2(0, main_height))
+
+ self._draw_title(self._title)
+
+ imgui.indent(6.0)
# draw stuff from subclass into window
self.update()
+ imgui.end_child()
+
# pop ID
imgui.pop_id()
diff --git a/fastplotlib/ui/right_click_menus/_colormap_picker.py b/fastplotlib/ui/right_click_menus/_colormap_picker.py
index a80e5b2aa..9df26dcdc 100644
--- a/fastplotlib/ui/right_click_menus/_colormap_picker.py
+++ b/fastplotlib/ui/right_click_menus/_colormap_picker.py
@@ -154,7 +154,8 @@ def update(self):
self._texture_height = (imgui.get_font_size()) - 2
if imgui.menu_item("Reset vmin-vmax", "", False)[0]:
- self._lut_tool.images[0].reset_vmin_vmax()
+ for image in self._lut_tool.images:
+ image.reset_vmin_vmax()
# add all the cmap options
for cmap_type in COLORMAP_NAMES.keys():
diff --git a/fastplotlib/ui/right_click_menus/_standard_menu.py b/fastplotlib/ui/right_click_menus/_standard_menu.py
index bb9e5bdef..9c659f4a7 100644
--- a/fastplotlib/ui/right_click_menus/_standard_menu.py
+++ b/fastplotlib/ui/right_click_menus/_standard_menu.py
@@ -31,6 +31,8 @@ def __init__(self, figure):
# whether the right click menu is currently open or not
self.is_open: bool = False
+ self._controller_window_open: bool | PlotArea = False
+
def get_subplot(self) -> PlotArea | bool | None:
"""get the subplot that a click occurred in"""
if self._last_right_click_pos is None:
@@ -47,6 +49,10 @@ def cleanup(self):
"""called when the popup disappears"""
self.is_open = False
+ def _extra_menu(self):
+ # extra menu items, optional, implement in subclass
+ pass
+
def update(self):
if imgui.is_mouse_down(1) and not self._mouse_down:
# mouse button was pressed down, store this position
@@ -147,39 +153,55 @@ def update(self):
imgui.separator()
# controller options
- if imgui.begin_menu("Controller"):
- _, enabled = imgui.menu_item(
- "Enabled", "", self.get_subplot().controller.enabled
- )
+ if imgui.menu_item("Controller Options", "", False)[0]:
+ self._controller_window_open = self.get_subplot()
- self.get_subplot().controller.enabled = enabled
+ self._extra_menu()
- changed, damping = imgui.slider_float(
- "Damping",
- v=self.get_subplot().controller.damping,
- v_min=0.0,
- v_max=10.0,
- )
-
- if changed:
- self.get_subplot().controller.damping = damping
-
- imgui.separator()
- imgui.text("Controller type:")
- # switching between different controllers
- for name, controller_type_iter in controller_types.items():
- current_type = type(self.get_subplot().controller)
+ imgui.end_popup()
- clicked, _ = imgui.menu_item(
- label=name,
- shortcut="",
- p_selected=current_type is controller_type_iter,
- )
+ if self._controller_window_open:
+ self._draw_controller_window()
+
+ def _draw_controller_window(self):
+ subplot = self._controller_window_open
+
+ imgui.set_next_window_size((0, 0))
+ _, keep_open = imgui.begin(f"Controller", True)
+ imgui.text(f"subplot: {subplot.name}")
+ _, enabled = imgui.menu_item(
+ "Enabled", "", subplot.controller.enabled
+ )
+
+ subplot.controller.enabled = enabled
+
+ changed, damping = imgui.slider_float(
+ "Damping",
+ v=subplot.controller.damping,
+ v_min=0.0,
+ v_max=10.0,
+ )
+
+ if changed:
+ subplot.controller.damping = damping
+
+ imgui.separator()
+ imgui.text("Controller type:")
+ # switching between different controllers
+ for name, controller_type_iter in controller_types.items():
+ current_type = type(subplot.controller)
+
+ clicked, _ = imgui.menu_item(
+ label=name,
+ shortcut="",
+ p_selected=current_type is controller_type_iter,
+ )
- if clicked and (current_type is not controller_type_iter):
- # menu item was clicked and the desired controller isn't the current one
- self.get_subplot().controller = name
+ if clicked and (current_type is not controller_type_iter):
+ # menu item was clicked and the desired controller isn't the current one
+ subplot.controller = name
- imgui.end_menu()
+ if not keep_open:
+ self._controller_window_open = False
- imgui.end_popup()
+ imgui.end()
diff --git a/fastplotlib/utils/__init__.py b/fastplotlib/utils/__init__.py
index dd527ca67..6f0059f6a 100644
--- a/fastplotlib/utils/__init__.py
+++ b/fastplotlib/utils/__init__.py
@@ -6,6 +6,7 @@
from .gpu import enumerate_adapters, select_adapter, print_wgpu_report
from ._plot_helpers import *
from .enums import *
+from ._protocols import ArrayProtocol, ARRAY_LIKE_ATTRS
@dataclass
diff --git a/fastplotlib/utils/_protocols.py b/fastplotlib/utils/_protocols.py
new file mode 100644
index 000000000..95d7d2763
--- /dev/null
+++ b/fastplotlib/utils/_protocols.py
@@ -0,0 +1,33 @@
+from __future__ import annotations
+
+from typing import Any, Protocol, runtime_checkable
+
+
+ARRAY_LIKE_ATTRS = [
+ "__array__",
+ "__array_ufunc__",
+ "dtype",
+ "shape",
+ "ndim",
+ "__getitem__",
+]
+
+
+@runtime_checkable
+class ArrayProtocol(Protocol):
+ def __array__(self) -> ArrayProtocol: ...
+
+ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): ...
+
+ def __array_function__(self, func, types, *args, **kwargs): ...
+
+ @property
+ def dtype(self) -> Any: ...
+
+ @property
+ def ndim(self) -> int: ...
+
+ @property
+ def shape(self) -> tuple[int, ...]: ...
+
+ def __getitem__(self, key): ...
diff --git a/fastplotlib/utils/functions.py b/fastplotlib/utils/functions.py
index a839ed9d0..34062824a 100644
--- a/fastplotlib/utils/functions.py
+++ b/fastplotlib/utils/functions.py
@@ -477,3 +477,38 @@ def subsample_array(
slices = tuple(slices)
return np.asarray(arr[slices])
+
+
+def heatmap_to_positions(heatmap: np.ndarray, xvals: np.ndarray) -> np.ndarray:
+ """
+
+ Convert a heatmap of shape [n_rows, n_datapoints] to timeseries x-y data of shape [n_rows, n_datapoints, xy]
+
+ Parameters
+ ----------
+ heatmap: np.ndarray, shape [n_rows, n_datapoints]
+ timeseries data with a heatmap representation, where each column represents a timepoint.
+
+ xvals: np.ndarray, shape: [n_datapoints,]
+ x-values for the columns in the heatmap
+
+ Returns
+ -------
+ np.ndarray, shape [n_rows, n_datapoints, 2]
+ timeseries data where the xy data are explicitly stored for every row
+
+ """
+ if heatmap.ndim != 2:
+ raise ValueError
+
+ if xvals.ndim != 1:
+ raise ValueError
+
+ if xvals.size != heatmap.shape[1]:
+ raise ValueError
+
+ ts = np.empty((*heatmap.shape, 2), dtype=np.float32)
+ ts[..., 0] = xvals
+ ts[..., 1] = heatmap
+
+ return ts
diff --git a/fastplotlib/widgets/__init__.py b/fastplotlib/widgets/__init__.py
index 766620ea6..4347f6c80 100644
--- a/fastplotlib/widgets/__init__.py
+++ b/fastplotlib/widgets/__init__.py
@@ -1,3 +1,12 @@
+from .nd_widget import (
+ NDWidget,
+ NDProcessor,
+ NDGraphic,
+ NDPositionsProcessor,
+ NDPositions,
+ NDImageProcessor,
+ NDImage,
+)
from .image_widget import ImageWidget
-__all__ = ["ImageWidget"]
+__all__ = ["NDWidget", "ImageWidget"]
diff --git a/fastplotlib/widgets/image_widget/_widget.py b/fastplotlib/widgets/image_widget/_widget.py
index 86a01b083..6d262678d 100644
--- a/fastplotlib/widgets/image_widget/_widget.py
+++ b/fastplotlib/widgets/image_widget/_widget.py
@@ -358,6 +358,11 @@ def __init__(
passed to each ImageGraphic in the ImageWidget figure subplots
"""
+ warn(
+ "`ImageWidget` is deprecated and will be removed in a"
+ " future release, please migrate to NDWidget",
+ DeprecationWarning
+ )
self._initialized = False
if figure_kwargs is None:
diff --git a/fastplotlib/widgets/nd_widget/__init__.py b/fastplotlib/widgets/nd_widget/__init__.py
new file mode 100644
index 000000000..378f7dfcd
--- /dev/null
+++ b/fastplotlib/widgets/nd_widget/__init__.py
@@ -0,0 +1,24 @@
+from ...layouts import IMGUI
+
+try:
+ import imgui_bundle
+except ImportError:
+ HAS_XARRAY = False
+else:
+ HAS_XARRAY = True
+
+
+if IMGUI and HAS_XARRAY:
+ from ._base import NDProcessor, NDGraphic
+ from ._nd_positions import NDPositions, NDPositionsProcessor, ndp_extras
+ from ._nd_image import NDImageProcessor, NDImage
+ from ._ndwidget import NDWidget
+
+else:
+
+ class NDWidget:
+ def __init__(self, *args, **kwargs):
+ raise ModuleNotFoundError(
+ "NDWidget requires `imgui-bundle` and `xarray` to be installed.\n"
+ "pip install imgui-bundle"
+ )
diff --git a/fastplotlib/widgets/nd_widget/_base.py b/fastplotlib/widgets/nd_widget/_base.py
new file mode 100644
index 000000000..932018f6e
--- /dev/null
+++ b/fastplotlib/widgets/nd_widget/_base.py
@@ -0,0 +1,738 @@
+from collections.abc import Callable, Hashable, Sequence
+from contextlib import contextmanager
+import inspect
+from numbers import Real
+from pprint import pformat
+import textwrap
+from typing import Literal, Any, Type
+from warnings import warn
+
+import xarray as xr
+import numpy as np
+from numpy.typing import ArrayLike
+
+from ...layouts import Subplot
+from ...utils import subsample_array, ArrayProtocol
+from ...graphics import Graphic
+from ._index import ReferenceIndex
+
+# must take arguments: array-like, `axis`: int, `keepdims`: bool
+WindowFuncCallable = Callable[[ArrayLike, int, bool], ArrayLike]
+
+
+def identity(index: int) -> int:
+ return round(index)
+
+
+class NDProcessor:
+ def __init__(
+ self,
+ data: Any,
+ dims: Sequence[Hashable],
+ spatial_dims: Sequence[Hashable] | None,
+ slider_dim_transforms: dict[Hashable, Callable[[Any], int] | ArrayLike] = None,
+ window_funcs: dict[
+ Hashable, tuple[WindowFuncCallable | None, int | float | None]
+ ] = None,
+ window_order: tuple[Hashable, ...] = None,
+ spatial_func: Callable[[ArrayProtocol], ArrayProtocol] | None = None,
+ ):
+ """
+ Base class for managing n-dimensional data and producing array slices.
+
+ By default, wraps input data into an ``xarray.DataArray`` and provides an interface
+ for indexing slider dimensions, applying window functions, spatial functions, and mapping
+ reference-space values to local array indices. Subclasses must implement
+ :meth:`get`, which is called whenever the :class:`ReferenceIndex` updates.
+
+ Subclasses can implement any type of data representation, they do not necessarily need to be compatible with
+ (they dot not have to be xarray compatible). However their ``get()`` method must still return a data slice that
+ corresponds to the graphical representation they map to.
+
+ Every dimension that is *not* listed in ``spatial_dims`` becomes a slider
+ dimension. Each slider dim must have a ``ReferenceRange`` defined in the
+ ``ReferenceIndex`` of the parent ``NDWidget``. The widget uses this to direct
+ a change in the ``ReferenceIndex`` and update the graphics.
+
+ Parameters
+ ----------
+ data: Any
+ data object that is managed, usually uses the ArrayProtocol. Custom subclasses can manage any kind of data
+ object but the corresponding :meth:`get` must return an array-like that maps to a graphical representation.
+
+ dims: Sequence[str]
+ names for each dimension in ``data``. Dimensions not listed in
+ ``spatial_dims`` are treated as slider dimensions and **must** appear as
+ keys in the parent ``NDWidget``'s ``ref_ranges``
+ Examples::
+ ``("time", "depth", "row", "col")``
+ ``("channels", "time", "xy")``
+ ``("keypoints", "time", "xyz")``
+
+ A custom subclass's ``data`` object doesn't necessarily need to have these dims, but the ``get()`` method
+ must operate as if these dimensions exist and return an array that matches the spatial dimensions.
+
+ spatial_dims: Sequence[str]
+ Subset of ``dims`` that are spatial (rendered) dimensions **in order**. All remaining dims are treated as
+ slider dims. See subclass for specific info.
+
+ slider_dim_transforms: dict mapping dim_name -> Callable, an ArrayLike, or None
+ Per-slider-dim mapping from reference-space values to local array indices.
+
+ You may also provide an array of reference values for the slider dims, ``searchsorted`` is then used
+ as the transform (ex: a timestamps array).
+
+ If ``None`` and identity mapping is used, i.e. rounds the current reference index value to the nearest
+ integer for array indexing.
+
+ If a transform is not provided for a dim then the identity mapping is used.
+
+ window_funcs: dict[
+ Hashable, tuple[WindowFuncCallable | None, int | float | None]
+ ]
+ Per-slider-dim window functions applied around the current slider position. Ex: {"time": (np.mean, 2.5)}.
+ Each value is a ``(func, window_size)`` pair where:
+
+ * *func* must accept ``axis: int`` and ``keepdims: bool`` kwargs
+ (ex: ``np.mean``, ``np.max``). The window function **must** return an array that has the same dimensions
+ as specified in the NDProcessor, therefore the size of any dim along which a window_func was applied
+ should reduce to ``1``. These dims must not be removed by the window_func.
+
+ * *window_size* is in reference-space units (ex: 2.5 seconds).
+
+
+ window_order: tuple[Hashable, ...]
+ Order in which window functions are applied across dims. Only dims listed
+ here have their window function applied. window_funcs are ignored for any
+ dims not specified in ``window_order``
+
+ spatial_func:
+ A function applied to the spatial slice *after* window_funcs right before rendering.
+
+ """
+ self._dims = tuple(dims)
+ self._data = self._validate_data(data)
+ self.spatial_dims = spatial_dims
+
+ self.slider_dim_transforms = slider_dim_transforms
+
+ self.window_funcs = window_funcs
+ self.window_order = window_order
+ self.spatial_func = spatial_func
+
+ @property
+ def data(self) -> xr.DataArray:
+ """
+ get or set managed data. If setting with new data, the new data is interpreted
+ to have the same dims (i.e. same dim names and ordering of dims).
+ """
+ return self._data
+
+ @data.setter
+ def data(self, data: ArrayProtocol):
+ self._data = self._validate_data(data)
+
+ def _validate_data(self, data: ArrayProtocol):
+ # does some basic validation
+ if data is None:
+ # we allow data to be None, in this case no ndgraphic is rendered
+ # useful when we want to initialize an NDWidget with no traces for example
+ # and populate it as components/channels are selected
+ return None
+
+ if not isinstance(data, ArrayProtocol):
+ # This is required for xarray compatibility and general array-like requirements
+ raise TypeError("`data` must implement the ArrayProtocol")
+
+ if data.ndim != len(self.dims):
+ raise IndexError("must specify a dim for every dimension in the data array")
+
+ # data can be set, but the dims must still match/have the same meaning
+ return xr.DataArray(data, dims=self.dims)
+
+ @property
+ def shape(self) -> dict[Hashable, int]:
+ """interpreted shape of the data"""
+ return {d: n for d, n in zip(self.dims, self.data.shape)}
+
+ @property
+ def ndim(self) -> int:
+ """number of dims"""
+ return self.data.ndim
+
+ @property
+ def dims(self) -> tuple[Hashable, ...]:
+ """dim names"""
+ # these are read-only and cannot be set after it's created
+ # the user should create a new NDGraphic if they need different dims
+ # I can't think of a usecase where we'd want to change the dims, and
+ # I think that would be complicated and probably and anti-pattern
+ return self._dims
+
+ @property
+ def spatial_dims(self) -> tuple[Hashable, ...]:
+ """Spatial dims, **in order**"""
+ return self._spatial_dims
+
+ @spatial_dims.setter
+ def spatial_dims(self, sdims: Sequence[Hashable]):
+ for dim in sdims:
+ if dim not in self.dims:
+ raise KeyError
+
+ self._spatial_dims = tuple(sdims)
+
+ @property
+ def tooltip(self) -> bool:
+ """
+ whether or not a custom tooltip formatter method exists
+ """
+ return False
+
+ def tooltip_format(self, *args) -> str | None:
+ """
+ Override in subclass to format custom tooltips
+ """
+ return None
+
+ @property
+ def slider_dims(self) -> set[Hashable]:
+ """Slider dim names, ``set(dims) - set(spatial_dims)"""
+ return set(self.dims) - set(self.spatial_dims)
+
+ @property
+ def n_slider_dims(self):
+ """number of slider dims, i.e. len(slider_dims)"""
+ return len(self.slider_dims)
+
+ @property
+ def window_funcs(
+ self,
+ ) -> dict[Hashable, tuple[WindowFuncCallable | None, int | float | None]]:
+ """get or set window functions, see docstring for details"""
+ return self._window_funcs
+
+ @window_funcs.setter
+ def window_funcs(
+ self,
+ window_funcs: (
+ dict[Hashable, tuple[WindowFuncCallable | None, int | float | None] | None]
+ | None
+ ),
+ ):
+ if window_funcs is None:
+ # tuple of (None, None) makes the checks easier in _apply_window_funcs
+ self._window_funcs = {d: (None, None) for d in self.slider_dims}
+ return
+
+ for k in window_funcs.keys():
+ if k not in self.slider_dims:
+ raise KeyError
+
+ func = window_funcs[k][0]
+ size = window_funcs[k][1]
+
+ if func is None:
+ pass
+ elif callable(func):
+ sig = inspect.signature(func)
+
+ if "axis" not in sig.parameters or "keepdims" not in sig.parameters:
+ raise TypeError(
+ f"Each window function must take an `axis` and `keepdims` argument, "
+ f"you passed: {func} with the following function signature: {sig}"
+ )
+ else:
+ raise TypeError(
+ f"`window_funcs` must be a dict mapping dim names to a tuple of the window function callable and "
+ f"window size, {'name': (func, size), ...}.\nYou have passed: {window_funcs}"
+ )
+
+ if size is None:
+ pass
+
+ elif not isinstance(size, Real):
+ raise TypeError
+
+ elif size < 0:
+ raise ValueError
+
+ # fill in rest with None
+ for d in self.slider_dims:
+ if d not in window_funcs.keys():
+ window_funcs[d] = (None, None)
+
+ self._window_funcs = window_funcs
+
+ @property
+ def window_order(self) -> tuple[Hashable, ...]:
+ """get or set dimension order in which window functions are applied"""
+ return self._window_order
+
+ @window_order.setter
+ def window_order(self, order: tuple[Hashable] | None):
+ if order is None:
+ self._window_order = tuple()
+ return
+
+ if not set(order).issubset(self.slider_dims):
+ raise ValueError(
+ f"each dimension in `window_order` must be a slider dim. You passed order: {order} "
+ f"and the slider dims are: {self.slider_dims}"
+ )
+
+ self._window_order = tuple(order)
+
+ @property
+ def spatial_func(self) -> Callable[[xr.DataArray], xr.DataArray] | None:
+ """get or set the spatial function which is applied on the data slice after the window functions"""
+ return self._spatial_func
+
+ @spatial_func.setter
+ def spatial_func(
+ self, func: Callable[[xr.DataArray], xr.DataArray]
+ ) -> Callable | None:
+ if not callable(func) and func is not None:
+ raise TypeError
+
+ self._spatial_func = func
+
+ @property
+ def slider_dim_transforms(self) -> dict[Hashable, Callable[[Any], int]]:
+ """get or set the slider_dim_transforms, see docstring for details"""
+ return self._index_mappings
+
+ @slider_dim_transforms.setter
+ def slider_dim_transforms(
+ self, maps: dict[Hashable, Callable[[Any], int] | ArrayLike | None] | None
+ ):
+ if maps is None:
+ self._index_mappings = {d: identity for d in self.dims}
+ return
+
+ for d in maps.keys():
+ if d not in self.dims:
+ raise KeyError(
+ f"`index_mapping` provided for non-existent dimension: {d}, existing dims are: {self.dims}"
+ )
+
+ if isinstance(maps[d], ArrayProtocol):
+ # create a searchsorted mapping function automatically
+ maps[d] = maps[d].searchsorted
+
+ elif maps[d] is None:
+ # assign identity mapping
+ maps[d] = identity
+
+ for d in self.dims:
+ # fill in any unspecified maps with identity
+ if d not in maps.keys():
+ maps[d] = identity
+
+ self._index_mappings = maps
+
+ def _ref_index_to_array_index(self, dim: str, ref_index: Any) -> int:
+ # wraps slider_dim_transforms, clamps between 0 and the array size in this dim
+
+ # ref-space -> local-array-index transform
+ index = self.slider_dim_transforms[dim](ref_index)
+
+ # clamp between 0 and array size in this dim
+ return max(min(index, self.shape[dim] - 1), 0)
+
+ def _get_slider_dims_indexer(self, indices: dict[Hashable, Any]) -> dict[Hashable, slice]:
+ """
+ Creates an xarray-compatible indexer dict mapping each slider_dim -> slice object.
+
+ - If a window_func is defined for a dim and the dim appears in ``window_order``,
+ the slice is defined as:
+ start: index - half_window
+ stop: index + half_window
+ step: 1
+
+ It then applies the slider_dim_transform to the start and stop to map these values from reference-space to
+ the local array index, and then finally produces the slice object in local array indices.
+
+ ex: if we have indices = {"time": 50.0}, a window size of 5.0s and the ``slider_dim_transform``
+ for time is based on a sampling rate of 10Hz, the window in ref units is [45.0, 55.0], and the final
+ slice object would be ``slice(450, 550, 1)``.
+
+ - If no window func is specified, the final slice just corresponds to that index as an int array-index.
+
+ This exists separate from ``_apply_window_functions()`` because it is useful for debugging purposes.
+
+ Parameters
+ ----------
+ indices : dict[Hashable, Any], {dim: ref_value}
+ Reference-space values for each slider dim. Must contain an entry
+ for every slider dim; raises ``IndexError`` otherwise.
+ ex: {"time": 46.397, "depth": 23.24}
+
+ Returns
+ -------
+ dict[Hashable, slice]
+ Indexer compatible for ``xr.DataArray.isel()``, with one ``slice`` per
+ slider dim. These are array indices mapped from the reference space using
+ the given ``slider_dim_transform``.
+
+ Raises
+ ------
+ IndexError
+ If ``indices`` are not provided for every ``slider_dim``
+ """
+
+ if set(indices.keys()) != set(self.slider_dims):
+ raise IndexError(
+ f"Must provide an index for all slider dims: {self.slider_dims}, you have provided: {indices.keys()}"
+ )
+
+ indexer = dict()
+
+ # get only slider dims which are not also spatial dims (example: p dim for positional data)
+ # since `p` dim windowing is dealt with separately for positional data
+ slider_dims = set(self.slider_dims) - set(self.spatial_dims)
+ # go through each slider dim and accumulate slice objects
+ for dim in slider_dims:
+ # index for this dim in reference space
+ index_ref = indices[dim]
+
+ if dim not in self.window_funcs.keys():
+ wf, ws = None, None
+ else:
+ # get window func and size in reference units
+ wf, ws = self.window_funcs[dim]
+
+ # if a window function exists for this dim, and it's specified in the window order
+ if (wf is not None) and (ws is not None) and (dim in self.window_order):
+ # half window in reference units
+ hw = ws / 2
+
+ # start in reference units
+ start_ref = index_ref - hw
+ # stop in ref units
+ stop_ref = index_ref + hw
+
+ # map start and stop ref to array indices
+ start = self.slider_dim_transforms[dim](start_ref)
+ stop = self.slider_dim_transforms[dim](stop_ref)
+
+ # clamp within array bounds
+ start = max(min(self.shape[dim] - 1, start), 0)
+ stop = max(min(self.shape[dim] - 1, stop), 0)
+ indexer[dim] = slice(start, stop, 1)
+ else:
+ # no window func for this dim, direct indexing
+ # index mapped to array index
+ index = self.slider_dim_transforms[dim](index_ref)
+
+ # clamp within the bounds
+ start = max(min(self.shape[dim] - 1, index), 0)
+
+ # stop index is just the start index + 1
+ indexer[dim] = slice(start, start + 1, 1)
+
+ return indexer
+
+ def _apply_window_functions(self, indices: dict[Hashable, Any]) -> xr.DataArray:
+ """
+ Slice the data at the given indices and apply window functions in the order specified by
+ ``window_order``.
+
+ Parameters
+ ----------
+ indices : dict[Hashable, Any], {dim: ref_value}
+ Reference-space values for each slider dim.
+ ex: {"time": 46.397, "depth": 23.24}
+
+ Returns
+ -------
+ xr.DataArray
+ Data slice after windowed indexing and window function application,
+ with the same dims as the original data. Dims of size ``1`` are not
+ squeezed.
+
+ """
+ indexer = self._get_slider_dims_indexer(indices)
+
+ # get the data slice w.r.t. the desired windows, and get the underlying numpy array
+ # ``.values`` gives the numpy array
+ # there is significant overhead with passing xarray objects to numpy for things like np.mean()
+ # so convert to numpy, apply window functions, then convert back to xarray
+ # creating an xarray object from a numpy array has very little overhead, ~10 microseconds
+ array = self.data.isel(indexer).values
+
+ # apply window funcs in the specified order
+ for dim in self.window_order:
+ if self.window_funcs[dim] is None:
+ continue
+
+ func, _ = self.window_funcs[dim]
+ # ``keepdims=True`` is critical, any "collapsed" dims will be of size ``1``.
+ # Ex: if `array` is of shape [10, 512, 512] and we applied the np.mean() window func on the first dim
+ # ``keepdims`` means the resultant shape is [1, 512, 512] and NOT [512, 512]
+ # this is necessary for applying window functions on multiple dims separately and so that the
+ # dims names correspond after all the window funcs are applied.
+ array = func(array, axis=self.dims.index(dim), keepdims=True)
+
+ return xr.DataArray(array, dims=self.dims)
+
+ def get(self, indices: dict[Hashable, Any]):
+ raise NotImplementedError
+
+ # TODO: html and pretty text repr #
+ # def _repr_html_(self) -> str:
+ # return ndp_fmt_html(self)
+ #
+ # def _repr_mimebundle_(self, **kwargs) -> dict:
+ # return {
+ # "text/plain": self._repr_text_(),
+ # "text/html": self._repr_html_(),
+ # }
+
+ def _repr_text_(self):
+ if self.data is None:
+ return (
+ f"{self.__class__.__name__}\n"
+ f"data is None, dims: {self.dims}"
+ )
+ tab = "\t"
+
+ wf = {k: v for k, v in self.window_funcs.items() if v != (None, None)}
+
+ r = (
+ f"{self.__class__.__name__}\n"
+ f"shape:\n\t{self.shape}\n"
+ f"dims:\n\t{self.dims}\n"
+ f"spatial_dims:\n\t{self.spatial_dims}\n"
+ f"slider_dims:\n\t{self.slider_dims}\n"
+ f"slider_dim_transforms:\n{textwrap.indent(pformat(self.slider_dim_transforms, width=120), prefix=tab)}\n"
+ )
+
+ if len(wf) > 0:
+ r += (
+ f"window_funcs:\n{textwrap.indent(pformat(wf, width=120), prefix=tab)}\n"
+ f"window_order:\n\t{self.window_order}\n"
+ )
+
+ if self.spatial_func is not None:
+ r += f"spatial_func:\n\t{self.spatial_func}\n"
+
+ return r
+
+
+class NDGraphic:
+ def __init__(
+ self,
+ subplot: Subplot,
+ name: str | None,
+ ):
+ self._subplot = subplot
+ self._name = name
+ self._graphic: Graphic | None = None
+
+ # used to indicate that the NDGraphic should ignore any requests to update the indices
+ # used by block_indices_ctx context manager, usecase is when the LinearSelector on timeseries
+ # NDGraphic changes the selection, it shouldn't change the graphic that it is on top of! Would
+ # also cause recursion
+ # It is also used by the @block_reentrance decorator which is on the ``NDGraphic.indices`` property setter
+ # this is also to block recursion
+ self._block_indices = False
+
+ # user settable bool to make the graphic unresponsive to change in the ReferenceIndex
+ self._pause = False
+
+
+ def _create_graphic(self):
+ raise NotImplementedError
+
+ @property
+ def pause(self) -> bool:
+ """if True, changes in the reference until it is set back to False"""
+ return self._pause
+
+ @pause.setter
+ def pause(self, val: bool):
+ self._pause = bool(val)
+
+ @property
+ def name(self) -> str | None:
+ """name given to the NDGraphic"""
+ return self._name
+
+ @property
+ def processor(self) -> NDProcessor:
+ raise NotImplementedError
+
+ @property
+ def graphic(self) -> Graphic:
+ raise NotImplementedError
+
+ @property
+ def indices(self) -> dict[Hashable, Any]:
+ raise NotImplementedError
+
+ @indices.setter
+ def indices(self, new: dict[Hashable, Any]):
+ raise NotImplementedError
+
+ # aliases for easier access to processor properties
+ @property
+ def data(self) -> Any:
+ """
+ get or set managed data. If setting with new data, the new data is interpreted
+ to have the same dims (i.e. same dim names and ordering of dims).
+ """
+ return self.processor.data
+
+ @data.setter
+ def data(self, data: Any):
+ self.processor.data = data
+ # create a new graphic when data has changed
+ if self.graphic is not None:
+ # it is already None if NDGraphic was initialized with no data
+ self._subplot.delete_graphic(self.graphic)
+ self._graphic = None
+
+ self._create_graphic()
+
+ # force a render
+ self.indices = self.indices
+
+ @property
+ def shape(self) -> dict[Hashable, int]:
+ """interpreted shape of the data"""
+ return self.processor.shape
+
+ @property
+ def ndim(self) -> int:
+ """number of dims"""
+ return self.processor.ndim
+
+ @property
+ def dims(self) -> tuple[Hashable, ...]:
+ """dim names"""
+ return self.processor.dims
+
+ @property
+ def spatial_dims(self) -> tuple[str, ...]:
+ # number of spatial dims for positional data is always 3
+ # for image is 2 or 3, so it must be implemented in subclass
+ raise NotImplementedError
+
+ @property
+ def slider_dims(self) -> set[Hashable]:
+ """the slider dims"""
+ return self.processor.slider_dims
+
+ @property
+ def slider_dim_transforms(self) -> dict[Hashable, Callable[[Any], int]]:
+ return self.processor.slider_dim_transforms
+
+ @slider_dim_transforms.setter
+ def slider_dim_transforms(
+ self, maps: dict[Hashable, Callable[[Any], int] | ArrayLike | None] | None
+ ):
+ """get or set the slider_dim_transforms, see docstring for details"""
+ self.processor.slider_dim_transforms = maps
+ # force a render
+ self.indices = self.indices
+
+ @property
+ def window_funcs(
+ self,
+ ) -> dict[Hashable, tuple[WindowFuncCallable | None, int | float | None]]:
+ """get or set window functions, see docstring for details"""
+ return self.processor.window_funcs
+
+ @window_funcs.setter
+ def window_funcs(
+ self,
+ window_funcs: (
+ dict[Hashable, tuple[WindowFuncCallable | None, int | float | None] | None]
+ | None
+ ),
+ ):
+ self.processor.window_funcs = window_funcs
+ # force a render
+ self.indices = self.indices
+
+ @property
+ def window_order(self) -> tuple[Hashable, ...]:
+ """get or set dimension order in which window functions are applied"""
+ return self.processor.window_order
+
+ @window_order.setter
+ def window_order(self, order: tuple[Hashable] | None):
+ self.processor.window_order = order
+ # force a render
+ self.indices = self.indices
+
+ @property
+ def spatial_func(self) -> Callable[[xr.DataArray], xr.DataArray] | None:
+ return self.processor.spatial_func
+
+ @spatial_func.setter
+ def spatial_func(
+ self, func: Callable[[xr.DataArray], xr.DataArray]
+ ) -> Callable | None:
+ """get or set the spatial_func, see docstring for details"""
+ self.processor.spatial_func = func
+ # force a render
+ self.indices = self.indices
+
+ # def _repr_text_(self) -> str:
+ # return ndg_fmt_text(self)
+ #
+ # def _repr_html_(self) -> str:
+ # return ndg_fmt_html(self)
+ #
+ # def _repr_mimebundle_(self, **kwargs) -> dict:
+ # return {
+ # "text/plain": self._repr_text_(),
+ # "text/html": self._repr_html_(),
+ # }
+
+ def _repr_text_(self):
+ return f"graphic: {self.graphic.__class__.__name__}\n" f"processor:\n{self.processor}"
+
+
+@contextmanager
+def block_indices_ctx(ndgraphic: NDGraphic):
+ """
+ Context manager for pausing an NDGraphic from updating indices
+ """
+ ndgraphic._block_indices = True
+
+ try:
+ yield
+ except Exception as e:
+ raise e from None # indices setter has raised, the line above and the lines below are probably more relevant!
+ finally:
+ ndgraphic._block_indices = False
+
+
+def block_reentrance(setter):
+ # decorator to block re-entrance of indices setter
+ def set_indices_wrapper(self: NDGraphic, new_indices):
+ """
+ wraps NDGraphic.indices
+
+ self: NDGraphic instance
+
+ new_indices: new indices to set
+ """
+ # set_value is already in the middle of an execution, block re-entrance
+ if self._block_indices:
+ return
+ try:
+ # block re-execution of set_value until it has *fully* finished executing
+ self._block_indices = True
+ setter(self, new_indices)
+ except Exception as exc:
+ # raise original exception
+ raise exc # set_value has raised. The line above and the lines 2+ steps below are probably more relevant!
+ finally:
+ # set_value has finished executing, now allow future executions
+ self._block_indices = False
+
+ return set_indices_wrapper
diff --git a/fastplotlib/widgets/nd_widget/_index.py b/fastplotlib/widgets/nd_widget/_index.py
new file mode 100644
index 000000000..fc51c345c
--- /dev/null
+++ b/fastplotlib/widgets/nd_widget/_index.py
@@ -0,0 +1,329 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+from numbers import Number
+from typing import Sequence, Any, Callable
+
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from ._ndwidget import NDWidget
+
+
+@dataclass
+class RangeContinuous:
+ """
+ A continuous reference range for a single slider dimension.
+
+ Stores the (start, stop, step) in scientific units (ex: seconds, micrometers,
+ Hz). The imgui slider for this dimension uses these values to determine its
+ minimum and maximum bounds. The step size is used for the "next" and "previous" buttons.
+
+ Parameters
+ ----------
+ start : int or float
+ Minimum value of the range, inclusive.
+
+ stop : int or float
+ Maximum value of the range, exclusive upper bound.
+
+ step : int or float
+ Step size used for imgui step next/previous buttons
+
+ Raises
+ ------
+ IndexError
+ If ``start >= stop``.
+
+ Examples
+ --------
+ A time axis sampled at 1 ms resolution over 10 seconds:
+
+ RangeContinuous(start=0, stop=10_000, step=1)
+
+ A depth axis in micrometers with 0.5 µm steps:
+
+ RangeContinuous(start=0.0, stop=500.0, step=0.5)
+ """
+
+ start: int | float
+ stop: int | float
+ step: int | float
+
+ def __post_init__(self):
+ if self.start >= self.stop:
+ raise IndexError(
+ f"start must be less than stop, {self.start} !< {self.stop}"
+ )
+
+ def __getitem__(self, index: int):
+ """return the value at the index w.r.t. the step size"""
+ # if index is negative, turn to positive index
+ if index < 0:
+ raise ValueError("negative indexing not supported")
+
+ val = self.start + (self.step * index)
+ if not self.start <= val <= self.stop:
+ raise IndexError(
+ f"index: {index} value: {val} out of bounds: [{self.start}, {self.stop}]"
+ )
+
+ return val
+
+ @property
+ def range(self) -> int | float:
+ return self.stop - self.start
+
+
+@dataclass
+class RangeDiscrete:
+ # TODO: not implemented yet, placeholder until we have a clear usecase
+ options: Sequence[Any]
+
+ def __getitem__(self, index: int):
+ if index > len(self.options):
+ raise IndexError
+
+ return self.options[index]
+
+ def __len__(self):
+ return len(self.options)
+
+
+class ReferenceIndex:
+ def __init__(
+ self,
+ ref_ranges: dict[
+ str,
+ tuple[Number, Number, Number] | tuple[Any] | RangeContinuous,
+ ],
+ ):
+ """
+ Manages the shared reference index for one or more ``NDWidget`` instances.
+
+ Stores the current index for each named slider dimension in reference-space
+ units (ex: seconds, depth in µm, Hz). Whenever an index is updated, every
+ ``NDGraphic`` in the manged ``NDWidgets`` are requested to render data at
+ the new indices.
+
+ Each key in ``ref_ranges`` defines a slider dimension. When adding an
+ ``NDGraphic``, every dimension listed in ``dims`` must be either a spatial
+ dimension (listed in ``spatial_dims``) or a key in ``ref_ranges``.
+ If a dim is not spatial, it must have a corresponding reference range,
+ otherwise an error will be raised.
+
+ You can also define conceptually identical but *independent* reference spaces
+ by using distinct names, ex: ``"time-1"`` and ``"time-2"`` for two recordings
+ that should be sycned independently. Each ``NDGraphic`` then declares the
+ specific "time-n" space that corresponds to its data, so the widget keeps the
+ two timelines decoupled.
+
+ Parameters
+ ----------
+ ref_ranges : dict[str, tuple], or a RangeContinuous
+ Mapping of dimension names to range specifications. A 3-tuple
+ ``(start, stop, step)`` creates a :class:`RangeContinuous`. A 1-tuple
+ ``(options,)`` creates a :class:`RangeDiscrete`.
+
+ Attributes
+ ----------
+ ref_ranges : dict[str, RangeContinuous | RangeDiscrete]
+ The reference range for each registered slider dimension.
+
+ dims: set[str]
+ the set of "slider dims"
+
+ Examples
+ --------
+ Single shared time axis:
+
+ ri = ReferenceIndex(ref_ranges={"time": (0, 1000, 1), "depth": (15, 35, 0.5)})
+ ri["time"] = 500 # update one dim and re-render
+ ri.set({"time": 500, "depth": 10}) # update several dims atomically
+
+ Two independent time axes for data from two different recording sessions:
+
+ ri = ReferenceIndex({
+ "time-1": (0, 3600, 1), # session 1 — 1 h at 1 s resolution
+ "time-s": (0, 1800, 1), # session 2 — 30 min at 1 s resolution
+ })
+
+ Each ``NDGraphic`` declares matching names for slider dims to indicate that these should be
+ synced across graphics.
+
+ ndw[0, 0].add_nd_image(data_s1, ("time-s1", "row", "col"), ("row", "col"))
+ ndw[0, 1].add_nd_image(data_s2, ("time-s2", "row", "col"), ("row", "col"))
+
+ """
+ self._ref_ranges = dict()
+ self.push_dims(ref_ranges)
+
+ # starting index for all dims
+ self._indices: dict[str, int | float | Any] = {
+ name: rr.start for name, rr in self._ref_ranges.items()
+ }
+
+ self._indices_changed_handlers = set()
+
+ self._ndwidgets: list[NDWidget] = list()
+
+ @property
+ def ref_ranges(self) -> dict[str, RangeContinuous | RangeDiscrete]:
+ return self._ref_ranges
+
+ @property
+ def dims(self) -> set[str]:
+ return set(self.ref_ranges.keys())
+
+ def _add_ndwidget_(self, ndw: NDWidget):
+ from ._ndwidget import NDWidget
+
+ if not isinstance(ndw, NDWidget):
+ raise TypeError
+
+ self._ndwidgets.append(ndw)
+
+ def set(self, indices: dict[str, Any]):
+ for dim, value in indices.items():
+ self._indices[dim] = self._clamp(dim, value)
+
+ self._render_indices()
+ self._indices_changed()
+
+ def _clamp(self, dim, value):
+ if isinstance(self.ref_ranges[dim], RangeContinuous):
+ return max(
+ min(value, self.ref_ranges[dim].stop - self.ref_ranges[dim].step),
+ self.ref_ranges[dim].start,
+ )
+
+ return value
+
+ def _render_indices(self):
+ for ndw in self._ndwidgets:
+ for g in ndw.ndgraphics:
+ if g.data is None or g.pause:
+ continue
+ # only provide slider indices to the graphic
+ g.indices = {d: self._indices[d] for d in g.processor.slider_dims}
+
+ def __getitem__(self, dim):
+ self._check_has_dim(dim)
+ return self._indices[dim]
+
+ def __setitem__(self, dim, value):
+ self._check_has_dim(dim)
+ # set index for given dim and render
+ self._indices[dim] = self._clamp(dim, value)
+ self._render_indices()
+ self._indices_changed()
+
+ def _check_has_dim(self, dim):
+ if dim not in self.dims:
+ raise KeyError(
+ f"provided dimension: {dim} has no associated ReferenceRange in this ReferenceIndex, valid dims in this ReferenceIndex are: {self.dims}"
+ )
+
+ def pop_dim(self):
+ pass
+
+ def push_dims(self, ref_ranges: dict[
+ str,
+ tuple[Number, Number, Number] | tuple[Any] | RangeContinuous,
+ ],):
+
+ for name, r in ref_ranges.items():
+ if isinstance(r, (RangeContinuous, RangeDiscrete)):
+ self._ref_ranges[name] = r
+
+ elif len(r) == 3:
+ # assume start, stop, step
+ self._ref_ranges[name] = RangeContinuous(*r)
+
+ elif len(r) == 1:
+ # assume just options
+ self._ref_ranges[name] = RangeDiscrete(*r)
+
+ else:
+ raise ValueError(
+ f"ref_ranges must be a mapping of dimension names to range specifications, "
+ f"see the docstring, you have passed: {ref_ranges}"
+ )
+
+ def add_event_handler(self, handler: Callable, event: str = "indices"):
+ """
+ Register an event handler that is called whenever the indices change.
+
+ Parameters
+ ----------
+ handler: Callable
+ callback function, must take a tuple of int as the only argument. This tuple will be the `indices`
+
+ event: str, "indices"
+ the only supported valid is "indices"
+
+ Example
+ -------
+
+ .. code-block:: py
+
+ def my_handler(indices):
+ print(indices)
+ # example prints: {"t": 100, "z": 15} if the index has 2 reference spaces "t" and "z"
+
+ # create an NDWidget
+ ndw = NDWidget(...)
+
+ # add event handler
+ ndw.indices.add_event_handler(my_handler)
+
+ """
+ if event != "indices":
+ raise ValueError("`indices` is the only event supported by `GlobalIndex`")
+
+ self._indices_changed_handlers.add(handler)
+
+ def remove_event_handler(self, handler: Callable):
+ """Remove a registered event handler"""
+ self._indices_changed_handlers.remove(handler)
+
+ def clear_event_handlers(self):
+ """Clear all registered event handlers"""
+ self._indices_changed_handlers.clear()
+
+ def _indices_changed(self):
+ for f in self._indices_changed_handlers:
+ f(self._indices)
+
+ def __iter__(self):
+ for index in self._indices.items():
+ yield index
+
+ def __len__(self):
+ return len(self._indices)
+
+ def __eq__(self, other):
+ return self._indices == other
+
+ def __repr__(self):
+ return f"Global Index: {self._indices}"
+
+ def __str__(self):
+ return str(self._indices)
+
+
+# TODO: Not sure if we'll actually do this here, just a placeholder for now
+class SelectionVector:
+ @property
+ def selection(self):
+ pass
+
+ @property
+ def graphics(self):
+ pass
+
+ def add_graphic(self):
+ pass
+
+ def remove_graphic(self):
+ pass
diff --git a/fastplotlib/widgets/nd_widget/_nd_image.py b/fastplotlib/widgets/nd_widget/_nd_image.py
new file mode 100644
index 000000000..be319942d
--- /dev/null
+++ b/fastplotlib/widgets/nd_widget/_nd_image.py
@@ -0,0 +1,549 @@
+from collections.abc import Hashable, Sequence
+import inspect
+from typing import Callable, Any
+
+import numpy as np
+from numpy.typing import ArrayLike
+import xarray as xr
+
+from ...layouts import Subplot
+from ...utils import subsample_array, ArrayProtocol, ARRAY_LIKE_ATTRS
+from ...graphics import ImageGraphic, ImageVolumeGraphic
+from ...tools import HistogramLUTTool
+from ._base import NDProcessor, NDGraphic, WindowFuncCallable
+from ._index import ReferenceIndex
+
+
+class NDImageProcessor(NDProcessor):
+ def __init__(
+ self,
+ data: ArrayProtocol | None,
+ dims: Sequence[Hashable],
+ spatial_dims: (
+ tuple[str, str] | tuple[str, str, str]
+ ), # must be in order! [rows, cols] | [z, rows, cols]
+ rgb_dim: str | None = None,
+ window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable = None,
+ window_order: tuple[int, ...] = None,
+ spatial_func: Callable[[ArrayLike], ArrayLike] = None,
+ compute_histogram: bool = True,
+ slider_dim_transforms=None,
+ ):
+ """
+ ``NDProcessor`` subclass for n-dimensional image data.
+
+ Produces 2-D or 3-D spatial slices for an ``ImageGraphic`` or ``ImageVolumeGraphic``.
+
+ Parameters
+ ----------
+ data: ArrayProtocol
+ array-like data, must have 2 or more dimensions
+
+ dims: Sequence[str]
+ names for each dimension in ``data``. Dimensions not listed in
+ ``spatial_dims`` are treated as slider dimensions and **must** appear as
+ keys in the parent ``NDWidget``'s ``ref_ranges``
+ Examples::
+ ``("time", "depth", "row", "col")``
+ ``("channels", "time", "xy")``
+ ``("keypoints", "time", "xyz")``
+
+ A custom subclass's ``data`` object doesn't necessarily need to have these dims, but the ``get()`` method
+ must operate as if these dimensions exist and return an array that matches the spatial dimensions.
+
+ dims: Sequence[str]
+ names for each dimension in ``data``. Dimensions not listed in
+ ``spatial_dims`` are treated as slider dimensions and **must** appear as
+ keys in the parent ``NDWidget``'s ``ref_ranges``
+ Examples::
+ ``("time", "depth", "row", "col")``
+ ``("row", "col")``
+ ``("other_dim", "depth", "time", "row", "col")``
+
+ dims in the array do not need to be in order, for example you can have a weird array where the dims are
+ interpreted as: ``("col", "depth", "row", "time")``, and then specify spatial_dims as ``("row", "col")``
+ thanks to xarray magic =D.
+
+ spatial_dims : tuple[str, str] | tuple[str, str, str]
+ The 2 or 3 spatial dimensions **in order**: ``(rows, cols)`` or ``(z, rows, cols)``.
+ This also determines whether an ``ImageGraphic`` or ``ImageVolumeGraphic`` is used for rendering.
+ The ordering determines how the Image/Volume is rendered. For example, if
+ you specify ``spatial_dims = ("rows", "cols")`` and then change it to ``("cols", "rows")``, it will display
+ the transpose.
+
+ rgb_dim : str, optional
+ Name of an RGB(A) dimension, if present.
+
+ compute_histogram: bool, default True
+ Compute a histogram of the data, disable if random-access of data is not blazing-fast (ex: data that uses
+ video codecs), or if histograms are not useful for this data.
+
+ slider_dim_transforms : dict, optional
+ See :class:`NDProcessor`.
+
+ window_funcs : dict, optional
+ See :class:`NDProcessor`.
+
+ window_order : tuple, optional
+ See :class:`NDProcessor`.
+
+ spatial_func : callable, optional
+ See :class:`NDProcessor`.
+
+ See Also
+ --------
+ NDProcessor : Base class with full parameter documentation.
+ NDImage : The ``NDGraphic`` that wraps this processor.
+ """
+
+ # set as False until data, window funcs stuff and spatial func is all set
+ self._compute_histogram = False
+
+ # make sure rgb dim is size 3 or 4
+ if rgb_dim is not None:
+ dim_index = dims.index(rgb_dim)
+ if data.shape[dim_index] not in (3, 4):
+ raise IndexError(
+ f"The size of the RGB(A) dim must be 3 | 4. You have specified an array of shape: {data.shape}, "
+ f"with dims: {dims}, and specified the ``rgb_dim`` name as: {rgb_dim} which has size "
+ f"{data.shape[dim_index]} != 3 | 4"
+ )
+
+ super().__init__(
+ data=data,
+ dims=dims,
+ spatial_dims=spatial_dims,
+ slider_dim_transforms=slider_dim_transforms,
+ window_funcs=window_funcs,
+ window_order=window_order,
+ spatial_func=spatial_func,
+ )
+
+ self.rgb_dim = rgb_dim
+ self._compute_histogram = compute_histogram
+ self._recompute_histogram()
+
+ @property
+ def data(self) -> xr.DataArray | None:
+ """
+ get or set managed data. If setting with new data, the new data is interpreted
+ to have the same dims (i.e. same dim names and ordering of dims).
+ """
+ return self._data
+
+ @data.setter
+ def data(self, data: ArrayProtocol):
+ self._data = self._validate_data(data)
+ self._recompute_histogram()
+
+ def _validate_data(self, data: ArrayProtocol):
+ if not isinstance(data, ArrayProtocol):
+ # check that it's compatible with array and generally array-like
+ raise TypeError(
+ f"`data` arrays must have all of the following attributes to be sufficiently array-like:\n"
+ f"{ARRAY_LIKE_ATTRS}, or they must be `None`"
+ )
+
+ if data.ndim < 2:
+ # ndim < 2 makes no sense for image data
+ raise IndexError(
+ f"Image data must have a minimum of 2 dimensions, you have passed an array of shape: {data.shape}"
+ )
+
+ return xr.DataArray(data, dims=self.dims)
+
+ @property
+ def rgb_dim(self) -> str | None:
+ """
+ get or set the RGB(A) dim name, ``None`` if no RGB(A) dim exists
+ """
+ return self._rgb
+
+ @rgb_dim.setter
+ def rgb_dim(self, rgb: str | None):
+ if rgb is not None:
+ if rgb not in self.dims:
+ raise KeyError
+
+ self._rgb = rgb
+
+ @property
+ def compute_histogram(self) -> bool:
+ """get or set whether or not to compute the histogram"""
+ return self._compute_histogram
+
+ @compute_histogram.setter
+ def compute_histogram(self, compute: bool):
+ if compute:
+ if not self._compute_histogram:
+ # compute a histogram
+ self._recompute_histogram()
+ self._compute_histogram = True
+ else:
+ self._compute_histogram = False
+ self._histogram = None
+
+ @property
+ def histogram(self) -> tuple[np.ndarray, np.ndarray] | None:
+ """
+ an estimate of the histogram of the data, (histogram_values, bin_edges).
+
+ returns `None` if `compute_histogram` is `False`
+ """
+ return self._histogram
+
+ def get(self, indices: dict[str, Any]) -> ArrayLike | None:
+ """
+ Get the data at the given index, process data through the window functions.
+
+ Note that we do not use __getitem__ here since the index is a tuple specifying a single integer
+ index for each dimension. Slices are not allowed, therefore __getitem__ is not suitable here.
+
+ Parameters
+ ----------
+ indices: tuple[int, ...]
+ Get the processed data at this index. Must provide a value for each dimension.
+ Example: get((100, 5))
+
+ """
+ if len(self.slider_dims) > 0:
+ # there are dims in addition to the spatial dims
+ window_output = self._apply_window_functions(indices).squeeze()
+ else:
+ # no slider dims, use all the data
+ window_output = self.data
+
+ if window_output.ndim != len(self.spatial_dims):
+ raise ValueError
+
+ # apply spatial_func
+ if self.spatial_func is not None:
+ spatial_out = self._spatial_func(window_output)
+ if spatial_out.ndim != len(self.spatial_dims):
+ raise ValueError
+
+ return spatial_out.transpose(*self.spatial_dims).values
+
+ return window_output.transpose(*self.spatial_dims).values
+
+ def _recompute_histogram(self):
+ """
+
+ Returns
+ -------
+ (histogram_values, bin_edges)
+
+ """
+ if not self._compute_histogram or self.data is None:
+ self._histogram = None
+ return
+
+ if self.spatial_func is not None:
+ # don't subsample spatial dims if a spatial function is used
+ # spatial functions often operate on the spatial dims, ex: a gaussian kernel
+ # so their results require the full spatial resolution, the histogram of a
+ # spatially subsampled image will be very different
+ ignore_dims = [self.dims.index(dim) for dim in self.spatial_dims]
+ else:
+ ignore_dims = None
+
+ # TODO: account for window funcs
+
+ sub = subsample_array(self.data, ignore_dims=ignore_dims)
+
+ if isinstance(sub, xr.DataArray):
+ # can't do the isnan and isinf boolean indexing below on xarray
+ sub = sub.values
+
+ sub_real = sub[~(np.isnan(sub) | np.isinf(sub))]
+
+ self._histogram = np.histogram(sub_real, bins=100)
+
+
+class NDImage(NDGraphic):
+ def __init__(
+ self,
+ ref_index: ReferenceIndex,
+ subplot: Subplot,
+ data: ArrayProtocol | None,
+ dims: Sequence[str],
+ spatial_dims: (
+ tuple[str, str] | tuple[str, str, str]
+ ), # must be in order! [rows, cols] | [z, rows, cols]
+ rgb_dim: str | None = None,
+ window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable = None,
+ window_order: tuple[int, ...] = None,
+ spatial_func: Callable[[ArrayLike], ArrayLike] = None,
+ compute_histogram: bool = True,
+ slider_dim_transforms=None,
+ name: str = None,
+ ):
+ """
+ ``NDGraphic`` subclass for n-dimensional image rendering.
+
+ Wraps an :class:`NDImageProcessor` and manages either an ``ImageGraphic`` or``ImageVolumeGraphic``.
+ swaps automatically when :attr:`spatial_dims` is reassigned at runtime. Also
+ owns a ``HistogramLUTTool`` for interactive vmin, vmax adjustment.
+
+ Every dimension that is *not* listed in ``spatial_dims`` becomes a slider
+ dimension. Each slider dim must have a ``ReferenceRange`` defined in the
+ ``ReferenceIndex`` of the parent ``NDWidget``. The widget uses this to direct
+ a change in the ``ReferenceIndex`` and update the graphics.
+
+ Parameters
+ ----------
+ ref_index : ReferenceIndex
+ The shared reference index that delivers slider updates to this graphic.
+
+ subplot : Subplot
+ parent subplot the NDGraphic is in
+
+ data : array-like or None
+ n-dimension image data array
+
+ dims : sequence of hashable
+ Name for every dimension of ``data``, in order. Non-spatial dims must
+ match keys in ``ref_index``.
+
+ ex: ``("time", "depth", "row", "col")`` — ``"time"`` and ``"depth"`` must
+ be present in ``ref_index``.
+
+ spatial_dims : tuple[str, str] | tuple[str, str, str]
+ Spatial dimensions **in order**: ``(rows, cols)`` for 2-D images or
+ ``(z, rows, cols)`` for volumes. Controls whether an ``ImageGraphic`` or
+ ``ImageVolumeGraphic`` is used.
+
+ rgb_dim : str, optional
+ Name of the RGB or channel dimension, if present.
+
+ window_funcs : dict, optional
+ See :class:`NDProcessor`.
+
+ window_order : tuple, optional
+ See :class:`NDProcessor`.
+
+ spatial_func : callable, optional
+ See :class:`NDProcessor`.
+
+ compute_histogram : bool, default ``True``
+ Whether to initialize the ``HistogramLUTTool``.
+
+ slider_dim_transforms : dict, optional
+ See :class:`NDProcessor`.
+
+ name : str, optional
+ Name for the underlying graphic.
+
+ See Also
+ --------
+ NDImageProcessor : The processor that backs this graphic.
+
+ """
+
+ if not (set(dims) - set(spatial_dims)).issubset(ref_index.dims):
+ raise IndexError(
+ f"all specified `dims` must either be a spatial dim or a slider dim "
+ f"specified in the NDWidget ref_ranges, provided dims: {dims}, "
+ f"spatial_dims: {spatial_dims}. Specified NDWidget ref_ranges: {ref_index.dims}"
+ )
+
+ super().__init__(subplot, name)
+
+ self._ref_index = ref_index
+
+ self._processor = NDImageProcessor(
+ data,
+ dims=dims,
+ spatial_dims=spatial_dims,
+ rgb_dim=rgb_dim,
+ window_funcs=window_funcs,
+ window_order=window_order,
+ spatial_func=spatial_func,
+ compute_histogram=compute_histogram,
+ slider_dim_transforms=slider_dim_transforms,
+ )
+
+ self._graphic: ImageGraphic | None = None
+ self._histogram_widget: HistogramLUTTool | None = None
+
+ # create a graphic
+ self._create_graphic()
+
+ @property
+ def processor(self) -> NDImageProcessor:
+ """NDProcessor that manages the data and produces data slices to display"""
+ return self._processor
+
+ @property
+ def graphic(
+ self,
+ ) -> ImageGraphic | ImageVolumeGraphic:
+ """Underlying Graphic object used to display the current data slice"""
+ return self._graphic
+
+ def _create_graphic(self):
+ # Creates an ``ImageGraphic`` or ``ImageVolumeGraphic`` based on the number of spatial dims,
+ # adds it to the subplot, and resets the camera and histogram.
+
+ if self.processor.data is None:
+ # no graphic if data is None, useful for initializing in null states when we want to set data later
+ return
+
+ # determine if we need a 2d image or 3d volume
+ # remove RGB spatial dim, ex: if we have an RGBA image of shape [512, 512, 4] we want to interpet this as
+ # 2D for images
+ # [30, 512, 512, 4] with an rgb dim is an RGBA volume which is also supported
+ match len(self.processor.spatial_dims) - int(bool(self.processor.rgb_dim)):
+ case 2:
+ cls = ImageGraphic
+ case 3:
+ cls = ImageVolumeGraphic
+
+ # get the data slice for this index
+ # this will only have the dims specified by ``spatial_dims``
+ data_slice = self.processor.get(self.indices)
+
+ # create the new graphic
+ new_graphic = cls(data_slice)
+
+ old_graphic = self._graphic
+ # check if we are replacing a graphic
+ # ex: swapping from 2D <-> 3D representation after ``spatial_dims`` was changed
+ if old_graphic is not None:
+ # carry over some attributes from old graphic
+ attrs = dict.fromkeys(["cmap", "interpolation", "cmap_interpolation"])
+ for k in attrs:
+ attrs[k] = getattr(old_graphic, k)
+
+ # delete the old graphic
+ self._subplot.delete_graphic(old_graphic)
+
+ # set any attributes that we're carrying over like cmap
+ for attr, val in attrs.items():
+ setattr(new_graphic, attr, val)
+
+ self._graphic = new_graphic
+
+ self._subplot.add_graphic(self._graphic)
+
+ self._reset_camera()
+ self._reset_histogram()
+
+ def _reset_histogram(self):
+ # reset histogram
+ if self.graphic is None:
+ return
+
+ if not self.processor.compute_histogram:
+ # hide right dock if histogram not desired
+ self._subplot.docks["right"].size = 0
+ return
+
+ if self.processor.histogram:
+ if self._histogram_widget:
+ # histogram widget exists, update it
+ self._histogram_widget.histogram = self.processor.histogram
+ self._histogram_widget.images = self.graphic
+ if self._subplot.docks["right"].size < 1:
+ self._subplot.docks["right"].size = 80
+ else:
+ # make hist tool
+ self._histogram_widget = HistogramLUTTool(
+ histogram=self.processor.histogram,
+ images=self.graphic,
+ name=f"hist-{hex(id(self.graphic))}",
+ )
+ self._subplot.docks["right"].add_graphic(self._histogram_widget)
+ self._subplot.docks["right"].size = 80
+
+ self.graphic.reset_vmin_vmax()
+
+ def _reset_camera(self):
+ # set camera to a nice position based on whether it's a 2D ImageGraphic or 3D ImageVolumeGraphic
+ if isinstance(self._graphic, ImageGraphic):
+ # set camera orthogonal to the xy plane, flip y axis
+ self._subplot.camera.set_state(
+ {
+ "position": [0, 0, -1],
+ "rotation": [0, 0, 0, 1],
+ "scale": [1, -1, 1],
+ "reference_up": [0, 1, 0],
+ "fov": 0, # orthographic projection
+ "depth_range": None,
+ }
+ )
+
+ self._subplot.controller = "panzoom"
+ self._subplot.axes.intersection = None
+ self._subplot.auto_scale()
+
+ else:
+ # It's not an ImageGraphic, set perspective projection
+ self._subplot.camera.fov = 50
+ self._subplot.controller = "orbit"
+
+ # set all 3D dimension camera scales to positive since positive scales
+ # are typically used for looking at volumes
+ for dim in ["x", "y", "z"]:
+ if getattr(self._subplot.camera.local, f"scale_{dim}") < 0:
+ setattr(self._subplot.camera.local, f"scale_{dim}", 1)
+
+ self._subplot.auto_scale()
+
+ @property
+ def spatial_dims(self) -> tuple[str, str] | tuple[str, str, str]:
+ """get or set the spatial dims, see docstring for details"""
+ return self.processor.spatial_dims
+
+ @spatial_dims.setter
+ def spatial_dims(self, dims: tuple[str, str] | tuple[str, str, str]):
+ self.processor.spatial_dims = dims
+
+ # shape has probably changed, recreate graphic
+ self._create_graphic()
+
+ @property
+ def indices(self) -> dict[Hashable, Any]:
+ """get or set the indices, managed by the ReferenceIndex, users usually don't want to set this manually"""
+ return {d: self._ref_index[d] for d in self.processor.slider_dims}
+
+ @indices.setter
+ def indices(self, indices):
+ data_slice = self.processor.get(indices)
+
+ self.graphic.data = data_slice
+
+ @property
+ def compute_histogram(self) -> bool:
+ """whether or not to compute the histogram and display the HistogramLUTTool"""
+ return self.processor.compute_histogram
+
+ @compute_histogram.setter
+ def compute_histogram(self, v: bool):
+ self.processor.compute_histogram = v
+ self._reset_histogram()
+
+ @property
+ def histogram_widget(self) -> HistogramLUTTool:
+ """The histogram lut tool associated with this NDGraphic"""
+ return self._histogram_widget
+
+ @property
+ def spatial_func(self) -> Callable[[xr.DataArray], xr.DataArray] | None:
+ """get or set the spatial_func, see docstring for details"""
+ return self.processor.spatial_func
+
+ @spatial_func.setter
+ def spatial_func(
+ self, func: Callable[[xr.DataArray], xr.DataArray]
+ ) -> Callable | None:
+ self.processor.spatial_func = func
+ self.processor._recompute_histogram()
+ self._reset_histogram()
+
+ def _tooltip_handler(self, graphic, pick_info):
+ # TODO: need to do this better
+ # get graphic within the collection
+ n_index = np.argwhere(self.graphic.graphics == graphic).item()
+ p_index = pick_info["vertex_index"]
+ return self.processor.tooltip_format(n_index, p_index)
diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/__init__.py b/fastplotlib/widgets/nd_widget/_nd_positions/__init__.py
new file mode 100644
index 000000000..60703f8c2
--- /dev/null
+++ b/fastplotlib/widgets/nd_widget/_nd_positions/__init__.py
@@ -0,0 +1,23 @@
+import importlib
+
+from ._nd_positions import NDPositions, NDPositionsProcessor
+
+class Extras:
+ pass
+
+ndp_extras = Extras()
+
+
+for optional in ["pandas", "zarr"]:
+ try:
+ importlib.import_module(optional)
+ except ImportError:
+ pass
+ else:
+ module = importlib.import_module(f"._{optional}", "fastplotlib.widgets.nd_widget._nd_positions")
+ cls = getattr(module, f"NDPP_{optional.capitalize()}")
+ setattr(
+ ndp_extras,
+ f"NDPP_{optional.capitalize()}",
+ cls
+ )
diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py
new file mode 100644
index 000000000..2d3ff2b9a
--- /dev/null
+++ b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py
@@ -0,0 +1,1144 @@
+from collections.abc import Callable, Hashable, Sequence
+from functools import partial
+from typing import Literal, Any, Type
+from warnings import warn
+
+import numpy as np
+from numpy.lib.stride_tricks import sliding_window_view
+from numpy.typing import ArrayLike
+import xarray as xr
+
+from ....layouts import Subplot
+from ....graphics import (
+ Graphic,
+ ImageGraphic,
+ LineGraphic,
+ LineStack,
+ LineCollection,
+ ScatterGraphic,
+ ScatterCollection,
+ ScatterStack,
+)
+from ....graphics.features.utils import parse_colors
+from ....graphics.utils import pause_events
+from ....graphics.selectors import LinearSelector
+from .._base import (
+ NDProcessor,
+ NDGraphic,
+ WindowFuncCallable,
+ block_reentrance,
+ block_indices_ctx,
+)
+from .._index import ReferenceIndex
+
+# types for the other features
+FeatureCallable = Callable[[np.ndarray, slice], np.ndarray]
+ColorsType = np.ndarray | FeatureCallable | None
+MarkersType = Sequence[str] | np.ndarray | FeatureCallable | None
+SizesType = Sequence[float] | np.ndarray | FeatureCallable | None
+
+
+def default_cmap_transform_each(p: int, data_slice: np.ndarray, s: slice):
+ # create a cmap transform based on the `p` dim size
+ n_displayed = data_slice.shape[1]
+
+ # linspace that's just normalized 0 - 1 within `p` dim size
+ return np.linspace(
+ start=s.start / p,
+ stop=s.stop / p,
+ num=n_displayed,
+ endpoint=False, # since we use a slice object for the displayed data, the last point isn't included
+ )
+
+
+class NDPositionsProcessor(NDProcessor):
+ _other_features = ["colors", "markers", "cmap_transform_each", "sizes"]
+
+ def __init__(
+ self,
+ data: Any,
+ dims: Sequence[Hashable],
+ # TODO: allow stack_dim to be None and auto-add new dim of size 1 in get logic
+ spatial_dims: tuple[
+ Hashable | None, Hashable, Hashable
+ ], # [stack_dim, n_datapoints, spatial_dim], IN ORDER!!
+ slider_dim_transforms: dict[str, Callable[[Any], int] | ArrayLike] = None,
+ display_window: int | float | None = 100, # window for n_datapoints dim only
+ max_display_datapoints: int = 1_000,
+ datapoints_window_func: tuple[Callable, str, int | float] | None = None,
+ colors: ColorsType = None,
+ markers: MarkersType = None,
+ cmap_transform_each: np.ndarray = None,
+ sizes: SizesType = None,
+ **kwargs,
+ ):
+ """
+ ``NDProcessor`` subclass for n-dimensional positional and timeseries data.
+
+
+ The *datapoints* dimension is
+ simultaneously a slider dim and a spatial dim and is handled by a dedicated
+ :attr:`datapoints_window_func` rather than the general ``window_funcs``
+ mechanism.
+
+
+ Parameters
+ ----------
+ data
+ dims
+ spatial_dims
+ slider_dim_transforms
+ display_window
+ max_display_datapoints: int, default 1_000
+ this is approximate since floor division is used to determine the step size of the current display window slice
+ datapoints_window_func:
+ Important note: if used, display_window is approximate and not exact due to padding from the window size
+ kwargs
+ """
+ self._display_window = display_window
+ self._max_display_datapoints = max_display_datapoints
+
+ super().__init__(
+ data=data,
+ dims=dims,
+ spatial_dims=spatial_dims,
+ slider_dim_transforms=slider_dim_transforms,
+ **kwargs,
+ )
+
+ self._datapoints_window_func = datapoints_window_func
+
+ self.colors = colors
+ self.markers = markers
+ self.cmap_transform_each = cmap_transform_each
+ self.sizes = sizes
+
+ def _check_shape_feature(
+ self, prop: str, check_shape: tuple[int, int]
+ ) -> tuple[int, int]:
+ # this function exists because it's used repeatedly for colors, markers, etc.
+ # shape for [l, p] dims must match, or l must be 1
+ shape = tuple([self.shape[dim] for dim in self.spatial_dims[:2]])
+
+ if check_shape[1] != shape[1]:
+ raise IndexError(
+ f"shape of first two dims of {prop} must must be [l, p] or [1, p].\n"
+ f"required `p` dim shape is: {shape[1]}, {check_shape[1]} was provided"
+ )
+
+ if check_shape[0] != 1 and check_shape[0] != shape[0]:
+ raise IndexError(
+ f"shape of first two dims of {prop} must must be [l, p] or [1, p]\n"
+ f"required `l` dim shape is {shape[0]} | 1, {check_shape[0]} was provided"
+ )
+
+ return shape
+
+ @property
+ def colors(self) -> ColorsType:
+ """
+ A callable that dynamically creates colors for the current display window, or array of colors per-datapoint.
+
+ Array must be of shape [l, p, 4] for unique colors per line/scatter, or [1, p, 4] for identical colors per
+ line/scatter.
+
+ Callable must return an array of shape [l, pw, 4] or [1, pw, 4], where pw is the number of currently displayed
+ datapoints given the current display window. The callable receives the current data slice array, as well as the
+ slice object that corresponds to the current display window.
+ """
+ return self._colors
+
+ @colors.setter
+ def colors(self, new):
+ if callable(new):
+ # custom callable that creates the colors
+ self._colors = new
+ return
+
+ if new is None:
+ self._colors = None
+ return
+
+ # as array so we can check shape
+ new = np.asarray(new)
+ if new.ndim == 2:
+ # only [p, 4] provided, broadcast to [1, p, 4]
+ new = new[None]
+
+ shape = self._check_shape_feature("colors", new.shape[:2])
+
+ if new.shape[0] == 1:
+ # same colors across all graphical elements
+ self._colors = parse_colors(new[0], n_colors=shape[1])[None]
+
+ else:
+ # colors specified for each individual line/scatter
+ new_ = np.zeros(shape=(*self.data.shape[:2], 4), dtype=np.float32)
+ for i in range(shape[0]):
+ new_[i] = parse_colors(new[i], n_colors=shape[1])
+
+ self._colors = new_
+
+ @property
+ def markers(self) -> MarkersType:
+ """
+ A callable that dynamically creates markers for the current display window, or array of markers per-datapoint.
+
+ Array must be of shape [l, p] for unique markers per line/scatter, or [p,] or [1, p] for identical markers per
+ line/scatter.
+
+ Callable must return an array of shape [l, pw], [1, pw], or [pw,] where pw is the number of currently displayed
+ datapoints given the current display window. The callable receives the current data slice array, as well as the
+ slice object that corresponds to the current display window.
+ """
+ return self._markers
+
+ @markers.setter
+ def markers(self, new: MarkersType):
+ if callable(new):
+ # custom callable that creates the markers dynamically
+ self._markers = new
+ return
+
+ if new is None:
+ self._markers = None
+ return
+
+ # as array so we can check shape
+ new = np.asarray(new)
+
+ # if 1-dim, assume it's specifying markers over `p` dim, so set `l` dim to 1
+ if new.ndim == 1:
+ new = new[None]
+
+ self._check_shape_feature("markers", new.shape[:2])
+
+ self._markers = np.asarray(new)
+
+ @property
+ def cmap_transform_each(self) -> np.ndarray | FeatureCallable | None:
+ return self._cmap_transform_each
+
+ @cmap_transform_each.setter
+ def cmap_transform_each(self, new: np.ndarray | FeatureCallable | None):
+ """
+ A callable that dynamically creates cmap transforms for the current display window, or array
+ of transforms per-datapoint.
+
+ Array must be of shape [l, p] for unique transforms per line/scatter, or [p,] or [1, p] for identical markers
+ per line/scatter.
+
+ Callable must return an array of shape [l, pw], [1, pw], or [pw,] where pw is the number of currently displayed
+ datapoints given the current display window. The callable receives the current data slice array, as well as the
+ slice object that corresponds to the current display window.
+ """
+ if callable(new):
+ self._cmap_transform_each = new
+ return
+
+ if new is None:
+ self._cmap_transform_each = None
+ return
+
+ new = np.asarray(new)
+
+ # if 1-dim, assume it's specifying sizes over `p` dim, set `l` dim to 1
+ if new.ndim == 1:
+ new = new[None]
+
+ self._check_shape_feature("cmap_transform_each", new.shape)
+
+ self._cmap_transform_each = new
+
+ @property
+ def sizes(self) -> SizesType:
+ return self._sizes
+
+ @sizes.setter
+ def sizes(self, new: SizesType):
+ """
+ A callable that dynamically creates sizes for the current display window, or array of sizes per-datapoint.
+
+ Array must be of shape [l, p] for unique sizes per line/scatter, or [p,] or [1, p] for identical markers per
+ line/scatter.
+
+ Callable must return an array of shape [l, pw], [1, pw], or [pw,] where pw is the number of currently displayed
+ datapoints given the current display window. The callable receives the current data slice array, as well as the
+ slice object that corresponds to the current display window.
+ """
+ if callable(new):
+ # custom callable
+ self._sizes = new
+ return
+
+ if new is None:
+ self._sizes = None
+ return
+
+ new = np.array(new)
+ # if 1-dim, assume it's specifying sizes over `p` dim, set `l` dim to 1
+ if new.ndim == 1:
+ new = new[None]
+
+ self._check_shape_feature("sizes", new.shape)
+
+ self._sizes = new
+
+ @property
+ def spatial_dims(self) -> tuple[str, str, str]:
+ return self._spatial_dims
+
+ @spatial_dims.setter
+ def spatial_dims(self, sdims: tuple[str, str, str]):
+ if len(sdims) != 3:
+ raise IndexError
+
+ if not all([d in self.dims for d in sdims]):
+ raise KeyError
+
+ self._spatial_dims = tuple(sdims)
+
+ @property
+ def slider_dims(self) -> set[Hashable]:
+ # append `p` dim to slider dims
+ return tuple([*super().slider_dims, self.spatial_dims[1]])
+
+ @property
+ def display_window(self) -> int | float | None:
+ """display window in the reference units for the n_datapoints dim"""
+ return self._display_window
+
+ @display_window.setter
+ def display_window(self, dw: int | float | None):
+ if dw is None:
+ self._display_window = None
+
+ elif not isinstance(dw, (int, float)):
+ raise TypeError
+
+ self._display_window = dw
+
+ @property
+ def max_display_datapoints(self) -> int:
+ return self._max_display_datapoints
+
+ @max_display_datapoints.setter
+ def max_display_datapoints(self, n: int):
+ if not isinstance(n, (int, np.integer)):
+ raise TypeError
+ if n < 2:
+ raise ValueError
+
+ self._max_display_datapoints = n
+
+ # TODO: validation for datapoints_window_func and size
+ @property
+ def datapoints_window_func(self) -> tuple[Callable, str, int | float] | None:
+ """
+ Callable, str indicating which dims to apply window function along, window_size in reference space:
+ 'all', 'x', 'y', 'z', 'xyz', 'xy', 'xz', 'yz'
+ '"""
+ return self._datapoints_window_func
+
+ @datapoints_window_func.setter
+ def datapoints_window_func(self, funcs: tuple[Callable, str, int | float]):
+ if len(funcs) != 3:
+ raise TypeError
+
+ self._datapoints_window_func = tuple(funcs)
+
+ def _get_dw_slice(self, indices: dict[str, Any]) -> slice:
+ # given indices, return slice required to obtain display window
+
+ # n_datapoints dim name
+ # display_window acts on this dim
+ p_dim = self.spatial_dims[1]
+
+ if self.display_window is None:
+ # just return everything
+ return slice(0, self.shape[p_dim])
+
+ if self.display_window == 0:
+ # just map p dimension at this index and return
+ index = self._ref_index_to_array_index(p_dim, indices[p_dim])
+ return slice(index, index + 1)
+
+ # half window size, in reference units
+ hw = self.display_window / 2
+
+ if self.datapoints_window_func is not None:
+ # add half datapoints_window_func size here, assumes the reference space is somewhat continuous
+ # and the display_window and datapoints window size map to their actual size values
+ hw += self.datapoints_window_func[2] / 2
+
+ # display window is in reference units, apply display window and then map to array indices
+ # start in reference units
+ start_ref = indices[p_dim] - hw
+ # stop in reference units
+ stop_ref = indices[p_dim] + hw
+
+ # map to array indices
+ start = self._ref_index_to_array_index(p_dim, start_ref)
+ stop = self._ref_index_to_array_index(p_dim, stop_ref)
+
+ if start >= stop:
+ stop = start + 1
+
+ w = stop - start
+
+ # get step size
+ step = max(1, w // self.max_display_datapoints)
+
+ return slice(start, stop, step)
+
+ def _apply_dw_window_func(
+ self, array: xr.DataArray | np.ndarray
+ ) -> xr.DataArray | np.ndarray:
+ """
+ Takes array where display window has already been applied and applies window functions on the `p` dim.
+
+ Parameters
+ ----------
+ array: np.ndarray
+ array of shape: [l, display_window, 2 | 3]
+
+ Returns
+ -------
+ np.ndarray
+ array with window functions applied along `p` dim
+ """
+ if self.display_window == 0:
+ # can't apply window func when there is only 1 datapoint
+ return array
+
+ p_dim = self.spatial_dims[1]
+
+ # display window in array index space
+ if self.display_window is not None:
+ dw = self.slider_dim_transforms[p_dim](self.display_window)
+
+ # step size based on max number of datapoints to render
+ step = max(1, dw // self.max_display_datapoints)
+
+ # apply window function on the `p` n_datapoints dim
+ if (
+ self.datapoints_window_func is not None
+ # if there are too many points to efficiently compute the window func, skip
+ # applying a window func also requires making a copy so that's a further performance hit
+ and (dw < self.max_display_datapoints * 2)
+ ):
+ # get windows
+
+ # graphic_data will be of shape: [n, p, 2 | 3]
+ # where:
+ # n - number of lines, scatters, heatmap rows
+ # p - number of datapoints/samples
+
+ # ws is in ref units
+ wf, apply_dims, ws = self.datapoints_window_func
+
+ # map ws in ref units to array index
+ # min window size is 3
+ ws = max(self._ref_index_to_array_index(p_dim, ws), 3)
+
+ if ws % 2 == 0:
+ # odd size windows are easier to handle
+ ws += 1
+
+ hw = ws // 2
+ start, stop = hw, array.shape[1] - hw
+
+ # apply user's window func
+ # result will be of shape [n, p, 2 | 3]
+ if apply_dims == "all":
+ # windows will be of shape [n, p, 1 | 2 | 3, ws]
+ windows = sliding_window_view(array, ws, axis=-2)
+ return wf(windows, axis=-1)[:, ::step]
+
+ # map user dims str to tuple of numerical dims
+ dims = tuple(map({"x": 0, "y": 1, "z": 2}.get, apply_dims))
+
+ # windows will be of shape [n, (p - ws + 1), 1 | 2 | 3, ws]
+ windows = sliding_window_view(array[..., dims], ws, axis=-2).squeeze()
+
+ # make a copy because we need to modify it
+ array = array[:, start:stop].copy()
+
+ # this reshape is required to reshape wf outputs of shape [n, p] -> [n, p, 1] only when necessary
+ array[..., dims] = wf(windows, axis=-1).reshape(
+ *array.shape[:-1], len(dims)
+ )
+
+ return array[:, ::step]
+
+ step = max(1, array.shape[1] // self.max_display_datapoints)
+
+ return array[:, ::step]
+
+ def _apply_spatial_func(
+ self, array: xr.DataArray | np.ndarray
+ ) -> xr.DataArray | np.ndarray:
+ if self.spatial_func is not None:
+ return self.spatial_func(array)
+
+ return array
+
+ def _finalize_(self, array: xr.DataArray | np.ndarray) -> xr.DataArray | np.ndarray:
+ return self._apply_spatial_func(self._apply_dw_window_func(array))
+
+ def _get_other_features(
+ self, data_slice: np.ndarray, dw_slice: slice
+ ) -> dict[str, np.ndarray]:
+ other = dict.fromkeys(self._other_features)
+ for attr in self._other_features:
+ val = getattr(self, attr)
+
+ if val is None:
+ continue
+
+ if callable(val):
+ # if it's a callable, give it the data and display window slice, it must return the appropriate
+ # type of array for that graphic feature
+ val_sliced = val(data_slice, dw_slice)
+
+ else:
+ # if no l dim, broadcast to [1, p]
+ if val.ndim == 1:
+ val = val[None]
+
+ # apply current display window slice
+ val_sliced = val[:, dw_slice]
+
+ # check if l dim size is 1
+ if val_sliced.shape[0] == 1:
+ # broadcast across all graphical elements
+ n_graphics = self.shape[self.spatial_dims[0]]
+ val_sliced = np.broadcast_to(
+ val_sliced, shape=(n_graphics, *val_sliced.shape[1:])
+ )
+
+ other[attr] = val_sliced
+
+ return other
+
+ def get(self, indices: dict[str, Any]) -> dict[str, np.ndarray]:
+ """
+ slices through all slider dims and outputs an array that can be used to set graphic data
+
+ Note that we do not use __getitem__ here since the index is a tuple specifying a single integer
+ index for each dimension. Slices are not allowed, therefore __getitem__ is not suitable here.
+ """
+
+ if len(self.slider_dims) > 1:
+ # there are slider dims in addition to the datapoints_dim
+ window_output = self._apply_window_functions(indices).squeeze()
+ else:
+ # no slider dims, use all the data
+ window_output = self.data
+
+ # verify window output only has the spatial dims
+ if not set(window_output.dims) == set(self.spatial_dims):
+ raise IndexError
+
+ # get slice obj for display window
+ dw_slice = self._get_dw_slice(indices)
+
+ # data that will be used for the graphical representation
+ # a copy is made, if there were no window functions then this is a view of the original data
+ p_dim = self.spatial_dims[1]
+
+ # slice the datapoints to be displayed in the graphic using the display window slice
+ # transpose to match spatial dims order, get numpy array, this is a view
+ graphic_data = window_output.isel({p_dim: dw_slice}).transpose(
+ *self.spatial_dims
+ )
+
+ data = self._finalize_(graphic_data).values
+ other = self._get_other_features(data, dw_slice)
+
+ return {
+ "data": data,
+ **other,
+ }
+
+
+class NDPositions(NDGraphic):
+ def __init__(
+ self,
+ ref_index: ReferenceIndex,
+ subplot: Subplot,
+ data: Any,
+ dims: Sequence[str],
+ spatial_dims: tuple[str, str, str],
+ *args,
+ graphic_type: Type[
+ LineGraphic
+ | LineCollection
+ | LineStack
+ | ScatterGraphic
+ | ScatterCollection
+ | ScatterStack
+ | ImageGraphic
+ ],
+ processor: type[NDPositionsProcessor] = NDPositionsProcessor,
+ display_window: int = 10,
+ window_funcs: tuple[WindowFuncCallable | None] | None = None,
+ slider_dim_transforms: tuple[Callable[[Any], int] | None] | None = None,
+ max_display_datapoints: int = 1_000,
+ linear_selector: bool = False,
+ x_range_mode: Literal["fixed", "auto"] | None = None,
+ colors: (
+ Sequence[str] | np.ndarray | Callable[[slice, np.ndarray], np.ndarray]
+ ) = None,
+ # TODO: cleanup how this cmap stuff works, require a cmap to be set per-graphic
+ # before allowing cmaps_transform, validate that stuff makes sense etc.
+ cmap: str = None, # across the line/scatter collection
+ cmap_each: Sequence[str] = None, # for each individual line/scatter
+ cmap_transform_each: np.ndarray = None, # for each individual line/scatter
+ markers: np.ndarray = None, # across the scatter collection, shape [l,]
+ markers_each: Sequence[str] = None, # for each individual scatter, shape [l, p]
+ sizes: np.ndarray = None, # across the scatter collection, shape [l,]
+ sizes_each: Sequence[float] = None, # for each individual scatter, shape [l, p]
+ thickness: np.ndarray = None, # for each line, shape [l,]
+ name: str = None,
+ timeseries: bool = False,
+ graphic_kwargs: dict = None,
+ processor_kwargs: dict = None,
+ ):
+ """
+ Wraps an :class:`NDPositionsProcessor` and supports four interchangeable
+ graphical representations: ``LineStack``, ``LineCollection``, ``ScatterStack``,
+ and ``ScatterCollection``, as well as a heatmap view. For timeseries use-cases
+ it also manages a linear selector and automatically adjusts the view according
+ to the current x-range of the displayed data.
+
+ Parameters
+ ----------
+ ref_index
+ subplot
+ data
+ dims
+ spatial_dims
+ args
+ graphic_type
+ processor
+ display_window
+ window_funcs
+ slider_dim_transforms
+ max_display_datapoints
+ linear_selector
+ x_range_mode
+ colors
+ cmap
+ cmap_each
+ cmap_transform_each
+ markers
+ markers_each
+ sizes
+ sizes_each
+ thickness
+ name
+ graphic_kwargs
+ processor_kwargs
+ """
+
+ super().__init__(subplot, name)
+
+ self._ref_index = ref_index
+
+ if processor_kwargs is None:
+ processor_kwargs = dict()
+
+ if graphic_kwargs is None:
+ self._graphic_kwargs = dict()
+ else:
+ self._graphic_kwargs = graphic_kwargs
+
+ self._processor = processor(
+ data,
+ dims,
+ spatial_dims,
+ *args,
+ display_window=display_window,
+ max_display_datapoints=max_display_datapoints,
+ window_funcs=window_funcs,
+ slider_dim_transforms=slider_dim_transforms,
+ colors=colors,
+ markers=markers_each,
+ cmap_transform_each=cmap_transform_each,
+ sizes=sizes_each,
+ **processor_kwargs,
+ )
+
+ self._cmap = cmap
+ self._sizes = sizes
+ self._markers = markers
+ self._thickness = thickness
+
+ self.cmap_each = cmap_each
+ self.cmap_transform_each = cmap_transform_each
+
+ self._graphic_type = graphic_type
+ self._create_graphic()
+
+ self._x_range_mode = None
+ self.x_range_mode = x_range_mode
+ self._last_x_range = np.array([0.0, 0.0], dtype=np.float32)
+
+ self._timeseries = timeseries
+ # TODO: I think this is messy af, NDTimeseriesSubclass???
+ if self._timeseries:
+ # makes some assumptions about positional data that apply only to timeseries representations
+ # probably don't want to maintain aspect
+ self._subplot.camera.maintain_aspect = False
+
+ # auto x range modes make no sense for non-timeseries data
+ self.x_range_mode = x_range_mode
+
+ if linear_selector:
+ self._linear_selector = LinearSelector(
+ 0, limits=(-np.inf, np.inf), edge_color="cyan"
+ )
+ self._linear_selector.add_event_handler(
+ self._linear_selector_handler, "selection"
+ )
+ self._subplot.add_graphic(self._linear_selector)
+ else:
+ self._linear_selector = None
+ else:
+ self._linear_selector = None
+
+ @property
+ def processor(self) -> NDPositionsProcessor:
+ return self._processor
+
+ @property
+ def graphic(
+ self,
+ ) -> (
+ LineGraphic
+ | LineCollection
+ | LineStack
+ | ScatterGraphic
+ | ScatterCollection
+ | ScatterStack
+ | ImageGraphic
+ | None
+ ):
+ """LineStack or ImageGraphic for heatmaps"""
+ return self._graphic
+
+ @property
+ def graphic_type(
+ self,
+ ) -> Type[
+ LineGraphic
+ | LineCollection
+ | LineStack
+ | ScatterGraphic
+ | ScatterCollection
+ | ScatterStack
+ | ImageGraphic
+ ]:
+ return self._graphic_type
+
+ @graphic_type.setter
+ def graphic_type(self, graphic_type):
+ if type(self.graphic) is graphic_type:
+ return
+
+ self._subplot.delete_graphic(self._graphic)
+ self._graphic_type = graphic_type
+ self._create_graphic()
+
+ @property
+ def spatial_dims(self) -> tuple[str, str, str]:
+ return self.processor.spatial_dims
+
+ @spatial_dims.setter
+ def spatial_dims(self, dims: tuple[str, str, str]):
+ self.processor.spatial_dims = dims
+ # force re-render
+ self.indices = self.indices
+
+ @property
+ def indices(self) -> dict[Hashable, Any]:
+ return {d: self._ref_index[d] for d in self.processor.slider_dims}
+
+ @indices.setter
+ @block_reentrance
+ def indices(self, indices):
+ if self.data is None:
+ return
+
+ new_features = self.processor.get(indices)
+ data_slice = new_features["data"]
+
+ # TODO: set other graphic features, colors, sizes, markers, etc.
+
+ if isinstance(self.graphic, (LineGraphic, ScatterGraphic)):
+ self.graphic.data[:, : data_slice.shape[-1]] = data_slice
+
+ elif isinstance(self.graphic, (LineCollection, ScatterCollection)):
+ for l, g in enumerate(self.graphic.graphics):
+ new_data = data_slice[l]
+ if g.data.value.shape[0] != new_data.shape[0]:
+ # will replace buffer internally
+ g.data = new_data
+ else:
+ # if data are only xy, set only xy
+ g.data[:, : new_data.shape[1]] = new_data
+
+ for feature in ["colors", "sizes", "markers"]:
+ value = new_features[feature]
+
+ match value:
+ case None:
+ pass
+ case _:
+ if feature == "colors":
+ g.color_mode = "vertex"
+
+ setattr(g, feature, value[l])
+
+ if self.cmap_each is not None:
+ match new_features["cmap_transform_each"]:
+ case None:
+ pass
+ case _:
+ setattr(
+ getattr(g, "cmap"), # ind_graphic.cmap
+ "transform",
+ new_features["cmap_transform_each"],
+ )
+
+ elif isinstance(self.graphic, ImageGraphic):
+ image_data, x0, x_scale = self._create_heatmap_data(data_slice)
+ self.graphic.data = image_data
+ self.graphic.offset = (x0, *self.graphic.offset[1:])
+ self.graphic.scale = (x_scale, *self.graphic.scale[1:])
+
+ # TODO: I think this is messy af, NDTimeseriesSubclass???
+ # x range of the data
+ xr = data_slice[0, 0, 0], data_slice[0, -1, 0]
+ if self.x_range_mode is not None:
+ self.graphic._plot_area.x_range = xr
+
+ # if the update_from_view is polling, this prevents it from being called by setting the new last xrange
+ # in theory, but this doesn't seem to fully work yet, not a big deal right now can check later
+ self._last_x_range[:] = self.graphic._plot_area.x_range
+
+ if self._linear_selector is not None:
+ with pause_events(self._linear_selector): # we don't want the linear selector change to update the indices
+ self._linear_selector.limits = xr
+ # linear selector acts on `p` dim
+ self._linear_selector.selection = indices[
+ self.processor.spatial_dims[1]
+ ]
+
+ def _linear_selector_handler(self, ev):
+ with block_indices_ctx(self):
+ # linear selector always acts on the `p` dim
+ self._ref_index[self.processor.spatial_dims[1]] = ev.info["value"]
+
+ def _tooltip_handler(self, graphic, pick_info):
+ if isinstance(self.graphic, (LineCollection, ScatterCollection)):
+ # get graphic within the collection
+ n_index = np.argwhere(self.graphic.graphics == graphic).item()
+ p_index = pick_info["vertex_index"]
+ return self.processor.tooltip_format(n_index, p_index)
+
+ def _create_graphic(self):
+ if self.data is None:
+ return
+
+ new_features = self.processor.get(self.indices)
+ data_slice = new_features["data"]
+
+ # store any cmap, sizes, thickness, etc. to assign to new graphic
+ graphic_attrs = dict()
+ for attr in ["cmap", "markers", "sizes", "thickness"]:
+ if attr in new_features.keys():
+ if new_features[attr] is not None:
+ # markers and sizes defined for each line via processor takes priority
+ continue
+
+ val = getattr(self, attr)
+ if val is not None:
+ graphic_attrs[attr] = val
+
+ if issubclass(self._graphic_type, ImageGraphic):
+ # `d` dim must only have xy data to be interpreted as a heatmap, xyz can't become a timeseries heatmap
+ if self.processor.shape[self.processor.spatial_dims[-1]] != 2:
+ raise ValueError
+
+ image_data, x0, x_scale = self._create_heatmap_data(data_slice)
+ self._graphic = self._graphic_type(
+ image_data, offset=(x0, 0, -1), scale=(x_scale, 1, 1)
+ )
+
+ else:
+ if issubclass(self._graphic_type, (LineStack, ScatterStack)):
+ kwargs = {"separation": 0.0, **self._graphic_kwargs}
+ else:
+ kwargs = self._graphic_kwargs
+ self._graphic = self._graphic_type(data_slice, **kwargs)
+
+ for attr in graphic_attrs.keys():
+ if hasattr(self._graphic, attr):
+ setattr(self._graphic, attr, graphic_attrs[attr])
+
+ if isinstance(self._graphic, (LineCollection, ScatterCollection)):
+ for l, g in enumerate(self.graphic.graphics):
+ for feature in ["colors", "sizes", "markers"]:
+ value = new_features[feature]
+
+ match value:
+ case None:
+ pass
+ case _:
+ if feature == "colors":
+ g.color_mode = "vertex"
+
+ setattr(g, feature, value[l])
+
+ if self.cmap_each is not None:
+ g.color_mode = "vertex"
+ g.cmap = self.cmap_each[l]
+ match new_features["cmap_transform_each"]:
+ case None:
+ pass
+ case _:
+ setattr(
+ getattr(g, "cmap"), # indv_graphic.cmap
+ "transform",
+ new_features["cmap_transform_each"],
+ )
+
+ if self.processor.tooltip:
+ if isinstance(self._graphic, (LineCollection, ScatterCollection)):
+ for g in self._graphic.graphics:
+ g.tooltip_format = partial(self._tooltip_handler, g)
+
+ self._subplot.add_graphic(self._graphic)
+
+ def _create_heatmap_data(self, data_slice) -> tuple[np.ndarray, float, float]:
+ """return [n_rows, n_cols] shape data from [n_timeseries, n_timepoints, xy] data"""
+ # assumes x vals in every row is the same, otherwise a heatmap representation makes no sense
+ # data slice is of shape [n_timeseries, n_timepoints, xy], where xy is x-y coordinates of each timeseries
+ x = data_slice[0, :, 0] # get x from just the first row
+
+ # check if we need to interpolate
+ norm = np.linalg.norm(np.diff(np.diff(x))) / x.size
+
+ if norm > 1e-6:
+ # x is not uniform upto float32 precision, must interpolate
+ x_uniform = np.linspace(x[0], x[-1], num=x.size)
+ y_interp = np.empty(shape=data_slice[..., 1].shape, dtype=np.float32)
+
+ # this for loop is actually slightly faster than numpy.apply_along_axis()
+ for i in range(data_slice.shape[0]):
+ y_interp[i] = np.interp(x_uniform, x, data_slice[i, :, 1])
+
+ else:
+ # x is sufficiently uniform
+ y_interp = data_slice[..., 1]
+
+ x0 = data_slice[0, 0, 0]
+
+ # assume all x values are the same across all lines
+ # otherwise a heatmap representation makes no sense anyways
+ x_stop = x[-1]
+ x_scale = (x_stop - x0) / data_slice.shape[1]
+
+ return y_interp, x0, x_scale
+
+ @property
+ def display_window(self) -> int | float | None:
+ """display window in the reference units for the n_datapoints dim"""
+ return self.processor.display_window
+
+ @display_window.setter
+ def display_window(self, dw: int | float | None):
+ self.processor.display_window = dw
+
+ # force re-render
+ self.indices = self.indices
+
+ @property
+ def datapoints_window_func(self) -> tuple[Callable, str, int | float] | None:
+ """
+ Callable, str indicating which dims to apply window function along, window_size in reference space:
+ 'all', 'x', 'y', 'z', 'xyz', 'xy', 'xz', 'yz'
+ '"""
+ return self.processor.datapoints_window_func
+
+ @datapoints_window_func.setter
+ def datapoints_window_func(self, funcs: tuple[Callable, str, int | float]):
+ self.processor.datapoints_window_func = funcs
+
+ @property
+ def x_range_mode(self) -> Literal["fixed", "auto"] | None:
+ """x-range using a fixed window from the display window, or by polling the camera (auto)"""
+ return self._x_range_mode
+
+ @x_range_mode.setter
+ def x_range_mode(self, mode: Literal[None, "fixed", "auto"]):
+ if self._x_range_mode == "auto":
+ # old mode was auto
+ self._subplot.remove_animation(self._update_from_view_range)
+
+ if mode == "auto":
+ self._subplot.add_animations(self._update_from_view_range)
+
+ self._x_range_mode = mode
+
+ def _update_from_view_range(self):
+ if self._graphic is None:
+ return
+
+ xr = self._subplot.x_range
+
+ # the floating point error near zero gets nasty here
+ if np.allclose(xr, self._last_x_range, atol=1e-14):
+ return
+
+ last_width = abs(self._last_x_range[1] - self._last_x_range[0])
+ self._last_x_range[:] = xr
+
+ new_width = abs(xr[1] - xr[0])
+ new_index = (xr[0] + xr[1]) / 2
+
+ if (new_index == self._ref_index[self.processor.spatial_dims[1]]) and (
+ last_width == new_width
+ ):
+ return
+
+ self.processor.display_window = new_width
+ # set the `p` dim on the global index vector
+ self._ref_index[self.processor.spatial_dims[1]] = new_index
+
+ @property
+ def cmap(self) -> str | None:
+ return self._cmap
+
+ @cmap.setter
+ def cmap(self, new: str | None):
+ if new is None:
+ # just set a default
+ if isinstance(self.graphic, (LineCollection, ScatterCollection)):
+ self.graphic.colors = "w"
+ else:
+ self.graphic.cmap = "plasma"
+
+ self._cmap = None
+ return
+
+ self._graphic.cmap = new
+ self._cmap = new
+ # force a re-render
+ self.indices = self.indices
+
+ @property
+ def cmap_each(self) -> np.ndarray[str] | None:
+ # per-line/scatter
+ return self._cmap_each
+
+ @cmap_each.setter
+ def cmap_each(self, new: Sequence[str] | None):
+ if new is None:
+ self._cmap_each = None
+ return
+
+ if isinstance(new, str):
+ new = [new]
+
+ new = np.asarray(new)
+
+ if new.ndim != 1:
+ raise ValueError
+
+ l_dim_size = self.processor.shape[self.processor.spatial_dims[0]]
+ # same cmap for all if size == 1, or specific cmap for each in `l` dim
+ if new.size != 1 and new.size != l_dim_size:
+ raise ValueError
+
+ self._cmap_each = np.broadcast_to(new, shape=(l_dim_size,))
+
+ @property
+ def cmap_transform_each(self) -> np.ndarray | None:
+ # PER line/scatter, only allowed after `cmaps` is set.
+ return self.processor.cmap_transform_each
+
+ @cmap_transform_each.setter
+ def cmap_transform_each(self, new: np.ndarray | FeatureCallable | None):
+ if new is None:
+ self.processor.cmap_transform_each = None
+
+ if self.cmap_each is None:
+ self.processor.cmap_transform_each = None
+ warn("must set `cmap_each` before `cmap_transform_each`")
+ return
+
+ if new is None and self.cmap_each is not None:
+ # default transform is just a transform based on the `p` dim size
+ new = partial(default_cmap_transform_each, self.shape[self.spatial_dims[1]])
+
+ self.processor.cmap_transform_each = new
+
+ @property
+ def markers(self) -> str | Sequence[str] | None:
+ return self._markers
+
+ @markers.setter
+ def markers(self, new: str | None):
+ if not isinstance(self.graphic, ScatterCollection):
+ self._markers = None
+ return
+
+ if new is None:
+ # just set a default
+ new = "circle"
+
+ self.graphic.markers = new
+ self._markers = new
+ # force a re-render
+ self.indices = self.indices
+
+ @property
+ def sizes(self) -> float | Sequence[float] | None:
+ return self._sizes
+
+ @sizes.setter
+ def sizes(self, new: float | Sequence[float] | None):
+ if not isinstance(self.graphic, ScatterCollection):
+ self._sizes = None
+ return
+
+ if new is None:
+ # just set a default
+ new = 5.0
+
+ self.graphic.sizes = new
+ self._sizes = new
+ # force a re-render
+ self.indices = self.indices
+
+ @property
+ def thickness(self) -> float | Sequence[float] | None:
+ return self._thickness
+
+ @thickness.setter
+ def thickness(self, new: float | Sequence[float] | None):
+ if not isinstance(self.graphic, LineCollection):
+ self._thickness = None
+ return
+
+ if new is None:
+ # just set a default
+ new = 2.0
+
+ self.graphic.thickness = new
+ self._thickness = new
+ # force a re-render
+ self.indices = self.indices
diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py b/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py
new file mode 100644
index 000000000..1b94e1cbc
--- /dev/null
+++ b/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py
@@ -0,0 +1,98 @@
+from typing import Any
+
+import numpy as np
+import pandas as pd
+
+from ._nd_positions import NDPositionsProcessor
+
+
+class NDPP_Pandas(NDPositionsProcessor):
+ def __init__(
+ self,
+ data: pd.DataFrame,
+ spatial_dims: tuple[str, str, str], # [l, p, d] dims in order
+ columns: list[tuple[str, str] | tuple[str, str, str]],
+ tooltip_columns: list[str] = None,
+ **kwargs,
+ ):
+ self._columns = columns
+
+ if tooltip_columns is not None:
+ if len(tooltip_columns) != len(self.columns):
+ raise ValueError
+ self._tooltip_columns = tooltip_columns
+ self._tooltip = True
+ else:
+ self._tooltip_columns = None
+ self._tooltip = False
+
+ self._dims = spatial_dims
+
+ super().__init__(
+ data=data,
+ dims=spatial_dims,
+ spatial_dims=spatial_dims,
+ **kwargs,
+ )
+
+ self._dw_slice = None
+
+ @property
+ def data(self) -> pd.DataFrame:
+ return self._data
+
+ def _validate_data(self, data: pd.DataFrame):
+ if not isinstance(data, pd.DataFrame):
+ raise TypeError
+
+ return data
+
+ @property
+ def columns(self) -> list[tuple[str, str] | tuple[str, str, str]]:
+ return self._columns
+
+ @property
+ def dims(self) -> tuple[str, str, str]:
+ return self._dims
+
+ @property
+ def shape(self) -> dict[str, int]:
+ # n_graphical_elements, n_timepoints, 2
+ return {self.dims[0]: len(self.columns), self.dims[1]: self.data.index.size, self.dims[2]: 2}
+
+ @property
+ def ndim(self) -> int:
+ return len(self.shape)
+
+ @property
+ def tooltip(self) -> bool:
+ return self._tooltip
+
+ def tooltip_format(self, n: int, p: int):
+ # datapoint index w.r.t. full data
+ p += self._dw_slice.start
+ return str(self.data[self._tooltip_columns[n]][p])
+
+ def get(self, indices: dict[str, Any]) -> dict[str, np.ndarray]:
+ # TODO: LOD by using a step size according to max_p
+ # TODO: Also what to do if display_window is None and data
+ # hasn't changed when indices keeps getting set, cache?
+
+ # assume no additional slider dims
+ self._dw_slice = self._get_dw_slice(indices)
+ gdata_shape = len(self.columns), self._dw_slice.stop - self._dw_slice.start, 3
+
+ graphic_data = np.zeros(shape=gdata_shape, dtype=np.float32)
+
+ for i, col in enumerate(self.columns):
+ graphic_data[i, :, :len(col)] = np.column_stack(
+ [self.data[c][self._dw_slice] for c in col]
+ )
+
+ data = self._finalize_(graphic_data)
+ other = self._get_other_features(data, self._dw_slice)
+
+ return {
+ "data": data,
+ **other,
+ }
diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/_zarr.py b/fastplotlib/widgets/nd_widget/_nd_positions/_zarr.py
new file mode 100644
index 000000000..fb3bb7015
--- /dev/null
+++ b/fastplotlib/widgets/nd_widget/_nd_positions/_zarr.py
@@ -0,0 +1,4 @@
+# placeholder
+
+class NDPP_Zarr:
+ pass
diff --git a/fastplotlib/widgets/nd_widget/_ndw_subplot.py b/fastplotlib/widgets/nd_widget/_ndw_subplot.py
new file mode 100644
index 000000000..6666b3fc1
--- /dev/null
+++ b/fastplotlib/widgets/nd_widget/_ndw_subplot.py
@@ -0,0 +1,122 @@
+from collections.abc import Callable
+from typing import Literal, Sequence, Hashable
+
+import numpy as np
+
+from ... import ScatterCollection, ScatterStack, LineCollection, LineStack, ImageGraphic
+from ...layouts import Subplot
+from ...utils import ArrayProtocol
+from . import NDImage, NDPositions
+from ._base import NDGraphic, WindowFuncCallable
+
+
+class NDWSubplot:
+ """
+ Entry point for adding ``NDGraphic`` objects to a subplot of an ``NDWidget``.
+
+ Accessed via ``ndw[row, col]`` or ``ndw["subplot_name"]``.
+ Each ``add_nd_<...>`` method constructs the appropriate ``NDGraphic``, registers it with the parent
+ ``ReferenceIndex``, appends it to this subplot and returns the ``NDGraphic`` instance to the user.
+
+ Note: ``NDWSubplot`` is not meant to be constructed directly, it only exists as part of an ``NDWidget``
+ """
+ def __init__(self, ndw, subplot: Subplot):
+ self.ndw = ndw
+ self._subplot = subplot
+
+ self._nd_graphics = list()
+
+ @property
+ def nd_graphics(self) -> tuple[NDGraphic]:
+ """all the NDGraphic instance in this subplot"""
+ return tuple(self._nd_graphics)
+
+ def __getitem__(self, key):
+ # get a specific NDGraphic by index or name
+ if isinstance(key, (int, np.integer)):
+ return self.nd_graphics[key]
+
+ for g in self.nd_graphics:
+ if g.name == key:
+ return g
+
+ else:
+ raise KeyError(f"NDGraphc with given key not found: {key}")
+
+ def add_nd_image(
+ self,
+ data: ArrayProtocol | None,
+ dims: Sequence[Hashable],
+ spatial_dims: (
+ tuple[str, str] | tuple[str, str, str]
+ ), # must be in order! [rows, cols] | [z, rows, cols]
+ rgb_dim: str | None = None,
+ window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable = None,
+ window_order: tuple[int, ...] = None,
+ spatial_func: Callable[[ArrayProtocol], ArrayProtocol] = None,
+ compute_histogram: bool = True,
+ slider_dim_transforms=None,
+ name: str = None,
+ ):
+ nd = NDImage(self.ndw.indices, self._subplot, data=data,
+ dims=dims,
+ spatial_dims=spatial_dims,
+ rgb_dim=rgb_dim,
+ window_funcs=window_funcs,
+ window_order=window_order,
+ spatial_func=spatial_func,
+ compute_histogram=compute_histogram,
+ slider_dim_transforms=slider_dim_transforms,
+ name=name,
+ )
+
+ self._nd_graphics.append(nd)
+ return nd
+
+ def add_nd_scatter(self, *args, **kwargs):
+ # TODO: better func signature here, send all kwargs to processor_kwargs
+ nd = NDPositions(
+ self.ndw.indices,
+ self._subplot,
+ *args,
+ graphic_type=ScatterCollection,
+ **kwargs,
+ )
+
+ self._nd_graphics.append(nd)
+ return nd
+
+ def add_nd_timeseries(
+ self,
+ *args,
+ graphic_type: type[
+ LineCollection | LineStack | ScatterStack | ImageGraphic
+ ] = LineStack,
+ x_range_mode: Literal["fixed", "auto"] | None = "auto",
+ **kwargs,
+ ):
+ nd = NDPositions(
+ self.ndw.indices,
+ self._subplot,
+ *args,
+ graphic_type=graphic_type,
+ linear_selector=True,
+ x_range_mode=x_range_mode,
+ timeseries=True,
+ **kwargs,
+ )
+
+ self._nd_graphics.append(nd)
+ return nd
+
+ def add_nd_lines(self, *args, **kwargs):
+ nd = NDPositions(
+ self.ndw.indices,
+ self._subplot,
+ *args,
+ graphic_type=LineCollection,
+ **kwargs,
+ )
+
+ self._nd_graphics.append(nd)
+ return nd
diff --git a/fastplotlib/widgets/nd_widget/_ndwidget.py b/fastplotlib/widgets/nd_widget/_ndwidget.py
new file mode 100644
index 000000000..9ddfa8986
--- /dev/null
+++ b/fastplotlib/widgets/nd_widget/_ndwidget.py
@@ -0,0 +1,66 @@
+from __future__ import annotations
+
+from typing import Any, Optional
+
+from ._index import RangeContinuous, RangeDiscrete, ReferenceIndex
+from ._ndw_subplot import NDWSubplot
+from ._ui import NDWidgetUI, RightClickMenu
+from ...layouts import ImguiFigure, Subplot
+
+
+class NDWidget:
+ def __init__(self, ref_ranges: dict[str, tuple], ref_index: Optional[ReferenceIndex] = None, **kwargs):
+ if ref_index is None:
+ self._indices = ReferenceIndex(ref_ranges)
+ else:
+ self._indices = ref_index
+
+ self._indices._add_ndwidget_(self)
+
+ self._figure = ImguiFigure(std_right_click_menu=RightClickMenu, **kwargs)
+ self._figure.std_right_click_menu.set_nd_widget(self)
+
+ self._subplots_nd: dict[Subplot, NDWSubplot] = dict()
+ for subplot in self.figure:
+ self._subplots_nd[subplot] = NDWSubplot(self, subplot)
+
+ # hard code the expected height so that the first render looks right in tests, docs etc.
+ ui_size = 57 + (50 * len(self.indices))
+
+ self._sliders_ui = NDWidgetUI(self.figure, ui_size, self)
+ self.figure.add_gui(self._sliders_ui)
+
+ @property
+ def figure(self) -> ImguiFigure:
+ return self._figure
+
+ @property
+ def indices(self) -> ReferenceIndex:
+ return self._indices
+
+ @indices.setter
+ def indices(self, new_indices: dict[str, int | float | Any]):
+ self._indices.set(new_indices)
+
+ @property
+ def ranges(self) -> dict[str, RangeContinuous | RangeDiscrete]:
+ return self._indices.ref_ranges
+
+ @property
+ def ndgraphics(self):
+ gs = list()
+ for subplot in self._subplots_nd.values():
+ gs.extend(subplot.nd_graphics)
+
+ return tuple(gs)
+
+ def __getitem__(self, key: str | tuple[int, int] | Subplot):
+ if not isinstance(key, Subplot):
+ key = self.figure[key]
+ return self._subplots_nd[key]
+
+ def show(self, **kwargs):
+ return self.figure.show(**kwargs)
+
+ def close(self):
+ self.figure.close()
diff --git a/fastplotlib/widgets/nd_widget/_repr_formatter.py b/fastplotlib/widgets/nd_widget/_repr_formatter.py
new file mode 100644
index 000000000..0569f1004
--- /dev/null
+++ b/fastplotlib/widgets/nd_widget/_repr_formatter.py
@@ -0,0 +1,599 @@
+from __future__ import annotations
+
+import html
+from collections.abc import Callable
+from typing import Any
+
+
+_RESET = "\033[0m"
+_BOLD = "\033[1m"
+_DIM = "\033[2m"
+
+_C = {
+ "title": "\033[38;5;75m", # sky-blue
+ "spatial": "\033[38;5;114m", # sage-green
+ "slider": "\033[38;5;215m", # soft-orange
+ "label": "\033[38;5;246m", # mid-grey
+ "value": "\033[38;5;252m", # near-white
+ "section": "\033[38;5;68m", # steel-blue
+ "muted": "\033[38;5;240m", # dark-grey
+ "warn": "\033[38;5;222m", # amber
+}
+
+
+def _c(key: str, text: str) -> str:
+ return f"{_C[key]}{text}{_RESET}"
+
+
+def _callable_name(f: Callable | None) -> str:
+ if f is None:
+ return "—"
+ module = getattr(f, "__module__", "") or ""
+ qname = getattr(f, "__qualname__", None) or getattr(f, "__name__", repr(f))
+ if module and not module.startswith("__"):
+ short = module.split(".")[-1]
+ return f"{short}.{qname}"
+ return qname
+
+
+def ndprocessor_fmt_txt(processor) -> str:
+ """
+ Returns a colored, ascii box
+ """
+ lines: list[str] = []
+
+ cls = type(processor).__name__
+ lines.append(_c("title", _BOLD + cls) + _RESET)
+ lines.append(_c("muted", "─" * 72))
+
+ lines.append(_c("section", " Dimensions"))
+
+ header = (
+ f" {'dim':<14}{'size':>6} {'role':<10} {'window_func size':<26} index_mapping"
+ )
+ lines.append(_c("label", header))
+ lines.append(_c("muted", " " + "─" * 70))
+
+ for dim in processor.dims:
+ size = processor.shape[dim]
+ is_sp = dim in processor.spatial_dims
+ role_s = (_c("spatial", f"{'spatial':<10}") if is_sp
+ else _c("slider", f"{'slider':<10}"))
+
+ # window_func - size column
+ if not is_sp:
+ wf, ws = processor.window_funcs.get(dim, (None, None))
+ if wf is not None and ws is not None:
+ win_s = _c("value", f"{_callable_name(wf)}") + _c("muted", f" - {ws}")
+ else:
+ win_s = _c("muted", "—")
+ else:
+ win_s = ""
+
+ # index_mapping column (slider dims only; skip identity)
+ if not is_sp:
+ imap = processor.index_mappings.get(dim)
+ iname = getattr(imap, "__name__", "") if imap is not None else ""
+ if iname != "identity" and imap is not None:
+ idx_s = _c("value", _callable_name(imap))
+ else:
+ idx_s = _c("muted", "—")
+ else:
+ idx_s = ""
+
+ # pad win_s to fixed visible width (strip ANSI for measuring)
+ import re
+ _ansi_re = re.compile(r"\033\[[^m]*m")
+ win_visible = len(_ansi_re.sub("", win_s))
+ win_pad = win_s + " " * max(0, 26 - win_visible)
+
+ line = (
+ f" {_c('value', f'{str(dim):<14}')}"
+ f"{_c('label', f'{size:>6}')} "
+ f"{role_s} {win_pad} {idx_s}"
+ )
+ lines.append(line)
+
+ # window order
+ if processor.window_order:
+ lines.append("")
+ order_s = " → ".join(str(d) for d in processor.window_order)
+ lines.append(f" {_c('section', 'Window order')} {_c('value', order_s)}")
+
+ # spatial func
+ if processor.spatial_func is not None:
+ lines.append("")
+ lines.append(
+ f" {_c('section', 'Spatial func')} "
+ f"{_c('value', _callable_name(processor.spatial_func))}"
+ )
+
+ lines.append(_c("muted", "─" * 72))
+ return "\n".join(lines)
+
+
+def ndgraphic_fmt_txt(ndg) -> str:
+ """Text repr for NDGraphic."""
+ cls = type(ndg).__name__
+ gcls = type(ndg.graphic).__name__ if ndg.graphic is not None else "—"
+ name = ndg.name or "—"
+
+ header = (
+ f"{_c('title', _BOLD + cls)}{_RESET} "
+ f"{_c('muted', '·')} "
+ f"{_c('section', 'graphic')} {_c('value', gcls)} "
+ f"{_c('muted', '·')} "
+ f"{_c('section', 'name')} {_c('value', name)}\n"
+ )
+
+ proc_block = ndprocessor_fmt_txt(ndg.processor)
+ # indent processor block
+ indented = "\n".join(" " + l for l in proc_block.splitlines())
+ return header + indented
+
+_CSS = """
+
+"""
+
+
+def _h(s: Any) -> str:
+ """html-escape a stringified value"""
+ return html.escape(str(s))
+
+
+def _badge(role: str) -> str:
+ cls = "fpl-badge-spatial" if role == "spatial" else "fpl-badge-slider"
+ return f'{role}'
+
+
+def _code(s: str) -> str:
+ return f"{_h(s)}"
+
+
+def _section(title: str, content_html: str, count: str = "", open_: bool = True) -> str:
+ open_attr = " open" if open_ else ""
+ count_badge = (
+ f'{_h(count)}' if count else ""
+ )
+ return (
+ f''
+ f''
+ f'{_h(title)}'
+ f'{count_badge}'
+ f'
'
+ f'{content_html}'
+ f' '
+ )
+
+
+def _dim_rows_html(proc) -> str:
+ rows = []
+ for dim in proc.dims:
+ size = proc.shape[dim]
+ is_sp = dim in proc.spatial_dims
+ badge = _badge("spatial" if is_sp else "slider")
+
+ # window_func - size column
+ if not is_sp:
+ wf, ws = proc.window_funcs.get(dim, (None, None))
+ if wf is not None and ws is not None:
+ win_td = (
+ f''
+ f'{_code(_callable_name(wf))}'
+ f'-'
+ f'{_code(str(ws))}'
+ f' | '
+ )
+ else:
+ win_td = '— | '
+ else:
+ win_td = ' | '
+
+ # index_mapping column (slider dims only; hide identity)
+ if not is_sp:
+ imap = proc.index_mappings.get(dim)
+ if imap is not None:
+ idx_td = f'{_code(_callable_name(imap))} | '
+ else:
+ idx_td = '— | '
+ else:
+ idx_td = ' | '
+
+ rows.append(
+ f''
+ f'| {_h(str(dim))} | '
+ f'{size:,} | '
+ f'{badge} | '
+ f'{win_td}'
+ f'{idx_td}'
+ f'
'
+ )
+
+ # column header row
+ header = (
+ f''
+ f'| dim | '
+ f'size | '
+ f'role | '
+ f'window_func - size | '
+ f'index_mapping | '
+ f'
'
+ )
+
+ table = (
+ ''
+ ''
+ ''
+ ''
+ ''
+ + header
+ + "".join(rows)
+ + "
"
+ )
+ return table
+
+
+def _footer_kv(pairs: list[tuple[str, str]]) -> str:
+ """Always-visible key/value rows rendered below the dim table."""
+ inner = ""
+ for k, v in pairs:
+ inner += (
+ f''
+ f''
+ )
+ return f''
+
+
+def _kv_list_html(pairs: list[tuple[str, str]]) -> str:
+ inner = ""
+ for k, v in pairs:
+ inner += (
+ f'{_h(k)}
'
+ f'{v}
'
+ )
+ return f'{inner}
'
+
+
+def _html_processor(proc) -> str:
+ cls = _h(type(proc).__name__)
+
+ # header
+ ndim_pill = (
+ f''
+ f'{proc.ndim}D'
+ )
+ header = (
+ f''
+ )
+
+ # dims section (always open)
+ dim_content = _dim_rows_html(proc)
+ sections = _section("Dimensions", dim_content,
+ count=str(proc.ndim), open_=True)
+
+ # always-visible footer rows
+ footer_pairs: list[tuple[str, str]] = []
+
+ if proc.window_order:
+ chain = " → ".join(
+ f'▶{_h(str(d))}'
+ if i > 0 else _h(str(d))
+ for i, d in enumerate(proc.window_order)
+ )
+ footer_pairs.append(("window order", f'{chain}'))
+
+ if proc.spatial_func is not None:
+ footer_pairs.append(("spatial func", _code(_callable_name(proc.spatial_func))))
+
+ if footer_pairs:
+ sections += _footer_kv(footer_pairs)
+
+ body = f'{sections}
'
+ return f'{_CSS}{header}{body}
'
+
+
+def ndgraphic_fmt_html(ndg) -> str:
+ cls = _h(type(ndg).__name__)
+ gcls = _h(type(ndg.graphic).__name__) if ndg.graphic is not None else "—"
+ name = _h(ndg.name or "—")
+
+ graphic_pill = f'graphic: {gcls}'
+ name_pill = f'name: {name}'
+
+ header = (
+ f''
+ )
+
+ # embed processor repr (without its own outer box) inside a section
+ proc_inner = _dim_rows_html(ndg.processor)
+ sections = _section("Processor · Dimensions", proc_inner, open_=True)
+
+ footer_pairs: list[tuple[str, str]] = []
+
+ if ndg.processor.window_order:
+ chain = " → ".join(
+ f'▶{_h(str(d))}'
+ if i > 0 else _h(str(d))
+ for i, d in enumerate(ndg.processor.window_order)
+ )
+ footer_pairs.append(("window order", f'{chain}'))
+
+ if ndg.processor.spatial_func is not None:
+ footer_pairs.append(("spatial func", _code(_callable_name(ndg.processor.spatial_func))))
+
+ if footer_pairs:
+ sections += _footer_kv(footer_pairs)
+
+ body = f'{sections}
'
+ return f'{_CSS}{header}{body}
'
+
+class ReprMixin:
+ """
+ Mixin that provides:
+ • __repr__ → coloured ANSI text (terminal / plain REPL)
+ • _repr_html_ → rich HTML (Jupyter)
+ • _repr_mimebundle_ → both, so Jupyter picks the richest format
+
+ Subclasses must implement _repr_text_() and _repr_html_() themselves OR
+ rely on the dispatch below which checks the concrete type.
+ """
+
+ def _repr_text_(self) -> str:
+ # lazy import avoids circular; swap for a direct call in your module
+ if _is_ndgraphic(self):
+ return ndgraphic_fmt_txt(self)
+ return ndprocessor_fmt_txt(self)
+
+ def _repr_html_(self) -> str:
+ return ndgraphic_fmt_html(self)
+ return _html_processor(self)
+
+ def __repr__(self) -> str:
+ return self._repr_text_()
+
+ def _repr_mimebundle_(self, **kwargs) -> dict:
+ return {
+ "text/plain": self._repr_text_(),
+ "text/html": self._repr_html_(),
+ }
+
+
+def _is_ndgraphic(obj) -> bool:
+ """duck-type check: does this object have a .graphic and .processor?"""
+ return hasattr(obj, "graphic") and hasattr(obj, "processor")
\ No newline at end of file
diff --git a/fastplotlib/widgets/nd_widget/_ui.py b/fastplotlib/widgets/nd_widget/_ui.py
new file mode 100644
index 000000000..2855c7063
--- /dev/null
+++ b/fastplotlib/widgets/nd_widget/_ui.py
@@ -0,0 +1,287 @@
+import os
+from time import perf_counter
+
+import numpy as np
+from imgui_bundle import imgui, imgui_ctx, icons_fontawesome_6 as fa
+
+from ...graphics import (
+ ScatterCollection,
+ ScatterStack,
+ LineCollection,
+ LineStack,
+ ImageGraphic,
+ ImageVolumeGraphic,
+)
+from ...utils import quick_min_max
+from ...layouts import Subplot
+from ...ui import EdgeWindow, StandardRightClickMenu
+from ._index import RangeContinuous
+from ._base import NDGraphic
+from ._nd_positions import NDPositions
+from ._nd_image import NDImage
+
+position_graphic_types = [ScatterCollection, ScatterStack, LineCollection, LineStack, ImageGraphic]
+
+
+class NDWidgetUI(EdgeWindow):
+ def __init__(self, figure, size, ndwidget):
+ super().__init__(
+ figure=figure,
+ size=size,
+ title="NDWidget controls",
+ location="bottom",
+ window_flags=imgui.WindowFlags_.no_collapse
+ | imgui.WindowFlags_.no_resize
+ | imgui.WindowFlags_.no_title_bar,
+ )
+ self._ndwidget = ndwidget
+
+ ref_ranges = self._ndwidget.ranges
+
+ # whether or not a dimension is in play mode
+ self._playing = {dim: False for dim in ref_ranges.keys()}
+
+ # approximate framerate for playing
+ self._fps = {dim: 20 for dim in ref_ranges.keys()}
+
+ # framerate converted to frame time
+ self._frame_time = {dim: 1 / 20 for dim in ref_ranges.keys()}
+
+ # last timepoint that a frame was displayed from a given dimension
+ self._last_frame_time = {dim: perf_counter() for dim in ref_ranges.keys()}
+
+ # loop playback
+ self._loop = {dim: False for dim in ref_ranges.keys()}
+
+ # auto-plays the ImageWidget's left-most dimension in docs galleries
+ if "DOCS_BUILD" in os.environ.keys():
+ if os.environ["DOCS_BUILD"] == "1":
+ self._playing[0] = True
+ self._loop = True
+
+ self._max_display_windows: dict[NDGraphic, float | int] = dict()
+
+ def _set_index(self, dim, index):
+ if index >= self._ndwidget.ranges[dim].stop:
+ if self._loop[dim]:
+ index = self._ndwidget.ranges[dim].start
+ else:
+ index = self._ndwidget.ranges[dim].stop
+ self._playing[dim] = False
+
+ self._ndwidget.indices[dim] = index
+
+ def update(self):
+ now = perf_counter()
+
+ for dim, current_index in self._ndwidget.indices:
+ # push id since we have the same buttons for each dim
+ imgui.push_id(f"{self._id_counter}_{dim}")
+
+ rr = self._ndwidget.ranges[dim]
+
+ if self._playing[dim]:
+ # show pause button if playing
+ if imgui.button(label=fa.ICON_FA_PAUSE):
+ # if pause button clicked, then set playing to false
+ self._playing[dim] = False
+
+ # if in play mode and enough time has elapsed w.r.t. the desired framerate, increment the index
+ if now - self._last_frame_time[dim] >= self._frame_time[dim]:
+ self._set_index(dim, current_index + rr.step)
+ self._last_frame_time[dim] = now
+
+ else:
+ # we are not playing, so display play button
+ if imgui.button(label=fa.ICON_FA_PLAY):
+ # if play button is clicked, set last frame time to 0 so that index increments on next render
+ self._last_frame_time[dim] = 0
+ # set playing to True since play button was clicked
+ self._playing[dim] = True
+
+ imgui.same_line()
+ # step back one frame button
+ if imgui.button(label=fa.ICON_FA_BACKWARD_STEP) and not self._playing[dim]:
+ self._set_index(dim, current_index - rr.step)
+
+ imgui.same_line()
+ # step forward one frame button
+ if imgui.button(label=fa.ICON_FA_FORWARD_STEP) and not self._playing[dim]:
+ self._set_index(dim, current_index + rr.step)
+
+ imgui.same_line()
+ # stop button
+ if imgui.button(label=fa.ICON_FA_STOP):
+ self._playing[dim] = False
+ self._last_frame_time[dim] = 0
+ self._ndwidget.indices[dim] = rr.start
+
+ imgui.same_line()
+ # loop checkbox
+ _, self._loop[dim] = imgui.checkbox(
+ label=fa.ICON_FA_ROTATE, v=self._loop[dim]
+ )
+ if imgui.is_item_hovered(0):
+ imgui.set_tooltip("loop playback")
+
+ imgui.same_line()
+ imgui.text("framerate :")
+ imgui.same_line()
+ imgui.set_next_item_width(100)
+ # framerate int entry
+ fps_changed, value = imgui.input_int(
+ label="fps", v=self._fps[dim], step_fast=5
+ )
+ if imgui.is_item_hovered(0):
+ imgui.set_tooltip(
+ "framerate is approximate and less reliable as it approaches your monitor refresh rate"
+ )
+ if fps_changed:
+ if value < 1:
+ value = 1
+ if value > 50:
+ value = 50
+ self._fps[dim] = value
+ self._frame_time[dim] = 1 / value
+
+ imgui.text(str(dim))
+ imgui.same_line()
+ # so that slider occupies full width
+ imgui.set_next_item_width(self.width * 0.85)
+
+ if isinstance(rr, RangeContinuous):
+ changed, new_index = imgui.slider_float(
+ v=current_index,
+ v_min=rr.start,
+ v_max=rr.stop - rr.step,
+ label=f"##{dim}",
+ )
+
+ # TODO: refactor all this stuff, make fully fledged UI
+ if changed:
+ self._ndwidget.indices[dim] = new_index
+
+ elif imgui.is_item_hovered():
+ if imgui.is_key_pressed(imgui.Key.right_arrow):
+ self._set_index(dim, current_index + rr.step)
+
+ elif imgui.is_key_pressed(imgui.Key.left_arrow):
+ self._set_index(dim, current_index - rr.step)
+
+ imgui.pop_id()
+
+
+class RightClickMenu(StandardRightClickMenu):
+ def __init__(self, figure):
+ self._ndwidget = None
+ self._ndgraphic_windows = set()
+
+ super().__init__(figure=figure)
+
+ def set_nd_widget(self, ndw):
+ self._ndwidget = ndw
+
+ def _extra_menu(self):
+ if self._ndwidget is None:
+ return
+
+ if imgui.begin_menu("ND Graphics"):
+ subplot = self.get_subplot()
+ for ndg in self._ndwidget[subplot].nd_graphics:
+ name = ndg.name if ndg.name is not None else hex(id(ndg))
+ if imgui.menu_item(
+ f"{name}", "", False
+ )[0]:
+ self._ndgraphic_windows.add(ndg)
+
+ imgui.end_menu()
+
+ def update(self):
+ super().update()
+
+ for ndg in list(self._ndgraphic_windows): # set -> list so we can change size during iteration
+ name = ndg.name if ndg.name is not None else hex(id(ndg))
+ subplot = ndg.graphic._plot_area
+ imgui.set_next_window_size((0, 0))
+ _, open = imgui.begin(f"subplot: {subplot.name}, {name}", True)
+
+ if isinstance(ndg, NDPositions):
+ self._draw_nd_pos_ui(subplot, ndg)
+
+ elif isinstance(ndg, NDImage):
+ self._draw_nd_image_ui(subplot, ndg)
+
+ _, ndg.pause = imgui.checkbox("pause", ndg.pause)
+
+ if not open:
+ self._ndgraphic_windows.remove(ndg)
+
+ imgui.end()
+
+ def _draw_nd_image_ui(self, subplot, nd_image: NDImage):
+ _min, _max = quick_min_max(nd_image.graphic.data.value)
+ changed, vmin = imgui.slider_float(
+ "vmin", nd_image.graphic.vmin, v_min=_min, v_max=_max
+ )
+ if changed:
+ nd_image.graphic.vmin = vmin
+
+ changed, vmax = imgui.slider_float(
+ "vmax", nd_image.graphic.vmax, v_min=_min, v_max=_max
+ )
+ if changed:
+ nd_image.graphic.vmax = vmax
+
+ changed, new_gamma = imgui.slider_float(
+ "gamma", nd_image.graphic._material.gamma, 0.01, 5
+ )
+ if changed:
+ nd_image.graphic._material.gamma = new_gamma
+
+ def _draw_nd_pos_ui(self, subplot: Subplot, nd_graphic: NDPositions):
+ for i, cls in enumerate(position_graphic_types):
+ if imgui.radio_button(cls.__name__, type(nd_graphic.graphic) is cls):
+ nd_graphic.graphic_type = cls
+ subplot.auto_scale()
+
+ changed, val = imgui.checkbox(
+ "use display window", nd_graphic.display_window is not None
+ )
+
+ p_dim = nd_graphic.processor.spatial_dims[1]
+
+ if changed:
+ if not val:
+ nd_graphic.display_window = None
+ else:
+ # pick a value 10% of the reference range
+ nd_graphic.display_window = self._ndwidget.ranges[p_dim].range * 0.1
+
+ if nd_graphic.display_window is not None:
+ if isinstance(nd_graphic.display_window, (int, np.integer)):
+ slider = imgui.slider_int
+ input_ = imgui.input_int
+ type_ = int
+ else:
+ slider = imgui.slider_float
+ input_ = imgui.input_float
+ type_ = float
+
+ changed, new = slider(
+ "display window",
+ v=nd_graphic.display_window,
+ v_min=type_(0),
+ v_max=type_(self._ndwidget.ranges[p_dim].stop * 0.1),
+ )
+
+ if changed:
+ nd_graphic.display_window = new
+
+ options = [None, "fixed", "auto"]
+ changed, option = imgui.combo(
+ "x-range mode",
+ options.index(nd_graphic.x_range_mode),
+ [str(o) for o in options],
+ )
+ if changed:
+ nd_graphic.x_range_mode = options[option]
diff --git a/pyproject.toml b/pyproject.toml
index 73dfd7ee3..b91b168c2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -20,7 +20,7 @@ keywords = [
requires-python = ">= 3.10"
dependencies = [
"numpy>=1.23.0",
- "pygfx==0.15.3",
+ "pygfx==0.16.0",
"wgpu", # Let pygfx constrain the wgpu version
"cmap>=0.1.3",
# (this comment keeps this list multiline in VSCode)
@@ -46,6 +46,7 @@ notebook = [
"jupyter-rfb>=0.5.1",
"ipywidgets>=8.0.0,<9",
"sidecar",
+ "simplejpeg",
]
tests = [
"pytest",
@@ -58,7 +59,8 @@ tests = [
"ome-zarr",
]
imgui = ["wgpu[imgui]"]
-dev = ["fastplotlib[docs,notebook,tests,imgui]"]
+ndwidget = ["wgpu[imgui]", "xarray"]
+dev = ["fastplotlib[docs,notebook,tests,imgui,ndwidget]"]
[project.urls]
Homepage = "https://www.fastplotlib.org/"
diff --git a/tests/test_colors_buffer_manager.py b/tests/test_colors_buffer_manager.py
index 7b1aef16a..f9d56189e 100644
--- a/tests/test_colors_buffer_manager.py
+++ b/tests/test_colors_buffer_manager.py
@@ -48,10 +48,10 @@ def test_int(test_graphic):
data = generate_positions_spiral_data("xyz")
if test_graphic == "line":
- graphic = fig[0, 0].add_line(data=data)
+ graphic = fig[0, 0].add_line(data=data, color_mode="vertex")
elif test_graphic == "scatter":
- graphic = fig[0, 0].add_scatter(data=data)
+ graphic = fig[0, 0].add_scatter(data=data, color_mode="vertex")
colors = graphic.colors
global EVENT_RETURN_VALUE
@@ -98,10 +98,10 @@ def test_tuple(test_graphic, slice_method):
data = generate_positions_spiral_data("xyz")
if test_graphic == "line":
- graphic = fig[0, 0].add_line(data=data)
+ graphic = fig[0, 0].add_line(data=data, color_mode="vertex")
elif test_graphic == "scatter":
- graphic = fig[0, 0].add_scatter(data=data)
+ graphic = fig[0, 0].add_scatter(data=data, color_mode="vertex")
colors = graphic.colors
global EVENT_RETURN_VALUE
@@ -190,10 +190,10 @@ def test_slice(color_input, slice_method: dict, test_graphic: bool):
data = generate_positions_spiral_data("xyz")
if test_graphic == "line":
- graphic = fig[0, 0].add_line(data=data)
+ graphic = fig[0, 0].add_line(data=data, color_mode="vertex")
elif test_graphic == "scatter":
- graphic = fig[0, 0].add_scatter(data=data)
+ graphic = fig[0, 0].add_scatter(data=data, color_mode="vertex")
colors = graphic.colors
diff --git a/tests/test_markers_buffer_manager.py b/tests/test_markers_buffer_manager.py
index 65ead392e..488bed194 100644
--- a/tests/test_markers_buffer_manager.py
+++ b/tests/test_markers_buffer_manager.py
@@ -46,10 +46,10 @@ def test_create_buffer(test_graphic):
if test_graphic:
fig = fpl.Figure()
- scatter = fig[0, 0].add_scatter(data, markers=MARKERS1)
+ scatter = fig[0, 0].add_scatter(data, markers=MARKERS1, uniform_marker=False)
vertex_markers = scatter.markers
assert isinstance(vertex_markers, VertexMarkers)
- assert vertex_markers.buffer is scatter.world_object.geometry.markers
+ assert vertex_markers._fpl_buffer is scatter.world_object.geometry.markers
else:
vertex_markers = VertexMarkers(MARKERS1, len(data))
@@ -68,7 +68,7 @@ def test_int(test_graphic, index: int):
if test_graphic:
fig = fpl.Figure()
- scatter = fig[0, 0].add_scatter(data, markers=MARKERS1)
+ scatter = fig[0, 0].add_scatter(data, markers=MARKERS1, uniform_marker=False)
scatter.add_event_handler(event_handler, "markers")
vertex_markers = scatter.markers
else:
@@ -108,7 +108,7 @@ def test_slice(test_graphic, slice_method):
if test_graphic:
fig = fpl.Figure()
- scatter = fig[0, 0].add_scatter(data, markers=MARKERS1)
+ scatter = fig[0, 0].add_scatter(data, markers=MARKERS1, uniform_marker=False)
scatter.add_event_handler(event_handler, "markers")
vertex_markers = scatter.markers
diff --git a/tests/test_point_rotations_buffer_manager.py b/tests/test_point_rotations_buffer_manager.py
index ec5fdbe0f..50ee88984 100644
--- a/tests/test_point_rotations_buffer_manager.py
+++ b/tests/test_point_rotations_buffer_manager.py
@@ -35,7 +35,7 @@ def test_create_buffer(test_graphic):
scatter = fig[0, 0].add_scatter(data, point_rotation_mode="vertex", point_rotations=ROTATIONS1)
vertex_rotations = scatter.point_rotations
assert isinstance(vertex_rotations, VertexRotations)
- assert vertex_rotations.buffer is scatter.world_object.geometry.rotations
+ assert vertex_rotations._fpl_buffer is scatter.world_object.geometry.rotations
else:
vertex_rotations = VertexRotations(ROTATIONS1, len(data))
diff --git a/tests/test_positions_data_buffer_manager.py b/tests/test_positions_data_buffer_manager.py
index e2582d4ba..cc550abf0 100644
--- a/tests/test_positions_data_buffer_manager.py
+++ b/tests/test_positions_data_buffer_manager.py
@@ -57,7 +57,7 @@ def test_int(test_graphic):
graphic = fig[0, 0].add_scatter(data=data)
points = graphic.data
- assert graphic.data.buffer is graphic.world_object.geometry.positions
+ assert graphic.data._fpl_buffer is graphic.world_object.geometry.positions
global EVENT_RETURN_VALUE
graphic.add_event_handler(event_handler, "data")
else:
diff --git a/tests/test_positions_graphics.py b/tests/test_positions_graphics.py
index 31c001888..4bc93b626 100644
--- a/tests/test_positions_graphics.py
+++ b/tests/test_positions_graphics.py
@@ -37,12 +37,12 @@ def test_sizes_slice():
@pytest.mark.parametrize("graphic_type", ["line", "scatter"])
@pytest.mark.parametrize("colors", [None, *generate_color_inputs("b")])
-@pytest.mark.parametrize("uniform_color", [True, False])
-def test_uniform_color(graphic_type, colors, uniform_color):
+@pytest.mark.parametrize("color_mode", ["uniform", "vertex"])
+def test_color_mode(graphic_type, colors, color_mode):
fig = fpl.Figure()
kwargs = dict()
- for kwarg in ["colors", "uniform_color"]:
+ for kwarg in ["colors", "color_mode"]:
if locals()[kwarg] is not None:
# add to dict of arguments that will be passed
kwargs[kwarg] = locals()[kwarg]
@@ -54,7 +54,7 @@ def test_uniform_color(graphic_type, colors, uniform_color):
elif graphic_type == "scatter":
graphic = fig[0, 0].add_scatter(data=data, **kwargs)
- if uniform_color:
+ if color_mode == "uniform":
assert isinstance(graphic._colors, UniformColor)
assert isinstance(graphic.colors, pygfx.Color)
if colors is None:
@@ -130,17 +130,17 @@ def test_positions_graphics_data(
@pytest.mark.parametrize("graphic_type", ["line", "scatter"])
@pytest.mark.parametrize("colors", [None, *generate_color_inputs("r")])
-@pytest.mark.parametrize("uniform_color", [None, False])
+@pytest.mark.parametrize("color_mode", ["vertex"])
def test_positions_graphic_vertex_colors(
graphic_type,
colors,
- uniform_color,
+ color_mode,
):
# test different ways of passing vertex colors
fig = fpl.Figure()
kwargs = dict()
- for kwarg in ["colors", "uniform_color"]:
+ for kwarg in ["colors", "color_mode"]:
if locals()[kwarg] is not None:
# add to dict of arguments that will be passed
kwargs[kwarg] = locals()[kwarg]
@@ -153,10 +153,9 @@ def test_positions_graphic_vertex_colors(
graphic = fig[0, 0].add_scatter(data=data, **kwargs)
# color per vertex
- # uniform colors is default False, or set to False
- assert isinstance(graphic._colors, VertexColors)
- assert isinstance(graphic.colors, VertexColors)
- assert len(graphic.colors) == len(graphic.data)
+ assert isinstance(graphic._colors, VertexColors)
+ assert isinstance(graphic.colors, VertexColors)
+ assert len(graphic.colors) == len(graphic.data)
if colors is None:
# default
@@ -179,7 +178,7 @@ def test_positions_graphic_vertex_colors(
@pytest.mark.parametrize("graphic_type", ["line", "scatter"])
@pytest.mark.parametrize("colors", [None, *generate_color_inputs("r")])
-@pytest.mark.parametrize("uniform_color", [None, False])
+@pytest.mark.parametrize("color_mode", ["auto", "vertex"])
@pytest.mark.parametrize("cmap", ["jet"])
@pytest.mark.parametrize(
"cmap_transform", [None, [3, 5, 2, 1, 0, 6, 9, 7, 4, 8], np.arange(9, -1, -1)]
@@ -187,7 +186,7 @@ def test_positions_graphic_vertex_colors(
def test_cmap(
graphic_type,
colors,
- uniform_color,
+ color_mode,
cmap,
cmap_transform,
):
@@ -195,7 +194,7 @@ def test_cmap(
fig = fpl.Figure()
kwargs = dict()
- for kwarg in ["cmap", "cmap_transform", "colors", "uniform_color"]:
+ for kwarg in ["cmap", "cmap_transform", "colors", "color_mode"]:
if locals()[kwarg] is not None:
# add to dict of arguments that will be passed
kwargs[kwarg] = locals()[kwarg]
@@ -220,7 +219,8 @@ def test_cmap(
# make sure buffer is identical
# cmap overrides colors argument
- assert graphic.colors.buffer is graphic.cmap.buffer
+ # use __repr__.__self__ to get the real reference from the cmap feature instead of the weakref proxy
+ assert graphic.colors._fpl_buffer is graphic.cmap.buffer.__repr__.__self__
npt.assert_almost_equal(graphic.cmap.value, truth)
npt.assert_almost_equal(graphic.colors.value, truth)
@@ -261,14 +261,14 @@ def test_cmap(
"colors", [None, *generate_color_inputs("multi")]
) # cmap arg overrides colors
@pytest.mark.parametrize(
- "uniform_color", [True] # none of these will work with a uniform buffer
+ "color_mode", ["uniform"] # none of these will work with a uniform buffer
)
-def test_incompatible_cmap_color_args(graphic_type, cmap, colors, uniform_color):
+def test_incompatible_cmap_color_args(graphic_type, cmap, colors, color_mode):
# test incompatible cmap args
fig = fpl.Figure()
kwargs = dict()
- for kwarg in ["cmap", "colors", "uniform_color"]:
+ for kwarg in ["cmap", "colors", "color_mode"]:
if locals()[kwarg] is not None:
# add to dict of arguments that will be passed
kwargs[kwarg] = locals()[kwarg]
@@ -276,24 +276,24 @@ def test_incompatible_cmap_color_args(graphic_type, cmap, colors, uniform_color)
data = generate_positions_spiral_data("xy")
if graphic_type == "line":
- with pytest.raises(TypeError):
+ with pytest.raises(ValueError):
graphic = fig[0, 0].add_line(data=data, **kwargs)
elif graphic_type == "scatter":
- with pytest.raises(TypeError):
+ with pytest.raises(ValueError):
graphic = fig[0, 0].add_scatter(data=data, **kwargs)
@pytest.mark.parametrize("graphic_type", ["line", "scatter"])
@pytest.mark.parametrize("colors", [*generate_color_inputs("multi")])
@pytest.mark.parametrize(
- "uniform_color", [True] # none of these will work with a uniform buffer
+ "color_mode", ["uniform"] # none of these will work with a uniform buffer
)
-def test_incompatible_color_args(graphic_type, colors, uniform_color):
+def test_incompatible_color_args(graphic_type, colors, color_mode):
# test incompatible color args
fig = fpl.Figure()
kwargs = dict()
- for kwarg in ["colors", "uniform_color"]:
+ for kwarg in ["colors", "color_mode"]:
if locals()[kwarg] is not None:
# add to dict of arguments that will be passed
kwargs[kwarg] = locals()[kwarg]
@@ -301,16 +301,15 @@ def test_incompatible_color_args(graphic_type, colors, uniform_color):
data = generate_positions_spiral_data("xy")
if graphic_type == "line":
- with pytest.raises(TypeError):
+ with pytest.raises(ValueError):
graphic = fig[0, 0].add_line(data=data, **kwargs)
elif graphic_type == "scatter":
- with pytest.raises(TypeError):
+ with pytest.raises(ValueError):
graphic = fig[0, 0].add_scatter(data=data, **kwargs)
@pytest.mark.parametrize("sizes", [None, 5.0, np.linspace(3, 8, 10, dtype=np.float32)])
-@pytest.mark.parametrize("uniform_size", [None, False])
-def test_sizes(sizes, uniform_size):
+def test_sizes(sizes):
# test scatter sizes
fig = fpl.Figure()
@@ -322,7 +321,7 @@ def test_sizes(sizes, uniform_size):
data = generate_positions_spiral_data("xy")
- graphic = fig[0, 0].add_scatter(data=data, **kwargs)
+ graphic = fig[0, 0].add_scatter(data=data, uniform_size=False, **kwargs)
assert isinstance(graphic.sizes, VertexPointSizes)
assert isinstance(graphic._sizes, VertexPointSizes)
diff --git a/tests/test_replace_buffer.py b/tests/test_replace_buffer.py
new file mode 100644
index 000000000..a9d0ffe41
--- /dev/null
+++ b/tests/test_replace_buffer.py
@@ -0,0 +1,155 @@
+import gc
+import weakref
+
+import pytest
+import numpy as np
+from itertools import product
+
+import fastplotlib as fpl
+from .utils_textures import MAX_TEXTURE_SIZE, check_texture_array, check_image_graphic
+
+# These are only de-referencing tests for positions graphics, and ImageGraphic
+# they do not test that VRAM gets free, for now this can only be checked manually
+# with the tests in examples/misc/buffer_replace_gc.py
+
+
+@pytest.mark.parametrize("graphic_type", ["line", "scatter"])
+@pytest.mark.parametrize("new_buffer_size", [50, 150])
+def test_replace_positions_buffer(graphic_type, new_buffer_size):
+ fig = fpl.Figure()
+
+ # create some data with an initial shape
+ orig_datapoints = 100
+
+ xs = np.linspace(0, 2 * np.pi, orig_datapoints)
+ ys = np.sin(xs)
+ zs = np.cos(xs)
+
+ data = np.column_stack([xs, ys, zs])
+
+ # add add_line or add_scatter method
+ adder = getattr(fig[0, 0], f"add_{graphic_type}")
+
+ if graphic_type == "scatter":
+ kwargs = {
+ "markers": np.random.choice(list("osD+x^v<>*"), size=orig_datapoints),
+ "uniform_marker": False,
+ "sizes": np.abs(ys),
+ "uniform_size": False,
+ # TODO: skipping edge_colors for now since that causes a WGPU bind group error that we will figure out later
+ # anyways I think changing buffer sizes in combination with per-vertex edge colors is a literal edge-case
+ "point_rotations": zs * 180,
+ "point_rotation_mode": "vertex",
+ }
+ else:
+ kwargs = dict()
+
+ # add a line or scatter graphic
+ graphic = adder(data=data, colors=np.random.rand(orig_datapoints, 4), **kwargs)
+
+ fig.show()
+
+ # weakrefs to the original buffers
+ # these should raise a ReferenceError when the corresponding feature is replaced with data of a different shape
+ orig_data_buffer = weakref.proxy(graphic.data._fpl_buffer)
+ orig_colors_buffer = weakref.proxy(graphic.colors._fpl_buffer)
+
+ buffers = [orig_data_buffer, orig_colors_buffer]
+
+ # extra buffers for the scatters
+ if graphic_type == "scatter":
+ for attr in ["markers", "sizes", "point_rotations"]:
+ buffers.append(weakref.proxy(getattr(graphic, attr)._fpl_buffer))
+
+ # create some new data that requires a different buffer shape
+ xs = np.linspace(0, 15 * np.pi, new_buffer_size)
+ ys = np.sin(xs)
+ zs = np.cos(xs)
+
+ new_data = np.column_stack([xs, ys, zs])
+
+ # set data that requires a larger buffer and check that old buffer is no longer referenced
+ graphic.data = new_data
+ graphic.colors = np.random.rand(new_buffer_size, 4)
+
+ if graphic_type == "scatter":
+ # changes values so that new larger buffers must be allocated
+ graphic.markers = np.random.choice(list("osD+x^v<>*"), size=new_buffer_size)
+ graphic.sizes = np.abs(zs)
+ graphic.point_rotations = ys * 180
+
+ # make sure old original buffers are de-referenced
+ for i in range(len(buffers)):
+ with pytest.raises(ReferenceError) as fail:
+ buffers[i]
+ pytest.fail(
+ f"GC failed for buffer: {buffers[i]}, "
+ f"with referrers: {gc.get_referrers(buffers[i].__repr__.__self__)}"
+ )
+
+
+# test all combination of dims that require TextureArrays of shapes 1x1, 1x2, 1x3, 2x3, 3x3 etc.
+@pytest.mark.parametrize(
+ "new_buffer_size", list(product(*[[(500, 1), (1200, 2), (2200, 3)]] * 2))
+)
+def test_replace_image_buffer(new_buffer_size):
+ # make an image with some starting shape
+ orig_size = (1_500, 1_500)
+
+ data = np.random.rand(*orig_size)
+
+ fig = fpl.Figure()
+ image = fig[0, 0].add_image(data)
+
+ # the original Texture buffers that represent the individual image tiles
+ orig_buffers = [
+ weakref.proxy(image.data.buffer.ravel()[i])
+ for i in range(image.data.buffer.size)
+ ]
+ orig_shape = image.data.buffer.shape
+
+ fig.show()
+
+ # dimensions for a new image
+ new_dims = [v[0] for v in new_buffer_size]
+
+ # the number of tiles required in each dim/shape of the TextureArray
+ new_shape = tuple(v[1] for v in new_buffer_size)
+
+ # make the new data and set the image
+ new_data = np.random.rand(*new_dims)
+ image.data = new_data
+
+ # test that old Texture buffers are de-referenced
+ for i in range(len(orig_buffers)):
+ with pytest.raises(ReferenceError) as fail:
+ orig_buffers[i]
+ pytest.fail(
+ f"GC failed for buffer: {orig_buffers[i]}, of shape: {orig_shape}"
+ f"with referrers: {gc.get_referrers(orig_buffers[i].__repr__.__self__)}"
+ )
+
+ # check new texture array
+ check_texture_array(
+ data=new_data,
+ ta=image.data,
+ buffer_size=np.prod(new_shape),
+ buffer_shape=new_shape,
+ row_indices_size=new_shape[0],
+ col_indices_size=new_shape[1],
+ row_indices_values=np.array(
+ [
+ i * MAX_TEXTURE_SIZE
+ for i in range(0, 1 + (new_data.shape[0] - 1) // MAX_TEXTURE_SIZE)
+ ]
+ ),
+ col_indices_values=np.array(
+ [
+ i * MAX_TEXTURE_SIZE
+ for i in range(0, 1 + (new_data.shape[1] - 1) // MAX_TEXTURE_SIZE)
+ ]
+ ),
+ )
+
+ # check that new image tiles are arranged correctly
+ check_image_graphic(image.data, image)
diff --git a/tests/test_scatter_graphic.py b/tests/test_scatter_graphic.py
index a61681f24..930d8c495 100644
--- a/tests/test_scatter_graphic.py
+++ b/tests/test_scatter_graphic.py
@@ -133,7 +133,7 @@ def test_edge_colors(edge_colors):
npt.assert_almost_equal(scatter.edge_colors.value, MULTI_COLORS_TRUTH)
assert (
- scatter.edge_colors.buffer is scatter.world_object.geometry.edge_colors
+ scatter.edge_colors._fpl_buffer is scatter.world_object.geometry.edge_colors
)
# test changes, don't need to test extensively here since it's tested in the main VertexColors test
diff --git a/tests/test_texture_array.py b/tests/test_texture_array.py
index 6220f2fe5..01abb9a97 100644
--- a/tests/test_texture_array.py
+++ b/tests/test_texture_array.py
@@ -2,14 +2,9 @@
from numpy import testing as npt
import pytest
-import pygfx
-
import fastplotlib as fpl
from fastplotlib.graphics.features import TextureArray
-from fastplotlib.graphics.image import _ImageTile
-
-
-MAX_TEXTURE_SIZE = 1024
+from .utils_textures import MAX_TEXTURE_SIZE, check_texture_array, check_image_graphic
def make_data(n_rows: int, n_cols: int) -> np.ndarray:
@@ -25,50 +20,6 @@ def make_data(n_rows: int, n_cols: int) -> np.ndarray:
return np.vstack([sine * i for i in range(n_rows)]).astype(np.float32)
-def check_texture_array(
- data: np.ndarray,
- ta: TextureArray,
- buffer_size: int,
- buffer_shape: tuple[int, int],
- row_indices_size: int,
- col_indices_size: int,
- row_indices_values: np.ndarray,
- col_indices_values: np.ndarray,
-):
-
- npt.assert_almost_equal(ta.value, data)
-
- assert ta.buffer.size == buffer_size
- assert ta.buffer.shape == buffer_shape
-
- assert all([isinstance(texture, pygfx.Texture) for texture in ta.buffer.ravel()])
-
- assert ta.row_indices.size == row_indices_size
- assert ta.col_indices.size == col_indices_size
- npt.assert_array_equal(ta.row_indices, row_indices_values)
- npt.assert_array_equal(ta.col_indices, col_indices_values)
-
- # make sure chunking is correct
- for texture, chunk_index, data_slice in ta:
- assert ta.buffer[chunk_index] is texture
- chunk_row, chunk_col = chunk_index
-
- data_row_start_index = chunk_row * MAX_TEXTURE_SIZE
- data_col_start_index = chunk_col * MAX_TEXTURE_SIZE
-
- data_row_stop_index = min(
- data.shape[0], data_row_start_index + MAX_TEXTURE_SIZE
- )
- data_col_stop_index = min(
- data.shape[1], data_col_start_index + MAX_TEXTURE_SIZE
- )
-
- row_slice = slice(data_row_start_index, data_row_stop_index)
- col_slice = slice(data_col_start_index, data_col_stop_index)
-
- assert data_slice == (row_slice, col_slice)
-
-
def check_set_slice(data, ta, row_slice, col_slice):
ta[row_slice, col_slice] = 1
npt.assert_almost_equal(ta[row_slice, col_slice], 1)
@@ -85,17 +36,6 @@ def make_image_graphic(data) -> fpl.ImageGraphic:
return fig[0, 0].add_image(data)
-def check_image_graphic(texture_array, graphic):
- # make sure each ImageTile has the right texture
- for (texture, chunk_index, data_slice), img in zip(
- texture_array, graphic.world_object.children
- ):
- assert isinstance(img, _ImageTile)
- assert img.geometry.grid is texture
- assert img.world.x == data_slice[1].start
- assert img.world.y == data_slice[0].start
-
-
@pytest.mark.parametrize("test_graphic", [False, True])
def test_small_texture(test_graphic):
# tests TextureArray with dims that requires only 1 texture
@@ -162,15 +102,27 @@ def test_wide(test_graphic):
else:
ta = TextureArray(data)
+ ta_shape = (2, 3)
+
check_texture_array(
data,
ta=ta,
- buffer_size=6,
- buffer_shape=(2, 3),
- row_indices_size=2,
- col_indices_size=3,
- row_indices_values=np.array([0, MAX_TEXTURE_SIZE]),
- col_indices_values=np.array([0, MAX_TEXTURE_SIZE, 2 * MAX_TEXTURE_SIZE]),
+ buffer_size=np.prod(ta_shape),
+ buffer_shape=ta_shape,
+ row_indices_size=ta_shape[0],
+ col_indices_size=ta_shape[1],
+ row_indices_values=np.array(
+ [
+ i * MAX_TEXTURE_SIZE
+ for i in range(0, 1 + (data.shape[0] - 1) // MAX_TEXTURE_SIZE)
+ ]
+ ),
+ col_indices_values=np.array(
+ [
+ i * MAX_TEXTURE_SIZE
+ for i in range(0, 1 + (data.shape[1] - 1) // MAX_TEXTURE_SIZE)
+ ]
+ ),
)
if test_graphic:
@@ -189,15 +141,27 @@ def test_tall(test_graphic):
else:
ta = TextureArray(data)
+ ta_shape = (3, 2)
+
check_texture_array(
data,
ta=ta,
- buffer_size=6,
- buffer_shape=(3, 2),
- row_indices_size=3,
- col_indices_size=2,
- row_indices_values=np.array([0, MAX_TEXTURE_SIZE, 2 * MAX_TEXTURE_SIZE]),
- col_indices_values=np.array([0, MAX_TEXTURE_SIZE]),
+ buffer_size=np.prod(ta_shape),
+ buffer_shape=ta_shape,
+ row_indices_size=ta_shape[0],
+ col_indices_size=ta_shape[1],
+ row_indices_values=np.array(
+ [
+ i * MAX_TEXTURE_SIZE
+ for i in range(0, 1 + (data.shape[0] - 1) // MAX_TEXTURE_SIZE)
+ ]
+ ),
+ col_indices_values=np.array(
+ [
+ i * MAX_TEXTURE_SIZE
+ for i in range(0, 1 + (data.shape[1] - 1) // MAX_TEXTURE_SIZE)
+ ]
+ ),
)
if test_graphic:
@@ -216,15 +180,27 @@ def test_square(test_graphic):
else:
ta = TextureArray(data)
+ ta_shape = (3, 3)
+
check_texture_array(
data,
ta=ta,
- buffer_size=9,
- buffer_shape=(3, 3),
- row_indices_size=3,
- col_indices_size=3,
- row_indices_values=np.array([0, MAX_TEXTURE_SIZE, 2 * MAX_TEXTURE_SIZE]),
- col_indices_values=np.array([0, MAX_TEXTURE_SIZE, 2 * MAX_TEXTURE_SIZE]),
+ buffer_size=np.prod(ta_shape),
+ buffer_shape=ta_shape,
+ row_indices_size=ta_shape[0],
+ col_indices_size=ta_shape[1],
+ row_indices_values=np.array(
+ [
+ i * MAX_TEXTURE_SIZE
+ for i in range(0, 1 + (data.shape[0] - 1) // MAX_TEXTURE_SIZE)
+ ]
+ ),
+ col_indices_values=np.array(
+ [
+ i * MAX_TEXTURE_SIZE
+ for i in range(0, 1 + (data.shape[1] - 1) // MAX_TEXTURE_SIZE)
+ ]
+ ),
)
if test_graphic:
diff --git a/tests/utils_textures.py b/tests/utils_textures.py
new file mode 100644
index 000000000..f40a7371c
--- /dev/null
+++ b/tests/utils_textures.py
@@ -0,0 +1,64 @@
+import numpy as np
+import pygfx
+from numpy import testing as npt
+
+from fastplotlib.graphics.features import TextureArray
+from fastplotlib.graphics.image import _ImageTile
+
+
+MAX_TEXTURE_SIZE = 1024
+
+
+def check_texture_array(
+ data: np.ndarray,
+ ta: TextureArray,
+ buffer_size: int,
+ buffer_shape: tuple[int, int],
+ row_indices_size: int,
+ col_indices_size: int,
+ row_indices_values: np.ndarray,
+ col_indices_values: np.ndarray,
+):
+
+ npt.assert_almost_equal(ta.value, data)
+
+ assert ta.buffer.size == buffer_size
+ assert ta.buffer.shape == buffer_shape
+
+ assert all([isinstance(texture, pygfx.Texture) for texture in ta.buffer.ravel()])
+
+ assert ta.row_indices.size == row_indices_size
+ assert ta.col_indices.size == col_indices_size
+ npt.assert_array_equal(ta.row_indices, row_indices_values)
+ npt.assert_array_equal(ta.col_indices, col_indices_values)
+
+ # make sure chunking is correct
+ for texture, chunk_index, data_slice in ta:
+ assert ta.buffer[chunk_index] is texture
+ chunk_row, chunk_col = chunk_index
+
+ data_row_start_index = chunk_row * MAX_TEXTURE_SIZE
+ data_col_start_index = chunk_col * MAX_TEXTURE_SIZE
+
+ data_row_stop_index = min(
+ data.shape[0], data_row_start_index + MAX_TEXTURE_SIZE
+ )
+ data_col_stop_index = min(
+ data.shape[1], data_col_start_index + MAX_TEXTURE_SIZE
+ )
+
+ row_slice = slice(data_row_start_index, data_row_stop_index)
+ col_slice = slice(data_col_start_index, data_col_stop_index)
+
+ assert data_slice == (row_slice, col_slice)
+
+
+def check_image_graphic(texture_array, graphic):
+ # make sure each ImageTile has the right texture
+ for (texture, chunk_index, data_slice), img in zip(
+ texture_array, graphic.world_object.children
+ ):
+ assert isinstance(img, _ImageTile)
+ assert img.geometry.grid is texture
+ assert img.world.x == data_slice[1].start
+ assert img.world.y == data_slice[0].start