Skip to content

Commit d001252

Browse files
committed
handle empty arrays in a variety of functions
allow indexing with empty arrays
1 parent 072d507 commit d001252

34 files changed

Lines changed: 542 additions & 14 deletions

src/api/c/assign.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,6 @@ af_err af_assign_gen(af_array *out,
206206
spanner.isSeq = true;
207207

208208
try {
209-
ARG_ASSERT(2, (ndims>0));
210209
ARG_ASSERT(3, (indexs!=NULL));
211210

212211
int track = 0;
@@ -233,6 +232,15 @@ af_err af_assign_gen(af_array *out,
233232
af_dtype lhsType= lInfo.getType();
234233
af_dtype rhsType= rInfo.getType();
235234

235+
if(rhsDims.ndims() == 0) {
236+
return af_retain_array(out, lhs);
237+
}
238+
239+
if(lhsDims.ndims() == 0) {
240+
dim_t my_dims[] = { 0, 0, 0, 0 };
241+
return af_create_handle(out, AF_MAX_DIMS, my_dims, lhsType);
242+
}
243+
236244
ARG_ASSERT(2, (ndims == 1) || (ndims == (dim_t)lInfo.ndims()));
237245

238246
if (ndims == 1 && ndims != (dim_t)lInfo.ndims()) {
@@ -246,7 +254,6 @@ af_err af_assign_gen(af_array *out,
246254
}
247255

248256
ARG_ASSERT(1, (lhsType==rhsType));
249-
ARG_ASSERT(3, (rhsDims.ndims()>0));
250257
ARG_ASSERT(1, (lhsDims.ndims()>=rhsDims.ndims()));
251258
ARG_ASSERT(2, (lhsDims.ndims()>=ndims));
252259

src/api/c/binary.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,11 @@ static af_err af_bitwise(af_array *out, const af_array lhs, const af_array rhs,
332332

333333
dim4 odims = getOutDims(linfo.dims(), rinfo.dims(), batchMode);
334334

335+
if(odims.ndims() == 0) {
336+
dim_t my_dims[] = {0, 0, 0, 0};
337+
return af_create_handle(out, AF_MAX_DIMS, my_dims, type);
338+
}
339+
335340
af_array res;
336341
switch (type) {
337342
case s32: res = bitOp<int , op>(lhs, rhs, odims); break;

src/api/c/blas.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ af_err af_dot( af_array *out,
105105
af_dtype lhs_type = lhsInfo.getType();
106106
af_dtype rhs_type = rhsInfo.getType();
107107

108+
if(lhsInfo.ndims() == 0) {
109+
return af_retain_array(out, lhs);
110+
}
108111
if (lhsInfo.ndims() > 1 ||
109112
rhsInfo.ndims() > 1) {
110113
AF_ERROR("dot can not be used in batch mode", AF_ERR_BATCH);

src/api/c/cast.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@ static af_array cast(const af_array in, const af_dtype type)
4848
af_err af_cast(af_array *out, const af_array in, const af_dtype type)
4949
{
5050
try {
51+
const ArrayInfo info = getInfo(in);
52+
dim4 idims = info.dims();
53+
if(idims.elements() == 0) {
54+
dim_t my_dims[] = {0, 0, 0, 0};
55+
return af_create_handle(out, AF_MAX_DIMS, my_dims, type);
56+
}
57+
5158
af_array res = cast(in, type);
5259
std::swap(*out, res);
5360
}

src/api/c/cholesky.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ af_err af_cholesky(af_array *out, int *info, const af_array in, const bool is_up
4545
ARG_ASSERT(2, i_info.isFloating()); // Only floating and complex types
4646
DIM_ASSERT(1, i_info.dims()[0] == i_info.dims()[1]); // Only square matrices
4747

48+
if(i_info.ndims() == 0) {
49+
dim_t my_dims[] = {0, 0, 0, 0};
50+
return af_create_handle(out, AF_MAX_DIMS, my_dims, type);
51+
}
52+
4853
af_array output;
4954
switch(type) {
5055
case f32: output = cholesky<float >(info, in, is_upper); break;
@@ -74,6 +79,10 @@ af_err af_cholesky_inplace(int *info, af_array in, const bool is_upper)
7479
ARG_ASSERT(1, i_info.isFloating()); // Only floating and complex types
7580
DIM_ASSERT(1, i_info.dims()[0] == i_info.dims()[1]); // Only square matrices
7681

82+
if(i_info.ndims() == 0) {
83+
return AF_SUCCESS;
84+
}
85+
7786
int out;
7887

7988
switch(type) {

src/api/c/convolve.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ af_err convolve(af_array *out, const af_array signal, const af_array filter)
7272
dim4 sdims = sInfo.dims();
7373
dim4 fdims = fInfo.dims();
7474

75+
if(fdims.ndims() == 0 || sdims.ndims() == 0) {
76+
return af_retain_array(out, signal);
77+
}
78+
7579
AF_BATCH_KIND convBT = identifyBatchKind<baseDim>(sdims, fdims);
7680

7781
ARG_ASSERT(1, (convBT != AF_BATCH_UNSUPPORTED && convBT != AF_BATCH_DIFF));

src/api/c/data.cpp

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828

2929
using af::dim4;
3030
using namespace detail;
31-
using namespace std;
3231

3332
dim4 verifyDims(const unsigned ndims, const dim_t * const dims)
3433
{
@@ -54,7 +53,14 @@ af_err af_constant(af_array *result, const double value,
5453
af_array out;
5554
AF_CHECK(af_init());
5655

57-
dim4 d = verifyDims(ndims, dims);
56+
dim4 d(1, 1, 1, 1);
57+
if(ndims <= 0) {
58+
dim_t my_dims[] = {0, 0, 0, 0};
59+
return af_create_handle(result, AF_MAX_DIMS, my_dims, type);
60+
} else {
61+
d = verifyDims(ndims, dims);
62+
}
63+
5864

5965
switch(type) {
6066
case f32: out = createHandleFromValue<float >(d, value); break;
@@ -229,6 +235,11 @@ af_err af_identity(af_array *out, const unsigned ndims, const dim_t * const dims
229235
af_array result;
230236
AF_CHECK(af_init());
231237

238+
if(ndims == 0) {
239+
dim_t my_dims[] = {0, 0, 0, 0};
240+
return af_create_handle(out, AF_MAX_DIMS, my_dims, type);
241+
}
242+
232243
dim4 d = verifyDims(ndims, dims);
233244

234245
switch(type) {
@@ -301,6 +312,11 @@ af_err af_iota(af_array *result, const unsigned ndims, const dim_t * const dims,
301312
af_array out;
302313
AF_CHECK(af_init());
303314

315+
if(ndims == 0) {
316+
dim_t my_dims[] = {0, 0, 0, 0};
317+
return af_create_handle(result, AF_MAX_DIMS, my_dims, type);
318+
}
319+
304320
DIM_ASSERT(1, ndims > 0 && ndims <= 4);
305321
DIM_ASSERT(3, t_ndims > 0 && t_ndims <= 4);
306322

@@ -345,6 +361,12 @@ af_err af_diag_create(af_array *out, const af_array in, const int num)
345361
af_dtype type = in_info.getType();
346362

347363
af_array result;
364+
365+
if(in_info.dims()[0] == 0) {
366+
dim_t my_dims[] = {0, 0, 0, 0};
367+
return af_create_handle(out, AF_MAX_DIMS, my_dims, type);
368+
}
369+
348370
switch(type) {
349371
case f32: result = diagCreate<float >(in, num); break;
350372
case c32: result = diagCreate<cfloat >(in, num); break;
@@ -372,9 +394,15 @@ af_err af_diag_extract(af_array *out, const af_array in, const int num)
372394

373395
try {
374396
ArrayInfo in_info = getInfo(in);
375-
DIM_ASSERT(1, in_info.ndims() >= 2);
376397
af_dtype type = in_info.getType();
377398

399+
if(in_info.ndims() == 0) {
400+
dim_t my_dims[] = {0, 0, 0, 0};
401+
return af_create_handle(out, AF_MAX_DIMS, my_dims, type);
402+
}
403+
404+
DIM_ASSERT(1, in_info.ndims() >= 2);
405+
378406
af_array result;
379407
switch(type) {
380408
case f32: result = diagExtract<float >(in, num); break;

src/api/c/det.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ T det(const af_array a)
2929

3030
const int num = A.dims()[0];
3131

32+
if(num == 0) {
33+
T res = scalar<T>(1.0);
34+
return res;
35+
}
36+
3237
std::vector<T> hD(num);
3338
std::vector<int> hP(num);
3439

src/api/c/diff.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ af_err af_diff1(af_array *out, const af_array in, const int dim)
4040
af_dtype type = info.getType();
4141

4242
af::dim4 in_dims = info.dims();
43+
if(in_dims[dim] < 2) {
44+
dim_t my_dims[] = {0, 0, 0, 0};
45+
return af_create_handle(out, AF_MAX_DIMS, my_dims, type);
46+
}
47+
4348
DIM_ASSERT(1, in_dims[dim] >= 2);
4449

4550
af_array output;
@@ -77,6 +82,10 @@ af_err af_diff2(af_array *out, const af_array in, const int dim)
7782
af_dtype type = info.getType();
7883

7984
af::dim4 in_dims = info.dims();
85+
if(in_dims[dim] < 3) {
86+
dim_t my_dims[] = {0, 0, 0, 0};
87+
return af_create_handle(out, AF_MAX_DIMS, my_dims, type);
88+
}
8089
DIM_ASSERT(1, in_dims[dim] >= 3);
8190

8291
af_array output;

src/api/c/fft.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ static af_err fft(af_array *out, const af_array in, const double norm_factor, co
3333
af_dtype type = info.getType();
3434
af::dim4 dims = info.dims();
3535

36+
if(dims.ndims() == 0) {
37+
return af_retain_array(out, in);
38+
}
39+
3640
DIM_ASSERT(1, (dims.ndims()>=rank));
3741

3842
af_array output;
@@ -104,6 +108,9 @@ static af_err fft_inplace(af_array in, const double norm_factor)
104108
af_dtype type = info.getType();
105109
af::dim4 dims = info.dims();
106110

111+
if(dims.ndims() == 0) {
112+
return AF_SUCCESS;
113+
}
107114
DIM_ASSERT(1, (dims.ndims()>=rank));
108115

109116
switch(type) {
@@ -163,6 +170,9 @@ static af_err fft_r2c(af_array *out, const af_array in, const double norm_factor
163170
af_dtype type = info.getType();
164171
af::dim4 dims = info.dims();
165172

173+
if(dims.ndims() == 0) {
174+
return af_retain_array(out, in);
175+
}
166176
DIM_ASSERT(1, (dims.ndims()>=rank));
167177

168178
af_array output;
@@ -215,6 +225,9 @@ static af_err fft_c2r(af_array *out, const af_array in, const double norm_factor
215225
af_dtype type = info.getType();
216226
af::dim4 idims = info.dims();
217227

228+
if(idims.ndims() == 0) {
229+
return af_retain_array(out, in);
230+
}
218231
DIM_ASSERT(1, (idims.ndims()>=rank));
219232

220233
dim4 odims = idims;

0 commit comments

Comments
 (0)