Skip to content

Commit c60920b

Browse files
committed
Default.Reduce*: Proper handling when empty array or scalar
1 parent bcb5bba commit c60920b

3 files changed

Lines changed: 7 additions & 18 deletions

File tree

src/NumSharp.Core/Backends/Default/Math/Reduction/Default.Reduction.CumAdd.cs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,7 @@ public override unsafe NDArray ReduceCumAdd(in NDArray arr, int? axis_, NPTypeCo
1616
return arr;
1717

1818
if (shape.IsScalar || shape.size == 1 && shape.dimensions.Length == 1)
19-
{
20-
var r = typeCode.HasValue ? Cast(arr, typeCode.Value, true) : arr.Clone();
21-
if (!r.Shape.IsScalar && r.Shape.size == 1 && r.ndim == 1)
22-
r.Storage.Reshape(Shape.Scalar);
23-
return r;
24-
}
19+
return typeCode.HasValue ? Cast(arr, typeCode.Value, copy: true) : arr.Clone();
2520

2621
if (axis_ == null)
2722
{

src/NumSharp.Core/Backends/Default/Math/Reduction/Default.Reduction.Mean.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,9 @@ public override NDArray ReduceMean(in NDArray arr, int? axis_, bool keepdims = f
1717

1818
if (shape.IsScalar || (shape.size == 1 && shape.NDim == 1))
1919
{
20-
var r = typeCode.HasValue ? Cast(arr, typeCode.Value, true) : arr.Clone();
20+
var r = NDArray.Scalar(typeCode.HasValue ? Converts.ChangeType(arr.GetAtIndex(0), typeCode.Value) : arr.GetAtIndex(0));
2121
if (keepdims)
2222
r.Storage.ExpandDimension(0);
23-
else if (!r.Shape.IsScalar && r.Shape.size == 1 && r.ndim == 1)
24-
r.Storage.Reshape(Shape.Scalar);
2523
return r;
2624
}
2725

src/NumSharp.Core/Backends/Default/Math/Reduction/Default.Reduction.Product.cs

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,15 @@ public override NDArray ReduceProduct(NDArray arr, int? axis_, bool keepdims = f
1212
//the size of the array is [1, 2, n, m] all shapes after 2nd multiplied gives size
1313
//the size of what we need to reduce is the size of the shape of the given axis (shape[axis])
1414
var shape = arr.Shape;
15-
if (shape.IsEmpty)
16-
return arr;
15+
if (shape.IsEmpty || shape.size==0)
16+
return NDArray.Scalar(1, (typeCode ?? arr.typecode));
1717

18-
if (shape.IsScalar || (shape.size <= 1 && shape.NDim == 1))
18+
if (shape.IsScalar || (shape.size == 1 && shape.NDim == 1))
1919
{
20-
var r = typeCode.HasValue ? Cast(arr, typeCode.Value, true) : arr.Clone();
20+
var r = NDArray.Scalar(typeCode.HasValue ? Converts.ChangeType(arr.GetAtIndex(0), typeCode.Value) : arr.GetAtIndex(0));
2121
if (keepdims)
2222
r.Storage.ExpandDimension(0);
23-
else if (!r.Shape.IsScalar && r.Shape.size == 1 && r.ndim == 1)
24-
r.Storage.Reshape(Shape.Scalar);
25-
else if (!r.Shape.IsScalar && r.Shape.size == 0 && r.ndim == 1)
26-
return NDArray.Scalar(1, arr.typecode);
27-
23+
2824
return r;
2925
}
3026

0 commit comments

Comments
 (0)