2020NP_NEW = (LooseVersion (np .version .version ) >= LooseVersion ('1.7' ))
2121
2222
23- def to_array (data , maxlen = 100 ):
23+ def to_str_array (data , maxlen = 100 ):
2424 if NP_NEW :
2525 return np .array (data , dtype = np .unicode )
2626 if cbook .is_scalar_or_string (data ):
@@ -53,13 +53,13 @@ def convert(value, unit, axis):
5353 vmap = dict (zip (axis .unit_data .seq , axis .unit_data .locs ))
5454
5555 if isinstance (value , six .string_types ):
56- return vmap [ value ]
56+ return vmap . get ( value , None )
5757
58- vals = to_array (value )
58+ vals = to_str_array (value )
5959 for lab , loc in vmap .items ():
6060 vals [vals == lab ] = loc
6161
62- return vals .astype ('float ' )
62+ return vals .astype ('float64 ' )
6363
6464 @staticmethod
6565 def axisinfo (unit , axis ):
@@ -74,16 +74,20 @@ def axisinfo(unit, axis):
7474 return munits .AxisInfo (majloc = majloc , majfmt = majfmt )
7575
7676 @staticmethod
77- def default_units (data , axis ):
77+ def default_units (data , axis , sort = True ):
7878 """
7979 Create mapping between string categories in *data*
8080 and integers, then store in *axis.unit_data*
8181 """
82- if axis .unit_data is None :
83- axis .unit_data = UnitData (data )
84- else :
85- axis .unit_data .update (data )
86- return None
82+
83+ if axis and axis .unit_data :
84+ axis .unit_data .update (data , sort )
85+ return
86+
87+ unit_data = UnitData (data , sort )
88+ if axis :
89+ axis .unit_data = unit_data
90+ return unit_data
8791
8892
8993class StrCategoryLocator (mticker .FixedLocator ):
@@ -115,30 +119,26 @@ def __init__(self, categories):
115119 *categories*
116120 distinct values for mapping
117121
118- Out-of-range values are mapped to a value not in categories;
119- these are then converted to valid indices by :meth:`Colormap.__call__`.
122+ Out-of-range values are mapped to np.nan
120123 """
121- self .categories = categories
124+
125+ self .unit_data = StrCategoryConverter .default_units (categories ,
126+ None , sort = False )
127+ self .categories = to_str_array (categories )
122128 self .N = len (self .categories )
123- self .vmin = 0
124- self .vmax = self .N
125- self ._interp = False
129+ self .nvals = self . unit_data . locs
130+ self .vmin = min ( self .nvals )
131+ self .vmax = max ( self . nvals )
126132
127133 def __call__ (self , value , clip = None ):
128- if not cbook .iterable (value ):
129- value = [value ]
130-
131- value = np .asarray (value )
132- ret = np .ones (value .shape ) * np .nan
133-
134- for i , c in enumerate (self .categories ):
135- ret [value == c ] = i / (self .N * 1.0 )
136-
137- return np .ma .array (ret , mask = np .isnan (ret ))
138-
139- def inverse (self , value ):
140- # not quite sure what invertible means in this context
141- return ValueError ("CategoryNorm is not invertible" )
134+ # gonna have to go into imshow and undo casting
135+ value = np .asarray (value , dtype = int )
136+ ret = StrCategoryConverter .convert (value , None , self )
137+ # knock out values not in the norm
138+ mask = np .in1d (ret , self .unit_data .locs ).reshape (ret .shape )
139+ # normalize ret
140+ ret /= self .vmax
141+ return np .ma .array (ret , mask = ~ mask )
142142
143143
144144def colors_from_categories (codings ):
@@ -187,27 +187,40 @@ class UnitData(object):
187187 # debatable makes sense to special code missing values
188188 spdict = {'nan' : - 1.0 , 'inf' : - 2.0 , '-inf' : - 3.0 }
189189
190- def __init__ (self , data ):
190+ def __init__ (self , data , sort = True ):
191191 """Create mapping between unique categorical values
192192 and numerical identifier
193193 Paramters
194194 ---------
195195 data: iterable
196196 sequence of values
197+ sort: bool
198+ sort input data, default is True
199+ False preserves input order
197200 """
198201 self .seq , self .locs = [], []
199- self ._set_seq_locs (data , 0 )
202+ self ._set_seq_locs (data , 0 , sort )
203+ self .sort = sort
200204
201- def update (self , new_data ):
205+ def update (self , new_data , sort = None ):
206+ if sort :
207+ self .sort = sort
202208 # so as not to conflict with spdict
203209 value = max (max (self .locs ) + 1 , 0 )
204- self ._set_seq_locs (new_data , value )
210+ self ._set_seq_locs (new_data , value , self . sort )
205211
206- def _set_seq_locs (self , data , value ):
212+ def _set_seq_locs (self , data , value , sort ):
207213 # magic to make it work under np1.6
208- strdata = to_array (data )
214+ strdata = to_str_array (data )
215+
209216 # np.unique makes dateframes work
210- new_s = [d for d in np .unique (strdata ) if d not in self .seq ]
217+ if sort :
218+ unq = np .unique (strdata )
219+ else :
220+ _ , idx = np .unique (strdata , return_index = ~ sort )
221+ unq = strdata [np .sort (idx )]
222+
223+ new_s = [d for d in unq if d not in self .seq ]
211224 for ns in new_s :
212225 self .seq .append (convert_to_string (ns ))
213226 if ns in UnitData .spdict .keys ():
0 commit comments