@@ -72,13 +72,13 @@ def test_get_data_pandas(self):
7272 self .assertEqual (data .shape [1 ], len (self .titanic .features ))
7373 self .assertEqual (data .shape [0 ], 1309 )
7474 col_dtype = {
75- "pclass" : "float64 " ,
75+ "pclass" : "uint8 " ,
7676 "survived" : "category" ,
7777 "name" : "object" ,
7878 "sex" : "category" ,
7979 "age" : "float64" ,
80- "sibsp" : "float64 " ,
81- "parch" : "float64 " ,
80+ "sibsp" : "uint8 " ,
81+ "parch" : "uint8 " ,
8282 "ticket" : "object" ,
8383 "fare" : "float64" ,
8484 "cabin" : "object" ,
@@ -118,21 +118,29 @@ def test_get_data_no_str_data_for_nparrays(self):
118118 with pytest .raises (PyOpenMLError , match = err_msg ):
119119 self .titanic .get_data (dataset_format = "array" )
120120
121+ def _check_expected_type (self , dtype , is_cat , col ):
122+ if is_cat :
123+ expected_type = 'category'
124+ elif not col .isna ().any () and (col .astype ('uint8' ) == col ).all ():
125+ expected_type = 'uint8'
126+ else :
127+ expected_type = 'float64'
128+
129+ self .assertEqual (dtype .name , expected_type )
130+
121131 def test_get_data_with_rowid (self ):
122132 self .dataset .row_id_attribute = "condition"
123133 rval , _ , categorical , _ = self .dataset .get_data (include_row_id = True )
124134 self .assertIsInstance (rval , pd .DataFrame )
125- for (dtype , is_cat ) in zip (rval .dtypes , categorical ):
126- expected_type = "category" if is_cat else "float64"
127- self .assertEqual (dtype .name , expected_type )
135+ for (dtype , is_cat , col ) in zip (rval .dtypes , categorical , rval ):
136+ self ._check_expected_type (dtype , is_cat , rval [col ])
128137 self .assertEqual (rval .shape , (898 , 39 ))
129138 self .assertEqual (len (categorical ), 39 )
130139
131140 rval , _ , categorical , _ = self .dataset .get_data ()
132141 self .assertIsInstance (rval , pd .DataFrame )
133- for (dtype , is_cat ) in zip (rval .dtypes , categorical ):
134- expected_type = "category" if is_cat else "float64"
135- self .assertEqual (dtype .name , expected_type )
142+ for (dtype , is_cat , col ) in zip (rval .dtypes , categorical , rval ):
143+ self ._check_expected_type (dtype , is_cat , rval [col ])
136144 self .assertEqual (rval .shape , (898 , 38 ))
137145 self .assertEqual (len (categorical ), 38 )
138146
@@ -149,9 +157,8 @@ def test_get_data_with_target_array(self):
149157 def test_get_data_with_target_pandas (self ):
150158 X , y , categorical , attribute_names = self .dataset .get_data (target = "class" )
151159 self .assertIsInstance (X , pd .DataFrame )
152- for (dtype , is_cat ) in zip (X .dtypes , categorical ):
153- expected_type = "category" if is_cat else "float64"
154- self .assertEqual (dtype .name , expected_type )
160+ for (dtype , is_cat , col ) in zip (X .dtypes , categorical , X ):
161+ self ._check_expected_type (dtype , is_cat , X [col ])
155162 self .assertIsInstance (y , pd .Series )
156163 self .assertEqual (y .dtype .name , "category" )
157164
@@ -174,16 +181,14 @@ def test_get_data_rowid_and_ignore_and_target(self):
174181 def test_get_data_with_ignore_attributes (self ):
175182 self .dataset .ignore_attribute = ["condition" ]
176183 rval , _ , categorical , _ = self .dataset .get_data (include_ignore_attribute = True )
177- for (dtype , is_cat ) in zip (rval .dtypes , categorical ):
178- expected_type = "category" if is_cat else "float64"
179- self .assertEqual (dtype .name , expected_type )
184+ for (dtype , is_cat , col ) in zip (rval .dtypes , categorical , rval ):
185+ self ._check_expected_type (dtype , is_cat , rval [col ])
180186 self .assertEqual (rval .shape , (898 , 39 ))
181187 self .assertEqual (len (categorical ), 39 )
182188
183189 rval , _ , categorical , _ = self .dataset .get_data (include_ignore_attribute = False )
184- for (dtype , is_cat ) in zip (rval .dtypes , categorical ):
185- expected_type = "category" if is_cat else "float64"
186- self .assertEqual (dtype .name , expected_type )
190+ for (dtype , is_cat , col ) in zip (rval .dtypes , categorical , rval ):
191+ self ._check_expected_type (dtype , is_cat , rval [col ])
187192 self .assertEqual (rval .shape , (898 , 38 ))
188193 self .assertEqual (len (categorical ), 38 )
189194
0 commit comments