Skip to content

Commit 69e4a7d

Browse files
authored
fix: Dask zero division error if parquet dataset has only one partition (feast-dev#3236)
* fix: dask zero division error if parquet dataset has only one partition Signed-off-by: Max Zwiessle <ibinbei@gmail.com> * Update file.py Signed-off-by: Max Zwiessle <ibinbei@gmail.com> * Update file.py Signed-off-by: Max Zwiessle <ibinbei@gmail.com> Signed-off-by: Max Zwiessle <ibinbei@gmail.com>
1 parent 1a446e2 commit 69e4a7d

File tree

3 files changed

+66
-6
lines changed

3 files changed

+66
-6
lines changed

sdk/python/feast/infra/offline_stores/file.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -662,14 +662,33 @@ def _drop_duplicates(
662662
created_timestamp_column: str,
663663
entity_df_event_timestamp_col: str,
664664
) -> dd.DataFrame:
665-
if created_timestamp_column:
666-
df_to_join = df_to_join.sort_values(
667-
by=created_timestamp_column, na_position="first"
668-
)
665+
column_order = df_to_join.columns
666+
667+
# try-catch block is added to deal with this issue https://github.com/dask/dask/issues/8939.
668+
# TODO(kevjumba): remove try catch when fix is merged upstream in Dask.
669+
try:
670+
if created_timestamp_column:
671+
df_to_join = df_to_join.sort_values(
672+
by=created_timestamp_column, na_position="first"
673+
)
674+
df_to_join = df_to_join.persist()
675+
676+
df_to_join = df_to_join.sort_values(by=timestamp_field, na_position="first")
669677
df_to_join = df_to_join.persist()
670678

671-
df_to_join = df_to_join.sort_values(by=timestamp_field, na_position="first")
672-
df_to_join = df_to_join.persist()
679+
except ZeroDivisionError:
680+
# Use 1 partition to get around case where everything in timestamp column is the same so the partition algorithm doesn't
681+
# try to divide by zero.
682+
if created_timestamp_column:
683+
df_to_join = df_to_join[column_order].sort_values(
684+
by=created_timestamp_column, na_position="first", npartitions=1
685+
)
686+
df_to_join = df_to_join.persist()
687+
688+
df_to_join = df_to_join[column_order].sort_values(
689+
by=timestamp_field, na_position="first", npartitions=1
690+
)
691+
df_to_join = df_to_join.persist()
673692

674693
df_to_join = df_to_join.drop_duplicates(
675694
all_join_keys + [entity_df_event_timestamp_col],

sdk/python/tests/integration/feature_repos/universal/data_sources/file.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from typing import Any, Dict, List, Optional
66

77
import pandas as pd
8+
import pyarrow as pa
9+
import pyarrow.parquet as pq
810
from minio import Minio
911
from testcontainers.core.generic import DockerContainer
1012
from testcontainers.core.waiting_utils import wait_for_logs
@@ -87,6 +89,39 @@ def teardown(self):
8789
shutil.rmtree(d)
8890

8991

92+
class FileParquetDatasetSourceCreator(FileDataSourceCreator):
93+
def create_data_source(
94+
self,
95+
df: pd.DataFrame,
96+
destination_name: str,
97+
timestamp_field="ts",
98+
created_timestamp_column="created_ts",
99+
field_mapping: Dict[str, str] = None,
100+
) -> DataSource:
101+
102+
destination_name = self.get_prefixed_table_name(destination_name)
103+
104+
dataset_path = tempfile.TemporaryDirectory(
105+
prefix=f"{self.project_name}_{destination_name}"
106+
)
107+
table = pa.Table.from_pandas(df)
108+
pq.write_to_dataset(
109+
table,
110+
base_dir=dataset_path.name,
111+
compression="snappy",
112+
format="parquet",
113+
existing_data_behavior="overwrite_or_ignore",
114+
)
115+
self.files.append(dataset_path.name)
116+
return FileSource(
117+
file_format=ParquetFormat(),
118+
path=dataset_path.name,
119+
timestamp_field=timestamp_field,
120+
created_timestamp_column=created_timestamp_column,
121+
field_mapping=field_mapping or {"ts_1": "ts"},
122+
)
123+
124+
90125
class S3FileDataSourceCreator(DataSourceCreator):
91126
f: Any
92127
minio: DockerContainer

sdk/python/tests/utils/e2e_test_validation.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
)
2828
from tests.integration.feature_repos.universal.data_sources.file import (
2929
FileDataSourceCreator,
30+
FileParquetDatasetSourceCreator,
3031
)
3132
from tests.integration.feature_repos.universal.data_sources.redshift import (
3233
RedshiftDataSourceCreator,
@@ -211,6 +212,11 @@ def make_feature_store_yaml(
211212
offline_store_creator=FileDataSourceCreator,
212213
online_store=None,
213214
),
215+
IntegrationTestRepoConfig(
216+
provider="local",
217+
offline_store_creator=FileParquetDatasetSourceCreator,
218+
online_store=None,
219+
),
214220
]
215221

216222
# Only test if this is NOT a local test

0 commit comments

Comments
 (0)