diff --git a/fastplotlib/graphics/features/_vectors.py b/fastplotlib/graphics/features/_vectors.py index 729562b06..a4533372b 100644 --- a/fastplotlib/graphics/features/_vectors.py +++ b/fastplotlib/graphics/features/_vectors.py @@ -82,11 +82,8 @@ def set_value(self, graphic, value: np.ndarray): else: self._positions[:] = value - for i in range(self._positions.shape[0]): - # only need to update the translation vector - graphic.world_object.instance_buffer.data["matrix"][i][3, 0:3] = ( - self._positions[i] - ) + # Only need to update the translation vector + graphic.world_object.instance_buffer.data["matrix"][:, 3, 0:3] = self._positions[:] graphic.world_object.instance_buffer.update_full() @@ -171,15 +168,164 @@ def set_value(self, graphic, value: np.ndarray): # vector determines the size of the vector magnitudes = np.linalg.norm(self._directions, axis=1, ord=2) - for i in range(self._directions.shape[0]): + # for i in range(self._directions.shape[0]): # get quaternion to rotate vector to new direction - rotation = la.quat_from_vecs(self.init_direction, self._directions[i]) - # get the new transform - transform = la.mat_compose(graphic.positions[i], rotation, magnitudes[i]) - # set the buffer - graphic.world_object.instance_buffer.data["matrix"][i] = transform.T + rotation = quat_from_vecs(self.init_direction, self._directions[:]) + # get the new transform + transform = mat_compose(graphic.positions[:], rotation, magnitudes[:]) + # set the buffer + graphic.world_object.instance_buffer.data["matrix"][:] = transform.transpose(0, 2, 1) graphic.world_object.instance_buffer.update_full() event = GraphicFeatureEvent(type="directions", info={"value": value}) self._call_event_handlers(event) + + + +def quat_from_vecs(source, target, out=None, dtype=None) -> np.ndarray: + source = np.asarray(source, dtype=float) + if source.ndim == 1: + source = source[None, :] + target = np.asarray(target, dtype=float) + if target.ndim == 1: + target = target[None, :] + + num_vecs = target.shape[0] + result_shape = (num_vecs, 4) + if out is None: + out = np.empty(result_shape, dtype=dtype) + + axis = np.cross(source, target) # (num_pts, 3) + axis_norm = np.linalg.norm(axis, axis=-1) # (num_pts,) + angle = np.arctan2(axis_norm, np.sum(source * target, axis=-1)) # (num_pts,) + + # Handle degenerate case: source and target are parallel (axis is zero vector). + # Pick any axis orthogonal to source as a replacement. + use_fallback = axis_norm == 0 + if np.any(use_fallback): + t = np.broadcast_to(source, (num_vecs, 3))[use_fallback] + + # Better case split: + y_zero = t[:, 1] == 0 + z_zero = t[:, 2] == 0 + neither_zero = ~y_zero & ~z_zero + + fb = np.empty((y_zero.shape[0], 3), dtype=float) + fb[y_zero] = (0., 1., 0.) + fb[~y_zero & z_zero] = (0., 0., 1.) + fb[neither_zero, 0] = 0. + fb[neither_zero, 1] = -t[neither_zero, 2] + fb[neither_zero, 2] = t[neither_zero, 1] + + axis[use_fallback] = fb + + return quat_from_axis_angle(axis, angle, out=out) + + +def quat_from_axis_angle(axis, angle, out=None, dtype=None) -> np.ndarray: + """Quaternion from axis-angle pair. + + Create a quaternion representing the rotation of an given angle + about a given unit vector + + Parameters + ---------- + axis : ndarray, [3] + Unit vector + angle : number + The angle (in radians) to rotate about axis + out : ndarray, optional + A location into which the result is stored. If provided, it + must have a shape that the inputs broadcast to. If not provided or + None, a freshly-allocated array is returned. A tuple must have + length equal to the number of outputs. + dtype : data-type, optional + Overrides the data type of the result. + + Returns + ------- + ndarray, [4] + Quaternion. + """ + + axis = np.asarray(axis, dtype=float) + angle = np.asarray(angle, dtype=float) + + if out is None: + out_shape = np.broadcast_shapes(axis.shape[:-1], angle.shape) + out = np.empty((*out_shape, 4), dtype=dtype) + + # result should be independent of the length of the given axis + lengths_shape = (*axis.shape[:-1], 1) + axis = axis / np.linalg.norm(axis, axis=-1).reshape(lengths_shape) + + out[..., :3] = axis * np.sin(angle / 2).reshape(lengths_shape) + out[..., 3] = np.cos(angle / 2) + + return out + + +def mat_compose(translation, rotation, scaling, /, *, out=None, dtype=None) -> np.ndarray: + """ + Compose transformation matrices given translation vectors, quaternions, + and scaling vectors. + + Parameters + ---------- + translation : ndarray, [3] or [num_vectors, 3] + rotation : ndarray, [4] or [num_vectors, 4] + scaling : ndarray, [3] or [num_vectors, 3] + + Returns + ------- + ndarray, [num_vectors, 4, 4] + """ + rotation = np.asarray(rotation, dtype=float) + translation = np.asarray(translation, dtype=float) + scaling = np.asarray(scaling, dtype=float) + + if rotation.ndim == 1: + rotation = rotation[None, :] + if translation.ndim == 1: + translation = translation[None, :] + if scaling.ndim == 0: + scaling = np.full((1, 3), scaling) + elif scaling.ndim == 1 and scaling.shape[0] == 3: + scaling = scaling[None, :] + elif scaling.ndim == 1: + scaling = scaling[:, None] * np.ones(3) + + num_vectors = max(rotation.shape[0], translation.shape[0], scaling.shape[0]) + + if out is None: + out = np.zeros((num_vectors, 4, 4), dtype=dtype) + else: + out[..., :, :] = 0 + + x, y, z, w = rotation[:, 0], rotation[:, 1], rotation[:, 2], rotation[:, 3] + + x2, y2, z2 = x + x, y + y, z + z + xx, xy, xz = x * x2, x * y2, x * z2 + yy, yz, zz = y * y2, y * z2, z * z2 + wx, wy, wz = w * x2, w * y2, w * z2 + + sx, sy, sz = scaling[:, 0], scaling[:, 1], scaling[:, 2] + + + out[:, 0, 0] = (1 - (yy + zz)) * sx + out[:, 1, 0] = (xy + wz) * sx + out[:, 2, 0] = (xz - wy) * sx + + out[:, 0, 1] = (xy - wz) * sy + out[:, 1, 1] = (1 - (xx + zz)) * sy + out[:, 2, 1] = (yz + wx) * sy + + out[:, 0, 2] = (xz + wy) * sz + out[:, 1, 2] = (yz - wx) * sz + out[:, 2, 2] = (1 - (xx + yy)) * sz + + out[:, 0:3, 3] = translation + out[:, 3, 3] = 1 + + return out \ No newline at end of file diff --git a/fastplotlib/widgets/nd_widget/__init__.py b/fastplotlib/widgets/nd_widget/__init__.py index 65f448b54..4ebca556a 100644 --- a/fastplotlib/widgets/nd_widget/__init__.py +++ b/fastplotlib/widgets/nd_widget/__init__.py @@ -5,6 +5,7 @@ from ._base import NDProcessor, NDGraphic from ._nd_positions import NDPositions, NDPositionsProcessor, ndp_extras from ._nd_image import NDImageProcessor, NDImage + from ._nd_vector import NDVectorProcessor, NDVector from ._ndwidget import NDWidget else: diff --git a/fastplotlib/widgets/nd_widget/_nd_vector.py b/fastplotlib/widgets/nd_widget/_nd_vector.py new file mode 100644 index 000000000..f713275bd --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_nd_vector.py @@ -0,0 +1,354 @@ +from collections.abc import Sequence, Generator, Callable +from typing import Any + +import numpy as np +from numpy.typing import ArrayLike + +from ...layouts import Subplot +from ...utils import subsample_array, ARRAY_LIKE_ATTRS, ArrayProtocol +from ...graphics import VectorsGraphic +from ._base import NDProcessor, NDGraphic, WindowFuncCallable, block_reentrance, AwaitedArray +from ._index import ReferenceIndex +from ._async import start_coroutine + + + +class NDVectorProcessor(NDProcessor): + def __init__( + self, + data: ArrayProtocol | None, + dims: Sequence[str], + spatial_dims: tuple[str, str, str], # must be in order, last dim must be 4 + window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable = None, + window_order: tuple[int, ...] = None, + spatial_func: Callable[[ArrayLike], ArrayLike] = None, + slider_dim_transforms=None, + ): + """ + ``NDProcessor`` subclass for n-dimensional vector data + + Produces (num_vectors, 4) slices for a ``VectorsGraphic``. The first two columns describe position of each vector, + the second two describe the orientation + + Parameters + ---------- + data: ArrayProtocol + Shape [num_vectors, 2, 2] or [num_vectors, 3, 2]. data[:, 0, :] gives the positions, data[:, 1, :] gives directions + n-dimension image data array + + dims: Sequence[str] + names for each dimension in ``data``. Dimensions not listed in + ``spatial_dims`` are treated as slider dimensions and **must** appear as + keys in the parent ``NDWidget``'s ``ref_ranges`` + Examples:: + ``("time", "depth", "row", "col")`` + ``("channels", "time", "xy")`` + ``("keypoints", "time", "xyz")`` + + A custom subclass's ``data`` object doesn't necessarily need to have these dims, but the ``get()`` method + must operate as if these dimensions exist and return an array that matches the spatial dimensions. + + dims: Sequence[str] + names for each dimension in ``data``. Dimensions not listed in + ``spatial_dims`` are treated as slider dimensions and **must** appear as + keys in the parent ``NDWidget``'s ``ref_ranges`` + Examples:: + ``("time", "depth", "row", "col")`` + ``("row", "col")`` + ``("other_dim", "depth", "time", "row", "col")`` + + dims in the array do not need to be in the order that you want to display them, for example you can have a + weird array where the dims are interpreted as: + ``("col", "depth", "row", "time")``, and then specify spatial_dims as ``("row", "col")``. + + spatial_dims : tuple[str, str] | tuple[str, str, str] + For NDVectors, this is always the last two dimensions without exception + + slider_dim_transforms : dict, optional + See :class:`NDProcessor`. + + window_funcs : dict, optional + See :class:`NDProcessor`. + + window_order : tuple, optional + See :class:`NDProcessor`. + + spatial_func : callable, optional + See :class:`NDProcessor`. + + See Also + -------- + NDProcessor : Base class with full parameter documentation. + NDImage : The ``NDGraphic`` that wraps this processor. + """ + + super().__init__( + data=data, + dims=dims, + spatial_dims=spatial_dims, + slider_dim_transforms=slider_dim_transforms, + window_funcs=window_funcs, + window_order=window_order, + spatial_func=spatial_func, + ) + + + @property + def data(self) -> ArrayProtocol | None: + """ + get or set managed data. If setting with new data, the new data is interpreted + to have the same dims (i.e. same dim names and ordering of dims). + """ + return self._data + + @data.setter + def data(self, data: ArrayProtocol): + if not isinstance(data, ArrayProtocol): + # check that it's generally array-like + raise TypeError( + f"`data` arrays must have all of the following attributes to be sufficiently array-like:\n" + f"{ARRAY_LIKE_ATTRS}, or they must be `None`" + ) + + if data.ndim < 3: + raise ValueError( + f"Shape must be (..., num_vecs, 2, [2 or 3]) you passed an array of shape {data.shape}" + ) + + self._data = data + + @property + def spatial_dims(self) -> tuple[str, str]: + """ + Spatial dims, **in display order**. + Dimensions in order are num_vectors, position/direction, xy[z], so the shape is [num_vectors, 2, 2 or 3] + """ + return self._spatial_dims + + @spatial_dims.setter + def spatial_dims(self, sdims: tuple[str, str, str]): + for dim in sdims: + if dim not in self.dims: + raise KeyError + + if len(sdims) != 3: + raise ValueError( + f"There must be exactly 3 spatial dims for vectors indicating [num_vectors, 2, 2] or [num_vectors, 2, 3] " + ) + + self._spatial_dims = tuple(sdims) + + if self.shape[self.spatial_dims[-2]] != 2 or self.shape[self.spatial_dims[-1]] not in (2, 3): + raise ValueError( + f"Spatial dimensions must haves shape (num_vecs, 2, [2 or 3]) you passed an array of shape {data.shape}" + ) + + + + def get(self, indices: dict[str, Any]) -> AwaitedArray: + """ + Get the data at the given index, process data through the window functions. + + Note that we do not use __getitem__ here since the index is a tuple specifying a single integer + index for each dimension. Slices are not allowed, therefore __getitem__ is not suitable here. + + Parameters + ---------- + indices: tuple[int, ...] + Get the processed data at this index. Must provide a value for each dimension. + Example: get((100, 5)) + + """ + # this will be squeezed output, with dims in the order of the user set spatial dims + window_output = yield from self.get_window_output(indices) + + # apply spatial_func + if self.spatial_func is not None: + spatial_out = self._spatial_func(window_output) + if spatial_out.ndim != len(self.spatial_dims): + raise ValueError + + return spatial_out + + return window_output + + +class NDVector(NDGraphic): + def __init__( + self, + ref_index: ReferenceIndex, + subplot: Subplot, + data: ArrayProtocol | None, + dims: Sequence[str], + spatial_dims: tuple[str, str, str], # must be in order! [rows, cols] | [z, rows, cols] + window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable = None, + window_order: tuple[int, ...] = None, + spatial_func: Callable[[ArrayProtocol], ArrayProtocol] = None, + slider_dim_transforms=None, + name: str = None, + ): + """ + ``NDGraphic`` subclass for n-dimensional image rendering. + + Wraps an :class:`NDImageProcessor` and manages either an ``ImageGraphic`` or``ImageVolumeGraphic``. + swaps automatically when :attr:`spatial_dims` is reassigned at runtime. Also + owns a ``HistogramLUTTool`` for interactive vmin, vmax adjustment. + + Every dimension that is *not* listed in ``spatial_dims`` becomes a slider + dimension. Each slider dim must have a ``ReferenceRange`` defined in the + ``ReferenceIndex`` of the parent ``NDWidget``. The widget uses this to direct + a change in the ``ReferenceIndex`` and update the graphics. + + Parameters + ---------- + ref_index : ReferenceIndex + The shared reference index that delivers slider updates to this graphic. + + subplot : Subplot + parent subplot the NDGraphic is in + + data : array-like or None + Shape [num_vectors, 2, 2] or [num_vectors, 3, 2]. data[:, :, 0] gives the positions, data[:, :, 1] gives directions + n-dimension image data array + + dims : sequence of hashable + Name for every dimension of ``data``, in order. Non-spatial dims must + match keys in ``ref_index``. + + ex: ``("time", "depth", "row", "col")`` — ``"time"`` and ``"depth"`` must + be present in ``ref_index``. + + spatial_dims : tuple[str, str] | tuple[str, str, str] + Spatial dimensions **in order**: These dims are either (num_points, 2, 2) or (num_points, 3, 2) + + window_funcs : dict, optional + See :class:`NDProcessor`. + + window_order : tuple, optional + See :class:`NDProcessor`. + + spatial_func : callable, optional + See :class:`NDProcessor`. + + slider_dim_transforms : dict, optional + See :class:`NDProcessor`. + + name : str, optional + Name for the underlying graphic. + + See Also + -------- + NDImageProcessor : The processor that backs this graphic. + + """ + + if not (set(dims) - set(spatial_dims)).issubset(ref_index.dims): + raise IndexError( + f"all specified `dims` must either be a spatial dim or a slider dim " + f"specified in the NDWidget ref_ranges, provided dims: {dims}, " + f"spatial_dims: {spatial_dims}. Specified NDWidget ref_ranges: {ref_index.dims}" + ) + + super().__init__(subplot, name) + + self._ref_index = ref_index + + self._processor = NDVectorProcessor( + data, + dims=dims, + spatial_dims=spatial_dims, + window_funcs=window_funcs, + window_order=window_order, + spatial_func=spatial_func, + slider_dim_transforms=slider_dim_transforms, + ) + + self._graphic: VectorsGraphic | None = None + + # create a graphic + self._create_graphic() + + @property + def processor(self) -> NDVectorProcessor: + """NDProcessor that manages the data and produces data slices to display""" + return self._processor + + @property + def graphic( + self, + ) -> VectorsGraphic: + """Underlying Graphic object used to display the current data slice""" + return self._graphic + + @start_coroutine + def _create_graphic(self): + # Creates an ``ImageGraphic`` or ``ImageVolumeGraphic`` based on the number of spatial dims, + # adds it to the subplot, and resets the camera and histogram. + + if self.processor.data is None: + # no graphic if data is None, useful for initializing in null states when we want to set data later + return + + # get the data slice for this index + # this will only have the dims specified by ``spatial_dims`` + + data_slice = yield from self._get_data_slice(self.indices) + + + old_graphic = self._graphic + # check if we are replacing a graphic + if old_graphic is not None: + # delete the old graphic + self._subplot.delete_graphic(old_graphic) + + # create the new graphic + self._graphic = self._subplot.add_vectors(positions=data_slice[:, 0, :], + directions=data_slice[:, 1, :]) + + self._subplot.add_graphic(self._graphic) + + @property + def spatial_dims(self) -> tuple[str, str, str]: + """ + get or set the spatial dims **in order**. + Spatial dim shape here is [num_vectors, position/dimension (2), xy[z] (2 or 3)] + """ + return self.processor.spatial_dims + + @spatial_dims.setter + def spatial_dims(self, dims: tuple[str, str, str]): + self.processor.spatial_dims = dims + + # shape has probably changed, recreate graphic + self._create_graphic() + + @property + def indices(self) -> dict[str, Any]: + """get or set the indices, managed by the ReferenceIndex, users usually don't want to set this manually""" + return {d: self._ref_index[d] for d in self.processor.slider_dims} + + @block_reentrance + @start_coroutine + def set_indices( + self, indices: dict[str, Any], block: bool = True, timeout: float = 1.0 + ): + data_slice = yield from self._get_data_slice(indices) + + positions = data_slice[:, 0, :] + directions = data_slice[:, 1, :] + + self.graphic.positions = positions + self.graphic.directions = directions + + @property + def spatial_func(self) -> Callable[[ArrayProtocol], ArrayProtocol] | None: + """get or set the spatial_func, see docstring for details""" + # this is here even though it's the same in the base class since we can't create the image specific setter + # without also defining the property in this subclass. + return self.processor.spatial_func + + @spatial_func.setter + def spatial_func( + self, func: Callable[[ArrayProtocol], ArrayProtocol] + ) -> Callable | None: + self.processor.spatial_func = func \ No newline at end of file diff --git a/fastplotlib/widgets/nd_widget/_ndw_subplot.py b/fastplotlib/widgets/nd_widget/_ndw_subplot.py index 6666b3fc1..82b336a60 100644 --- a/fastplotlib/widgets/nd_widget/_ndw_subplot.py +++ b/fastplotlib/widgets/nd_widget/_ndw_subplot.py @@ -3,10 +3,10 @@ import numpy as np -from ... import ScatterCollection, ScatterStack, LineCollection, LineStack, ImageGraphic +from ... import ScatterCollection, ScatterStack, LineCollection, LineStack, ImageGraphic, VectorsGraphic from ...layouts import Subplot from ...utils import ArrayProtocol -from . import NDImage, NDPositions +from . import NDImage, NDPositions, NDVector from ._base import NDGraphic, WindowFuncCallable @@ -73,6 +73,32 @@ def add_nd_image( self._nd_graphics.append(nd) return nd + def add_nd_vector(self, + data: ArrayProtocol | None, + dims: Sequence[str], + spatial_dims: tuple[str, str, str], + window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable = None, + window_order: tuple[int, ...] = None, + spatial_func: Callable[[ArrayProtocol], ArrayProtocol] = None, + slider_dim_transforms=None, + name: str = None, + ): + nd = NDVector( + self.ndw.indices, + self._subplot, + data=data, + dims=dims, + spatial_dims=spatial_dims, + window_funcs=window_funcs, + window_order=window_order, + spatial_func=spatial_func, + slider_dim_transforms=slider_dim_transforms, + name=name + ) + + self._nd_graphics.append(nd) + return nd + def add_nd_scatter(self, *args, **kwargs): # TODO: better func signature here, send all kwargs to processor_kwargs nd = NDPositions(