Skip to content

Commit 6afc880

Browse files
authored
Updated the way 'image features' are stored, updated old unit tests, added unit test, fixed typo (#983)
1 parent 3132dac commit 6afc880

File tree

4 files changed

+44
-19
lines changed

4 files changed

+44
-19
lines changed

doc/progress.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Changelog
88

99
0.11.1
1010
~~~~~~
11+
* MAINT #891: Changed the way that numerical features are stored. Numerical features that range from 0 to 255 are now stored as uint8, which reduces the storage space required as well as storing and loading times.
1112
* MAINT #671: Improved the performance of ``check_datasets_active`` by only querying the given list of datasets in contrast to querying all datasets. Modified the corresponding unit test.
1213

1314
0.11.0

openml/datasets/dataset.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def _parse_data_from_arff(
407407
categories_names = {}
408408
categorical = []
409409
for i, (name, type_) in enumerate(data["attributes"]):
410-
# if the feature is nominal and the a sparse matrix is
410+
# if the feature is nominal and a sparse matrix is
411411
# requested, the categories need to be numeric
412412
if isinstance(type_, list) and self.format.lower() == "sparse_arff":
413413
try:
@@ -456,6 +456,18 @@ def _parse_data_from_arff(
456456
col.append(
457457
self._unpack_categories(X[column_name], categories_names[column_name])
458458
)
459+
elif attribute_dtype[column_name] in ('floating',
460+
'integer'):
461+
X_col = X[column_name]
462+
if X_col.min() >= 0 and X_col.max() <= 255:
463+
try:
464+
X_col_uint = X_col.astype('uint8')
465+
if (X_col == X_col_uint).all():
466+
col.append(X_col_uint)
467+
continue
468+
except ValueError:
469+
pass
470+
col.append(X[column_name])
459471
else:
460472
col.append(X[column_name])
461473
X = pd.concat(col, axis=1)

tests/test_datasets/test_dataset.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,13 @@ def test_get_data_pandas(self):
7272
self.assertEqual(data.shape[1], len(self.titanic.features))
7373
self.assertEqual(data.shape[0], 1309)
7474
col_dtype = {
75-
"pclass": "float64",
75+
"pclass": "uint8",
7676
"survived": "category",
7777
"name": "object",
7878
"sex": "category",
7979
"age": "float64",
80-
"sibsp": "float64",
81-
"parch": "float64",
80+
"sibsp": "uint8",
81+
"parch": "uint8",
8282
"ticket": "object",
8383
"fare": "float64",
8484
"cabin": "object",
@@ -118,21 +118,29 @@ def test_get_data_no_str_data_for_nparrays(self):
118118
with pytest.raises(PyOpenMLError, match=err_msg):
119119
self.titanic.get_data(dataset_format="array")
120120

121+
def _check_expected_type(self, dtype, is_cat, col):
122+
if is_cat:
123+
expected_type = 'category'
124+
elif not col.isna().any() and (col.astype('uint8') == col).all():
125+
expected_type = 'uint8'
126+
else:
127+
expected_type = 'float64'
128+
129+
self.assertEqual(dtype.name, expected_type)
130+
121131
def test_get_data_with_rowid(self):
122132
self.dataset.row_id_attribute = "condition"
123133
rval, _, categorical, _ = self.dataset.get_data(include_row_id=True)
124134
self.assertIsInstance(rval, pd.DataFrame)
125-
for (dtype, is_cat) in zip(rval.dtypes, categorical):
126-
expected_type = "category" if is_cat else "float64"
127-
self.assertEqual(dtype.name, expected_type)
135+
for (dtype, is_cat, col) in zip(rval.dtypes, categorical, rval):
136+
self._check_expected_type(dtype, is_cat, rval[col])
128137
self.assertEqual(rval.shape, (898, 39))
129138
self.assertEqual(len(categorical), 39)
130139

131140
rval, _, categorical, _ = self.dataset.get_data()
132141
self.assertIsInstance(rval, pd.DataFrame)
133-
for (dtype, is_cat) in zip(rval.dtypes, categorical):
134-
expected_type = "category" if is_cat else "float64"
135-
self.assertEqual(dtype.name, expected_type)
142+
for (dtype, is_cat, col) in zip(rval.dtypes, categorical, rval):
143+
self._check_expected_type(dtype, is_cat, rval[col])
136144
self.assertEqual(rval.shape, (898, 38))
137145
self.assertEqual(len(categorical), 38)
138146

@@ -149,9 +157,8 @@ def test_get_data_with_target_array(self):
149157
def test_get_data_with_target_pandas(self):
150158
X, y, categorical, attribute_names = self.dataset.get_data(target="class")
151159
self.assertIsInstance(X, pd.DataFrame)
152-
for (dtype, is_cat) in zip(X.dtypes, categorical):
153-
expected_type = "category" if is_cat else "float64"
154-
self.assertEqual(dtype.name, expected_type)
160+
for (dtype, is_cat, col) in zip(X.dtypes, categorical, X):
161+
self._check_expected_type(dtype, is_cat, X[col])
155162
self.assertIsInstance(y, pd.Series)
156163
self.assertEqual(y.dtype.name, "category")
157164

@@ -174,16 +181,14 @@ def test_get_data_rowid_and_ignore_and_target(self):
174181
def test_get_data_with_ignore_attributes(self):
175182
self.dataset.ignore_attribute = ["condition"]
176183
rval, _, categorical, _ = self.dataset.get_data(include_ignore_attribute=True)
177-
for (dtype, is_cat) in zip(rval.dtypes, categorical):
178-
expected_type = "category" if is_cat else "float64"
179-
self.assertEqual(dtype.name, expected_type)
184+
for (dtype, is_cat, col) in zip(rval.dtypes, categorical, rval):
185+
self._check_expected_type(dtype, is_cat, rval[col])
180186
self.assertEqual(rval.shape, (898, 39))
181187
self.assertEqual(len(categorical), 39)
182188

183189
rval, _, categorical, _ = self.dataset.get_data(include_ignore_attribute=False)
184-
for (dtype, is_cat) in zip(rval.dtypes, categorical):
185-
expected_type = "category" if is_cat else "float64"
186-
self.assertEqual(dtype.name, expected_type)
190+
for (dtype, is_cat, col) in zip(rval.dtypes, categorical, rval):
191+
self._check_expected_type(dtype, is_cat, rval[col])
187192
self.assertEqual(rval.shape, (898, 38))
188193
self.assertEqual(len(categorical), 38)
189194

tests/test_datasets/test_dataset_functions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,13 @@ def test_get_dataset_by_name(self):
373373
openml.config.server = self.production_server
374374
self.assertRaises(OpenMLPrivateDatasetError, openml.datasets.get_dataset, 45)
375375

376+
def test_get_dataset_uint8_dtype(self):
377+
dataset = openml.datasets.get_dataset(1)
378+
self.assertEqual(type(dataset), OpenMLDataset)
379+
self.assertEqual(dataset.name, 'anneal')
380+
df, _, _, _ = dataset.get_data()
381+
self.assertEqual(df['carbon'].dtype, 'uint8')
382+
376383
def test_get_dataset(self):
377384
# This is the only non-lazy load to ensure default behaviour works.
378385
dataset = openml.datasets.get_dataset(1)

0 commit comments

Comments
 (0)