1515import pandas as pd
1616import pytest
1717import ntpath
18- from feast .sdk .resources .feature import Feature , Granularity , ValueType , Datastore
18+ from feast .sdk .resources .feature import Feature , Granularity , ValueType , \
19+ Datastore
1920from feast .sdk .importer import _create_feature , Importer
2021from feast .sdk .utils .gs_utils import is_gs_path
2122from feast .types .Granularity_pb2 import Granularity as Granularity_pb2
@@ -30,56 +31,60 @@ def test_from_csv(self):
3031 staging_location = "gs://test-bucket"
3132 id_column = "driver_id"
3233 feature_columns = ["avg_distance_completed" ,
33- "avg_customer_distance_completed" ]
34+ "avg_customer_distance_completed" ]
3435 timestamp_column = "ts"
3536
36- importer = Importer .from_csv (path = csv_path ,
37- entity = entity_name ,
38- granularity = feature_granularity ,
39- owner = owner ,
40- staging_location = staging_location ,
41- id_column = id_column ,
42- feature_columns = feature_columns ,
43- timestamp_column = timestamp_column )
37+ importer = Importer .from_csv (path = csv_path ,
38+ entity = entity_name ,
39+ granularity = feature_granularity ,
40+ owner = owner ,
41+ staging_location = staging_location ,
42+ id_column = id_column ,
43+ feature_columns = feature_columns ,
44+ timestamp_column = timestamp_column )
4445
4546 self ._validate_csv_importer (importer , csv_path , entity_name ,
46- feature_granularity , owner , staging_location , id_column ,
47- feature_columns , timestamp_column )
47+ feature_granularity , owner ,
48+ staging_location , id_column ,
49+ feature_columns , timestamp_column )
4850
4951 def test_from_csv_id_column_not_specified (self ):
5052 with pytest .raises (ValueError ,
51- match = "Column with name driver is not found" ) as e_info :
53+ match = "Column with name driver is not found" ):
5254 feature_columns = ["avg_distance_completed" ,
53- "avg_customer_distance_completed" ]
55+ "avg_customer_distance_completed" ]
5456 csv_path = "tests/data/driver_features.csv"
55- importer = Importer .from_csv (path = csv_path ,
56- entity = "driver" ,
57- granularity = Granularity .DAY ,
58- owner = "owner@feast.com" ,
59- staging_location = "gs://test-bucket" ,
60- feature_columns = feature_columns ,
61- timestamp_column = "ts" )
57+ Importer .from_csv (path = csv_path ,
58+ entity = "driver" ,
59+ granularity = Granularity .DAY ,
60+ owner = "owner@feast.com" ,
61+ staging_location = "gs://test-bucket" ,
62+ feature_columns = feature_columns ,
63+ timestamp_column = "ts" )
6264
6365 def test_from_csv_timestamp_column_not_specified (self ):
6466 feature_columns = ["avg_distance_completed" ,
65- "avg_customer_distance_completed" , "avg_distance_cancelled" ]
67+ "avg_customer_distance_completed" ,
68+ "avg_distance_cancelled" ]
6669 csv_path = "tests/data/driver_features.csv"
6770 entity_name = "driver"
6871 granularity = Granularity .DAY
6972 owner = "owner@feast.com"
7073 staging_location = "gs://test-bucket"
7174 id_column = "driver_id"
72- importer = Importer .from_csv (path = csv_path ,
73- entity = entity_name ,
74- granularity = granularity ,
75- owner = owner ,
76- staging_location = staging_location ,
77- id_column = id_column ,
78- feature_columns = feature_columns )
75+ importer = Importer .from_csv (path = csv_path ,
76+ entity = entity_name ,
77+ granularity = granularity ,
78+ owner = owner ,
79+ staging_location = staging_location ,
80+ id_column = id_column ,
81+ feature_columns = feature_columns )
7982
8083 self ._validate_csv_importer (importer , csv_path , entity_name ,
81- granularity , owner , staging_location = staging_location ,
82- id_column = id_column , feature_columns = feature_columns )
84+ granularity , owner ,
85+ staging_location = staging_location ,
86+ id_column = id_column ,
87+ feature_columns = feature_columns )
8388
8489 def test_from_csv_feature_columns_not_specified (self ):
8590 csv_path = "tests/data/driver_features.csv"
@@ -89,103 +94,109 @@ def test_from_csv_feature_columns_not_specified(self):
8994 staging_location = "gs://test-bucket"
9095 id_column = "driver_id"
9196 timestamp_column = "ts"
92- importer = Importer .from_csv (path = csv_path ,
93- entity = entity_name ,
94- granularity = granularity ,
95- owner = owner ,
96- staging_location = staging_location ,
97- id_column = id_column ,
98- timestamp_column = timestamp_column )
97+ importer = Importer .from_csv (path = csv_path ,
98+ entity = entity_name ,
99+ granularity = granularity ,
100+ owner = owner ,
101+ staging_location = staging_location ,
102+ id_column = id_column ,
103+ timestamp_column = timestamp_column )
99104
100105 self ._validate_csv_importer (importer , csv_path , entity_name ,
101- granularity , owner , staging_location = staging_location ,
102- id_column = id_column , timestamp_column = timestamp_column )
106+ granularity , owner ,
107+ staging_location = staging_location ,
108+ id_column = id_column ,
109+ timestamp_column = timestamp_column )
103110
104111 def test_from_csv_staging_location_not_specified (self ):
105112 with pytest .raises (ValueError ,
106- match = "Specify staging_location for importing local file/dataframe" ) as e_info :
113+ match = "Specify staging_location for importing local file/dataframe" ):
107114 feature_columns = ["avg_distance_completed" ,
108- "avg_customer_distance_completed" ]
115+ "avg_customer_distance_completed" ]
109116 csv_path = "tests/data/driver_features.csv"
110- importer = Importer .from_csv (path = csv_path ,
111- entity = "driver" ,
112- granularity = Granularity .DAY ,
113- owner = "owner@feast.com" ,
114- feature_columns = feature_columns ,
115- timestamp_column = "ts" )
117+ Importer .from_csv (path = csv_path ,
118+ entity = "driver" ,
119+ granularity = Granularity .DAY ,
120+ owner = "owner@feast.com" ,
121+ feature_columns = feature_columns ,
122+ timestamp_column = "ts" )
116123
117124 with pytest .raises (ValueError ,
118- match = "Staging location must be in GCS" ) as e_info :
125+ match = "Staging location must be in GCS" ) as e_info :
119126 feature_columns = ["avg_distance_completed" ,
120- "avg_customer_distance_completed" ]
127+ "avg_customer_distance_completed" ]
121128 csv_path = "tests/data/driver_features.csv"
122- importer = Importer .from_csv (path = csv_path ,
123- entity = "driver" ,
124- granularity = Granularity .DAY ,
125- owner = "owner@feast.com" ,
126- staging_location = "/home" ,
127- feature_columns = feature_columns ,
128- timestamp_column = "ts" )
129+ Importer .from_csv (path = csv_path ,
130+ entity = "driver" ,
131+ granularity = Granularity .DAY ,
132+ owner = "owner@feast.com" ,
133+ staging_location = "/home" ,
134+ feature_columns = feature_columns ,
135+ timestamp_column = "ts" )
129136
130137 def test_from_df (self ):
131138 csv_path = "tests/data/driver_features.csv"
132139 df = pd .read_csv (csv_path )
133140 staging_location = "gs://test-bucket"
134141 entity = "driver"
135142
136- importer = Importer .from_df (df = df ,
137- entity = entity ,
138- granularity = Granularity .DAY ,
139- owner = "owner@feast.com" ,
140- staging_location = staging_location ,
141- id_column = "driver_id" ,
142- timestamp_column = "ts" )
143-
143+ importer = Importer .from_df (df = df ,
144+ entity = entity ,
145+ granularity = Granularity .DAY ,
146+ owner = "owner@feast.com" ,
147+ staging_location = staging_location ,
148+ id_column = "driver_id" ,
149+ timestamp_column = "ts" )
144150
145151 assert importer .require_staging == True
146152 assert ("{}/tmp_{}" .format (staging_location , entity )
147- in importer .remote_path )
153+ in importer .remote_path )
148154 for feature in importer .features .values ():
149155 assert feature .name in df .columns
150156 assert feature .id == "driver.day." + feature .name
151157
152158 import_spec = importer .spec
153159 assert import_spec .type == "file"
154- assert import_spec .sourceOptions == {"format" : "csv" , "path" : importer .remote_path }
160+ assert import_spec .sourceOptions == {"format" : "csv" ,
161+ "path" : importer .remote_path }
155162 assert import_spec .entities == ["driver" ]
156163
157164 schema = import_spec .schema
158165 assert schema .entityIdColumn == "driver_id"
159166 assert schema .timestampValue is not None
160167 feature_columns = ["completed" , "avg_distance_completed" ,
161- "avg_customer_distance_completed" ,
162- "avg_distance_cancelled" ]
168+ "avg_customer_distance_completed" ,
169+ "avg_distance_cancelled" ]
163170 for col , field in zip (df .columns .values , schema .fields ):
164171 assert col == field .name
165172 if col in feature_columns :
166173 assert field .featureId == "driver.day." + col
167174
168175 def _validate_csv_importer (self ,
169- importer , csv_path , entity_name , feature_granularity , owner ,
170- staging_location = None , id_column = None , feature_columns = None ,
171- timestamp_column = None , timestamp_value = None ):
176+ importer , csv_path , entity_name ,
177+ feature_granularity , owner ,
178+ staging_location = None , id_column = None ,
179+ feature_columns = None ,
180+ timestamp_column = None , timestamp_value = None ):
172181 df = pd .read_csv (csv_path )
173182 assert not importer .require_staging == is_gs_path (csv_path )
174183 if importer .require_staging :
175184 assert importer .remote_path == "{}/{}" .format (staging_location ,
176- ntpath .basename (csv_path ))
185+ ntpath .basename (
186+ csv_path ))
177187
178188 # check features created
179189 for feature in importer .features .values ():
180190 assert feature .name in df .columns
181191 assert feature .id == "{}.{}.{}" .format (entity_name ,
182- Granularity_pb2 .Enum .Name (feature_granularity .value ).lower (),
183- feature .name )
192+ Granularity_pb2 .Enum .Name (
193+ feature_granularity .value ).lower (),
194+ feature .name )
184195
185196 import_spec = importer .spec
186- assert import_spec .type == "file"
197+ assert import_spec .type == "file.csv "
187198 path = importer .remote_path if importer .require_staging else csv_path
188- assert import_spec .sourceOptions == {"format" : "csv" , " path" : path }
199+ assert import_spec .sourceOptions == {"path" : path }
189200 assert import_spec .entities == [entity_name ]
190201
191202 schema = import_spec .schema
@@ -204,19 +215,23 @@ def _validate_csv_importer(self,
204215 for col , field in zip (df .columns .values , schema .fields ):
205216 assert col == field .name
206217 if col in feature_columns :
207- assert field .featureId == "{}.{}.{}" .format (entity_name ,
208- Granularity_pb2 .Enum .Name (feature_granularity .value ).lower (), col )
218+ assert field .featureId == \
219+ "{}.{}.{}" .format (entity_name ,
220+ Granularity_pb2 .Enum .Name (
221+ feature_granularity .value ).lower (),
222+ col )
209223
210224
211225class TestHelpers :
212226 def test_create_feature (self ):
213- col = pd .Series ([1 ]* 3 , dtype = 'int32' ,name = "test" )
227+ col = pd .Series ([1 ] * 3 , dtype = 'int32' , name = "test" )
214228 expected = Feature (name = "test" ,
215- entity = "test" ,
216- granularity = Granularity .NONE ,
217- owner = "person" ,
218- value_type = ValueType .INT32 )
219- actual = _create_feature (col , "test" , Granularity .NONE , "person" , None , None )
229+ entity = "test" ,
230+ granularity = Granularity .NONE ,
231+ owner = "person" ,
232+ value_type = ValueType .INT32 )
233+ actual = _create_feature (col , "test" , Granularity .NONE , "person" , None ,
234+ None )
220235 assert actual .id == expected .id
221236 assert actual .value_type == expected .value_type
222237 assert actual .owner == expected .owner
@@ -231,7 +246,8 @@ def test_create_feature_with_stores(self):
231246 serving_store = Datastore (id = "SERVING" ),
232247 warehouse_store = Datastore (id = "WAREHOUSE" ))
233248 actual = _create_feature (col , "test" , Granularity .NONE , "person" ,
234- Datastore (id = "SERVING" ), Datastore (id = "WAREHOUSE" ))
249+ Datastore (id = "SERVING" ),
250+ Datastore (id = "WAREHOUSE" ))
235251 assert actual .id == expected .id
236252 assert actual .value_type == expected .value_type
237253 assert actual .owner == expected .owner
0 commit comments