-
Notifications
You must be signed in to change notification settings - Fork 145
Expand file tree
/
Copy pathdatasource.py
More file actions
152 lines (125 loc) · 4.83 KB
/
datasource.py
File metadata and controls
152 lines (125 loc) · 4.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright the Vortex contributors
from __future__ import annotations
from collections.abc import Iterable
from typing import TYPE_CHECKING, TypeVar, final
from ray.data import Datasource, ReadTask
from ray.data.block import BlockMetadata
from ray.data.context import DataContext
from ray.data.datasource import BaseFileMetadataProvider, DefaultFileMetadataProvider
from ray.data.datasource.path_util import (
_resolve_paths_and_filesystem, # pyright: ignore[reportPrivateUsage, reportUnknownVariableType]
)
from typing_extensions import override
from .. import open as vx_open
from ..arrow.expression import ensure_vortex_expression
from ..expr import Expr as VortexExpr
from ..type_aliases import IntoProjection
if TYPE_CHECKING:
import pandas
import pyarrow.compute as pc
T = TypeVar("T")
def partition(k: int, ls: list[T]) -> list[list[T]]:
assert k > 0
n = len(ls)
out: list[list[T]] = []
start = 0
for i in range(k):
# (n // k) * k === n + (n % k)
#
# We add that extra length to the leading sub-lists.
part_len = (n // k) + 1 if i < n % k else (n // k)
out.append(ls[start : start + part_len])
start += part_len
return out
@final
class VortexDatasource(Datasource):
"""Read a folder of Vortex files as a row-wise-partitioned table."""
def __init__(
self,
*,
url: str,
columns: IntoProjection = None,
filter: pc.Expression | VortexExpr | None = None,
batch_size: int | None = None,
meta_provider: BaseFileMetadataProvider = DefaultFileMetadataProvider(), # pyright: ignore[reportCallInDefaultInitializer]
):
super().__init__()
self._columns = columns
self._filter = filter
urls, fs = _resolve_paths_and_filesystem(url, None) # pyright: ignore[reportUnknownVariableType]
paths_and_sizes = list(
meta_provider.expand_paths(
urls,
fs, # pyright: ignore[reportUnknownArgumentType]
None,
ignore_missing_paths=False,
)
)
self._paths: list[str] = [path for path, _ in paths_and_sizes]
self._batch_size = batch_size
@override
def estimate_inmemory_data_size(self) -> int | None:
"""Return an estimate of the in-memory data size, or None if unknown.
Note that the in-memory data size may be larger than the on-disk data size.
"""
return None
@override
def get_read_tasks(
self, parallelism: int, per_task_row_limit: int | None = None, data_context: DataContext | None = None
) -> list[ReadTask]:
"""Execute the read and return read tasks.
Args:
parallelism: The requested read parallelism. The number of read
tasks should equal to this value if possible.
per_task_row_limit: The per-task row limit for the read tasks.
Returns:
A list of read tasks that can be executed to read blocks from the
datasource in parallel.
"""
return [
_read_task(
paths,
self._columns,
self._filter,
per_task_row_limit if per_task_row_limit is not None else self._batch_size,
)
for paths in partition(parallelism, self._paths)
if len(paths) > 0
]
@property
@override
def supports_distributed_reads(self) -> bool:
"""If ``False``, only launch read tasks on the driver's node."""
return True
def _read_task(
paths: list[str],
columns: IntoProjection,
filter: pc.Expression | VortexExpr | None,
batch_size: int | None,
) -> ReadTask:
if not paths:
raise ValueError("no paths specified")
files = [vx_open(path) for path in paths]
schemas = [f.dtype.to_arrow_schema() for f in files]
schema = schemas[0]
assert all(s == schema for s in schemas[1:])
num_rows = sum(len(f) for f in files)
metadata = BlockMetadata(
num_rows=num_rows,
size_bytes=None,
exec_stats=None,
input_files=paths,
)
def read() -> Iterable[pandas.DataFrame]:
# If we could serialize a PyVortexFile and a PyExpr, we could set those up earlier.
vx_filter = ensure_vortex_expression(filter, schema=schema)
for path in paths:
f = vx_open(path)
for rb in f.to_arrow(columns, expr=vx_filter, batch_size=batch_size):
# We would prefer to generate Arrow, but we run into this issue: https://github.com/apache/arrow/issues/47279
#
# yield pa.Table.from_batches([rb])
#
yield rb.to_pandas() # pyright: ignore[reportUnknownMemberType]
return ReadTask(read, metadata, schema)