From fe966062827059f53c1b4eccd14a5c38086e06ca Mon Sep 17 00:00:00 2001 From: apasarkar Date: Wed, 15 Apr 2026 11:18:31 +0800 Subject: [PATCH 1/8] Includes nd vector code that works --- fastplotlib/widgets/nd_widget/__init__.py | 1 + fastplotlib/widgets/nd_widget/_nd_vector.py | 351 ++++++++++++++++++ fastplotlib/widgets/nd_widget/_ndw_subplot.py | 30 +- 3 files changed, 380 insertions(+), 2 deletions(-) create mode 100644 fastplotlib/widgets/nd_widget/_nd_vector.py diff --git a/fastplotlib/widgets/nd_widget/__init__.py b/fastplotlib/widgets/nd_widget/__init__.py index 65f448b54..4ebca556a 100644 --- a/fastplotlib/widgets/nd_widget/__init__.py +++ b/fastplotlib/widgets/nd_widget/__init__.py @@ -5,6 +5,7 @@ from ._base import NDProcessor, NDGraphic from ._nd_positions import NDPositions, NDPositionsProcessor, ndp_extras from ._nd_image import NDImageProcessor, NDImage + from ._nd_vector import NDVectorProcessor, NDVector from ._ndwidget import NDWidget else: diff --git a/fastplotlib/widgets/nd_widget/_nd_vector.py b/fastplotlib/widgets/nd_widget/_nd_vector.py new file mode 100644 index 000000000..44f79d2c2 --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_nd_vector.py @@ -0,0 +1,351 @@ +from collections.abc import Sequence, Generator +from typing import Callable, Any + +import numpy as np +from numpy.typing import ArrayLike + +from ...layouts import Subplot +from ...utils import subsample_array, ARRAY_LIKE_ATTRS, ArrayProtocol +from ...graphics import VectorsGraphic +from ._base import NDProcessor, NDGraphic, WindowFuncCallable, block_reentrance, AwaitedArray +from ._index import ReferenceIndex +from ._async import start_coroutine + + + +class NDVectorProcessor(NDProcessor): + def __init__( + self, + data: ArrayProtocol | None, + dims: Sequence[str], + spatial_dims: tuple[str, str, str], # must be in order, last dim must be 4 + window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable = None, + window_order: tuple[int, ...] = None, + spatial_func: Callable[[ArrayLike], ArrayLike] = None, + slider_dim_transforms=None, + ): + """ + ``NDProcessor`` subclass for n-dimensional vector data + + Produces (num_vectors, 4) slices for a ``VectorsGraphic``. The first two columns describe position of each vector, + the second two describe the orientation + + Parameters + ---------- + data: ArrayProtocol + Shape [num_vectors, 2, 2] or [num_vectors, 3, 2]. data[:, :, 0] gives the positions, data[:, :, 1] gives directions + n-dimension image data array + + dims: Sequence[str] + names for each dimension in ``data``. Dimensions not listed in + ``spatial_dims`` are treated as slider dimensions and **must** appear as + keys in the parent ``NDWidget``'s ``ref_ranges`` + Examples:: + ``("time", "depth", "row", "col")`` + ``("channels", "time", "xy")`` + ``("keypoints", "time", "xyz")`` + + A custom subclass's ``data`` object doesn't necessarily need to have these dims, but the ``get()`` method + must operate as if these dimensions exist and return an array that matches the spatial dimensions. + + dims: Sequence[str] + names for each dimension in ``data``. Dimensions not listed in + ``spatial_dims`` are treated as slider dimensions and **must** appear as + keys in the parent ``NDWidget``'s ``ref_ranges`` + Examples:: + ``("time", "depth", "row", "col")`` + ``("row", "col")`` + ``("other_dim", "depth", "time", "row", "col")`` + + dims in the array do not need to be in the order that you want to display them, for example you can have a + weird array where the dims are interpreted as: + ``("col", "depth", "row", "time")``, and then specify spatial_dims as ``("row", "col")``. + + spatial_dims : tuple[str, str] | tuple[str, str, str] + For NDVectors, this is always the last two dimensions without exception + + slider_dim_transforms : dict, optional + See :class:`NDProcessor`. + + window_funcs : dict, optional + See :class:`NDProcessor`. + + window_order : tuple, optional + See :class:`NDProcessor`. + + spatial_func : callable, optional + See :class:`NDProcessor`. + + See Also + -------- + NDProcessor : Base class with full parameter documentation. + NDImage : The ``NDGraphic`` that wraps this processor. + """ + + super().__init__( + data=data, + dims=dims, + spatial_dims=spatial_dims, + slider_dim_transforms=slider_dim_transforms, + window_funcs=window_funcs, + window_order=window_order, + spatial_func=spatial_func, + ) + + + @property + def data(self) -> ArrayProtocol | None: + """ + get or set managed data. If setting with new data, the new data is interpreted + to have the same dims (i.e. same dim names and ordering of dims). + """ + return self._data + + @data.setter + def data(self, data: ArrayProtocol): + if not isinstance(data, ArrayProtocol): + # check that it's generally array-like + raise TypeError( + f"`data` arrays must have all of the following attributes to be sufficiently array-like:\n" + f"{ARRAY_LIKE_ATTRS}, or they must be `None`" + ) + + if data.ndim < 3 or data.shape[-1] != 2 or data.shape[-2] not in (2, 3): + raise ValueError( + f"Final dimension must be , indicating spatial dimensions and magnitude of vector, you passed an array of shape {data.shape}" + ) + + self._data = data + + @property + def spatial_dims(self) -> tuple[str, str]: + """ + Spatial dims, **in display order**. + + [num_vectors, 2, 2] + """ + return self._spatial_dims + + @spatial_dims.setter + def spatial_dims(self, sdims: tuple[str, str, str]): + for dim in sdims: + if dim not in self.dims: + raise KeyError + + if len(sdims) != 3: + raise ValueError( + f"There must be exactly 3 spatial dims for vectors indicating [num_vectors, 2, 2] or [num_vectors, 3, 2] " + ) + + self._spatial_dims = tuple(sdims) + + + def get(self, indices: dict[str, Any]) -> AwaitedArray: + """ + Get the data at the given index, process data through the window functions. + + Note that we do not use __getitem__ here since the index is a tuple specifying a single integer + index for each dimension. Slices are not allowed, therefore __getitem__ is not suitable here. + + Parameters + ---------- + indices: tuple[int, ...] + Get the processed data at this index. Must provide a value for each dimension. + Example: get((100, 5)) + + """ + # this will be squeezed output, with dims in the order of the user set spatial dims + window_output = yield from self.get_window_output(indices) + + # apply spatial_func + if self.spatial_func is not None: + spatial_out = self._spatial_func(window_output) + if spatial_out.ndim != len(self.spatial_dims): + raise ValueError + + return spatial_out + + return window_output + + +class NDVector(NDGraphic): + def __init__( + self, + ref_index: ReferenceIndex, + subplot: Subplot, + data: ArrayProtocol | None, + dims: Sequence[str], + spatial_dims: tuple[str, str, str], # must be in order! [rows, cols] | [z, rows, cols] + window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable = None, + window_order: tuple[int, ...] = None, + spatial_func: Callable[[ArrayProtocol], ArrayProtocol] = None, + slider_dim_transforms=None, + name: str = None, + ): + """ + ``NDGraphic`` subclass for n-dimensional image rendering. + + Wraps an :class:`NDImageProcessor` and manages either an ``ImageGraphic`` or``ImageVolumeGraphic``. + swaps automatically when :attr:`spatial_dims` is reassigned at runtime. Also + owns a ``HistogramLUTTool`` for interactive vmin, vmax adjustment. + + Every dimension that is *not* listed in ``spatial_dims`` becomes a slider + dimension. Each slider dim must have a ``ReferenceRange`` defined in the + ``ReferenceIndex`` of the parent ``NDWidget``. The widget uses this to direct + a change in the ``ReferenceIndex`` and update the graphics. + + Parameters + ---------- + ref_index : ReferenceIndex + The shared reference index that delivers slider updates to this graphic. + + subplot : Subplot + parent subplot the NDGraphic is in + + data : array-like or None + Shape [num_vectors, 2, 2] or [num_vectors, 3, 2]. data[:, :, 0] gives the positions, data[:, :, 1] gives directions + n-dimension image data array + + dims : sequence of hashable + Name for every dimension of ``data``, in order. Non-spatial dims must + match keys in ``ref_index``. + + ex: ``("time", "depth", "row", "col")`` — ``"time"`` and ``"depth"`` must + be present in ``ref_index``. + + spatial_dims : tuple[str, str] | tuple[str, str, str] + Spatial dimensions **in order**: These dims are either (num_points, 2, 2) or (num_points, 3, 2) + + window_funcs : dict, optional + See :class:`NDProcessor`. + + window_order : tuple, optional + See :class:`NDProcessor`. + + spatial_func : callable, optional + See :class:`NDProcessor`. + + slider_dim_transforms : dict, optional + See :class:`NDProcessor`. + + name : str, optional + Name for the underlying graphic. + + See Also + -------- + NDImageProcessor : The processor that backs this graphic. + + """ + + if not (set(dims) - set(spatial_dims)).issubset(ref_index.dims): + raise IndexError( + f"all specified `dims` must either be a spatial dim or a slider dim " + f"specified in the NDWidget ref_ranges, provided dims: {dims}, " + f"spatial_dims: {spatial_dims}. Specified NDWidget ref_ranges: {ref_index.dims}" + ) + + super().__init__(subplot, name) + + self._ref_index = ref_index + + self._processor = NDVectorProcessor( + data, + dims=dims, + spatial_dims=spatial_dims, + window_funcs=window_funcs, + window_order=window_order, + spatial_func=spatial_func, + slider_dim_transforms=slider_dim_transforms, + ) + + self._graphic: VectorsGraphic | None = None + + # create a graphic + self._create_graphic() + + @property + def processor(self) -> NDVectorProcessor: + """NDProcessor that manages the data and produces data slices to display""" + return self._processor + + @property + def graphic( + self, + ) -> VectorsGraphic: + """Underlying Graphic object used to display the current data slice""" + return self._graphic + + @start_coroutine + def _create_graphic(self): + # Creates an ``ImageGraphic`` or ``ImageVolumeGraphic`` based on the number of spatial dims, + # adds it to the subplot, and resets the camera and histogram. + + if self.processor.data is None: + # no graphic if data is None, useful for initializing in null states when we want to set data later + return + + # get the data slice for this index + # this will only have the dims specified by ``spatial_dims`` + + data_slice = yield from self._get_data_slice(self.indices) + + + old_graphic = self._graphic + # check if we are replacing a graphic + # ex: swapping from 2D <-> 3D representation after ``spatial_dims`` was changed + if old_graphic is not None: + # delete the old graphic + self._subplot.delete_graphic(old_graphic) + + # create the new graphic + self._graphic = self._subplot.add_vectors(positions=data_slice[:, :, 0], + directions=data_slice[:, :, 1]) + + self._subplot.add_graphic(self._graphic) + + @property + def spatial_dims(self) -> tuple[str, str] | tuple[str, str, str]: + """ + get or set the spatial dims **in order** + + [row_dim, col_dim] or [row_dim, col_dim, rgb(a) dim] + """ + return self.processor.spatial_dims + + @spatial_dims.setter + def spatial_dims(self, dims: tuple[str, str] | tuple[str, str, str]): + self.processor.spatial_dims = dims + + # shape has probably changed, recreate graphic + self._create_graphic() + + @property + def indices(self) -> dict[str, Any]: + """get or set the indices, managed by the ReferenceIndex, users usually don't want to set this manually""" + return {d: self._ref_index[d] for d in self.processor.slider_dims} + + @block_reentrance + @start_coroutine + def set_indices( + self, indices: dict[str, Any], block: bool = True, timeout: float = 1.0 + ): + data_slice = yield from self._get_data_slice(indices) + + positions = data_slice[:, :, 0] + directions = data_slice[:, :, 1] + + self.graphic.positions = positions + self.graphic.directions = directions + + @property + def spatial_func(self) -> Callable[[ArrayProtocol], ArrayProtocol] | None: + """get or set the spatial_func, see docstring for details""" + # this is here even though it's the same in the base class since we can't create the image specific setter + # without also defining the property in this subclass. + return self.processor.spatial_func + + @spatial_func.setter + def spatial_func( + self, func: Callable[[ArrayProtocol], ArrayProtocol] + ) -> Callable | None: + self.processor.spatial_func = func \ No newline at end of file diff --git a/fastplotlib/widgets/nd_widget/_ndw_subplot.py b/fastplotlib/widgets/nd_widget/_ndw_subplot.py index 6666b3fc1..f63fad152 100644 --- a/fastplotlib/widgets/nd_widget/_ndw_subplot.py +++ b/fastplotlib/widgets/nd_widget/_ndw_subplot.py @@ -3,10 +3,10 @@ import numpy as np -from ... import ScatterCollection, ScatterStack, LineCollection, LineStack, ImageGraphic +from ... import ScatterCollection, ScatterStack, LineCollection, LineStack, ImageGraphic, VectorsGraphic from ...layouts import Subplot from ...utils import ArrayProtocol -from . import NDImage, NDPositions +from . import NDImage, NDPositions, NDVector from ._base import NDGraphic, WindowFuncCallable @@ -73,6 +73,32 @@ def add_nd_image( self._nd_graphics.append(nd) return nd + def add_nd_vector(self, + data: ArrayProtocol | None, + dims: Sequence[str], + spatial_dims: tuple[str, str, str], # must be in order! [rows, cols] | [z, rows, cols] + window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable = None, + window_order: tuple[int, ...] = None, + spatial_func: Callable[[ArrayProtocol], ArrayProtocol] = None, + slider_dim_transforms=None, + name: str = None, + ): + nd = NDVector( + self.ndw.indices, + self._subplot, + data=data, + dims=dims, + spatial_dims=spatial_dims, + window_funcs=window_funcs, + window_order=window_order, + spatial_func=spatial_func, + slider_dim_transforms=slider_dim_transforms, + name=name + ) + + self._nd_graphics.append(nd) + return nd + def add_nd_scatter(self, *args, **kwargs): # TODO: better func signature here, send all kwargs to processor_kwargs nd = NDPositions( From 2b1f79c8c3ba6fe8a34dc229584387aee4a4ccc7 Mon Sep 17 00:00:00 2001 From: apasarkar Date: Wed, 15 Apr 2026 12:08:22 +0800 Subject: [PATCH 2/8] Faster position assignment, no more for loop --- fastplotlib/graphics/features/_vectors.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/fastplotlib/graphics/features/_vectors.py b/fastplotlib/graphics/features/_vectors.py index 729562b06..d75f2a3fc 100644 --- a/fastplotlib/graphics/features/_vectors.py +++ b/fastplotlib/graphics/features/_vectors.py @@ -82,12 +82,7 @@ def set_value(self, graphic, value: np.ndarray): else: self._positions[:] = value - for i in range(self._positions.shape[0]): - # only need to update the translation vector - graphic.world_object.instance_buffer.data["matrix"][i][3, 0:3] = ( - self._positions[i] - ) - + graphic.world_object.instance_buffer.data["matrix"][:, 3, 0:3] = self._positions[:] graphic.world_object.instance_buffer.update_full() event = GraphicFeatureEvent(type="positions", info={"value": value}) @@ -183,3 +178,4 @@ def set_value(self, graphic, value: np.ndarray): event = GraphicFeatureEvent(type="directions", info={"value": value}) self._call_event_handlers(event) + From d508a32eecf25d15b3d0da608eb78eaa4072816f Mon Sep 17 00:00:00 2001 From: apasarkar Date: Wed, 15 Apr 2026 13:02:26 +0800 Subject: [PATCH 3/8] Batched computations for vector set function --- fastplotlib/graphics/features/_vectors.py | 160 +++++++++++++++++++++- 1 file changed, 154 insertions(+), 6 deletions(-) diff --git a/fastplotlib/graphics/features/_vectors.py b/fastplotlib/graphics/features/_vectors.py index d75f2a3fc..51f39dee8 100644 --- a/fastplotlib/graphics/features/_vectors.py +++ b/fastplotlib/graphics/features/_vectors.py @@ -166,16 +166,164 @@ def set_value(self, graphic, value: np.ndarray): # vector determines the size of the vector magnitudes = np.linalg.norm(self._directions, axis=1, ord=2) - for i in range(self._directions.shape[0]): + # for i in range(self._directions.shape[0]): # get quaternion to rotate vector to new direction - rotation = la.quat_from_vecs(self.init_direction, self._directions[i]) - # get the new transform - transform = la.mat_compose(graphic.positions[i], rotation, magnitudes[i]) - # set the buffer - graphic.world_object.instance_buffer.data["matrix"][i] = transform.T + rotation = quat_from_vecs(self.init_direction, self._directions[:]) + # get the new transform + transform = mat_compose(graphic.positions[:], rotation, magnitudes[:]) + # set the buffer + graphic.world_object.instance_buffer.data["matrix"][:] = transform.transpose(0, 2, 1) graphic.world_object.instance_buffer.update_full() event = GraphicFeatureEvent(type="directions", info={"value": value}) self._call_event_handlers(event) + + +def quat_from_vecs(source, target, out=None, dtype=None) -> np.ndarray: + source = np.asarray(source, dtype=float) + if source.ndim == 1: + source = source[None, :] + target = np.asarray(target, dtype=float) + if target.ndim == 1: + target = target[None, :] + + num_vecs = target.shape[0] + result_shape = (num_vecs, 4) + if out is None: + out = np.empty(result_shape, dtype=dtype) + + axis = np.cross(source, target) # (num_pts, 3) + axis_norm = np.linalg.norm(axis, axis=-1) # (num_pts,) + angle = np.arctan2(axis_norm, np.sum(source * target, axis=-1)) # (num_pts,) + + # Handle degenerate case: source and target are parallel (axis is zero vector). + # Pick any axis orthogonal to source as a replacement. + use_fallback = axis_norm == 0 + if np.any(use_fallback): + t = np.broadcast_to(source, (num_vecs, 3))[use_fallback] + + # Better case split: + y_zero = t[:, 1] == 0 + z_zero = t[:, 2] == 0 + neither_zero = ~y_zero & ~z_zero + + fb = np.empty((y_zero.shape[0], 3), dtype=float) + fb[y_zero] = (0., 1., 0.) + fb[~y_zero & z_zero] = (0., 0., 1.) + fb[neither_zero, 0] = 0. + fb[neither_zero, 1] = -t[neither_zero, 2] + fb[neither_zero, 2] = t[neither_zero, 1] + + axis[use_fallback] = fb + + return quat_from_axis_angle(axis, angle, out=out) + + +def quat_from_axis_angle(axis, angle, out=None, dtype=None) -> np.ndarray: + """Quaternion from axis-angle pair. + + Create a quaternion representing the rotation of an given angle + about a given unit vector + + Parameters + ---------- + axis : ndarray, [3] + Unit vector + angle : number + The angle (in radians) to rotate about axis + out : ndarray, optional + A location into which the result is stored. If provided, it + must have a shape that the inputs broadcast to. If not provided or + None, a freshly-allocated array is returned. A tuple must have + length equal to the number of outputs. + dtype : data-type, optional + Overrides the data type of the result. + + Returns + ------- + ndarray, [4] + Quaternion. + """ + + axis = np.asarray(axis, dtype=float) + angle = np.asarray(angle, dtype=float) + + if out is None: + out_shape = np.broadcast_shapes(axis.shape[:-1], angle.shape) + out = np.empty((*out_shape, 4), dtype=dtype) + + # result should be independent of the length of the given axis + lengths_shape = (*axis.shape[:-1], 1) + axis = axis / np.linalg.norm(axis, axis=-1).reshape(lengths_shape) + + out[..., :3] = axis * np.sin(angle / 2).reshape(lengths_shape) + out[..., 3] = np.cos(angle / 2) + + return out + + +def mat_compose(translation, rotation, scaling, /, *, out=None, dtype=None) -> np.ndarray: + """ + Compose transformation matrices given translation vectors, quaternions, + and scaling vectors. + + Parameters + ---------- + translation : ndarray, [3] or [num_vectors, 3] + rotation : ndarray, [4] or [num_vectors, 4] + scaling : ndarray, [3] or [num_vectors, 3] + + Returns + ------- + ndarray, [num_vectors, 4, 4] + """ + rotation = np.asarray(rotation, dtype=float) + translation = np.asarray(translation, dtype=float) + scaling = np.asarray(scaling, dtype=float) + + if rotation.ndim == 1: + rotation = rotation[None, :] + if translation.ndim == 1: + translation = translation[None, :] + if scaling.ndim == 0: + scaling = np.full((1, 3), scaling) + elif scaling.ndim == 1 and scaling.shape[0] == 3: + scaling = scaling[None, :] + elif scaling.ndim == 1: + scaling = scaling[:, None] * np.ones(3) + + num_vectors = max(rotation.shape[0], translation.shape[0], scaling.shape[0]) + + if out is None: + out = np.zeros((num_vectors, 4, 4), dtype=dtype) + else: + out[..., :, :] = 0 + + x, y, z, w = rotation[:, 0], rotation[:, 1], rotation[:, 2], rotation[:, 3] + + x2, y2, z2 = x + x, y + y, z + z + xx, xy, xz = x * x2, x * y2, x * z2 + yy, yz, zz = y * y2, y * z2, z * z2 + wx, wy, wz = w * x2, w * y2, w * z2 + + sx, sy, sz = scaling[:, 0], scaling[:, 1], scaling[:, 2] + + + out[:, 0, 0] = (1 - (yy + zz)) * sx + out[:, 1, 0] = (xy + wz) * sx + out[:, 2, 0] = (xz - wy) * sx + + out[:, 0, 1] = (xy - wz) * sy + out[:, 1, 1] = (1 - (xx + zz)) * sy + out[:, 2, 1] = (yz + wx) * sy + + out[:, 0, 2] = (xz + wy) * sz + out[:, 1, 2] = (yz - wx) * sz + out[:, 2, 2] = (1 - (xx + yy)) * sz + + out[:, 0:3, 3] = translation + out[:, 3, 3] = 1 + + return out \ No newline at end of file From a87ed4c4289720fd2201018edcb542823b4a4231 Mon Sep 17 00:00:00 2001 From: apasarkar Date: Wed, 15 Apr 2026 13:08:26 +0800 Subject: [PATCH 4/8] Formatting updates --- fastplotlib/graphics/features/_vectors.py | 2 ++ fastplotlib/widgets/nd_widget/_nd_vector.py | 4 ++-- fastplotlib/widgets/nd_widget/_ndw_subplot.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/fastplotlib/graphics/features/_vectors.py b/fastplotlib/graphics/features/_vectors.py index 51f39dee8..a4533372b 100644 --- a/fastplotlib/graphics/features/_vectors.py +++ b/fastplotlib/graphics/features/_vectors.py @@ -82,7 +82,9 @@ def set_value(self, graphic, value: np.ndarray): else: self._positions[:] = value + # Only need to update the translation vector graphic.world_object.instance_buffer.data["matrix"][:, 3, 0:3] = self._positions[:] + graphic.world_object.instance_buffer.update_full() event = GraphicFeatureEvent(type="positions", info={"value": value}) diff --git a/fastplotlib/widgets/nd_widget/_nd_vector.py b/fastplotlib/widgets/nd_widget/_nd_vector.py index 44f79d2c2..1d5a82a1a 100644 --- a/fastplotlib/widgets/nd_widget/_nd_vector.py +++ b/fastplotlib/widgets/nd_widget/_nd_vector.py @@ -1,5 +1,5 @@ -from collections.abc import Sequence, Generator -from typing import Callable, Any +from collections.abc import Sequence, Generator, Callable +from typing import Any import numpy as np from numpy.typing import ArrayLike diff --git a/fastplotlib/widgets/nd_widget/_ndw_subplot.py b/fastplotlib/widgets/nd_widget/_ndw_subplot.py index f63fad152..82b336a60 100644 --- a/fastplotlib/widgets/nd_widget/_ndw_subplot.py +++ b/fastplotlib/widgets/nd_widget/_ndw_subplot.py @@ -76,7 +76,7 @@ def add_nd_image( def add_nd_vector(self, data: ArrayProtocol | None, dims: Sequence[str], - spatial_dims: tuple[str, str, str], # must be in order! [rows, cols] | [z, rows, cols] + spatial_dims: tuple[str, str, str], window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable = None, window_order: tuple[int, ...] = None, spatial_func: Callable[[ArrayProtocol], ArrayProtocol] = None, From f49b56998827c7127893e36d8aa042b472441a0d Mon Sep 17 00:00:00 2001 From: apasarkar Date: Wed, 15 Apr 2026 13:31:57 +0800 Subject: [PATCH 5/8] Includes improved annotations and changes ordering of the data slice from the vectors graphic --- fastplotlib/widgets/nd_widget/_nd_vector.py | 35 +++++++++++---------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_nd_vector.py b/fastplotlib/widgets/nd_widget/_nd_vector.py index 1d5a82a1a..f713275bd 100644 --- a/fastplotlib/widgets/nd_widget/_nd_vector.py +++ b/fastplotlib/widgets/nd_widget/_nd_vector.py @@ -33,7 +33,7 @@ def __init__( Parameters ---------- data: ArrayProtocol - Shape [num_vectors, 2, 2] or [num_vectors, 3, 2]. data[:, :, 0] gives the positions, data[:, :, 1] gives directions + Shape [num_vectors, 2, 2] or [num_vectors, 3, 2]. data[:, 0, :] gives the positions, data[:, 1, :] gives directions n-dimension image data array dims: Sequence[str] @@ -110,9 +110,9 @@ def data(self, data: ArrayProtocol): f"{ARRAY_LIKE_ATTRS}, or they must be `None`" ) - if data.ndim < 3 or data.shape[-1] != 2 or data.shape[-2] not in (2, 3): + if data.ndim < 3: raise ValueError( - f"Final dimension must be , indicating spatial dimensions and magnitude of vector, you passed an array of shape {data.shape}" + f"Shape must be (..., num_vecs, 2, [2 or 3]) you passed an array of shape {data.shape}" ) self._data = data @@ -121,8 +121,7 @@ def data(self, data: ArrayProtocol): def spatial_dims(self) -> tuple[str, str]: """ Spatial dims, **in display order**. - - [num_vectors, 2, 2] + Dimensions in order are num_vectors, position/direction, xy[z], so the shape is [num_vectors, 2, 2 or 3] """ return self._spatial_dims @@ -134,11 +133,17 @@ def spatial_dims(self, sdims: tuple[str, str, str]): if len(sdims) != 3: raise ValueError( - f"There must be exactly 3 spatial dims for vectors indicating [num_vectors, 2, 2] or [num_vectors, 3, 2] " + f"There must be exactly 3 spatial dims for vectors indicating [num_vectors, 2, 2] or [num_vectors, 2, 3] " ) self._spatial_dims = tuple(sdims) + if self.shape[self.spatial_dims[-2]] != 2 or self.shape[self.spatial_dims[-1]] not in (2, 3): + raise ValueError( + f"Spatial dimensions must haves shape (num_vecs, 2, [2 or 3]) you passed an array of shape {data.shape}" + ) + + def get(self, indices: dict[str, Any]) -> AwaitedArray: """ @@ -292,28 +297,26 @@ def _create_graphic(self): old_graphic = self._graphic # check if we are replacing a graphic - # ex: swapping from 2D <-> 3D representation after ``spatial_dims`` was changed if old_graphic is not None: # delete the old graphic self._subplot.delete_graphic(old_graphic) # create the new graphic - self._graphic = self._subplot.add_vectors(positions=data_slice[:, :, 0], - directions=data_slice[:, :, 1]) + self._graphic = self._subplot.add_vectors(positions=data_slice[:, 0, :], + directions=data_slice[:, 1, :]) self._subplot.add_graphic(self._graphic) @property - def spatial_dims(self) -> tuple[str, str] | tuple[str, str, str]: + def spatial_dims(self) -> tuple[str, str, str]: """ - get or set the spatial dims **in order** - - [row_dim, col_dim] or [row_dim, col_dim, rgb(a) dim] + get or set the spatial dims **in order**. + Spatial dim shape here is [num_vectors, position/dimension (2), xy[z] (2 or 3)] """ return self.processor.spatial_dims @spatial_dims.setter - def spatial_dims(self, dims: tuple[str, str] | tuple[str, str, str]): + def spatial_dims(self, dims: tuple[str, str, str]): self.processor.spatial_dims = dims # shape has probably changed, recreate graphic @@ -331,8 +334,8 @@ def set_indices( ): data_slice = yield from self._get_data_slice(indices) - positions = data_slice[:, :, 0] - directions = data_slice[:, :, 1] + positions = data_slice[:, 0, :] + directions = data_slice[:, 1, :] self.graphic.positions = positions self.graphic.directions = directions From 5b880bd32d77e6d421794a47dc6991dec6f4ac8e Mon Sep 17 00:00:00 2001 From: apasarkar Date: Wed, 15 Apr 2026 23:39:40 +0800 Subject: [PATCH 6/8] Some further improvements to the pylinalg code --- fastplotlib/graphics/features/_vectors.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/fastplotlib/graphics/features/_vectors.py b/fastplotlib/graphics/features/_vectors.py index a4533372b..d3210817f 100644 --- a/fastplotlib/graphics/features/_vectors.py +++ b/fastplotlib/graphics/features/_vectors.py @@ -198,7 +198,7 @@ def quat_from_vecs(source, target, out=None, dtype=None) -> np.ndarray: axis = np.cross(source, target) # (num_pts, 3) axis_norm = np.linalg.norm(axis, axis=-1) # (num_pts,) - angle = np.arctan2(axis_norm, np.sum(source * target, axis=-1)) # (num_pts,) + angle = np.arctan2(axis_norm, (target @ source.T).squeeze(1)) # (num_pts,) # Handle degenerate case: source and target are parallel (axis is zero vector). # Pick any axis orthogonal to source as a replacement. @@ -231,9 +231,9 @@ def quat_from_axis_angle(axis, angle, out=None, dtype=None) -> np.ndarray: Parameters ---------- - axis : ndarray, [3] + axis : ndarray, [num_vectors, 3] or [3] Unit vector - angle : number + angle : number or np.ndarray of shape [num_pts,] The angle (in radians) to rotate about axis out : ndarray, optional A location into which the result is stored. If provided, it @@ -245,7 +245,7 @@ def quat_from_axis_angle(axis, angle, out=None, dtype=None) -> np.ndarray: Returns ------- - ndarray, [4] + ndarray, [num_pts, 4] or [4] Quaternion. """ @@ -263,7 +263,7 @@ def quat_from_axis_angle(axis, angle, out=None, dtype=None) -> np.ndarray: out[..., :3] = axis * np.sin(angle / 2).reshape(lengths_shape) out[..., 3] = np.cos(angle / 2) - return out + return out.squeeze(0) if out.shape[0] == 1 else out def mat_compose(translation, rotation, scaling, /, *, out=None, dtype=None) -> np.ndarray: @@ -279,7 +279,7 @@ def mat_compose(translation, rotation, scaling, /, *, out=None, dtype=None) -> n Returns ------- - ndarray, [num_vectors, 4, 4] + ndarray, [num_vectors, 4, 4] or [4, 4] """ rotation = np.asarray(rotation, dtype=float) translation = np.asarray(translation, dtype=float) @@ -328,4 +328,4 @@ def mat_compose(translation, rotation, scaling, /, *, out=None, dtype=None) -> n out[:, 0:3, 3] = translation out[:, 3, 3] = 1 - return out \ No newline at end of file + return out.squeeze(0) if out.shape[0] == 1 else out \ No newline at end of file From 6680492bcad32c48bac8707e26caffb38f234cae Mon Sep 17 00:00:00 2001 From: apasarkar Date: Thu, 16 Apr 2026 00:39:31 +0800 Subject: [PATCH 7/8] Fixes remaining formatting and naming issues --- fastplotlib/graphics/features/_vectors.py | 2 - fastplotlib/widgets/nd_widget/__init__.py | 2 +- .../{_nd_vector.py => _nd_vectors.py} | 53 ++++++++++--------- fastplotlib/widgets/nd_widget/_ndw_subplot.py | 44 +++++++++------ 4 files changed, 57 insertions(+), 44 deletions(-) rename fastplotlib/widgets/nd_widget/{_nd_vector.py => _nd_vectors.py} (89%) diff --git a/fastplotlib/graphics/features/_vectors.py b/fastplotlib/graphics/features/_vectors.py index d3210817f..82767ca21 100644 --- a/fastplotlib/graphics/features/_vectors.py +++ b/fastplotlib/graphics/features/_vectors.py @@ -168,8 +168,6 @@ def set_value(self, graphic, value: np.ndarray): # vector determines the size of the vector magnitudes = np.linalg.norm(self._directions, axis=1, ord=2) - # for i in range(self._directions.shape[0]): - # get quaternion to rotate vector to new direction rotation = quat_from_vecs(self.init_direction, self._directions[:]) # get the new transform transform = mat_compose(graphic.positions[:], rotation, magnitudes[:]) diff --git a/fastplotlib/widgets/nd_widget/__init__.py b/fastplotlib/widgets/nd_widget/__init__.py index 4ebca556a..8416288c7 100644 --- a/fastplotlib/widgets/nd_widget/__init__.py +++ b/fastplotlib/widgets/nd_widget/__init__.py @@ -5,7 +5,7 @@ from ._base import NDProcessor, NDGraphic from ._nd_positions import NDPositions, NDPositionsProcessor, ndp_extras from ._nd_image import NDImageProcessor, NDImage - from ._nd_vector import NDVectorProcessor, NDVector + from ._nd_vectors import NDVectorsProcessor, NDVectors from ._ndwidget import NDWidget else: diff --git a/fastplotlib/widgets/nd_widget/_nd_vector.py b/fastplotlib/widgets/nd_widget/_nd_vectors.py similarity index 89% rename from fastplotlib/widgets/nd_widget/_nd_vector.py rename to fastplotlib/widgets/nd_widget/_nd_vectors.py index f713275bd..3b09a6d04 100644 --- a/fastplotlib/widgets/nd_widget/_nd_vector.py +++ b/fastplotlib/widgets/nd_widget/_nd_vectors.py @@ -7,13 +7,18 @@ from ...layouts import Subplot from ...utils import subsample_array, ARRAY_LIKE_ATTRS, ArrayProtocol from ...graphics import VectorsGraphic -from ._base import NDProcessor, NDGraphic, WindowFuncCallable, block_reentrance, AwaitedArray +from ._base import ( + NDProcessor, + NDGraphic, + WindowFuncCallable, + block_reentrance, + AwaitedArray, +) from ._index import ReferenceIndex from ._async import start_coroutine - -class NDVectorProcessor(NDProcessor): +class NDVectorsProcessor(NDProcessor): def __init__( self, data: ArrayProtocol | None, @@ -27,14 +32,13 @@ def __init__( """ ``NDProcessor`` subclass for n-dimensional vector data - Produces (num_vectors, 4) slices for a ``VectorsGraphic``. The first two columns describe position of each vector, - the second two describe the orientation + Produces (num_vectors, 2, [2 or 3]) slices for a ``VectorsGraphic``. The last two dimensions describe the + position/direction and the 2D/3D spatial coordinate, respectively. Parameters ---------- data: ArrayProtocol - Shape [num_vectors, 2, 2] or [num_vectors, 3, 2]. data[:, 0, :] gives the positions, data[:, 1, :] gives directions - n-dimension image data array + Shape [..., num_vectors, 2, 2] or [..., num_vectors, 2, 3]. data[..., 0, :] gives the positions, data[..., 1, :] gives directions dims: Sequence[str] names for each dimension in ``data``. Dimensions not listed in @@ -92,7 +96,6 @@ def __init__( spatial_func=spatial_func, ) - @property def data(self) -> ArrayProtocol | None: """ @@ -138,13 +141,13 @@ def spatial_dims(self, sdims: tuple[str, str, str]): self._spatial_dims = tuple(sdims) - if self.shape[self.spatial_dims[-2]] != 2 or self.shape[self.spatial_dims[-1]] not in (2, 3): + if self.shape[self.spatial_dims[-2]] != 2 or self.shape[ + self.spatial_dims[-1] + ] not in (2, 3): raise ValueError( f"Spatial dimensions must haves shape (num_vecs, 2, [2 or 3]) you passed an array of shape {data.shape}" ) - - def get(self, indices: dict[str, Any]) -> AwaitedArray: """ Get the data at the given index, process data through the window functions. @@ -173,14 +176,16 @@ def get(self, indices: dict[str, Any]) -> AwaitedArray: return window_output -class NDVector(NDGraphic): +class NDVectors(NDGraphic): def __init__( self, ref_index: ReferenceIndex, subplot: Subplot, data: ArrayProtocol | None, dims: Sequence[str], - spatial_dims: tuple[str, str, str], # must be in order! [rows, cols] | [z, rows, cols] + spatial_dims: tuple[ + str, str, str + ], # must be in order! [rows, cols] | [z, rows, cols] window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable = None, window_order: tuple[int, ...] = None, spatial_func: Callable[[ArrayProtocol], ArrayProtocol] = None, @@ -188,11 +193,9 @@ def __init__( name: str = None, ): """ - ``NDGraphic`` subclass for n-dimensional image rendering. + ``NDGraphic`` subclass for n-dimensional vector rendering - Wraps an :class:`NDImageProcessor` and manages either an ``ImageGraphic`` or``ImageVolumeGraphic``. - swaps automatically when :attr:`spatial_dims` is reassigned at runtime. Also - owns a ``HistogramLUTTool`` for interactive vmin, vmax adjustment. + Wraps an :class:`VectorGraphic` Every dimension that is *not* listed in ``spatial_dims`` becomes a slider dimension. Each slider dim must have a ``ReferenceRange`` defined in the @@ -253,7 +256,7 @@ def __init__( self._ref_index = ref_index - self._processor = NDVectorProcessor( + self._processor = NDVectorsProcessor( data, dims=dims, spatial_dims=spatial_dims, @@ -269,7 +272,7 @@ def __init__( self._create_graphic() @property - def processor(self) -> NDVectorProcessor: + def processor(self) -> NDVectorsProcessor: """NDProcessor that manages the data and produces data slices to display""" return self._processor @@ -294,7 +297,6 @@ def _create_graphic(self): data_slice = yield from self._get_data_slice(self.indices) - old_graphic = self._graphic # check if we are replacing a graphic if old_graphic is not None: @@ -302,8 +304,9 @@ def _create_graphic(self): self._subplot.delete_graphic(old_graphic) # create the new graphic - self._graphic = self._subplot.add_vectors(positions=data_slice[:, 0, :], - directions=data_slice[:, 1, :]) + self._graphic = self._subplot.add_vectors( + positions=data_slice[:, 0], directions=data_slice[:, 1] + ) self._subplot.add_graphic(self._graphic) @@ -334,8 +337,8 @@ def set_indices( ): data_slice = yield from self._get_data_slice(indices) - positions = data_slice[:, 0, :] - directions = data_slice[:, 1, :] + positions = data_slice[:, 0] + directions = data_slice[:, 1] self.graphic.positions = positions self.graphic.directions = directions @@ -351,4 +354,4 @@ def spatial_func(self) -> Callable[[ArrayProtocol], ArrayProtocol] | None: def spatial_func( self, func: Callable[[ArrayProtocol], ArrayProtocol] ) -> Callable | None: - self.processor.spatial_func = func \ No newline at end of file + self.processor.spatial_func = func diff --git a/fastplotlib/widgets/nd_widget/_ndw_subplot.py b/fastplotlib/widgets/nd_widget/_ndw_subplot.py index 82b336a60..a684c32cc 100644 --- a/fastplotlib/widgets/nd_widget/_ndw_subplot.py +++ b/fastplotlib/widgets/nd_widget/_ndw_subplot.py @@ -3,10 +3,17 @@ import numpy as np -from ... import ScatterCollection, ScatterStack, LineCollection, LineStack, ImageGraphic, VectorsGraphic +from ... import ( + ScatterCollection, + ScatterStack, + LineCollection, + LineStack, + ImageGraphic, + VectorsGraphic, +) from ...layouts import Subplot from ...utils import ArrayProtocol -from . import NDImage, NDPositions, NDVector +from . import NDImage, NDPositions, NDVectors from ._base import NDGraphic, WindowFuncCallable @@ -20,6 +27,7 @@ class NDWSubplot: Note: ``NDWSubplot`` is not meant to be constructed directly, it only exists as part of an ``NDWidget`` """ + def __init__(self, ndw, subplot: Subplot): self.ndw = ndw self._subplot = subplot @@ -58,7 +66,10 @@ def add_nd_image( slider_dim_transforms=None, name: str = None, ): - nd = NDImage(self.ndw.indices, self._subplot, data=data, + nd = NDImage( + self.ndw.indices, + self._subplot, + data=data, dims=dims, spatial_dims=spatial_dims, rgb_dim=rgb_dim, @@ -68,22 +79,23 @@ def add_nd_image( compute_histogram=compute_histogram, slider_dim_transforms=slider_dim_transforms, name=name, - ) + ) self._nd_graphics.append(nd) return nd - def add_nd_vector(self, - data: ArrayProtocol | None, - dims: Sequence[str], - spatial_dims: tuple[str, str, str], - window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable = None, - window_order: tuple[int, ...] = None, - spatial_func: Callable[[ArrayProtocol], ArrayProtocol] = None, - slider_dim_transforms=None, - name: str = None, - ): - nd = NDVector( + def add_nd_vectors( + self, + data: ArrayProtocol | None, + dims: Sequence[str], + spatial_dims: tuple[str, str, str], + window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable = None, + window_order: tuple[int, ...] = None, + spatial_func: Callable[[ArrayProtocol], ArrayProtocol] = None, + slider_dim_transforms=None, + name: str = None, + ) -> NDVectors: + nd = NDVectors( self.ndw.indices, self._subplot, data=data, @@ -93,7 +105,7 @@ def add_nd_vector(self, window_order=window_order, spatial_func=spatial_func, slider_dim_transforms=slider_dim_transforms, - name=name + name=name, ) self._nd_graphics.append(nd) From 0289e17bf7942bac5ef036109302aa3a1aedd2d8 Mon Sep 17 00:00:00 2001 From: Kushal Kolar Date: Wed, 15 Apr 2026 14:31:34 -0400 Subject: [PATCH 8/8] Apply suggestions from code review Co-authored-by: Kushal Kolar --- fastplotlib/widgets/nd_widget/_nd_vectors.py | 21 +++++--------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_nd_vectors.py b/fastplotlib/widgets/nd_widget/_nd_vectors.py index 3b09a6d04..85ff19e13 100644 --- a/fastplotlib/widgets/nd_widget/_nd_vectors.py +++ b/fastplotlib/widgets/nd_widget/_nd_vectors.py @@ -45,28 +45,17 @@ def __init__( ``spatial_dims`` are treated as slider dimensions and **must** appear as keys in the parent ``NDWidget``'s ``ref_ranges`` Examples:: - ``("time", "depth", "row", "col")`` - ``("channels", "time", "xy")`` - ``("keypoints", "time", "xyz")`` A custom subclass's ``data`` object doesn't necessarily need to have these dims, but the ``get()`` method must operate as if these dimensions exist and return an array that matches the spatial dimensions. - dims: Sequence[str] - names for each dimension in ``data``. Dimensions not listed in - ``spatial_dims`` are treated as slider dimensions and **must** appear as - keys in the parent ``NDWidget``'s ``ref_ranges`` - Examples:: - ``("time", "depth", "row", "col")`` - ``("row", "col")`` - ``("other_dim", "depth", "time", "row", "col")`` dims in the array do not need to be in the order that you want to display them, for example you can have a weird array where the dims are interpreted as: ``("col", "depth", "row", "time")``, and then specify spatial_dims as ``("row", "col")``. spatial_dims : tuple[str, str] | tuple[str, str, str] - For NDVectors, this is always the last two dimensions without exception + The dim names that indicate [n_vectors, positions & directions, xy(z)], **in that order** slider_dim_transforms : dict, optional See :class:`NDProcessor`. @@ -83,7 +72,7 @@ def __init__( See Also -------- NDProcessor : Base class with full parameter documentation. - NDImage : The ``NDGraphic`` that wraps this processor. + NDVectors : The ``NDGraphic`` that uses this processor by default. """ super().__init__( @@ -123,7 +112,7 @@ def data(self, data: ArrayProtocol): @property def spatial_dims(self) -> tuple[str, str]: """ - Spatial dims, **in display order**. + Spatial dims, **in order** Dimensions in order are num_vectors, position/direction, xy[z], so the shape is [num_vectors, 2, 2 or 3] """ return self._spatial_dims @@ -185,7 +174,7 @@ def __init__( dims: Sequence[str], spatial_dims: tuple[ str, str, str - ], # must be in order! [rows, cols] | [z, rows, cols] + ], # must be in order! window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable = None, window_order: tuple[int, ...] = None, spatial_func: Callable[[ArrayProtocol], ArrayProtocol] = None, @@ -222,7 +211,7 @@ def __init__( be present in ``ref_index``. spatial_dims : tuple[str, str] | tuple[str, str, str] - Spatial dimensions **in order**: These dims are either (num_points, 2, 2) or (num_points, 3, 2) + Spatial dimensions **in order**: These dims are either [n_vectors, 2, 2] or [n_vectors, 2, 3], indicating [n_vectors, positions & directions, xy(z)] window_funcs : dict, optional See :class:`NDProcessor`.