Skip to content

Commit 560e952

Browse files
mfeurerPGijsbers
andauthored
Cache dataset features and qualities as pickle (#979)
* cache dataset features and qualities as pickle * incorporate feedback * Fix unit tests * black, pep8 etc * Remove unused imports Co-authored-by: PGijsbers <p.gijsbers@tue.nl>
1 parent accde88 commit 560e952

5 files changed

Lines changed: 184 additions & 265 deletions

File tree

openml/datasets/data_feature.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# License: BSD 3-Clause
22

3+
from typing import List
4+
35

46
class OpenMLDataFeature(object):
57
"""
@@ -20,7 +22,14 @@ class OpenMLDataFeature(object):
2022

2123
LEGAL_DATA_TYPES = ["nominal", "numeric", "string", "date"]
2224

23-
def __init__(self, index, name, data_type, nominal_values, number_missing_values):
25+
def __init__(
26+
self,
27+
index: int,
28+
name: str,
29+
data_type: str,
30+
nominal_values: List[str],
31+
number_missing_values: int,
32+
):
2433
if type(index) != int:
2534
raise ValueError("Index is of wrong datatype")
2635
if data_type not in self.LEGAL_DATA_TYPES:

openml/datasets/dataset.py

Lines changed: 96 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from collections import OrderedDict
44
import re
55
import gzip
6-
import io
76
import logging
87
import os
98
import pickle
@@ -13,6 +12,7 @@
1312
import numpy as np
1413
import pandas as pd
1514
import scipy.sparse
15+
import xmltodict
1616

1717
from openml.base import OpenMLBase
1818
from .data_feature import OpenMLDataFeature
@@ -125,8 +125,8 @@ def __init__(
125125
update_comment=None,
126126
md5_checksum=None,
127127
data_file=None,
128-
features=None,
129-
qualities=None,
128+
features_file: Optional[str] = None,
129+
qualities_file: Optional[str] = None,
130130
dataset=None,
131131
):
132132
def find_invalid_characters(string, pattern):
@@ -188,7 +188,7 @@ def find_invalid_characters(string, pattern):
188188
self.default_target_attribute = default_target_attribute
189189
self.row_id_attribute = row_id_attribute
190190
if isinstance(ignore_attribute, str):
191-
self.ignore_attribute = [ignore_attribute]
191+
self.ignore_attribute = [ignore_attribute] # type: Optional[List[str]]
192192
elif isinstance(ignore_attribute, list) or ignore_attribute is None:
193193
self.ignore_attribute = ignore_attribute
194194
else:
@@ -202,33 +202,25 @@ def find_invalid_characters(string, pattern):
202202
self.update_comment = update_comment
203203
self.md5_checksum = md5_checksum
204204
self.data_file = data_file
205-
self.features = None
206-
self.qualities = None
207205
self._dataset = dataset
208206

209-
if features is not None:
210-
self.features = {}
211-
for idx, xmlfeature in enumerate(features["oml:feature"]):
212-
nr_missing = xmlfeature.get("oml:number_of_missing_values", 0)
213-
feature = OpenMLDataFeature(
214-
int(xmlfeature["oml:index"]),
215-
xmlfeature["oml:name"],
216-
xmlfeature["oml:data_type"],
217-
xmlfeature.get("oml:nominal_value"),
218-
int(nr_missing),
219-
)
220-
if idx != feature.index:
221-
raise ValueError("Data features not provided " "in right order")
222-
self.features[feature.index] = feature
207+
if features_file is not None:
208+
self.features = _read_features(
209+
features_file
210+
) # type: Optional[Dict[int, OpenMLDataFeature]]
211+
else:
212+
self.features = None
223213

224-
self.qualities = _check_qualities(qualities)
214+
if qualities_file:
215+
self.qualities = _read_qualities(qualities_file) # type: Optional[Dict[str, float]]
216+
else:
217+
self.qualities = None
225218

226219
if data_file is not None:
227-
(
228-
self.data_pickle_file,
229-
self.data_feather_file,
230-
self.feather_attribute_file,
231-
) = self._create_pickle_in_cache(data_file)
220+
rval = self._create_pickle_in_cache(data_file)
221+
self.data_pickle_file = rval[0] # type: Optional[str]
222+
self.data_feather_file = rval[1] # type: Optional[str]
223+
self.feather_attribute_file = rval[2] # type: Optional[str]
232224
else:
233225
self.data_pickle_file, self.data_feather_file, self.feather_attribute_file = (
234226
None,
@@ -357,7 +349,7 @@ def decode_arff(fh):
357349
with gzip.open(filename) as fh:
358350
return decode_arff(fh)
359351
else:
360-
with io.open(filename, encoding="utf8") as fh:
352+
with open(filename, encoding="utf8") as fh:
361353
return decode_arff(fh)
362354

363355
def _parse_data_from_arff(
@@ -405,12 +397,10 @@ def _parse_data_from_arff(
405397
# can be encoded into integers
406398
pd.factorize(type_)[0]
407399
except ValueError:
408-
raise ValueError(
409-
"Categorical data needs to be numeric when " "using sparse ARFF."
410-
)
400+
raise ValueError("Categorical data needs to be numeric when using sparse ARFF.")
411401
# string can only be supported with pandas DataFrame
412402
elif type_ == "STRING" and self.format.lower() == "sparse_arff":
413-
raise ValueError("Dataset containing strings is not supported " "with sparse ARFF.")
403+
raise ValueError("Dataset containing strings is not supported with sparse ARFF.")
414404

415405
# infer the dtype from the ARFF header
416406
if isinstance(type_, list):
@@ -743,7 +733,7 @@ def get_data(
743733
to_exclude.extend(self.ignore_attribute)
744734

745735
if len(to_exclude) > 0:
746-
logger.info("Going to remove the following attributes:" " %s" % to_exclude)
736+
logger.info("Going to remove the following attributes: %s" % to_exclude)
747737
keep = np.array(
748738
[True if column not in to_exclude else False for column in attribute_names]
749739
)
@@ -810,6 +800,10 @@ def retrieve_class_labels(self, target_name: str = "class") -> Union[None, List[
810800
-------
811801
list
812802
"""
803+
if self.features is None:
804+
raise ValueError(
805+
"retrieve_class_labels can only be called if feature information is available."
806+
)
813807
for feature in self.features.values():
814808
if (feature.name == target_name) and (feature.data_type == "nominal"):
815809
return feature.nominal_values
@@ -938,18 +932,73 @@ def _to_dict(self) -> "OrderedDict[str, OrderedDict]":
938932
return data_container
939933

940934

941-
def _check_qualities(qualities):
942-
if qualities is not None:
943-
qualities_ = {}
944-
for xmlquality in qualities:
945-
name = xmlquality["oml:name"]
946-
if xmlquality.get("oml:value", None) is None:
947-
value = float("NaN")
948-
elif xmlquality["oml:value"] == "null":
949-
value = float("NaN")
950-
else:
951-
value = float(xmlquality["oml:value"])
952-
qualities_[name] = value
953-
return qualities_
954-
else:
955-
return None
935+
def _read_features(features_file: str) -> Dict[int, OpenMLDataFeature]:
936+
features_pickle_file = _get_features_pickle_file(features_file)
937+
try:
938+
with open(features_pickle_file, "rb") as fh_binary:
939+
features = pickle.load(fh_binary)
940+
except: # noqa E722
941+
with open(features_file, encoding="utf8") as fh:
942+
features_xml_string = fh.read()
943+
xml_dict = xmltodict.parse(
944+
features_xml_string, force_list=("oml:feature", "oml:nominal_value")
945+
)
946+
features_xml = xml_dict["oml:data_features"]
947+
948+
features = {}
949+
for idx, xmlfeature in enumerate(features_xml["oml:feature"]):
950+
nr_missing = xmlfeature.get("oml:number_of_missing_values", 0)
951+
feature = OpenMLDataFeature(
952+
int(xmlfeature["oml:index"]),
953+
xmlfeature["oml:name"],
954+
xmlfeature["oml:data_type"],
955+
xmlfeature.get("oml:nominal_value"),
956+
int(nr_missing),
957+
)
958+
if idx != feature.index:
959+
raise ValueError("Data features not provided in right order")
960+
features[feature.index] = feature
961+
962+
with open(features_pickle_file, "wb") as fh_binary:
963+
pickle.dump(features, fh_binary)
964+
return features
965+
966+
967+
def _get_features_pickle_file(features_file: str) -> str:
968+
"""This function only exists so it can be mocked during unit testing"""
969+
return features_file + ".pkl"
970+
971+
972+
def _read_qualities(qualities_file: str) -> Dict[str, float]:
973+
qualities_pickle_file = _get_qualities_pickle_file(qualities_file)
974+
try:
975+
with open(qualities_pickle_file, "rb") as fh_binary:
976+
qualities = pickle.load(fh_binary)
977+
except: # noqa E722
978+
with open(qualities_file, encoding="utf8") as fh:
979+
qualities_xml = fh.read()
980+
xml_as_dict = xmltodict.parse(qualities_xml, force_list=("oml:quality",))
981+
qualities = xml_as_dict["oml:data_qualities"]["oml:quality"]
982+
qualities = _check_qualities(qualities)
983+
with open(qualities_pickle_file, "wb") as fh_binary:
984+
pickle.dump(qualities, fh_binary)
985+
return qualities
986+
987+
988+
def _get_qualities_pickle_file(qualities_file: str) -> str:
989+
"""This function only exists so it can be mocked during unit testing"""
990+
return qualities_file + ".pkl"
991+
992+
993+
def _check_qualities(qualities: List[Dict[str, str]]) -> Dict[str, float]:
994+
qualities_ = {}
995+
for xmlquality in qualities:
996+
name = xmlquality["oml:name"]
997+
if xmlquality.get("oml:value", None) is None:
998+
value = float("NaN")
999+
elif xmlquality["oml:value"] == "null":
1000+
value = float("NaN")
1001+
else:
1002+
value = float(xmlquality["oml:value"])
1003+
qualities_[name] = value
1004+
return qualities_

0 commit comments

Comments
 (0)