33from collections import OrderedDict
44import re
55import gzip
6- import io
76import logging
87import os
98import pickle
1312import numpy as np
1413import pandas as pd
1514import scipy .sparse
15+ import xmltodict
1616
1717from openml .base import OpenMLBase
1818from .data_feature import OpenMLDataFeature
@@ -125,8 +125,8 @@ def __init__(
125125 update_comment = None ,
126126 md5_checksum = None ,
127127 data_file = None ,
128- features = None ,
129- qualities = None ,
128+ features_file : Optional [ str ] = None ,
129+ qualities_file : Optional [ str ] = None ,
130130 dataset = None ,
131131 ):
132132 def find_invalid_characters (string , pattern ):
@@ -188,7 +188,7 @@ def find_invalid_characters(string, pattern):
188188 self .default_target_attribute = default_target_attribute
189189 self .row_id_attribute = row_id_attribute
190190 if isinstance (ignore_attribute , str ):
191- self .ignore_attribute = [ignore_attribute ]
191+ self .ignore_attribute = [ignore_attribute ] # type: Optional[List[str]]
192192 elif isinstance (ignore_attribute , list ) or ignore_attribute is None :
193193 self .ignore_attribute = ignore_attribute
194194 else :
@@ -202,33 +202,25 @@ def find_invalid_characters(string, pattern):
202202 self .update_comment = update_comment
203203 self .md5_checksum = md5_checksum
204204 self .data_file = data_file
205- self .features = None
206- self .qualities = None
207205 self ._dataset = dataset
208206
209- if features is not None :
210- self .features = {}
211- for idx , xmlfeature in enumerate (features ["oml:feature" ]):
212- nr_missing = xmlfeature .get ("oml:number_of_missing_values" , 0 )
213- feature = OpenMLDataFeature (
214- int (xmlfeature ["oml:index" ]),
215- xmlfeature ["oml:name" ],
216- xmlfeature ["oml:data_type" ],
217- xmlfeature .get ("oml:nominal_value" ),
218- int (nr_missing ),
219- )
220- if idx != feature .index :
221- raise ValueError ("Data features not provided " "in right order" )
222- self .features [feature .index ] = feature
207+ if features_file is not None :
208+ self .features = _read_features (
209+ features_file
210+ ) # type: Optional[Dict[int, OpenMLDataFeature]]
211+ else :
212+ self .features = None
223213
224- self .qualities = _check_qualities (qualities )
214+ if qualities_file :
215+ self .qualities = _read_qualities (qualities_file ) # type: Optional[Dict[str, float]]
216+ else :
217+ self .qualities = None
225218
226219 if data_file is not None :
227- (
228- self .data_pickle_file ,
229- self .data_feather_file ,
230- self .feather_attribute_file ,
231- ) = self ._create_pickle_in_cache (data_file )
220+ rval = self ._create_pickle_in_cache (data_file )
221+ self .data_pickle_file = rval [0 ] # type: Optional[str]
222+ self .data_feather_file = rval [1 ] # type: Optional[str]
223+ self .feather_attribute_file = rval [2 ] # type: Optional[str]
232224 else :
233225 self .data_pickle_file , self .data_feather_file , self .feather_attribute_file = (
234226 None ,
@@ -357,7 +349,7 @@ def decode_arff(fh):
357349 with gzip .open (filename ) as fh :
358350 return decode_arff (fh )
359351 else :
360- with io . open (filename , encoding = "utf8" ) as fh :
352+ with open (filename , encoding = "utf8" ) as fh :
361353 return decode_arff (fh )
362354
363355 def _parse_data_from_arff (
@@ -405,12 +397,10 @@ def _parse_data_from_arff(
405397 # can be encoded into integers
406398 pd .factorize (type_ )[0 ]
407399 except ValueError :
408- raise ValueError (
409- "Categorical data needs to be numeric when " "using sparse ARFF."
410- )
400+ raise ValueError ("Categorical data needs to be numeric when using sparse ARFF." )
411401 # string can only be supported with pandas DataFrame
412402 elif type_ == "STRING" and self .format .lower () == "sparse_arff" :
413- raise ValueError ("Dataset containing strings is not supported " " with sparse ARFF." )
403+ raise ValueError ("Dataset containing strings is not supported with sparse ARFF." )
414404
415405 # infer the dtype from the ARFF header
416406 if isinstance (type_ , list ):
@@ -743,7 +733,7 @@ def get_data(
743733 to_exclude .extend (self .ignore_attribute )
744734
745735 if len (to_exclude ) > 0 :
746- logger .info ("Going to remove the following attributes:" " %s" % to_exclude )
736+ logger .info ("Going to remove the following attributes: %s" % to_exclude )
747737 keep = np .array (
748738 [True if column not in to_exclude else False for column in attribute_names ]
749739 )
@@ -810,6 +800,10 @@ def retrieve_class_labels(self, target_name: str = "class") -> Union[None, List[
810800 -------
811801 list
812802 """
803+ if self .features is None :
804+ raise ValueError (
805+ "retrieve_class_labels can only be called if feature information is available."
806+ )
813807 for feature in self .features .values ():
814808 if (feature .name == target_name ) and (feature .data_type == "nominal" ):
815809 return feature .nominal_values
@@ -938,18 +932,73 @@ def _to_dict(self) -> "OrderedDict[str, OrderedDict]":
938932 return data_container
939933
940934
941- def _check_qualities (qualities ):
942- if qualities is not None :
943- qualities_ = {}
944- for xmlquality in qualities :
945- name = xmlquality ["oml:name" ]
946- if xmlquality .get ("oml:value" , None ) is None :
947- value = float ("NaN" )
948- elif xmlquality ["oml:value" ] == "null" :
949- value = float ("NaN" )
950- else :
951- value = float (xmlquality ["oml:value" ])
952- qualities_ [name ] = value
953- return qualities_
954- else :
955- return None
935+ def _read_features (features_file : str ) -> Dict [int , OpenMLDataFeature ]:
936+ features_pickle_file = _get_features_pickle_file (features_file )
937+ try :
938+ with open (features_pickle_file , "rb" ) as fh_binary :
939+ features = pickle .load (fh_binary )
940+ except : # noqa E722
941+ with open (features_file , encoding = "utf8" ) as fh :
942+ features_xml_string = fh .read ()
943+ xml_dict = xmltodict .parse (
944+ features_xml_string , force_list = ("oml:feature" , "oml:nominal_value" )
945+ )
946+ features_xml = xml_dict ["oml:data_features" ]
947+
948+ features = {}
949+ for idx , xmlfeature in enumerate (features_xml ["oml:feature" ]):
950+ nr_missing = xmlfeature .get ("oml:number_of_missing_values" , 0 )
951+ feature = OpenMLDataFeature (
952+ int (xmlfeature ["oml:index" ]),
953+ xmlfeature ["oml:name" ],
954+ xmlfeature ["oml:data_type" ],
955+ xmlfeature .get ("oml:nominal_value" ),
956+ int (nr_missing ),
957+ )
958+ if idx != feature .index :
959+ raise ValueError ("Data features not provided in right order" )
960+ features [feature .index ] = feature
961+
962+ with open (features_pickle_file , "wb" ) as fh_binary :
963+ pickle .dump (features , fh_binary )
964+ return features
965+
966+
967+ def _get_features_pickle_file (features_file : str ) -> str :
968+ """This function only exists so it can be mocked during unit testing"""
969+ return features_file + ".pkl"
970+
971+
972+ def _read_qualities (qualities_file : str ) -> Dict [str , float ]:
973+ qualities_pickle_file = _get_qualities_pickle_file (qualities_file )
974+ try :
975+ with open (qualities_pickle_file , "rb" ) as fh_binary :
976+ qualities = pickle .load (fh_binary )
977+ except : # noqa E722
978+ with open (qualities_file , encoding = "utf8" ) as fh :
979+ qualities_xml = fh .read ()
980+ xml_as_dict = xmltodict .parse (qualities_xml , force_list = ("oml:quality" ,))
981+ qualities = xml_as_dict ["oml:data_qualities" ]["oml:quality" ]
982+ qualities = _check_qualities (qualities )
983+ with open (qualities_pickle_file , "wb" ) as fh_binary :
984+ pickle .dump (qualities , fh_binary )
985+ return qualities
986+
987+
988+ def _get_qualities_pickle_file (qualities_file : str ) -> str :
989+ """This function only exists so it can be mocked during unit testing"""
990+ return qualities_file + ".pkl"
991+
992+
993+ def _check_qualities (qualities : List [Dict [str , str ]]) -> Dict [str , float ]:
994+ qualities_ = {}
995+ for xmlquality in qualities :
996+ name = xmlquality ["oml:name" ]
997+ if xmlquality .get ("oml:value" , None ) is None :
998+ value = float ("NaN" )
999+ elif xmlquality ["oml:value" ] == "null" :
1000+ value = float ("NaN" )
1001+ else :
1002+ value = float (xmlquality ["oml:value" ])
1003+ qualities_ [name ] = value
1004+ return qualities_
0 commit comments