1717package feast .core .model ;
1818
1919import com .google .protobuf .Duration ;
20+ import com .google .protobuf .InvalidProtocolBufferException ;
2021import com .google .protobuf .Timestamp ;
2122import feast .core .FeatureSetProto ;
2223import feast .core .FeatureSetProto .EntitySpec ;
2324import feast .core .FeatureSetProto .FeatureSetMeta ;
2425import feast .core .FeatureSetProto .FeatureSetSpec ;
2526import feast .core .FeatureSetProto .FeatureSetStatus ;
2627import feast .core .FeatureSetProto .FeatureSpec ;
27- import feast .types .ValueProto .ValueType ;
28+ import feast .types .ValueProto .ValueType . Enum ;
2829import java .util .ArrayList ;
2930import java .util .HashMap ;
3031import java .util .HashSet ;
4748import org .apache .commons .lang3 .builder .HashCodeBuilder ;
4849import org .hibernate .annotations .Fetch ;
4950import org .hibernate .annotations .FetchMode ;
51+ import org .tensorflow .metadata .v0 .BoolDomain ;
52+ import org .tensorflow .metadata .v0 .FeaturePresence ;
53+ import org .tensorflow .metadata .v0 .FeaturePresenceWithinGroup ;
54+ import org .tensorflow .metadata .v0 .FixedShape ;
55+ import org .tensorflow .metadata .v0 .FloatDomain ;
56+ import org .tensorflow .metadata .v0 .ImageDomain ;
57+ import org .tensorflow .metadata .v0 .IntDomain ;
58+ import org .tensorflow .metadata .v0 .NaturalLanguageDomain ;
59+ import org .tensorflow .metadata .v0 .StringDomain ;
60+ import org .tensorflow .metadata .v0 .StructDomain ;
61+ import org .tensorflow .metadata .v0 .TimeDomain ;
62+ import org .tensorflow .metadata .v0 .TimeOfDayDomain ;
63+ import org .tensorflow .metadata .v0 .URLDomain ;
64+ import org .tensorflow .metadata .v0 .ValueCount ;
5065
5166@ Getter
5267@ Setter
@@ -157,23 +172,23 @@ public static FeatureSet fromProto(FeatureSetProto.FeatureSet featureSetProto) {
157172 FeatureSetSpec featureSetSpec = featureSetProto .getSpec ();
158173 Source source = Source .fromProto (featureSetSpec .getSource ());
159174
160- List <Field > features = new ArrayList <>();
161- for (FeatureSpec feature : featureSetSpec .getFeaturesList ()) {
162- features .add (new Field (feature . getName (), feature . getValueType () ));
175+ List <Field > featureSpecs = new ArrayList <>();
176+ for (FeatureSpec featureSpec : featureSetSpec .getFeaturesList ()) {
177+ featureSpecs .add (new Field (featureSpec ));
163178 }
164179
165- List <Field > entities = new ArrayList <>();
166- for (EntitySpec entity : featureSetSpec .getEntitiesList ()) {
167- entities .add (new Field (entity . getName (), entity . getValueType () ));
180+ List <Field > entitySpecs = new ArrayList <>();
181+ for (EntitySpec entitySpec : featureSetSpec .getEntitiesList ()) {
182+ entitySpecs .add (new Field (entitySpec ));
168183 }
169184
170185 return new FeatureSet (
171186 featureSetProto .getSpec ().getName (),
172187 featureSetProto .getSpec ().getProject (),
173188 featureSetProto .getSpec ().getVersion (),
174189 featureSetSpec .getMaxAge ().getSeconds (),
175- entities ,
176- features ,
190+ entitySpecs ,
191+ featureSpecs ,
177192 source ,
178193 featureSetProto .getMeta ().getStatus ());
179194 }
@@ -202,24 +217,21 @@ public void addFeature(Field field) {
202217 features .add (field );
203218 }
204219
205- public FeatureSetProto .FeatureSet toProto () {
220+ public FeatureSetProto .FeatureSet toProto () throws InvalidProtocolBufferException {
206221 List <EntitySpec > entitySpecs = new ArrayList <>();
207- for (Field entity : entities ) {
208- entitySpecs .add (
209- EntitySpec .newBuilder ()
210- .setName (entity .getName ())
211- .setValueType (ValueType .Enum .valueOf (entity .getType ()))
212- .build ());
222+ for (Field entityField : entities ) {
223+ EntitySpec .Builder entitySpecBuilder = EntitySpec .newBuilder ();
224+ setEntitySpecFields (entitySpecBuilder , entityField );
225+ entitySpecs .add (entitySpecBuilder .build ());
213226 }
214227
215228 List <FeatureSpec > featureSpecs = new ArrayList <>();
216- for (Field feature : features ) {
217- featureSpecs .add (
218- FeatureSpec .newBuilder ()
219- .setName (feature .getName ())
220- .setValueType (ValueType .Enum .valueOf (feature .getType ()))
221- .build ());
229+ for (Field featureField : features ) {
230+ FeatureSpec .Builder featureSpecBuilder = FeatureSpec .newBuilder ();
231+ setFeatureSpecFields (featureSpecBuilder , featureField );
232+ featureSpecs .add (featureSpecBuilder .build ());
222233 }
234+
223235 FeatureSetMeta .Builder meta =
224236 FeatureSetMeta .newBuilder ()
225237 .setCreatedTimestamp (
@@ -239,6 +251,108 @@ public FeatureSetProto.FeatureSet toProto() {
239251 return FeatureSetProto .FeatureSet .newBuilder ().setMeta (meta ).setSpec (spec ).build ();
240252 }
241253
254+ // setEntitySpecFields and setFeatureSpecFields methods contain duplicated code because
255+ // Feast internally treat EntitySpec and FeatureSpec as Field class. However, the proto message
256+ // builder for EntitySpec and FeatureSpec are of different class.
257+ @ SuppressWarnings ("DuplicatedCode" )
258+ private void setEntitySpecFields (EntitySpec .Builder entitySpecBuilder , Field entityField )
259+ throws InvalidProtocolBufferException {
260+ entitySpecBuilder
261+ .setName (entityField .getName ())
262+ .setValueType (Enum .valueOf (entityField .getType ()));
263+
264+ if (entityField .getPresence () != null ) {
265+ entitySpecBuilder .setPresence (FeaturePresence .parseFrom (entityField .getPresence ()));
266+ } else if (entityField .getGroupPresence () != null ) {
267+ entitySpecBuilder
268+ .setGroupPresence (FeaturePresenceWithinGroup .parseFrom (entityField .getGroupPresence ()));
269+ }
270+
271+ if (entityField .getShape () != null ) {
272+ entitySpecBuilder .setShape (FixedShape .parseFrom (entityField .getShape ()));
273+ } else if (entityField .getValueCount () != null ) {
274+ entitySpecBuilder .setValueCount (ValueCount .parseFrom (entityField .getValueCount ()));
275+ }
276+
277+ if (entityField .getDomain () != null ) {
278+ entitySpecBuilder .setDomain (entityField .getDomain ());
279+ } else if (entityField .getIntDomain () != null ) {
280+ entitySpecBuilder .setIntDomain (IntDomain .parseFrom (entityField .getIntDomain ()));
281+ } else if (entityField .getFloatDomain () != null ) {
282+ entitySpecBuilder .setFloatDomain (FloatDomain .parseFrom (entityField .getFloatDomain ()));
283+ } else if (entityField .getStringDomain () != null ) {
284+ entitySpecBuilder .setStringDomain (StringDomain .parseFrom (entityField .getStringDomain ()));
285+ } else if (entityField .getBoolDomain () != null ) {
286+ entitySpecBuilder .setBoolDomain (BoolDomain .parseFrom (entityField .getBoolDomain ()));
287+ } else if (entityField .getStructDomain () != null ) {
288+ entitySpecBuilder .setStructDomain (StructDomain .parseFrom (entityField .getStructDomain ()));
289+ } else if (entityField .getNaturalLanguageDomain () != null ) {
290+ entitySpecBuilder .setNaturalLanguageDomain (
291+ NaturalLanguageDomain .parseFrom (entityField .getNaturalLanguageDomain ()));
292+ } else if (entityField .getImageDomain () != null ) {
293+ entitySpecBuilder .setImageDomain (ImageDomain .parseFrom (entityField .getImageDomain ()));
294+ } else if (entityField .getMidDomain () != null ) {
295+ entitySpecBuilder .setIntDomain (IntDomain .parseFrom (entityField .getIntDomain ()));
296+ } else if (entityField .getUrlDomain () != null ) {
297+ entitySpecBuilder .setUrlDomain (URLDomain .parseFrom (entityField .getUrlDomain ()));
298+ } else if (entityField .getTimeDomain () != null ) {
299+ entitySpecBuilder .setTimeDomain (TimeDomain .parseFrom (entityField .getTimeDomain ()));
300+ } else if (entityField .getTimeOfDayDomain () != null ) {
301+ entitySpecBuilder
302+ .setTimeOfDayDomain (TimeOfDayDomain .parseFrom (entityField .getTimeOfDayDomain ()));
303+ }
304+ }
305+
306+ // Refer to setEntitySpecFields method for the reason for code duplication.
307+ @ SuppressWarnings ("DuplicatedCode" )
308+ private void setFeatureSpecFields (FeatureSpec .Builder featureSpecBuilder , Field featureField )
309+ throws InvalidProtocolBufferException {
310+ featureSpecBuilder
311+ .setName (featureField .getName ())
312+ .setValueType (Enum .valueOf (featureField .getType ()));
313+
314+ if (featureField .getPresence () != null ) {
315+ featureSpecBuilder .setPresence (FeaturePresence .parseFrom (featureField .getPresence ()));
316+ } else if (featureField .getGroupPresence () != null ) {
317+ featureSpecBuilder
318+ .setGroupPresence (FeaturePresenceWithinGroup .parseFrom (featureField .getGroupPresence ()));
319+ }
320+
321+ if (featureField .getShape () != null ) {
322+ featureSpecBuilder .setShape (FixedShape .parseFrom (featureField .getShape ()));
323+ } else if (featureField .getValueCount () != null ) {
324+ featureSpecBuilder .setValueCount (ValueCount .parseFrom (featureField .getValueCount ()));
325+ }
326+
327+ if (featureField .getDomain () != null ) {
328+ featureSpecBuilder .setDomain (featureField .getDomain ());
329+ } else if (featureField .getIntDomain () != null ) {
330+ featureSpecBuilder .setIntDomain (IntDomain .parseFrom (featureField .getIntDomain ()));
331+ } else if (featureField .getFloatDomain () != null ) {
332+ featureSpecBuilder .setFloatDomain (FloatDomain .parseFrom (featureField .getFloatDomain ()));
333+ } else if (featureField .getStringDomain () != null ) {
334+ featureSpecBuilder .setStringDomain (StringDomain .parseFrom (featureField .getStringDomain ()));
335+ } else if (featureField .getBoolDomain () != null ) {
336+ featureSpecBuilder .setBoolDomain (BoolDomain .parseFrom (featureField .getBoolDomain ()));
337+ } else if (featureField .getStructDomain () != null ) {
338+ featureSpecBuilder .setStructDomain (StructDomain .parseFrom (featureField .getStructDomain ()));
339+ } else if (featureField .getNaturalLanguageDomain () != null ) {
340+ featureSpecBuilder .setNaturalLanguageDomain (
341+ NaturalLanguageDomain .parseFrom (featureField .getNaturalLanguageDomain ()));
342+ } else if (featureField .getImageDomain () != null ) {
343+ featureSpecBuilder .setImageDomain (ImageDomain .parseFrom (featureField .getImageDomain ()));
344+ } else if (featureField .getMidDomain () != null ) {
345+ featureSpecBuilder .setIntDomain (IntDomain .parseFrom (featureField .getIntDomain ()));
346+ } else if (featureField .getUrlDomain () != null ) {
347+ featureSpecBuilder .setUrlDomain (URLDomain .parseFrom (featureField .getUrlDomain ()));
348+ } else if (featureField .getTimeDomain () != null ) {
349+ featureSpecBuilder .setTimeDomain (TimeDomain .parseFrom (featureField .getTimeDomain ()));
350+ } else if (featureField .getTimeOfDayDomain () != null ) {
351+ featureSpecBuilder
352+ .setTimeOfDayDomain (TimeOfDayDomain .parseFrom (featureField .getTimeOfDayDomain ()));
353+ }
354+ }
355+
242356 /**
243357 * Checks if the given featureSet's schema and source has is different from this one.
244358 *
0 commit comments