Skip to content

Commit c330648

Browse files
hasanrashidCopilottimhoffmrcomer
authored
Improving error message for width and position type mismatch in violinplot (#30752)
* Improving error message for width and position type mismatch in violinplot * Improving error message for width and position type mismatch in violinplot * Fix violin plot statistics in test data * Trigger CI pipeline * Trigger CI pipeline * Update lib/matplotlib/axes/_axes.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update lib/matplotlib/axes/_axes.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update lib/matplotlib/tests/test_axes.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Store original widths parameter before conversion for accurate type validation - Improve pytest match strings with proper line continuation formatting - Enhanced error messages provide clear examples for correct usage * Update lib/matplotlib/axes/_axes.py Co-authored-by: Tim Hoffmann <2836374+timhoffm@users.noreply.github.com> * Update lib/matplotlib/axes/_axes.py Co-authored-by: Tim Hoffmann <2836374+timhoffm@users.noreply.github.com> * Remove low-information comment Co-authored-by: Tim Hoffmann <2836374+timhoffm@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Tim Hoffmann <2836374+timhoffm@users.noreply.github.com> Co-authored-by: Ruth Comer <10599679+rcomer@users.noreply.github.com>
1 parent 5dee947 commit c330648

File tree

2 files changed

+96
-2
lines changed

2 files changed

+96
-2
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import itertools
33
import logging
44
import math
5+
import datetime
56
from numbers import Integral, Number, Real
67

78
import re
@@ -9144,14 +9145,28 @@ def violin(self, vpstats, positions=None, vert=None,
91449145
positions = range(1, N + 1)
91459146
elif len(positions) != N:
91469147
raise ValueError(datashape_message.format("positions"))
9147-
91489148
# Validate widths
91499149
if np.isscalar(widths):
91509150
widths = [widths] * N
91519151
elif len(widths) != N:
91529152
raise ValueError(datashape_message.format("widths"))
91539153

9154-
# Validate side
9154+
# For usability / better error message:
9155+
# Validate that datetime-like positions have timedelta-like widths.
9156+
# Checking only the first element is good enough for standard misuse cases
9157+
if N > 0: # No need to validate if there is no data
9158+
pos0 = positions[0]
9159+
width0 = widths[0]
9160+
if (isinstance(pos0, (datetime.datetime, datetime.date))
9161+
and not isinstance(width0, datetime.timedelta)):
9162+
raise TypeError(
9163+
"datetime/date 'position' values require timedelta 'widths'. "
9164+
"For example, use positions=[datetime.date(2024, 1, 1)] "
9165+
"and widths=[datetime.timedelta(days=1)].")
9166+
elif (isinstance(pos0, np.datetime64)
9167+
and not isinstance(width0, np.timedelta64)):
9168+
raise TypeError(
9169+
"np.datetime64 'position' values require np.timedelta64 'widths'")
91559170
_api.check_in_list(["both", "low", "high"], side=side)
91569171

91579172
# Calculate ranges for statistics lines (shape (2, N)).

lib/matplotlib/tests/test_axes.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4172,6 +4172,85 @@ def test_violinplot_sides():
41724172
showextrema=True, showmedians=True, side=side)
41734173

41744174

4175+
def violin_plot_stats():
4176+
datetimes = [
4177+
datetime.datetime(2023, 2, 10),
4178+
datetime.datetime(2023, 5, 18),
4179+
datetime.datetime(2023, 6, 6)
4180+
]
4181+
return [{
4182+
'coords': datetimes,
4183+
'vals': [1.2, 2.8, 1.5],
4184+
'mean': 1.84,
4185+
'median': 1.5,
4186+
'min': 1.2,
4187+
'max': 2.8,
4188+
'quantiles': [1.2, 1.5, 2.8]
4189+
}, {
4190+
'coords': datetimes,
4191+
'vals': [0.8, 1.1, 0.9],
4192+
'mean': 0.94,
4193+
'median': 0.9,
4194+
'min': 0.8,
4195+
'max': 1.1,
4196+
'quantiles': [0.8, 0.9, 1.1]
4197+
}]
4198+
4199+
4200+
def test_datetime_positions_with_datetime64():
4201+
"""Test that datetime positions with float widths raise TypeError."""
4202+
fig, ax = plt.subplots()
4203+
positions = [np.datetime64('2020-01-01'), np.datetime64('2021-01-01')]
4204+
widths = [0.5, 1.0]
4205+
with pytest.raises(TypeError,
4206+
match=("np.datetime64 'position' values require "
4207+
"np.timedelta64 'widths'")):
4208+
ax.violin(violin_plot_stats(), positions=positions, widths=widths)
4209+
4210+
4211+
def test_datetime_positions_with_float_widths_raises():
4212+
"""Test that datetime positions with float widths raise TypeError."""
4213+
fig, ax = plt.subplots()
4214+
positions = [datetime.datetime(2020, 1, 1), datetime.datetime(2021, 1, 1)]
4215+
widths = [0.5, 1.0]
4216+
with pytest.raises(TypeError,
4217+
match=("datetime/date 'position' values require "
4218+
"timedelta 'widths'")):
4219+
ax.violin(violin_plot_stats(), positions=positions, widths=widths)
4220+
4221+
4222+
def test_datetime_positions_with_scalar_float_width_raises():
4223+
"""Test that datetime positions with scalar float width raise TypeError."""
4224+
fig, ax = plt.subplots()
4225+
positions = [datetime.datetime(2020, 1, 1), datetime.datetime(2021, 1, 1)]
4226+
widths = 0.75
4227+
with pytest.raises(TypeError,
4228+
match=("datetime/date 'position' values require "
4229+
"timedelta 'widths'")):
4230+
ax.violin(violin_plot_stats(), positions=positions, widths=widths)
4231+
4232+
4233+
def test_numeric_positions_with_float_widths_ok():
4234+
"""Test that numeric positions with float widths work."""
4235+
fig, ax = plt.subplots()
4236+
positions = [1.0, 2.0]
4237+
widths = [0.5, 1.0]
4238+
ax.violin(violin_plot_stats(), positions=positions, widths=widths)
4239+
4240+
4241+
def test_mixed_positions_datetime_and_numeric_raises():
4242+
"""Test that mixed datetime and numeric positions
4243+
with float widths raise TypeError.
4244+
"""
4245+
fig, ax = plt.subplots()
4246+
positions = [datetime.datetime(2020, 1, 1), 2.0]
4247+
widths = [0.5, 1.0]
4248+
with pytest.raises(TypeError,
4249+
match=("datetime/date 'position' values require "
4250+
"timedelta 'widths'")):
4251+
ax.violin(violin_plot_stats(), positions=positions, widths=widths)
4252+
4253+
41754254
def test_violinplot_bad_positions():
41764255
ax = plt.axes()
41774256
# First 9 digits of frac(sqrt(47))

0 commit comments

Comments
 (0)