1010#include < af/dim4.hpp>
1111#include < af/defines.h>
1212#include < af/image.h>
13+ #include < af/signal.h>
1314#include < af/data.h>
1415#include < handle.hpp>
1516#include < err_common.hpp>
1920using af::dim4;
2021using namespace detail ;
2122
23+ af_err af_medfilt (af_array *out, const af_array in, const dim_t wind_length, const dim_t wind_width, const af_border_type edge_pad)
24+ {
25+ return af_medfilt2 (out, in, wind_length, wind_width, edge_pad);
26+ }
27+
2228template <typename T>
23- static af_array medfilt (af_array const &in, dim_t w_len , dim_t w_wid, af_border_type edge_pad)
29+ static af_array medfilt1 (af_array const &in, dim_t w_wid, af_border_type edge_pad)
2430{
2531 switch (edge_pad) {
26- case AF_PAD_ZERO : return getHandle<T>(medfilt <T, AF_PAD_ZERO >(getArray<T>(in), w_len , w_wid)); break ;
27- case AF_PAD_SYM : return getHandle<T>(medfilt <T, AF_PAD_SYM >(getArray<T>(in), w_len , w_wid)); break ;
28- default : return getHandle<T>(medfilt <T, AF_PAD_ZERO >(getArray<T>(in), w_len , w_wid)); break ;
32+ case AF_PAD_ZERO : return getHandle<T>(medfilt1 <T, AF_PAD_ZERO >(getArray<T>(in), w_wid)); break ;
33+ case AF_PAD_SYM : return getHandle<T>(medfilt1 <T, AF_PAD_SYM >(getArray<T>(in), w_wid)); break ;
34+ default : return getHandle<T>(medfilt1 <T, AF_PAD_ZERO >(getArray<T>(in), w_wid)); break ;
2935 }
3036}
3137
32- af_err af_medfilt (af_array *out, const af_array in, const dim_t wind_length , const dim_t wind_width, const af_border_type edge_pad)
38+ af_err af_medfilt1 (af_array *out, const af_array in, const dim_t wind_width, const af_border_type edge_pad)
3339{
3440 try {
35- ARG_ASSERT (2 , (wind_length==wind_width));
36- ARG_ASSERT (2 , (wind_length>0 ));
37- ARG_ASSERT (3 , (wind_width>0 ));
41+ ARG_ASSERT (2 , (wind_width>0 ));
3842 ARG_ASSERT (4 , (edge_pad>=AF_PAD_ZERO && edge_pad<=AF_PAD_SYM ));
3943
4044 ArrayInfo info = getInfo (in);
4145 af::dim4 dims = info.dims ();
4246
4347 dim_t input_ndims = dims.ndims ();
44- DIM_ASSERT (1 , (input_ndims >= 2 ));
48+ DIM_ASSERT (1 , (input_ndims >= 1 ));
4549
46- if (wind_length ==1 ) {
50+ if (wind_width ==1 ) {
4751 *out = retain (in);
4852 } else {
4953 af_array output;
5054 af_dtype type = info.getType ();
5155 switch (type) {
52- case f32 : output = medfilt <float >(in, wind_length , wind_width, edge_pad); break ;
53- case f64 : output = medfilt <double >(in, wind_length , wind_width, edge_pad); break ;
54- case b8 : output = medfilt <char >(in, wind_length , wind_width, edge_pad); break ;
55- case s32: output = medfilt <int >(in, wind_length , wind_width, edge_pad); break ;
56- case u32 : output = medfilt <uint >(in, wind_length , wind_width, edge_pad); break ;
57- case s16: output = medfilt <short >(in, wind_length , wind_width, edge_pad); break ;
58- case u16 : output = medfilt <ushort>(in, wind_length , wind_width, edge_pad); break ;
59- case u8 : output = medfilt <uchar >(in, wind_length , wind_width, edge_pad); break ;
56+ case f32 : output = medfilt1 <float >(in, wind_width, edge_pad); break ;
57+ case f64 : output = medfilt1 <double >(in, wind_width, edge_pad); break ;
58+ case b8 : output = medfilt1 <char >(in, wind_width, edge_pad); break ;
59+ case s32: output = medfilt1 <int >(in, wind_width, edge_pad); break ;
60+ case u32 : output = medfilt1 <uint >(in, wind_width, edge_pad); break ;
61+ case s16: output = medfilt1 <short >(in, wind_width, edge_pad); break ;
62+ case u16 : output = medfilt1 <ushort>(in, wind_width, edge_pad); break ;
63+ case u8 : output = medfilt1 <uchar >(in, wind_width, edge_pad); break ;
6064 default : TYPE_ERROR (1 , type);
6165 }
6266 std::swap (*out, output);
@@ -68,41 +72,43 @@ af_err af_medfilt(af_array *out, const af_array in, const dim_t wind_length, con
6872}
6973
7074template <typename T>
71- static af_array medfilt_1d (af_array const &in, dim_t w_wid, af_border_type edge_pad)
75+ static af_array medfilt2 (af_array const &in, dim_t w_len , dim_t w_wid, af_border_type edge_pad)
7276{
7377 switch (edge_pad) {
74- case AF_PAD_ZERO : return getHandle<T>(medfilt_1d <T, AF_PAD_ZERO >(getArray<T>(in), w_wid)); break ;
75- case AF_PAD_SYM : return getHandle<T>(medfilt_1d <T, AF_PAD_SYM >(getArray<T>(in), w_wid)); break ;
76- default : return getHandle<T>(medfilt_1d <T, AF_PAD_ZERO >(getArray<T>(in), w_wid)); break ;
78+ case AF_PAD_ZERO : return getHandle<T>(medfilt2 <T, AF_PAD_ZERO >(getArray<T>(in), w_len , w_wid)); break ;
79+ case AF_PAD_SYM : return getHandle<T>(medfilt2 <T, AF_PAD_SYM >(getArray<T>(in), w_len , w_wid)); break ;
80+ default : return getHandle<T>(medfilt2 <T, AF_PAD_ZERO >(getArray<T>(in), w_len , w_wid)); break ;
7781 }
7882}
7983
80- af_err af_medfilt_1d (af_array *out, const af_array in, const dim_t wind_width, const af_border_type edge_pad)
84+ af_err af_medfilt2 (af_array *out, const af_array in, const dim_t wind_length , const dim_t wind_width, const af_border_type edge_pad)
8185{
8286 try {
83- ARG_ASSERT (2 , (wind_width>0 ));
87+ ARG_ASSERT (2 , (wind_length==wind_width));
88+ ARG_ASSERT (2 , (wind_length>0 ));
89+ ARG_ASSERT (3 , (wind_width>0 ));
8490 ARG_ASSERT (4 , (edge_pad>=AF_PAD_ZERO && edge_pad<=AF_PAD_SYM ));
8591
8692 ArrayInfo info = getInfo (in);
8793 af::dim4 dims = info.dims ();
8894
8995 dim_t input_ndims = dims.ndims ();
90- DIM_ASSERT (1 , (input_ndims >= 1 ));
96+ DIM_ASSERT (1 , (input_ndims >= 2 ));
9197
92- if (wind_width ==1 ) {
98+ if (wind_length ==1 ) {
9399 *out = retain (in);
94100 } else {
95101 af_array output;
96102 af_dtype type = info.getType ();
97103 switch (type) {
98- case f32 : output = medfilt_1d <float >(in, wind_width, edge_pad); break ;
99- case f64 : output = medfilt_1d <double >(in, wind_width, edge_pad); break ;
100- case b8 : output = medfilt_1d <char >(in, wind_width, edge_pad); break ;
101- case s32: output = medfilt_1d <int >(in, wind_width, edge_pad); break ;
102- case u32 : output = medfilt_1d <uint >(in, wind_width, edge_pad); break ;
103- case s16: output = medfilt_1d <short >(in, wind_width, edge_pad); break ;
104- case u16 : output = medfilt_1d <ushort>(in, wind_width, edge_pad); break ;
105- case u8 : output = medfilt_1d <uchar >(in, wind_width, edge_pad); break ;
104+ case f32 : output = medfilt2 <float >(in, wind_length , wind_width, edge_pad); break ;
105+ case f64 : output = medfilt2 <double >(in, wind_length , wind_width, edge_pad); break ;
106+ case b8 : output = medfilt2 <char >(in, wind_length , wind_width, edge_pad); break ;
107+ case s32: output = medfilt2 <int >(in, wind_length , wind_width, edge_pad); break ;
108+ case u32 : output = medfilt2 <uint >(in, wind_length , wind_width, edge_pad); break ;
109+ case s16: output = medfilt2 <short >(in, wind_length , wind_width, edge_pad); break ;
110+ case u16 : output = medfilt2 <ushort>(in, wind_length , wind_width, edge_pad); break ;
111+ case u8 : output = medfilt2 <uchar >(in, wind_length , wind_width, edge_pad); break ;
106112 default : TYPE_ERROR (1 , type);
107113 }
108114 std::swap (*out, output);
0 commit comments