Skip to content

Commit 9f06445

Browse files
authored
unsafe Scalar::struct_unchecked in scalar_at vtable (#6741)
1 parent 0288f11 commit 9f06445

4 files changed

Lines changed: 33 additions & 11 deletions

File tree

vortex-array/public-api.lock

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12248,7 +12248,9 @@ pub fn vortex_array::scalar::Scalar::from_proto_value(value: &vortex_proto::scal
1224812248

1224912249
impl vortex_array::scalar::Scalar
1225012250

12251-
pub fn vortex_array::scalar::Scalar::struct_(dtype: vortex_array::dtype::DType, children: alloc::vec::Vec<vortex_array::scalar::Scalar>) -> Self
12251+
pub fn vortex_array::scalar::Scalar::struct_(dtype: vortex_array::dtype::DType, children: impl core::iter::traits::collect::IntoIterator<Item = vortex_array::scalar::Scalar>) -> Self
12252+
12253+
pub unsafe fn vortex_array::scalar::Scalar::struct_unchecked(dtype: vortex_array::dtype::DType, children: impl core::iter::traits::collect::IntoIterator<Item = vortex_array::scalar::Scalar>) -> Self
1225212254

1225312255
impl vortex_array::scalar::Scalar
1225412256

vortex-array/src/arrays/struct_/vtable/operations.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@ use crate::vtable::OperationsVTable;
1111

1212
impl OperationsVTable<StructVTable> for StructVTable {
1313
fn scalar_at(array: &StructArray, index: usize) -> VortexResult<Scalar> {
14-
let field_scalars: VortexResult<Vec<_>> = array
14+
let field_scalars: VortexResult<Vec<Scalar>> = array
1515
.unmasked_fields()
1616
.iter()
1717
.map(|field| field.scalar_at(index))
1818
.collect();
19-
Ok(Scalar::struct_(array.dtype().clone(), field_scalars?))
19+
// SAFETY: The vtable guarantees index is in-bounds and non-null before this is called.
20+
// Each field's scalar_at returns a scalar with the field's own dtype.
21+
Ok(unsafe { Scalar::struct_unchecked(array.dtype().clone(), field_scalars?) })
2022
}
2123
}

vortex-array/src/scalar/typed_view/struct_.rs

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -279,12 +279,13 @@ impl<'a> StructScalar<'a> {
279279
}
280280

281281
impl Scalar {
282-
/// Creates a new struct scalar with the given fields.
283-
pub fn struct_(dtype: DType, children: Vec<Scalar>) -> Self {
282+
/// Creates a new struct scalar with the given fields, checking dtypes at runtime.
283+
pub fn struct_(dtype: DType, children: impl IntoIterator<Item = Scalar>) -> Self {
284284
let DType::Struct(struct_fields, _) = &dtype else {
285285
vortex_panic!("Expected struct dtype, found {}", dtype);
286286
};
287287

288+
let children: Vec<Scalar> = children.into_iter().collect();
288289
let field_dtypes = struct_fields.fields();
289290
if children.len() != field_dtypes.len() {
290291
vortex_panic!(
@@ -305,9 +306,24 @@ impl Scalar {
305306
}
306307
}
307308

308-
let mut value_children = Vec::with_capacity(children.len());
309-
value_children.extend(children.into_iter().map(|x| x.into_value()));
309+
let value_children: Vec<_> = children.into_iter().map(|x| x.into_value()).collect();
310+
Self::try_new(dtype, Some(ScalarValue::List(value_children)))
311+
.vortex_expect("unable to construct a struct `Scalar`")
312+
}
310313

314+
/// Creates a new struct scalar from an iterator of field scalars, skipping dtype checks.
315+
///
316+
/// # Safety
317+
///
318+
/// Caller must ensure:
319+
/// - `dtype` is `DType::Struct`
320+
/// - The iterator yields exactly as many scalars as `dtype` has fields
321+
/// - Each scalar's dtype matches the corresponding field dtype in `dtype`
322+
pub unsafe fn struct_unchecked(
323+
dtype: DType,
324+
children: impl IntoIterator<Item = Scalar>,
325+
) -> Self {
326+
let value_children: Vec<_> = children.into_iter().map(|s| s.into_value()).collect();
311327
Self::try_new(dtype, Some(ScalarValue::List(value_children)))
312328
.vortex_expect("unable to construct a struct `Scalar`")
313329
}

vortex-python/src/scalar/factory.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,14 @@ fn scalar_helper_inner(value: &Bound<'_, PyAny>, dtype: Option<&DType>) -> PyRes
151151
)));
152152
}
153153

154+
let children: Vec<Scalar> = dict
155+
.values()
156+
.into_iter()
157+
.map(|item| scalar_helper_inner(&item, None))
158+
.try_collect()?;
154159
return Ok(Scalar::struct_(
155160
DType::Struct(dtype.clone(), *nullability),
156-
dict.values()
157-
.into_iter()
158-
.map(|item| scalar_helper_inner(&item, None))
159-
.try_collect()?,
161+
children,
160162
));
161163
} else {
162164
let values: Vec<Scalar> = dict

0 commit comments

Comments
 (0)