Skip to content

Commit aeb12cd

Browse files
davidheryantofeast-ci-bot
authored andcommitted
Fix BigQuery query template to retrieve training data (#182)
* Fix BigQuery query template to retrieve training data * Update expected value BigQuery template test * Use FeatureInfo to create Features in BigQueryDatasetTemplater so it's neater
1 parent d5c3809 commit aeb12cd

5 files changed

Lines changed: 39 additions & 44 deletions

File tree

core/src/main/java/feast/core/training/BigQueryDatasetTemplater.java

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,15 @@
2222
import feast.core.dao.FeatureInfoRepository;
2323
import feast.core.model.FeatureInfo;
2424
import feast.core.model.StorageInfo;
25-
import feast.specs.FeatureSpecProto.FeatureSpec;
2625
import feast.specs.StorageSpecProto.StorageSpec;
26+
import lombok.Getter;
27+
2728
import java.time.Instant;
2829
import java.time.ZoneId;
2930
import java.time.format.DateTimeFormatter;
3031
import java.time.temporal.ChronoUnit;
31-
import java.util.HashMap;
32-
import java.util.List;
33-
import java.util.Map;
34-
import java.util.NoSuchElementException;
35-
import java.util.Set;
32+
import java.util.*;
3633
import java.util.stream.Collectors;
37-
import lombok.Getter;
3834

3935
public class BigQueryDatasetTemplater {
4036
private final FeatureInfoRepository featureInfoRepository;
@@ -59,20 +55,18 @@ public BigQueryDatasetTemplater(
5955
* @param limit limit
6056
* @return SQL query for creating training table.
6157
*/
62-
public String createQuery(
63-
FeatureSet featureSet, Timestamp startDate, Timestamp endDate, long limit) {
58+
String createQuery(FeatureSet featureSet, Timestamp startDate, Timestamp endDate, long limit) {
6459
List<String> featureIds = featureSet.getFeatureIdsList();
6560
List<FeatureInfo> featureInfos = featureInfoRepository.findAllById(featureIds);
61+
Features features = new Features(featureInfos);
62+
6663
if (featureInfos.size() < featureIds.size()) {
6764
Set<String> foundFeatureIds =
6865
featureInfos.stream().map(FeatureInfo::getId).collect(Collectors.toSet());
6966
featureIds.removeAll(foundFeatureIds);
7067
throw new NoSuchElementException("features not found: " + featureIds);
7168
}
7269

73-
String tableId = getBqTableId(featureInfos.get(0));
74-
Features features = new Features(featureIds, tableId);
75-
7670
String startDateStr = formatDateString(startDate);
7771
String endDateStr = formatDateString(endDate);
7872
String limitStr = (limit != 0) ? String.valueOf(limit) : null;
@@ -90,7 +84,7 @@ private String renderTemplate(
9084
return jinjava.render(template, context);
9185
}
9286

93-
private String getBqTableId(FeatureInfo featureInfo) {
87+
private static String getBqTableId(FeatureInfo featureInfo) {
9488
StorageInfo whStorage = featureInfo.getWarehouseStore();
9589

9690
String type = whStorage.getType();
@@ -117,12 +111,9 @@ static final class Features {
117111
final List<String> columns;
118112
final String tableId;
119113

120-
public Features(List<String> featureIds, String tableId) {
121-
this.columns = featureIds.stream()
122-
.map(f -> f.replace(".", "_"))
123-
.collect(Collectors.toList());
124-
this.tableId = tableId;
114+
Features(List<FeatureInfo> featureInfos) {
115+
columns = featureInfos.stream().map(FeatureInfo::getName).collect(Collectors.toList());
116+
tableId = featureInfos.size() > 0 ? getBqTableId(featureInfos.get(0)) : "";
125117
}
126118
}
127-
128119
}

core/src/main/resources/templates/bq_training.tmpl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
SELECT
2-
{{ feature_set.tableId }}.id,
3-
{{ feature_set.tableId }}.event_timestamp
4-
{% for feature in feature_set.columns -%}
5-
,{{ feature }}
6-
{%- endfor %}
2+
id,
3+
event_timestamp{%- if feature_set.columns | length > 0 %},{%- endif %}
4+
{{ feature_set.columns | join(',') }}
75
FROM
8-
{{ feature_set.tableId }}
6+
`{{ feature_set.tableId }}`
97
WHERE event_timestamp >= TIMESTAMP("{{ start_date }}") AND event_timestamp <= TIMESTAMP(DATETIME_ADD("{{ end_date }}", INTERVAL 1 DAY))
108
{% if limit is not none -%}
119
LIMIT {{ limit }}

core/src/test/java/feast/core/training/BigQueryDatasetTemplaterTest.java

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,11 @@ public void shouldPassCorrectArgumentToTemplateEngine() {
9696
Timestamps.fromSeconds(Instant.parse("2019-01-01T00:00:00.00Z").getEpochSecond());
9797
int limit = 100;
9898
String featureId = "myentity.feature1";
99+
String featureName = "feature1";
99100
String tableId = "project.dataset.myentity";
100101

101102
when(featureInfoRespository.findAllById(any(List.class)))
102-
.thenReturn(Collections.singletonList(createFeatureInfo(featureId, tableId)));
103+
.thenReturn(Collections.singletonList(createFeatureInfo(featureId, featureName, tableId)));
103104

104105
FeatureSet fs =
105106
FeatureSet.newBuilder()
@@ -123,22 +124,25 @@ public void shouldPassCorrectArgumentToTemplateEngine() {
123124

124125
Features features = (Features) actualContext.get("feature_set");
125126
assertThat(features.getColumns().size(), equalTo(1));
126-
assertThat(features.getColumns().get(0), equalTo(featureId.replace(".", "_")));
127+
assertThat(features.getColumns().get(0), equalTo(featureName));
127128
assertThat(features.getTableId(), equalTo(tableId));
128129
}
129130

130131
@Test
131132
public void shouldRenderCorrectQuery1() throws Exception {
132133
String tableId1 = "project.dataset.myentity";
133134
String featureId1 = "myentity.feature1";
135+
String featureName1 = "feature1";
134136
String featureId2 = "myentity.feature2";
137+
String featureName2 = "feature2";
135138

136-
FeatureInfo featureInfo1 = createFeatureInfo(featureId1, tableId1);
137-
FeatureInfo featureInfo2 = createFeatureInfo(featureId2, tableId1);
139+
FeatureInfo featureInfo1 = createFeatureInfo(featureId1, featureName1, tableId1);
140+
FeatureInfo featureInfo2 = createFeatureInfo(featureId2, featureName2, tableId1);
138141

139142
String tableId2 = "project.dataset.myentity";
140143
String featureId3 = "myentity.feature3";
141-
FeatureInfo featureInfo3 = createFeatureInfo(featureId3, tableId2);
144+
String featureName3 = "feature3";
145+
FeatureInfo featureInfo3 = createFeatureInfo(featureId3, featureName3, tableId2);
142146

143147
when(featureInfoRespository.findAllById(any(List.class)))
144148
.thenReturn(Arrays.asList(featureInfo1, featureInfo2, featureInfo3));
@@ -166,8 +170,9 @@ public void shouldRenderCorrectQuery2() throws Exception {
166170

167171
String tableId = "project.dataset.myentity";
168172
String featureId = "myentity.feature1";
173+
String featureName = "feature1";
169174

170-
featureInfos.add(createFeatureInfo(featureId, tableId));
175+
featureInfos.add(createFeatureInfo(featureId, featureName, tableId));
171176
featureIds.add(featureId);
172177

173178
when(featureInfoRespository.findAllById(any(List.class))).thenReturn(featureInfos);
@@ -197,7 +202,7 @@ private void checkExpectedQuery(String query, String pathToExpQuery) throws Exce
197202
assertThat(query, equalTo(expQuery));
198203
}
199204

200-
private FeatureInfo createFeatureInfo(String id, String tableId) {
205+
private FeatureInfo createFeatureInfo(String featureId, String featureName, String tableId) {
201206
StorageSpec storageSpec =
202207
StorageSpec.newBuilder()
203208
.setId("BQ")
@@ -209,11 +214,12 @@ private FeatureInfo createFeatureInfo(String id, String tableId) {
209214

210215
FeatureSpec fs =
211216
FeatureSpec.newBuilder()
212-
.setId(id)
217+
.setId(featureId)
218+
.setName(featureName)
213219
.setDataStores(DataStores.newBuilder().setWarehouse(DataStore.newBuilder().setId("BQ")))
214220
.build();
215221

216-
EntitySpec entitySpec = EntitySpec.newBuilder().setName(id.split("\\.")[0]).build();
222+
EntitySpec entitySpec = EntitySpec.newBuilder().setName(featureId.split("\\.")[0]).build();
217223
EntityInfo entityInfo = new EntityInfo(entitySpec);
218224
return new FeatureInfo(fs, entityInfo, null, storageInfo, null);
219225
}
Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
SELECT
2-
project.dataset.myentity.id,
3-
project.dataset.myentity.event_timestamp ,
4-
myentity_feature1,
5-
myentity_feature2,
6-
myentity_feature3
2+
id,
3+
event_timestamp,
4+
feature1,
5+
feature2,
6+
feature3
77
FROM
8-
project.dataset.myentity
8+
`project.dataset.myentity`
99
WHERE
1010
event_timestamp >= TIMESTAMP("2018-01-02")
1111
AND event_timestamp <= TIMESTAMP(DATETIME_ADD("2018-01-30", INTERVAL 1 DAY)) LIMIT 100
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
SELECT
2-
project.dataset.myentity.id,
3-
project.dataset.myentity.event_timestamp ,
4-
myentity_feature1
2+
id,
3+
event_timestamp,
4+
feature1
55
FROM
6-
project.dataset.myentity
6+
`project.dataset.myentity`
77
WHERE
88
event_timestamp >= TIMESTAMP("2018-01-02")
99
AND event_timestamp <= TIMESTAMP(DATETIME_ADD("2018-01-30", INTERVAL 1 DAY)) LIMIT 1000

0 commit comments

Comments
 (0)