Skip to content

Commit 5483037

Browse files
chore[array]: replace to_canonical with to execute. (#6061)
Signed-off-by: Joe Isaacs <joe.isaacs@live.co.uk>
1 parent 0f3fbad commit 5483037

51 files changed

Lines changed: 619 additions & 307 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

encodings/alp/src/alp_rd/array.rs

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ use vortex_array::IntoArray;
1818
use vortex_array::Precision;
1919
use vortex_array::ProstMetadata;
2020
use vortex_array::SerializeMetadata;
21-
use vortex_array::ToCanonical;
2221
use vortex_array::arrays::PrimitiveArray;
2322
use vortex_array::buffer::BufferHandle;
2423
use vortex_array::patches::Patches;
@@ -45,8 +44,9 @@ use vortex_error::VortexResult;
4544
use vortex_error::vortex_bail;
4645
use vortex_error::vortex_ensure;
4746
use vortex_error::vortex_err;
47+
use vortex_mask::Mask;
4848

49-
use crate::alp_rd::alp_rd_decode;
49+
use crate::alp_rd_decode;
5050

5151
vtable!(ALPRD);
5252

@@ -238,13 +238,19 @@ impl VTable for ALPRDVTable {
238238
Ok(())
239239
}
240240

241-
fn execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult<Canonical> {
242-
let left_parts = array.left_parts().to_primitive();
243-
let right_parts = array.right_parts().to_primitive();
241+
fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<Canonical> {
242+
let left_parts = array.left_parts().clone().execute::<PrimitiveArray>(ctx)?;
243+
let right_parts = array.right_parts().clone().execute::<PrimitiveArray>(ctx)?;
244244

245245
// Decode the left_parts using our builtin dictionary.
246246
let left_parts_dict = array.left_parts_dictionary();
247247

248+
let validity = array
249+
.left_parts()
250+
.validity()?
251+
.to_array(array.len())
252+
.execute::<Mask>(ctx)?;
253+
248254
let decoded_array = if array.is_f32() {
249255
PrimitiveArray::new(
250256
alp_rd_decode::<f32>(
@@ -253,8 +259,9 @@ impl VTable for ALPRDVTable {
253259
array.right_bit_width,
254260
right_parts.into_buffer_mut::<u32>(),
255261
array.left_parts_patches(),
256-
),
257-
Validity::copy_from_array(array.as_ref()),
262+
ctx,
263+
)?,
264+
Validity::from_mask(validity, array.dtype().nullability()),
258265
)
259266
} else {
260267
PrimitiveArray::new(
@@ -264,8 +271,9 @@ impl VTable for ALPRDVTable {
264271
array.right_bit_width,
265272
right_parts.into_buffer_mut::<u64>(),
266273
array.left_parts_patches(),
267-
),
268-
Validity::copy_from_array(array.as_ref()),
274+
ctx,
275+
)?,
276+
Validity::from_mask(validity, array.dtype().nullability()),
269277
)
270278
};
271279

encodings/alp/src/alp_rd/mod.rs

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#![allow(clippy::cast_possible_truncation)]
55

66
pub use array::*;
7+
use vortex_array::ExecutionCtx;
78
use vortex_array::IntoArray;
89
use vortex_array::patches::Patches;
910
use vortex_array::validity::Validity;
@@ -22,7 +23,6 @@ use num_traits::One;
2223
use num_traits::PrimInt;
2324
use rustc_hash::FxBuildHasher;
2425
use vortex_array::Array;
25-
use vortex_array::ToCanonical;
2626
use vortex_array::arrays::PrimitiveArray;
2727
use vortex_array::vtable::ValidityHelper;
2828
use vortex_buffer::Buffer;
@@ -31,6 +31,7 @@ use vortex_dtype::DType;
3131
use vortex_dtype::NativePType;
3232
use vortex_dtype::match_each_integer_ptype;
3333
use vortex_error::VortexExpect;
34+
use vortex_error::VortexResult;
3435
use vortex_error::vortex_panic;
3536
use vortex_utils::aliases::hash_map::HashMap;
3637

@@ -290,7 +291,8 @@ pub fn alp_rd_decode<T: ALPRDFloat>(
290291
right_bit_width: u8,
291292
right_parts: BufferMut<T::UINT>,
292293
left_parts_patches: Option<&Patches>,
293-
) -> Buffer<T> {
294+
ctx: &mut ExecutionCtx,
295+
) -> VortexResult<Buffer<T>> {
294296
if left_parts.len() != right_parts.len() {
295297
vortex_panic!("alp_rd_decode: left_parts.len != right_parts.len");
296298
}
@@ -304,19 +306,45 @@ pub fn alp_rd_decode<T: ALPRDFloat>(
304306

305307
// Apply any patches
306308
if let Some(patches) = left_parts_patches {
307-
let indices = patches.indices().to_primitive();
308-
let patch_values = patches.values().to_primitive();
309-
match_each_integer_ptype!(indices.ptype(), |T| {
310-
indices
311-
.as_slice::<T>()
312-
.iter()
313-
.copied()
314-
.map(|idx| idx - patches.offset() as T)
315-
.zip(patch_values.as_slice::<u16>().iter())
316-
.for_each(|(idx, v)| values[idx as usize] = *v);
317-
})
309+
let indices = patches.indices().clone().execute::<PrimitiveArray>(ctx)?;
310+
let patch_values = patches.values().clone().execute::<PrimitiveArray>(ctx)?;
311+
alp_rd_apply_patches(&mut values, &indices, &patch_values, patches.offset());
318312
}
319313

314+
// Shift the left-parts and add in the right-parts.
315+
Ok(alp_rd_decode_core(
316+
left_parts_dict,
317+
right_bit_width,
318+
right_parts,
319+
values,
320+
))
321+
}
322+
323+
/// Apply patches to the decoded left-parts values.
324+
fn alp_rd_apply_patches(
325+
values: &mut BufferMut<u16>,
326+
indices: &PrimitiveArray,
327+
patch_values: &PrimitiveArray,
328+
offset: usize,
329+
) {
330+
match_each_integer_ptype!(indices.ptype(), |T| {
331+
indices
332+
.as_slice::<T>()
333+
.iter()
334+
.copied()
335+
.map(|idx| idx - offset as T)
336+
.zip(patch_values.as_slice::<u16>().iter())
337+
.for_each(|(idx, v)| values[idx as usize] = *v);
338+
})
339+
}
340+
341+
/// Core decode logic shared between `alp_rd_decode` and `execute_alp_rd_decode`.
342+
fn alp_rd_decode_core<T: ALPRDFloat>(
343+
_left_parts_dict: &[u16],
344+
right_bit_width: u8,
345+
right_parts: BufferMut<T::UINT>,
346+
values: BufferMut<u16>,
347+
) -> Buffer<T> {
320348
// Shift the left-parts and add in the right-parts.
321349
let mut index = 0;
322350
right_parts

encodings/datetime-parts/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,5 @@ vortex-scalar = { workspace = true }
2929
[dev-dependencies]
3030
rstest = { workspace = true }
3131
vortex-array = { workspace = true, features = ["_test-harness"] }
32+
vortex-error = { workspace = true }
33+
vortex-session = { workspace = true }

encodings/datetime-parts/src/array.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,8 @@ impl VTable for DateTimePartsVTable {
155155
Ok(())
156156
}
157157

158-
fn execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult<Canonical> {
159-
Ok(Canonical::Extension(decode_to_temporal(array).into()))
158+
fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<Canonical> {
159+
Ok(Canonical::Extension(decode_to_temporal(array, ctx)?.into()))
160160
}
161161

162162
fn reduce_parent(

encodings/datetime-parts/src/canonical.rs

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
use num_traits::AsPrimitive;
5+
use vortex_array::ExecutionCtx;
56
use vortex_array::IntoArray;
6-
use vortex_array::ToCanonical;
77
use vortex_array::arrays::PrimitiveArray;
88
use vortex_array::arrays::TemporalArray;
99
use vortex_array::compute::cast;
@@ -15,14 +15,18 @@ use vortex_dtype::datetime::TemporalMetadata;
1515
use vortex_dtype::datetime::TimeUnit;
1616
use vortex_dtype::match_each_integer_ptype;
1717
use vortex_error::VortexExpect as _;
18+
use vortex_error::VortexResult;
1819
use vortex_error::vortex_panic;
1920

2021
use crate::DateTimePartsArray;
2122

2223
/// Decode an [Array] into a [TemporalArray].
2324
///
2425
/// Enforces that the passed array is actually a [DateTimePartsArray] with proper metadata.
25-
pub fn decode_to_temporal(array: &DateTimePartsArray) -> TemporalArray {
26+
pub fn decode_to_temporal(
27+
array: &DateTimePartsArray,
28+
ctx: &mut ExecutionCtx,
29+
) -> VortexResult<TemporalArray> {
2630
let DType::Extension(ext) = array.dtype().clone() else {
2731
vortex_panic!(ComputeError: "expected dtype to be DType::Extension variant")
2832
};
@@ -44,7 +48,7 @@ pub fn decode_to_temporal(array: &DateTimePartsArray) -> TemporalArray {
4448
&DType::Primitive(PType::I64, array.dtype().nullability()),
4549
)
4650
.vortex_expect("must be able to cast days to i64")
47-
.to_primitive();
51+
.execute::<PrimitiveArray>(ctx)?;
4852

4953
// We start with the days component, which is always present.
5054
// And then add the seconds and subseconds components.
@@ -64,7 +68,7 @@ pub fn decode_to_temporal(array: &DateTimePartsArray) -> TemporalArray {
6468
*v += seconds;
6569
}
6670
} else {
67-
let seconds_buf = array.seconds().to_primitive();
71+
let seconds_buf = array.seconds().clone().execute::<PrimitiveArray>(ctx)?;
6872
match_each_integer_ptype!(seconds_buf.ptype(), |S| {
6973
for (v, second) in values.iter_mut().zip(seconds_buf.as_slice::<S>()) {
7074
let second: i64 = second.as_();
@@ -82,7 +86,7 @@ pub fn decode_to_temporal(array: &DateTimePartsArray) -> TemporalArray {
8286
*v += subseconds;
8387
}
8488
} else {
85-
let subseconds_buf = array.subseconds().to_primitive();
89+
let subseconds_buf = array.subseconds().clone().execute::<PrimitiveArray>(ctx)?;
8690
match_each_integer_ptype!(subseconds_buf.ptype(), |S| {
8791
for (v, subseconds) in values.iter_mut().zip(subseconds_buf.as_slice::<S>()) {
8892
let subseconds: i64 = subseconds.as_();
@@ -91,17 +95,18 @@ pub fn decode_to_temporal(array: &DateTimePartsArray) -> TemporalArray {
9195
});
9296
}
9397

94-
TemporalArray::new_timestamp(
98+
Ok(TemporalArray::new_timestamp(
9599
PrimitiveArray::new(values.freeze(), Validity::copy_from_array(array.as_ref()))
96100
.into_array(),
97101
temporal_metadata.time_unit(),
98102
temporal_metadata.time_zone().map(ToString::to_string),
99-
)
103+
))
100104
}
101105

102106
#[cfg(test)]
103107
mod test {
104108
use rstest::rstest;
109+
use vortex_array::ExecutionCtx;
105110
use vortex_array::IntoArray;
106111
use vortex_array::ToCanonical;
107112
use vortex_array::arrays::PrimitiveArray;
@@ -111,6 +116,8 @@ mod test {
111116
use vortex_array::vtable::ValidityHelper;
112117
use vortex_buffer::buffer;
113118
use vortex_dtype::datetime::TimeUnit;
119+
use vortex_error::VortexResult;
120+
use vortex_session::VortexSession;
114121

115122
use crate::DateTimePartsArray;
116123
use crate::canonical::decode_to_temporal;
@@ -120,7 +127,7 @@ mod test {
120127
#[case(Validity::AllValid)]
121128
#[case(Validity::AllInvalid)]
122129
#[case(Validity::from_iter([true, true, false, false, true, true]))]
123-
fn test_decode_to_temporal(#[case] validity: Validity) {
130+
fn test_decode_to_temporal(#[case] validity: Validity) -> VortexResult<()> {
124131
let milliseconds = PrimitiveArray::new(
125132
buffer![
126133
86_400i64, // element with only day component
@@ -144,11 +151,13 @@ mod test {
144151
validity.to_mask(date_times.len())
145152
);
146153

147-
let primitive_values = decode_to_temporal(&date_times)
154+
let mut ctx = ExecutionCtx::new(VortexSession::empty());
155+
let primitive_values = decode_to_temporal(&date_times, &mut ctx)?
148156
.temporal_values()
149157
.to_primitive();
150158

151159
assert_arrays_eq!(primitive_values, milliseconds);
152160
assert_eq!(primitive_values.validity(), &validity);
161+
Ok(())
153162
}
154163
}

encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ use vortex_array::IntoArray;
2020
use vortex_array::Precision;
2121
use vortex_array::ProstMetadata;
2222
use vortex_array::SerializeMetadata;
23-
use vortex_array::ToCanonical;
2423
use vortex_array::arrays::DecimalArray;
24+
use vortex_array::arrays::PrimitiveArray;
2525
use vortex_array::buffer::BufferHandle;
2626
use vortex_array::serde::ArrayChildren;
2727
use vortex_array::stats::ArrayStats;
@@ -138,8 +138,8 @@ impl VTable for DecimalBytePartsVTable {
138138
}))
139139
}
140140

141-
fn execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult<Canonical> {
142-
to_canonical_decimal(array)
141+
fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<Canonical> {
142+
to_canonical_decimal(array, ctx)
143143
}
144144
}
145145

@@ -234,9 +234,12 @@ impl BaseArrayVTable<DecimalBytePartsVTable> for DecimalBytePartsVTable {
234234
}
235235

236236
/// Converts a DecimalBytePartsArray to its canonical DecimalArray representation.
237-
fn to_canonical_decimal(array: &DecimalBytePartsArray) -> VortexResult<Canonical> {
237+
fn to_canonical_decimal(
238+
array: &DecimalBytePartsArray,
239+
ctx: &mut ExecutionCtx,
240+
) -> VortexResult<Canonical> {
238241
// TODO(joe): support parts len != 1
239-
let prim = array.msp.to_primitive();
242+
let prim = array.msp.clone().execute::<PrimitiveArray>(ctx)?;
240243
// Depending on the decimal type and the min/max of the primitive array we can choose
241244
// the correct buffer size
242245

encodings/fastlanes/benches/canonicalize_bench.rs

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use std::sync::LazyLock;
5+
46
use divan::Bencher;
57
use rand::SeedableRng;
68
use rand::prelude::StdRng;
79
use vortex_array::Array;
10+
use vortex_array::Canonical;
811
use vortex_array::IntoArray;
12+
use vortex_array::VortexSessionExecute;
913
use vortex_array::arrays::ChunkedArray;
1014
use vortex_array::builders::ArrayBuilder;
1115
use vortex_array::builders::PrimitiveBuilder;
1216
use vortex_array::compute::warm_up_vtables;
17+
use vortex_array::session::ArraySession;
1318
use vortex_error::VortexExpect;
1419
use vortex_fastlanes::bitpack_compress::test_harness::make_array;
20+
use vortex_session::VortexSession;
1521

1622
fn main() {
1723
warm_up_vtables();
@@ -32,6 +38,9 @@ const BENCH_ARGS: &[(usize, usize, f64)] = &[
3238
(10000, 1000, 0.00),
3339
];
3440

41+
static SESSION: LazyLock<VortexSession> =
42+
LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
43+
3544
#[divan::bench(args = BENCH_ARGS)]
3645
fn into_canonical_non_nullable(
3746
bencher: Bencher,
@@ -71,7 +80,7 @@ fn canonical_into_non_nullable(
7180
chunk_len * chunk_count,
7281
);
7382
chunked
74-
.append_to_builder(&mut primitive_builder)
83+
.append_to_builder(&mut primitive_builder, &mut SESSION.create_execution_ctx())
7584
.vortex_expect("append failed");
7685
primitive_builder.finish()
7786
});
@@ -103,8 +112,8 @@ fn into_canonical_nullable(
103112
let chunked = ChunkedArray::from_iter(chunks).into_array();
104113

105114
bencher
106-
.with_inputs(|| &chunked)
107-
.bench_refs(|chunked| chunked.to_canonical());
115+
.with_inputs(|| chunked.clone())
116+
.bench_values(|chunked| chunked.execute::<Canonical>(&mut SESSION.create_execution_ctx()));
108117
}
109118

110119
#[divan::bench(args = NULLABLE_BENCH_ARGS)]
@@ -128,7 +137,7 @@ fn canonical_into_nullable(
128137
chunk_len * chunk_count,
129138
);
130139
chunked
131-
.append_to_builder(&mut primitive_builder)
140+
.append_to_builder(&mut primitive_builder, &mut SESSION.create_execution_ctx())
132141
.vortex_expect("append failed");
133142
primitive_builder.finish()
134143
});

0 commit comments

Comments
 (0)