Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 102 additions & 18 deletions fastplotlib/widgets/image.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
import traceback
from datetime import datetime
from itertools import product
from typing import *
from warnings import warn
from functools import partial
from copy import deepcopy

import numpy as np
from ipywidgets.widgets import IntSlider, VBox, HBox, Layout, FloatRangeSlider, Button, BoundedIntText, Play, jslink

from ipywidgets import Dropdown
from wgpu.gui.jupyter import JupyterWgpuCanvas

from ..layouts._subplot import Subplot
from ..plot import Plot
from ..layouts import GridPlot
from ..graphics import ImageGraphic
from ..utils import quick_min_max
from ipywidgets.widgets import IntSlider, VBox, HBox, Layout, FloatRangeSlider, Button, BoundedIntText, Play, jslink
import numpy as np
from typing import *
from warnings import warn
from functools import partial
from copy import deepcopy


DEFAULT_DIMS_ORDER = \
{
Expand Down Expand Up @@ -846,14 +843,100 @@ def reset_vmin_vmax(self):
for i, ig in enumerate(self.image_graphics):
mm = self._get_vmin_vmax_range(ig.data())

state = {
"value": mm[0],
"step": mm[1] / 150,
"min": mm[2],
"max": mm[3]
}
if len(self.vmin_vmax_sliders) != 0:
state = {
"value": mm[0],
"step": mm[1] / 150,
"min": mm[2],
"max": mm[3]
}

self.vmin_vmax_sliders[i].set_state(state)
else:
ig.min, ig.max = mm

def set_data(
self,
new_data: Union[np.ndarray, List[np.ndarray]],
reset_vmin_vmax: bool = True,
reset_indices: bool = True
):
"""
Change data of widget. Note: sliders max currently update only for ``txy`` and ``tzxy`` data.

Parameters
----------
new_data: array-like or list of array-like
The new data to display in the widget

reset_vmin_vmax: bool, default ``True``
reset the vmin vmax levels based on the new data

reset_indices: bool, default ``True``
reset the current index for all dimensions to 0

"""

if reset_indices:
for key in self.current_index:
self.current_index[key] = 0
for key in self.sliders:
self.sliders[key].value = 0

# set slider max according to new data
max_lengths = {"t": np.inf, "z": np.inf}

# single plot
if isinstance(new_data, np.ndarray) and isinstance(self.plot, Plot):
if new_data.ndim != self._data[0].ndim:
raise ValueError(
f"new data ndim {new_data.ndim} does not equal current data ndim {self._data[0].ndim}"
)
self._data[0] = new_data

if new_data.ndim > 2:
# to set max of time slider, txy or tzxy
max_lengths["t"] = min(max_lengths["t"], new_data.shape[0] - 1)

if new_data.ndim > 3: # tzxy
max_lengths["z"] = min(max_lengths["z"], new_data.shape[1] - 1)

else: # gridplot
if len(self._data) != len(new_data):
raise ValueError(
f"number of new data arrays {len(new_data)} must match"
f" current number of data arrays {len(self._data)}"
)
# check all arrays
for i, (new_array, current_array) in enumerate(zip(new_data, self._data)):
if new_array.ndim != current_array.ndim:
raise ValueError(
f"new data ndim {new_array.ndim} at index {i} "
f"does not equal current data ndim {current_array.ndim}"
)

# if checks pass, update with new data
for i, (new_array, current_array, subplot) in enumerate(zip(new_data, self._data, self.plot)):
self._data[i] = new_array

if new_array.ndim > 2:
# to set max of time slider, txy or tzxy
max_lengths["t"] = min(max_lengths["t"], new_array.shape[0] - 1)

if new_array.ndim > 3: # tzxy
max_lengths["z"] = min(max_lengths["z"], new_array.shape[1] - 1)

# set slider maxes
# TODO: maybe make this stuff a property, like ndims, n_frames etc. and have it set the sliders
for key in self.sliders.keys():
self.sliders[key].max = max_lengths[key]
self._dims_max_bounds[key] = max_lengths[key]

# force graphics to update
self.current_index = self.current_index

self.vmin_vmax_sliders[i].set_state(state)
if reset_vmin_vmax:
self.reset_vmin_vmax()

def show(self, toolbar: bool = True):
"""
Expand Down Expand Up @@ -916,6 +999,7 @@ def __init__(self,
self.reset_vminvmax_button.on_click(self.reset_vminvmax)
self.step_size_setter.observe(self.change_stepsize, 'value')
jslink((self.play_button, 'value'), (self.iw.sliders["t"], 'value'))
jslink((self.play_button, "max"), (self.iw.sliders["t"], "max"))

def reset_vminvmax(self, obj):
if len(self.iw.vmin_vmax_sliders) != 0:
Expand Down