Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
cleanup ColorFeature
  • Loading branch information
kushalkolar committed Dec 19, 2022
commit 865cbf0e2403b0e35c8d6953a0de5e9179581e7f
50 changes: 41 additions & 9 deletions fastplotlib/graphics/_graphic_attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,40 @@ def __init__(self, parent, data):
data = parent.geometry.colors.data
super(ColorFeature, self).__init__(parent, data)

self._bounds = data.shape[0]
self._upper_bound = data.shape[0]

def __setitem__(self, key, value):
if abs(key.start) > self._bounds or abs(key.stop) > self._bounds:
raise IndexError

# parse numerical slice indices
if isinstance(key, slice):
start = key.start
stop = key.stop
step = key.step
for attr in [start, stop, step]:
if attr is None:
continue
if attr < 0:
raise IndexError("Negative indexing not supported.")

if start is None:
start = 0

if stop is None:
stop = self._upper_bound

elif stop > self._upper_bound:
raise IndexError("Index out of bounds")

step = key.step
if step is None:
step = 1

indices = range(start, stop, step)
key = slice(start, stop, step)
indices = range(key.start, key.stop, key.step)

# or single numerical index
elif isinstance(key, int):
if key > self._upper_bound:
raise IndexError("Index out of bounds")
indices = [key]

else:
Expand All @@ -51,17 +69,31 @@ def __setitem__(self, key, value):
new_data_size = len(indices)

if not isinstance(value, np.ndarray):
new_colors = np.repeat(np.array([Color(value)]), new_data_size, axis=0)

color = np.array(Color(value)) # pygfx color parser
# make it of shape [n_colors_modify, 4]
new_colors = np.repeat(
np.array([color]).astype(np.float32),
new_data_size,
axis=0
)

# if already a numpy array
elif isinstance(value, np.ndarray):
# if a single color provided as numpy array
if value.shape == (4,):
new_colors = value.astype(np.float32)
# if there are more than 1 datapoint color to modify
if new_data_size > 1:
new_colors = np.repeat(np.array([new_colors]), new_data_size, axis=0)
new_colors = np.repeat(
np.array([new_colors]).astype(np.float32),
new_data_size,
axis=0
)

elif value.shape[1] == 4 and value.ndim == 2:
if not value.shape[0] == new_data_size:
if value.shape[0] != new_data_size:
raise ValueError("numpy array passed to color must be of shape (4,) or (n_colors_modify, 4)")
# if there is a single datapoint to change color of but user has provided shape [1, 4]
if new_data_size == 1:
new_colors = value.ravel().astype(np.float32)
else:
Expand Down