Skip to content

Commit 6cdd173

Browse files
pradithyafeast-ci-bot
authored andcommitted
Fix import spec created from Importer.from_csv (feast-dev#143)
1 parent 294a6ec commit 6cdd173

2 files changed

Lines changed: 107 additions & 89 deletions

File tree

sdk/python/feast/sdk/importer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ def from_csv(cls, path, entity, granularity, owner, staging_location=None,
106106
Returns:
107107
Importer: the importer for the dataset provided.
108108
"""
109-
source_options = {"format": "csv"}
109+
src_type = "file.csv"
110+
source_options = {}
110111
source_options["path"], require_staging = \
111112
_get_remote_location(path, staging_location)
112113
if is_gs_path(path):
@@ -118,9 +119,10 @@ def from_csv(cls, path, entity, granularity, owner, staging_location=None,
118119
feature_columns, timestamp_column,
119120
timestamp_value, serving_store,
120121
warehouse_store, df)
121-
iport_spec = _create_import("file", source_options, job_options, entity, schema)
122+
iport_spec = _create_import(src_type, source_options, job_options,
123+
entity, schema)
122124

123-
props = (_properties("csv", len(df.index), require_staging,
125+
props = (_properties(src_type, len(df.index), require_staging,
124126
source_options["path"]))
125127
specs = _specs(iport_spec, Entity(name=entity), features)
126128

sdk/python/tests/sdk/test_importer.py

Lines changed: 102 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
import pandas as pd
1616
import pytest
1717
import ntpath
18-
from feast.sdk.resources.feature import Feature, Granularity, ValueType, Datastore
18+
from feast.sdk.resources.feature import Feature, Granularity, ValueType, \
19+
Datastore
1920
from feast.sdk.importer import _create_feature, Importer
2021
from feast.sdk.utils.gs_utils import is_gs_path
2122
from feast.types.Granularity_pb2 import Granularity as Granularity_pb2
@@ -30,56 +31,60 @@ def test_from_csv(self):
3031
staging_location = "gs://test-bucket"
3132
id_column = "driver_id"
3233
feature_columns = ["avg_distance_completed",
33-
"avg_customer_distance_completed"]
34+
"avg_customer_distance_completed"]
3435
timestamp_column = "ts"
3536

36-
importer = Importer.from_csv(path = csv_path,
37-
entity = entity_name,
38-
granularity = feature_granularity,
39-
owner = owner,
40-
staging_location=staging_location,
41-
id_column = id_column,
42-
feature_columns=feature_columns,
43-
timestamp_column=timestamp_column)
37+
importer = Importer.from_csv(path=csv_path,
38+
entity=entity_name,
39+
granularity=feature_granularity,
40+
owner=owner,
41+
staging_location=staging_location,
42+
id_column=id_column,
43+
feature_columns=feature_columns,
44+
timestamp_column=timestamp_column)
4445

4546
self._validate_csv_importer(importer, csv_path, entity_name,
46-
feature_granularity, owner, staging_location, id_column,
47-
feature_columns, timestamp_column)
47+
feature_granularity, owner,
48+
staging_location, id_column,
49+
feature_columns, timestamp_column)
4850

4951
def test_from_csv_id_column_not_specified(self):
5052
with pytest.raises(ValueError,
51-
match="Column with name driver is not found") as e_info:
53+
match="Column with name driver is not found"):
5254
feature_columns = ["avg_distance_completed",
53-
"avg_customer_distance_completed"]
55+
"avg_customer_distance_completed"]
5456
csv_path = "tests/data/driver_features.csv"
55-
importer = Importer.from_csv(path = csv_path,
56-
entity = "driver",
57-
granularity = Granularity.DAY,
58-
owner = "owner@feast.com",
59-
staging_location="gs://test-bucket",
60-
feature_columns=feature_columns,
61-
timestamp_column="ts")
57+
Importer.from_csv(path=csv_path,
58+
entity="driver",
59+
granularity=Granularity.DAY,
60+
owner="owner@feast.com",
61+
staging_location="gs://test-bucket",
62+
feature_columns=feature_columns,
63+
timestamp_column="ts")
6264

6365
def test_from_csv_timestamp_column_not_specified(self):
6466
feature_columns = ["avg_distance_completed",
65-
"avg_customer_distance_completed", "avg_distance_cancelled"]
67+
"avg_customer_distance_completed",
68+
"avg_distance_cancelled"]
6669
csv_path = "tests/data/driver_features.csv"
6770
entity_name = "driver"
6871
granularity = Granularity.DAY
6972
owner = "owner@feast.com"
7073
staging_location = "gs://test-bucket"
7174
id_column = "driver_id"
72-
importer = Importer.from_csv(path = csv_path,
73-
entity = entity_name,
74-
granularity = granularity,
75-
owner = owner,
76-
staging_location=staging_location,
77-
id_column = id_column,
78-
feature_columns= feature_columns)
75+
importer = Importer.from_csv(path=csv_path,
76+
entity=entity_name,
77+
granularity=granularity,
78+
owner=owner,
79+
staging_location=staging_location,
80+
id_column=id_column,
81+
feature_columns=feature_columns)
7982

8083
self._validate_csv_importer(importer, csv_path, entity_name,
81-
granularity, owner, staging_location = staging_location,
82-
id_column=id_column, feature_columns=feature_columns)
84+
granularity, owner,
85+
staging_location=staging_location,
86+
id_column=id_column,
87+
feature_columns=feature_columns)
8388

8489
def test_from_csv_feature_columns_not_specified(self):
8590
csv_path = "tests/data/driver_features.csv"
@@ -89,103 +94,109 @@ def test_from_csv_feature_columns_not_specified(self):
8994
staging_location = "gs://test-bucket"
9095
id_column = "driver_id"
9196
timestamp_column = "ts"
92-
importer = Importer.from_csv(path = csv_path,
93-
entity = entity_name,
94-
granularity = granularity,
95-
owner = owner,
96-
staging_location=staging_location,
97-
id_column = id_column,
98-
timestamp_column=timestamp_column)
97+
importer = Importer.from_csv(path=csv_path,
98+
entity=entity_name,
99+
granularity=granularity,
100+
owner=owner,
101+
staging_location=staging_location,
102+
id_column=id_column,
103+
timestamp_column=timestamp_column)
99104

100105
self._validate_csv_importer(importer, csv_path, entity_name,
101-
granularity, owner, staging_location = staging_location,
102-
id_column=id_column, timestamp_column=timestamp_column)
106+
granularity, owner,
107+
staging_location=staging_location,
108+
id_column=id_column,
109+
timestamp_column=timestamp_column)
103110

104111
def test_from_csv_staging_location_not_specified(self):
105112
with pytest.raises(ValueError,
106-
match="Specify staging_location for importing local file/dataframe") as e_info:
113+
match="Specify staging_location for importing local file/dataframe"):
107114
feature_columns = ["avg_distance_completed",
108-
"avg_customer_distance_completed"]
115+
"avg_customer_distance_completed"]
109116
csv_path = "tests/data/driver_features.csv"
110-
importer = Importer.from_csv(path = csv_path,
111-
entity = "driver",
112-
granularity = Granularity.DAY,
113-
owner = "owner@feast.com",
114-
feature_columns=feature_columns,
115-
timestamp_column="ts")
117+
Importer.from_csv(path=csv_path,
118+
entity="driver",
119+
granularity=Granularity.DAY,
120+
owner="owner@feast.com",
121+
feature_columns=feature_columns,
122+
timestamp_column="ts")
116123

117124
with pytest.raises(ValueError,
118-
match="Staging location must be in GCS") as e_info:
125+
match="Staging location must be in GCS") as e_info:
119126
feature_columns = ["avg_distance_completed",
120-
"avg_customer_distance_completed"]
127+
"avg_customer_distance_completed"]
121128
csv_path = "tests/data/driver_features.csv"
122-
importer = Importer.from_csv(path = csv_path,
123-
entity = "driver",
124-
granularity = Granularity.DAY,
125-
owner = "owner@feast.com",
126-
staging_location = "/home",
127-
feature_columns=feature_columns,
128-
timestamp_column="ts")
129+
Importer.from_csv(path=csv_path,
130+
entity="driver",
131+
granularity=Granularity.DAY,
132+
owner="owner@feast.com",
133+
staging_location="/home",
134+
feature_columns=feature_columns,
135+
timestamp_column="ts")
129136

130137
def test_from_df(self):
131138
csv_path = "tests/data/driver_features.csv"
132139
df = pd.read_csv(csv_path)
133140
staging_location = "gs://test-bucket"
134141
entity = "driver"
135142

136-
importer = Importer.from_df(df = df,
137-
entity = entity,
138-
granularity = Granularity.DAY,
139-
owner = "owner@feast.com",
140-
staging_location=staging_location,
141-
id_column = "driver_id",
142-
timestamp_column="ts")
143-
143+
importer = Importer.from_df(df=df,
144+
entity=entity,
145+
granularity=Granularity.DAY,
146+
owner="owner@feast.com",
147+
staging_location=staging_location,
148+
id_column="driver_id",
149+
timestamp_column="ts")
144150

145151
assert importer.require_staging == True
146152
assert ("{}/tmp_{}".format(staging_location, entity)
147-
in importer.remote_path)
153+
in importer.remote_path)
148154
for feature in importer.features.values():
149155
assert feature.name in df.columns
150156
assert feature.id == "driver.day." + feature.name
151157

152158
import_spec = importer.spec
153159
assert import_spec.type == "file"
154-
assert import_spec.sourceOptions == {"format" : "csv", "path" : importer.remote_path}
160+
assert import_spec.sourceOptions == {"format": "csv",
161+
"path": importer.remote_path}
155162
assert import_spec.entities == ["driver"]
156163

157164
schema = import_spec.schema
158165
assert schema.entityIdColumn == "driver_id"
159166
assert schema.timestampValue is not None
160167
feature_columns = ["completed", "avg_distance_completed",
161-
"avg_customer_distance_completed",
162-
"avg_distance_cancelled"]
168+
"avg_customer_distance_completed",
169+
"avg_distance_cancelled"]
163170
for col, field in zip(df.columns.values, schema.fields):
164171
assert col == field.name
165172
if col in feature_columns:
166173
assert field.featureId == "driver.day." + col
167174

168175
def _validate_csv_importer(self,
169-
importer, csv_path, entity_name, feature_granularity, owner,
170-
staging_location = None, id_column = None, feature_columns = None,
171-
timestamp_column = None, timestamp_value = None):
176+
importer, csv_path, entity_name,
177+
feature_granularity, owner,
178+
staging_location=None, id_column=None,
179+
feature_columns=None,
180+
timestamp_column=None, timestamp_value=None):
172181
df = pd.read_csv(csv_path)
173182
assert not importer.require_staging == is_gs_path(csv_path)
174183
if importer.require_staging:
175184
assert importer.remote_path == "{}/{}".format(staging_location,
176-
ntpath.basename(csv_path))
185+
ntpath.basename(
186+
csv_path))
177187

178188
# check features created
179189
for feature in importer.features.values():
180190
assert feature.name in df.columns
181191
assert feature.id == "{}.{}.{}".format(entity_name,
182-
Granularity_pb2.Enum.Name(feature_granularity.value).lower(),
183-
feature.name)
192+
Granularity_pb2.Enum.Name(
193+
feature_granularity.value).lower(),
194+
feature.name)
184195

185196
import_spec = importer.spec
186-
assert import_spec.type == "file"
197+
assert import_spec.type == "file.csv"
187198
path = importer.remote_path if importer.require_staging else csv_path
188-
assert import_spec.sourceOptions == {"format" : "csv", "path" : path}
199+
assert import_spec.sourceOptions == {"path": path}
189200
assert import_spec.entities == [entity_name]
190201

191202
schema = import_spec.schema
@@ -204,19 +215,23 @@ def _validate_csv_importer(self,
204215
for col, field in zip(df.columns.values, schema.fields):
205216
assert col == field.name
206217
if col in feature_columns:
207-
assert field.featureId == "{}.{}.{}".format(entity_name,
208-
Granularity_pb2.Enum.Name(feature_granularity.value).lower(), col)
218+
assert field.featureId == \
219+
"{}.{}.{}".format(entity_name,
220+
Granularity_pb2.Enum.Name(
221+
feature_granularity.value).lower(),
222+
col)
209223

210224

211225
class TestHelpers:
212226
def test_create_feature(self):
213-
col = pd.Series([1]*3,dtype='int32',name="test")
227+
col = pd.Series([1] * 3, dtype='int32', name="test")
214228
expected = Feature(name="test",
215-
entity="test",
216-
granularity=Granularity.NONE,
217-
owner="person",
218-
value_type=ValueType.INT32)
219-
actual = _create_feature(col, "test", Granularity.NONE, "person", None, None)
229+
entity="test",
230+
granularity=Granularity.NONE,
231+
owner="person",
232+
value_type=ValueType.INT32)
233+
actual = _create_feature(col, "test", Granularity.NONE, "person", None,
234+
None)
220235
assert actual.id == expected.id
221236
assert actual.value_type == expected.value_type
222237
assert actual.owner == expected.owner
@@ -231,7 +246,8 @@ def test_create_feature_with_stores(self):
231246
serving_store=Datastore(id="SERVING"),
232247
warehouse_store=Datastore(id="WAREHOUSE"))
233248
actual = _create_feature(col, "test", Granularity.NONE, "person",
234-
Datastore(id="SERVING"), Datastore(id="WAREHOUSE"))
249+
Datastore(id="SERVING"),
250+
Datastore(id="WAREHOUSE"))
235251
assert actual.id == expected.id
236252
assert actual.value_type == expected.value_type
237253
assert actual.owner == expected.owner

0 commit comments

Comments
 (0)