Skip to content
Merged
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
1fea4aa
started implementation of ImageWidget, nothing is tested yet
kushalkolar Dec 2, 2022
1bc65a0
basic stuff finished for single plot with a slider, need to test
kushalkolar Dec 2, 2022
6592cc3
basic image widget works
kushalkolar Dec 2, 2022
658f298
docs
kushalkolar Dec 2, 2022
5cc91ae
splitting imagewidget into two classes
kushalkolar Dec 4, 2022
7a6a136
split ImageWidget into ImageWidgetSingle and later ImageWidgetGrid
kushalkolar Dec 4, 2022
7cce91f
combined single and grid ImageWidget into single class
kushalkolar Dec 4, 2022
56a9b1c
simple and grid image widget works, tested with simple args, need to …
kushalkolar Dec 4, 2022
79f4452
catch another user error
kushalkolar Dec 4, 2022
e4f6b12
fix type annotation
kushalkolar Dec 4, 2022
5028ce7
docstrings, started slice_avg implementation
kushalkolar Dec 4, 2022
5ef32eb
slice averaging on single and multiple dimensions works perfectly
kushalkolar Dec 5, 2022
eca6599
is_array() checks for and attr, better error messages
kushalkolar Dec 8, 2022
0e3cf16
rename axis -> dims
kushalkolar Dec 10, 2022
570c076
make most imagewidget methods private, most attributes as read-only p…
kushalkolar Dec 10, 2022
4764b84
quick_min_max() returns pre-computed min max if int or float, imagewi…
kushalkolar Dec 10, 2022
afc0378
vmin vmax for gridplot
kushalkolar Dec 10, 2022
fc3365b
Merge branch 'master' into high-level-widgets
kushalkolar Dec 11, 2022
b1e922c
update Image -> ImageGraphic
kushalkolar Dec 11, 2022
e6c5d3a
refactor, window_funcs now works very well, slow with multiple dims b…
kushalkolar Dec 11, 2022
64faffd
vminmax works, also added names for subplots, everything works
kushalkolar Dec 11, 2022
8fc63b3
proper-ish image widget example
kushalkolar Dec 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
slice averaging on single and multiple dimensions works perfectly
  • Loading branch information
kushalkolar committed Dec 5, 2022
commit 5ef32ebe902630c5e4942a0ba97160ecaf819c79
137 changes: 114 additions & 23 deletions fastplotlib/widgets/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
data: Union[np.ndarray, List[np.ndarray]],
axes_order: Union[str, Dict[np.ndarray, str]] = None,
slider_axes: Union[str, int, List[Union[str, int]]] = None,
slice_avg: Union[int, Dict[Union[int, str], int]] = None,
slice_avg: Union[int, Dict[str, int]] = None,
frame_apply: Union[callable, Dict[Union[int, str], callable]] = None,
grid_shape: Tuple[int, int] = None,
**kwargs
Expand Down Expand Up @@ -75,10 +75,11 @@ def __init__(

slice_avg: Dict[Union[int, str], int]
| average one or more dimensions using a given window
| dict mapping of ``{dimension: window_size}``
| if a slider exists for only one dimension this can be an ``int``.
| if multiple sliders exist, then it must be a `dict`` mapping in the form of: ``{dimension: window_size}``
| dimension/axes can be specified using ``str`` such as "t", "z" etc. or ``int`` that indexes the dimension
| if window_size is not an odd number, adds 1
| use ``window_size = 0`` to disable averaging for a dimension, example: ``{"t": 5, "z": 0}``
| use ``None`` to disable averaging for a dimension, example: ``{"t": 5, "z": None}``

frame_apply
grid_shape: Optional[Tuple[int, int]]
Expand Down Expand Up @@ -210,17 +211,17 @@ def __init__(
elif isinstance(slider_axes, list):
self.slider_axes: List[str] = list()

if slice_avg is not None:
if not isinstance(slice_avg, dict):
raise TypeError(
f"`slice_avg` must be a <dict> if multiple `slider_axes` are provided. You must specify the "
f"window for each dimension."
)
if not isinstance(frame_apply, dict):
raise TypeError(
f"`frame_apply` must be a <dict> if multiple `slider_axes` are provided. You must specify a "
f"function for each dimension."
)
# make sure slice_avg and frame_apply are dicts if multiple sliders are desired
if (not isinstance(slice_avg, dict)) and (slice_avg is not None):
raise TypeError(
f"`slice_avg` must be a <dict> if multiple `slider_axes` are provided. You must specify the "
f"window for each dimension."
)
if (not isinstance(frame_apply, dict)) and (frame_apply is not None):
raise TypeError(
f"`frame_apply` must be a <dict> if multiple `slider_axes` are provided. You must specify a "
f"function for each dimension."
)

for sax in slider_axes:
if isinstance(sax, int):
Expand Down Expand Up @@ -249,13 +250,22 @@ def __init__(
else:
raise TypeError(f"`slider_axes` must a <int>, <str> or <list>, you have passed a: {type(slider_axes)}")

self._slice_avg = None
self.slice_avg = slice_avg

self.sliders = list()
self.vertical_sliders = list()
self.horizontal_sliders = list()

# current_index stores {dimension_index: slice_index} for every dimension
self.current_index: Dict[str, int] = {sax: 0 for sax in self.slider_axes}

# get max bound for all data arrays for all dimensions
self.axes_max_bounds: Dict[str, int] = {k: np.inf for k in self.slider_axes}
for axis in list(self.axes_max_bounds.keys()):
for array, order in zip(self.data, self.axes_order):
self.axes_max_bounds[axis] = min(self.axes_max_bounds[axis], array.shape[order.index(axis)])

if self.plot_type == "single":
self.plot: Plot = Plot()

Expand All @@ -278,12 +288,6 @@ def __init__(

self.plot.renderer.add_event_handler(self.set_slider_layout, "resize")

# get max bound for all sliders using the max index for that dim from all arrays
slider_axes_max = {k: np.inf for k in self.slider_axes}
for axis in list(slider_axes_max.keys()):
for array, order in zip(self.data, self.axes_order):
slider_axes_max[axis] = min(slider_axes_max[axis], array.shape[order.index(axis)])

for sax in self.slider_axes:
if sax == "z":
# TODO: once ipywidgets plays nicely with HBox and jupyter-rfb can use vertical
Expand All @@ -294,7 +298,7 @@ def __init__(

slider = IntSlider(
min=0,
max=slider_axes_max[sax] - 1,
max=self.axes_max_bounds[sax] - 1,
step=1,
value=0,
description=f"Axis: {sax}",
Expand Down Expand Up @@ -334,6 +338,47 @@ def __init__(
# else:
# self.widget = VBox([self.plot.canvas, *self.horizontal_sliders])

@property
def slice_avg(self) -> Union[int, Dict[str, int]]:
return self._slice_avg

@slice_avg.setter
def slice_avg(self, sa: Union[int, Dict[str, int]]):
if sa is None:
self._slice_avg = None
return

# for a single dim
elif isinstance(sa, int):
if sa < 3:
self._slice_avg = None
warn(f"Invalid ``slice_avg`` value, setting ``slice_avg = None``. Valid values are integers >= 3.")
return
if sa % 2 == 0:
self._slice_avg = sa + 1
else:
self._slice_avg = sa
# for multiple dims
elif isinstance(sa, dict):
self._slice_avg = dict()
for k in list(sa.keys()):
if sa[k] is None:
self._slice_avg[k] = None
elif (sa[k] < 3):
warn(
f"Invalid ``slice_avg`` value, setting ``slice_avg = None``. Valid values are integers >= 3."
)
self._slice_avg[k] = None
elif sa[k] % 2 == 0:
self._slice_avg[k] = sa[k] + 1
else:
self._slice_avg[k] = sa[k]
else:
raise TypeError(
f"`slice_avg` must be of type `int` if using a single slider or a dict if using multiple sliders. "
f"You have passed a {type(sa)}. See the docstring."
)

def get_2d_slice(
self,
array: np.ndarray,
Expand Down Expand Up @@ -366,6 +411,7 @@ def get_2d_slice(
"""
indexer = [slice(None)] * self.ndim

numerical_dims = list()
for dim in list(slice_indices.keys()):
if isinstance(dim, str):
data_ix = None
Expand All @@ -382,9 +428,54 @@ def get_2d_slice(
else:
numerical_dim = dim

indexer[numerical_dim] = slice_indices[dim]
indices_dim = slice_indices[dim]

# takes care of averaging if it was specified
indices_dim = self._process_dim_index(data_ix, numerical_dim, indices_dim)

return array[tuple(indexer)]
# set the indices for this dimension
indexer[numerical_dim] = indices_dim

numerical_dims.append(numerical_dim)

if self.slice_avg is not None:
a = array
for i, dim in enumerate(sorted(numerical_dims)):
dim = dim - i # since we loose a dimension every iteration
_indexer = [slice(None)] * (self.ndim - i)
_indexer[dim] = indexer[dim + i]
if isinstance(_indexer[dim], int):
a = a[tuple(_indexer)]
else:
a = np.mean(a[tuple(_indexer)], axis=dim)
return a
else:
return array[tuple(indexer)]

def _process_dim_index(self, data_ix, dim, indices_dim):
if self.slice_avg is None:
return indices_dim

else:
ix = indices_dim

# if there is only a single dimension for averaging
if isinstance(self.slice_avg, int):
sa = self.slice_avg
dim_str = self.axes_order[0][dim]

# if there are multiple dims to average, get the avg for the current dim in the loop
elif isinstance(self.slice_avg, dict):
dim_str = self.axes_order[data_ix][dim]
sa = self.slice_avg[dim_str]
if (sa == 0) or (sa is None):
return indices_dim

hw = int((sa - 1) / 2) # half-window size
# get the max bound for that dimension
max_bound = self.axes_max_bounds[dim_str]
indices_dim = range(max(0, ix - hw), min(max_bound, ix + hw))
return indices_dim

def slider_value_changed(
self,
Expand Down