Skip to content

Commit 961e25a

Browse files
committed
WIP on .expand() for expanding array element into dimension
1 parent 209d171 commit 961e25a

3 files changed

Lines changed: 126 additions & 1 deletion

File tree

src/impl_expand.rs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// Copyright 2021 bluss and ndarray developers.
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6+
// option. This file may not be copied, modified, or distributed
7+
// except according to those terms.
8+
9+
use crate::imp_prelude::*;
10+
11+
use crate::data_traits::RawDataSubst;
12+
13+
use num_complex::Complex;
14+
15+
16+
pub trait MultiElement {
17+
type Elem;
18+
const LEN: usize;
19+
}
20+
21+
impl<A, const N: usize> MultiElement for [A; N] {
22+
type Elem = A;
23+
const LEN: usize = N;
24+
}
25+
26+
impl<A> MultiElement for Complex<A> {
27+
type Elem = A;
28+
const LEN: usize = 2;
29+
}
30+
31+
impl<'a, A, D, const N: usize> ArrayView<'a, [A; N], D>
32+
where
33+
D: Dimension,
34+
{
35+
///
36+
/// Note: expanding a zero-element array leads to a new axis of length zero,
37+
/// i.e. the array becomes empty.
38+
///
39+
/// **Panics** if the product of non-zero axis lengths overflows `isize`.
40+
pub fn expand(self, new_axis: Axis) -> ArrayView<'a, A, D::Larger> {
41+
let mut strides = self.strides.insert_axis(new_axis);
42+
let mut dim = self.dim.insert_axis(new_axis);
43+
let len = N as isize;
44+
for ax in 0..strides.ndim() {
45+
if Axis(ax) == new_axis {
46+
continue;
47+
}
48+
if dim[ax] > 1 {
49+
strides[ax] = ((strides[ax] as isize) * len) as usize;
50+
}
51+
}
52+
dim[new_axis.index()] = N;
53+
// TODO nicer assertion
54+
crate::dimension::size_of_shape_checked(&dim).unwrap();
55+
56+
// safe because
57+
// size still fits in isize;
58+
// new strides are adapted to new element type, inside the same allocation.
59+
unsafe {
60+
ArrayBase::from_data_ptr(self.data.data_subst(), self.ptr.cast())
61+
.with_strides_dim(strides, dim)
62+
}
63+
}
64+
}

src/lib.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
// Copyright 2014-2020 bluss and ndarray developers.
1+
// Copyright 2014-2021 ndarray developers.
2+
// Main authors:
3+
//
4+
// Ulrik Sverdrup "bluss"
5+
// Jim Turner "jturner314"
6+
// and many others
27
//
38
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
49
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
@@ -1515,6 +1520,7 @@ mod impl_internal_constructors;
15151520
mod impl_constructors;
15161521

15171522
mod impl_methods;
1523+
mod impl_expand;
15181524
mod impl_owned_array;
15191525
mod impl_special_element_types;
15201526

tests/expand.rs

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
2+
use ndarray::prelude::*;
3+
4+
#[test]
5+
fn test_expand_from_zero() {
6+
let a = Array::from_elem((), [[[1, 2], [3, 4], [5, 6]],
7+
[[11, 12], [13, 14], [15, 16]]]);
8+
let av = a.view();
9+
println!("{:?}", av);
10+
let av = av.expand(Axis(0));
11+
println!("{:?}", av);
12+
let av = av.expand(Axis(1));
13+
println!("{:?}", av);
14+
let av = av.expand(Axis(2));
15+
println!("{:?}", av);
16+
assert!(av.is_standard_layout());
17+
assert_eq!(av, av.to_owned());
18+
19+
let av = a.view();
20+
println!("{:?}", av);
21+
let av = av.expand(Axis(0));
22+
println!("{:?}", av);
23+
let av = av.expand(Axis(0));
24+
println!("{:?}", av);
25+
let av = av.expand(Axis(0));
26+
println!("{:?}", av);
27+
assert!(av.t().is_standard_layout());
28+
assert_eq!(av, av.to_owned());
29+
}
30+
31+
#[test]
32+
fn test_expand_zero() {
33+
let a = Array::from_elem((3, 4), [0.; 0]);
34+
35+
for ax in 0..=2 {
36+
let mut new_shape = [3, 4, 4];
37+
new_shape[1] = if ax == 0 { 3 } else { 4 };
38+
new_shape[ax] = 0;
39+
let av = a.view();
40+
let av = av.expand(Axis(ax));
41+
assert_eq!(av.shape(), &new_shape);
42+
}
43+
}
44+
45+
#[test]
46+
fn test_expand1() {
47+
let a = Array::from_elem((3, 3), [1, 2, 3]);
48+
println!("{:?}", a);
49+
let b = a.view().expand(Axis(2));
50+
println!("{:?}", b);
51+
let b = a.view().expand(Axis(1));
52+
println!("{:?}", b);
53+
let b = a.view().expand(Axis(0));
54+
println!("{:?}", b);
55+
}

0 commit comments

Comments
 (0)