2828
2929using af::dim4;
3030using namespace detail ;
31- using namespace std ;
3231
3332dim4 verifyDims (const unsigned ndims, const dim_t * const dims)
3433{
@@ -54,7 +53,14 @@ af_err af_constant(af_array *result, const double value,
5453 af_array out;
5554 AF_CHECK (af_init ());
5655
57- dim4 d = verifyDims (ndims, dims);
56+ dim4 d (1 , 1 , 1 , 1 );
57+ if (ndims <= 0 ) {
58+ dim_t my_dims[] = {0 , 0 , 0 , 0 };
59+ return af_create_handle (result, AF_MAX_DIMS, my_dims, type);
60+ } else {
61+ d = verifyDims (ndims, dims);
62+ }
63+
5864
5965 switch (type) {
6066 case f32 : out = createHandleFromValue<float >(d, value); break ;
@@ -229,6 +235,11 @@ af_err af_identity(af_array *out, const unsigned ndims, const dim_t * const dims
229235 af_array result;
230236 AF_CHECK (af_init ());
231237
238+ if (ndims == 0 ) {
239+ dim_t my_dims[] = {0 , 0 , 0 , 0 };
240+ return af_create_handle (out, AF_MAX_DIMS, my_dims, type);
241+ }
242+
232243 dim4 d = verifyDims (ndims, dims);
233244
234245 switch (type) {
@@ -301,6 +312,11 @@ af_err af_iota(af_array *result, const unsigned ndims, const dim_t * const dims,
301312 af_array out;
302313 AF_CHECK (af_init ());
303314
315+ if (ndims == 0 ) {
316+ dim_t my_dims[] = {0 , 0 , 0 , 0 };
317+ return af_create_handle (result, AF_MAX_DIMS, my_dims, type);
318+ }
319+
304320 DIM_ASSERT (1 , ndims > 0 && ndims <= 4 );
305321 DIM_ASSERT (3 , t_ndims > 0 && t_ndims <= 4 );
306322
@@ -345,6 +361,12 @@ af_err af_diag_create(af_array *out, const af_array in, const int num)
345361 af_dtype type = in_info.getType ();
346362
347363 af_array result;
364+
365+ if (in_info.dims ()[0 ] == 0 ) {
366+ dim_t my_dims[] = {0 , 0 , 0 , 0 };
367+ return af_create_handle (out, AF_MAX_DIMS, my_dims, type);
368+ }
369+
348370 switch (type) {
349371 case f32 : result = diagCreate<float >(in, num); break ;
350372 case c32: result = diagCreate<cfloat >(in, num); break ;
@@ -372,9 +394,15 @@ af_err af_diag_extract(af_array *out, const af_array in, const int num)
372394
373395 try {
374396 ArrayInfo in_info = getInfo (in);
375- DIM_ASSERT (1 , in_info.ndims () >= 2 );
376397 af_dtype type = in_info.getType ();
377398
399+ if (in_info.ndims () == 0 ) {
400+ dim_t my_dims[] = {0 , 0 , 0 , 0 };
401+ return af_create_handle (out, AF_MAX_DIMS, my_dims, type);
402+ }
403+
404+ DIM_ASSERT (1 , in_info.ndims () >= 2 );
405+
378406 af_array result;
379407 switch (type) {
380408 case f32 : result = diagExtract<float >(in, num); break ;
0 commit comments