Skip to content

Commit 2bc78f9

Browse files
committed
fix: Identify s3/remote uri path correctly
Signed-off-by: ntkathole <nikhilkathole2683@gmail.com>
1 parent f3a24de commit 2bc78f9

File tree

4 files changed

+44
-17
lines changed

4 files changed

+44
-17
lines changed

sdk/python/feast/infra/offline_stores/dask.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,9 @@ def persist(
100100
# Check if the specified location already exists.
101101
if not allow_overwrite and os.path.exists(storage.file_options.uri):
102102
raise SavedDatasetLocationAlreadyExists(location=storage.file_options.uri)
103-
104-
if not Path(storage.file_options.uri).is_absolute():
105-
absolute_path = Path(self.repo_path) / storage.file_options.uri
106-
else:
107-
absolute_path = Path(storage.file_options.uri)
103+
absolute_path = FileSource.get_uri_for_file_path(
104+
repo_path=self.repo_path, uri=storage.file_options.uri
105+
)
108106

109107
filesystem, path = FileSource.create_filesystem_and_path(
110108
str(absolute_path),

sdk/python/feast/infra/offline_stores/duckdb.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,9 @@ def _write_data_source(
5151

5252
file_options = data_source.file_options
5353

54-
if not Path(file_options.uri).is_absolute():
55-
absolute_path = Path(repo_path) / file_options.uri
56-
else:
57-
absolute_path = Path(file_options.uri)
54+
absolute_path = FileSource.get_uri_for_file_path(
55+
repo_path=repo_path, uri=file_options.uri
56+
)
5857

5958
if (
6059
mode == "overwrite"

sdk/python/feast/infra/offline_stores/file_source.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from pathlib import Path
22
from typing import Callable, Dict, Iterable, List, Optional, Tuple
3+
from urllib.parse import urlparse
34

45
import pyarrow
56
from packaging import version
@@ -154,17 +155,21 @@ def validate(self, config: RepoConfig):
154155
def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]:
155156
return type_map.pa_to_feast_value_type
156157

158+
@staticmethod
159+
def get_uri_for_file_path(repo_path, uri):
160+
parsed_uri = urlparse(uri)
161+
if parsed_uri.scheme and parsed_uri.netloc:
162+
return uri # Keep remote URIs as they are
163+
if repo_path is not None and not Path(uri).is_absolute():
164+
return str(Path(repo_path) / uri)
165+
return str(Path(uri))
166+
157167
def get_table_column_names_and_types(
158168
self, config: RepoConfig
159169
) -> Iterable[Tuple[str, str]]:
160-
if (
161-
config.repo_path is not None
162-
and not Path(self.file_options.uri).is_absolute()
163-
):
164-
absolute_path = config.repo_path / self.file_options.uri
165-
else:
166-
absolute_path = Path(self.file_options.uri)
167-
170+
absolute_path = self.get_uri_for_file_path(
171+
repo_path=config.repo_path, uri=self.file_options.uri
172+
)
168173
filesystem, path = FileSource.create_filesystem_and_path(
169174
str(absolute_path), self.file_options.s3_endpoint_override
170175
)

sdk/python/tests/unit/infra/offline_stores/test_offline_store.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from pathlib import Path
12
from typing import List, Optional
23
from unittest.mock import MagicMock, patch
34

@@ -21,6 +22,7 @@
2122
TrinoRetrievalJob,
2223
)
2324
from feast.infra.offline_stores.dask import DaskRetrievalJob
25+
from feast.infra.offline_stores.file_source import FileSource
2426
from feast.infra.offline_stores.offline_store import RetrievalJob, RetrievalMetadata
2527
from feast.infra.offline_stores.redshift import (
2628
RedshiftOfflineStoreConfig,
@@ -246,3 +248,26 @@ def test_to_arrow_timeout(retrieval_job, timeout: Optional[int]):
246248
with patch.object(retrieval_job, "_to_arrow_internal") as mock_to_arrow_internal:
247249
retrieval_job.to_arrow(timeout=timeout)
248250
mock_to_arrow_internal.assert_called_once_with(timeout=timeout)
251+
252+
253+
@pytest.mark.parametrize(
254+
"repo_path, uri, expected",
255+
[
256+
# Remote URI - Should return as-is
257+
(
258+
Path("/some/repo"),
259+
"s3://bucket-name/file.parquet",
260+
"s3://bucket-name/file.parquet",
261+
),
262+
# Absolute Path - Should return as-is
263+
(Path("/some/repo"), "/abs/path/file.parquet", "/abs/path/file.parquet"),
264+
# Relative Path with repo_path - Should combine
265+
(Path("/some/repo"), "data/output.parquet", "/some/repo/data/output.parquet"),
266+
# Relative Path without repo_path - Should return absolute path
267+
(None, "C:/path/to/file.parquet", "C:/path/to/file.parquet"),
268+
],
269+
ids=["s3_uri", "absolute_path", "relative_path", "windows_path"],
270+
)
271+
def test_get_uri_for_file_path(repo_path, uri, expected):
272+
result = FileSource.get_uri_for_file_path(repo_path=repo_path, uri=uri)
273+
assert result == expected, f"Expected {expected}, but got {result}"

0 commit comments

Comments
 (0)