1+ from datetime import date
12import enum
23import re
3- from typing import Optional , List , Tuple , Literal , Any
4+ from typing import Optional , List , Tuple , Literal , Any , Union
45
6+ import matplotlib
57import pandas
68from matplotlib .axes import Axes
79from matplotlib .collections import PathCollection
810from matplotlib .lines import Line2D
911from matplotlib .patches import Rectangle , Wedge , PathPatch
1012from matplotlib .pyplot import Figure
13+ from matplotlib .dates import _SwitchableDateConverter
1114import IPython
1215
1316from IPython .core .formatters import BaseFormatter
1417from matplotlib .text import Text
15- from pydantic import BaseModel , Field
18+ from pydantic import BaseModel , Field , field_validator
1619from traitlets .traitlets import Unicode , ObjectName
1720
1821
@@ -87,13 +90,29 @@ def _change_orientation(self):
8790
8891class PointData (BaseModel ):
8992 label : str
90- points : List [Tuple [float , float ]]
93+ points : List [Tuple [Union [str , int , float ], Union [str , int , float ]]]
94+
95+ @field_validator ("points" , mode = "before" )
96+ @classmethod
97+ def transform_points (
98+ cls , value
99+ ) -> List [Tuple [Union [float , str ], Union [float , str ]]]:
100+ parsed_value = []
101+ for x , y in value :
102+ if isinstance (x , date ):
103+ x = x .isoformat ()
104+
105+ if isinstance (y , date ):
106+ y = y .isoformat ()
107+
108+ parsed_value .append ((x , y ))
109+ return parsed_value
91110
92111
93112class PointGraph (Graph2D ):
94- x_ticks : List [float ] = Field (default_factory = list )
113+ x_ticks : List [Union [ str , int , float ] ] = Field (default_factory = list )
95114 x_tick_labels : List [str ] = Field (default_factory = list )
96- y_ticks : List [float ] = Field (default_factory = list )
115+ y_ticks : List [Union [ str , int , float ] ] = Field (default_factory = list )
97116 y_tick_labels : List [str ] = Field (default_factory = list )
98117
99118 elements : List [PointData ] = Field (default_factory = list )
@@ -103,11 +122,27 @@ def _extract_info(self, ax: Axes) -> None:
103122 Function to extract information for PointGraph
104123 """
105124 super ()._extract_info (ax )
106- self . x_ticks = [ float ( tick ) for tick in ax . get_xticks ()]
125+
107126 self .x_tick_labels = [label .get_text () for label in ax .get_xticklabels ()]
127+ self .x_ticks = self ._extract_ticks_info (ax .xaxis .converter , ax .get_xticks ())
108128
109- self .y_ticks = [float (tick ) for tick in ax .get_yticks ()]
110129 self .y_tick_labels = [label .get_text () for label in ax .get_yticklabels ()]
130+ self .y_ticks = self ._extract_ticks_info (ax .yaxis .converter , ax .get_yticks ())
131+
132+ @staticmethod
133+ def _extract_ticks_info (converter : Any , ticks : list ) -> list :
134+ example_tick = ticks [0 ]
135+
136+ if isinstance (converter , _SwitchableDateConverter ):
137+ return [matplotlib .dates .num2date (tick ).isoformat () for tick in ticks ]
138+ else :
139+ ticks_type = type (example_tick ).__name__
140+ if ticks_type == "float" :
141+ return [float (tick ) for tick in ticks ]
142+ elif ticks_type == "int" :
143+ return [int (tick ) for tick in ticks ]
144+ else :
145+ return ticks
111146
112147
113148class LineGraph (PointGraph ):
0 commit comments