@@ -19,13 +19,15 @@ using af::dim4;
1919using namespace detail ;
2020
2121template <typename T>
22- static inline af_array transform (const af_array in, const af_array tf, const af::dim4 &odims, const bool inverse)
22+ static inline af_array transform (const af_array in, const af_array tf, const af::dim4 &odims,
23+ const af_interp_type method, const bool inverse)
2324{
24- return getHandle (*transform<T>(getArray<T>(in), getArray<float >(tf), odims, inverse));
25+ return getHandle (*transform<T>(getArray<T>(in), getArray<float >(tf), odims, method, inverse));
2526}
2627
2728af_err af_transform (af_array *out, const af_array in, const af_array tf,
28- const dim_type odim0, const dim_type odim1, const bool inverse)
29+ const dim_type odim0, const dim_type odim1,
30+ const af_interp_type method, const bool inverse)
2931{
3032 try {
3133 ArrayInfo t_info = getInfo (tf);
@@ -36,6 +38,7 @@ af_err af_transform(af_array *out, const af_array in, const af_array tf,
3638 af_dtype itype = i_info.getType ();
3739
3840 ARG_ASSERT (2 , t_info.getType () == f32 );
41+ ARG_ASSERT (5 , method == AF_INTERP_NEAREST || method == AF_INTERP_BILINEAR);
3942 DIM_ASSERT (2 , (tdims[0 ] == 3 && tdims[1 ] == 2 ));
4043 DIM_ASSERT (1 , idims.elements () > 0 );
4144 DIM_ASSERT (1 , (idims.ndims () == 2 || idims.ndims () == 3 ));
@@ -50,11 +53,11 @@ af_err af_transform(af_array *out, const af_array in, const af_array tf,
5053
5154 af_array output = 0 ;
5255 switch (itype) {
53- case f32 : output = transform<float >(in, tf, odims, inverse); break ;
54- case f64 : output = transform<double >(in, tf, odims, inverse); break ;
55- case s32: output = transform<int >(in, tf, odims, inverse); break ;
56- case u32 : output = transform<uint >(in, tf, odims, inverse); break ;
57- case u8 : output = transform<uchar >(in, tf, odims, inverse); break ;
56+ case f32 : output = transform<float >(in, tf, odims, method, inverse); break ;
57+ case f64 : output = transform<double >(in, tf, odims, method, inverse); break ;
58+ case s32: output = transform<int >(in, tf, odims, method, inverse); break ;
59+ case u32 : output = transform<uint >(in, tf, odims, method, inverse); break ;
60+ case u8 : output = transform<uchar >(in, tf, odims, method, inverse); break ;
5861 default : TYPE_ERROR (1 , itype);
5962 }
6063 std::swap (*out,output);
@@ -64,7 +67,7 @@ af_err af_transform(af_array *out, const af_array in, const af_array tf,
6467 return AF_SUCCESS;
6568}
6669
67- af_err af_rotate (af_array *out, const af_array in, const float theta,
70+ af_err af_rotate (af_array *out, const af_array in, const float theta, const af_interp_type method,
6871 const bool crop, const bool recenter)
6972{
7073 af_err ret = AF_SUCCESS;
@@ -114,7 +117,7 @@ af_err af_rotate(af_array *out, const af_array in, const float theta,
114117 ret = af_create_array (&t, trans_mat, tdims.ndims (), tdims.get (), f32 );
115118
116119 if (ret == AF_SUCCESS) {
117- return af_transform (out, in, t, odims0, odims1, true );
120+ return af_transform (out, in, t, odims0, odims1, method, true );
118121 }
119122 }
120123 CATCHALL;
@@ -123,7 +126,7 @@ af_err af_rotate(af_array *out, const af_array in, const float theta,
123126}
124127
125128af_err af_translate (af_array *out, const af_array in, const float trans0, const float trans1,
126- const dim_type odim0, const dim_type odim1)
129+ const dim_type odim0, const dim_type odim1, const af_interp_type method )
127130{
128131 af_err ret = AF_SUCCESS;
129132
@@ -139,7 +142,7 @@ af_err af_translate(af_array *out, const af_array in, const float trans0, const
139142 ret = af_create_array (&t, trans_mat, tdims.ndims (), tdims.get (), f32 );
140143
141144 if (ret == AF_SUCCESS) {
142- ret = af_transform (out, in, t, odim0, odim1, true );
145+ ret = af_transform (out, in, t, odim0, odim1, method, true );
143146 }
144147 }
145148 CATCHALL;
@@ -148,20 +151,26 @@ af_err af_translate(af_array *out, const af_array in, const float trans0, const
148151}
149152
150153af_err af_scale (af_array *out, const af_array in, const float scale0, const float scale1,
151- const dim_type odim0, const dim_type odim1)
154+ const dim_type odim0, const dim_type odim1, const af_interp_type method )
152155{
153156 af_err ret = AF_SUCCESS;
154157 try {
155158 ArrayInfo i_info = getInfo (in);
156159 af::dim4 idims = i_info.dims ();
157160
158161 dim_type _odim0 = odim0, _odim1 = odim1;
159- float sx = 1 . f / scale0 , sy = 1 . f / scale1 ;
162+ float sx, sy;
160163 if (_odim0 == 0 && _odim1 == 0 ) {
164+ sx = 1 .f / scale0, sy = 1 .f / scale1;
161165 _odim0 = idims[0 ] / sx;
162166 _odim1 = idims[1 ] / sy;
163167 } else if ( _odim0 == 0 || _odim1 == 0 ) {
164- return AF_ERR_ARG;
168+ return AF_ERR_SIZE;
169+ } else if (scale0 == 0 && scale1 == 0 ) {
170+ sx = idims[0 ] / (float )_odim0;
171+ sy = idims[1 ] / (float )_odim1;
172+ } else {
173+ sx = 1 .f / scale0, sy = 1 .f / scale1;
165174 }
166175
167176 static float trans_mat[6 ] = {1 , 0 , 0 ,
@@ -174,7 +183,7 @@ af_err af_scale(af_array *out, const af_array in, const float scale0, const floa
174183 ret = af_create_array (&t, trans_mat, tdims.ndims (), tdims.get (), f32 );
175184
176185 if (ret == AF_SUCCESS) {
177- return af_transform (out, in, t, odim0, odim1 , true );
186+ return af_transform (out, in, t, _odim0, _odim1, method , true );
178187 }
179188 }
180189 CATCHALL;
@@ -183,7 +192,8 @@ af_err af_scale(af_array *out, const af_array in, const float scale0, const floa
183192}
184193
185194af_err af_skew (af_array *out, const af_array in, const float skew0, const float skew1,
186- const dim_type odim0, const dim_type odim1, const bool inverse)
195+ const dim_type odim0, const dim_type odim1, const af_interp_type method,
196+ const bool inverse)
187197{
188198 af_err ret = AF_SUCCESS;
189199 try {
@@ -214,7 +224,7 @@ af_err af_skew(af_array *out, const af_array in, const float skew0, const float
214224 ret = af_create_array (&t, trans_mat, tdims.ndims (), tdims.get (), f32 );
215225
216226 if (ret == AF_SUCCESS) {
217- return af_transform (out, in, t, odim0, odim1, true );
227+ return af_transform (out, in, t, odim0, odim1, method, true );
218228 }
219229 }
220230 CATCHALL;
0 commit comments