-
Notifications
You must be signed in to change notification settings - Fork 145
Expand file tree
/
Copy pathrange_to_sequence.rs
More file actions
127 lines (110 loc) · 4.5 KB
/
range_to_sequence.rs
File metadata and controls
127 lines (110 loc) · 4.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors
use vortex::array::ArrayRef;
use vortex::array::IntoArray;
use vortex::array::arrays::PrimitiveArray;
use vortex::array::validity::Validity;
use vortex::buffer::Buffer;
use vortex::dtype::DType;
use vortex::dtype::NativePType;
use vortex::dtype::Nullability;
use vortex::encodings::sequence::Sequence;
use vortex::error::VortexExpect;
use vortex::error::VortexResult;
use vortex::error::vortex_bail;
use vortex::scalar::PValue;
pub fn sequence_array_from_range<T: NativePType + TryFrom<isize> + Into<PValue>>(
start: isize,
stop: isize,
step: isize,
dtype: DType,
) -> VortexResult<ArrayRef> {
if step == 0 {
vortex_bail!("Step must not be zero");
}
let Some(len) = range_len(start, stop, step) else {
let validity = match dtype.nullability() {
Nullability::NonNullable => Validity::NonNullable,
Nullability::Nullable => Validity::AllValid,
};
return Ok(PrimitiveArray::new::<T>(Buffer::empty(), validity).into_array());
};
let Ok(start) = T::try_from(start) else {
vortex_bail!(
"Start, {}, does not fit in requested dtype: {}",
start,
dtype
);
};
let Ok(step) = T::try_from(step) else {
vortex_bail!("Step, {}, does not fit in requested dtype: {}", step, dtype);
};
Ok(Sequence::try_new_typed::<T>(start, step, dtype.nullability(), len)?.into_array())
}
fn range_len(start: isize, stop: isize, step: isize) -> Option<usize> {
if step > 0 {
if start > stop {
return None;
}
let len = (stop - start + step - 1) / step;
let len =
usize::try_from(len).vortex_expect("stop >= start, step > 0, so len is non-negative");
Some(len)
} else {
assert!(step != 0);
if stop > start {
return None;
}
let len = (start - stop + -step - 1) / -step;
let len =
usize::try_from(len).vortex_expect("start >= stop, step < 0, so len is non-negative");
Some(len)
}
}
#[cfg(test)]
mod test {
use vortex::array::IntoArray as _;
use vortex::array::assert_arrays_eq;
use vortex::buffer::buffer;
use vortex::dtype::DType;
use vortex::dtype::Nullability;
use vortex::dtype::PType;
use crate::arrays::range_to_sequence::range_len;
use crate::arrays::range_to_sequence::sequence_array_from_range;
#[test]
fn test_range_len() {
assert_eq!(range_len(0, 10, 1).unwrap(), 10);
assert_eq!(range_len(0, 10, 5).unwrap(), 2);
assert_eq!(range_len(0, 10, 10).unwrap(), 1);
assert_eq!(range_len(0, 10, 100).unwrap(), 1);
assert_eq!(range_len(-5, -5, 1).unwrap(), 0);
assert_eq!(range_len(-5, 5, 3).unwrap(), 4);
assert_eq!(range_len(-7, -5, 1).unwrap(), 2);
assert_eq!(range_len(3, -3, -1).unwrap(), 6);
assert_eq!(range_len(10, 3, 1), None);
assert_eq!(range_len(0, 10, -1), None);
}
#[test]
fn test_sequence_array_from_len() {
let dtype = DType::Primitive(PType::U16, Nullability::NonNullable);
let arr = sequence_array_from_range::<u16>(0, 10, 1, dtype).unwrap();
assert_arrays_eq!(arr, buffer![0u16, 1, 2, 3, 4, 5, 6, 7, 8, 9].into_array());
let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
let arr = sequence_array_from_range::<i32>(0, 10, 5, dtype).unwrap();
assert_arrays_eq!(arr, buffer![0i32, 5].into_array());
let dtype = DType::Primitive(PType::I8, Nullability::NonNullable);
let arr = sequence_array_from_range::<i8>(-5, 5, 3, dtype).unwrap();
assert_arrays_eq!(arr, buffer![-5i8, -2, 1, 4].into_array());
let dtype = DType::Primitive(PType::I8, Nullability::NonNullable);
let arr = sequence_array_from_range::<i8>(3, -3, -1, dtype).unwrap();
assert_arrays_eq!(arr, buffer![3i8, 2, 1, 0, -1, -2].into_array());
let dtype = DType::Primitive(PType::U32, Nullability::NonNullable);
let result = sequence_array_from_range::<u32>(1_000_000, 10, -500_000, dtype);
assert!(
result.is_err_and(|err| err.to_string().contains("does not fit in requested dtype"))
);
let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
let arr = sequence_array_from_range::<i32>(1_000_000, 10, -500_000, dtype).unwrap();
assert_arrays_eq!(arr, buffer![1_000_000i32, 500_000].into_array());
}
}