Skip to content

Commit b76ee66

Browse files
committed
Make batch_source optional in PushSource (#5440)
Signed-off-by: snehsuresh <snehsuresh02@gmail.com>
1 parent 99afd6d commit b76ee66

File tree

2 files changed

+28
-8
lines changed

2 files changed

+28
-8
lines changed

sdk/python/feast/data_source.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -764,13 +764,13 @@ class PushSource(DataSource):
764764

765765
# TODO(adchia): consider adding schema here in case where Feast manages pushing events to the offline store
766766
# TODO(adchia): consider a "mode" to support pushing raw vs transformed events
767-
batch_source: DataSource
767+
batch_source: Optional[DataSource] = None
768768

769769
def __init__(
770770
self,
771771
*,
772772
name: str,
773-
batch_source: DataSource,
773+
batch_source: Optional[DataSource] = None,
774774
description: Optional[str] = "",
775775
tags: Optional[Dict[str, str]] = None,
776776
owner: Optional[str] = "",
@@ -815,8 +815,12 @@ def get_table_column_names_and_types(
815815

816816
@staticmethod
817817
def from_proto(data_source: DataSourceProto):
818-
assert data_source.HasField("batch_source")
819-
batch_source = DataSource.from_proto(data_source.batch_source)
818+
# assert data_source.HasField("batch_source")
819+
batch_source = (
820+
DataSource.from_proto(data_source.batch_source)
821+
if data_source.HasField("batch_source")
822+
else None
823+
)
820824

821825
return PushSource(
822826
name=data_source.name,
@@ -827,19 +831,19 @@ def from_proto(data_source: DataSourceProto):
827831
)
828832

829833
def to_proto(self) -> DataSourceProto:
830-
batch_source_proto = None
831-
if self.batch_source:
832-
batch_source_proto = self.batch_source.to_proto()
834+
# batch_source_proto = None
833835

834836
data_source_proto = DataSourceProto(
835837
name=self.name,
836838
type=DataSourceProto.PUSH_SOURCE,
837839
description=self.description,
838840
tags=self.tags,
839841
owner=self.owner,
840-
batch_source=batch_source_proto,
841842
)
842843

844+
if self.batch_source:
845+
data_source_proto.batch_source.MergeFrom(self.batch_source.to_proto())
846+
843847
return data_source_proto
844848

845849
def get_table_query_string(self) -> str:

sdk/python/tests/unit/test_data_sources.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,22 @@ def test_push_with_batch():
3030
assert push_source.batch_source.name == push_source_unproto.batch_source.name
3131

3232

33+
def test_push_source_without_batch_source():
34+
# Create PushSource with no batch_source
35+
push_source = PushSource(name="test_push_source")
36+
37+
# Convert to proto
38+
push_source_proto = push_source.to_proto()
39+
40+
# Assert batch_source is not present in proto
41+
assert not push_source_proto.HasField("batch_source")
42+
43+
# Deserialize and check again
44+
push_source_unproto = PushSource.from_proto(push_source_proto)
45+
assert push_source_unproto.batch_source is None
46+
assert push_source_unproto.name == "test_push_source"
47+
48+
3349
def test_request_source_primitive_type_to_proto():
3450
schema = [
3551
Field(name="f1", dtype=Float32),

0 commit comments

Comments
 (0)