/******************************************************* * Copyright (c) 2014, ArrayFire * All rights reserved. * * This file is distributed under 3-clause BSD license. * The complete license agreement can be obtained at: * http://arrayfire.com/licenses/BSD-3-Clause ********************************************************/ #include #include #include #include #include #include #include #include #include #include #include using af::dim4; using namespace detail; template static inline af_array fftconvolve_fallback(const af_array signal, const af_array filter, bool expand) { const Array S = castArray(signal); const Array F = castArray(filter); const dim4 sdims = S.dims(); const dim4 fdims = F.dims(); dim4 odims(1, 1, 1, 1); dim4 psdims(1, 1, 1, 1); dim4 pfdims(1, 1, 1, 1); std::vector index(4); int count = 1; for (int i = 0; i < baseDim; i++) { dim_t tdim_i = sdims[i] + fdims[i] - 1; // Pad temporary buffers to power of 2 for performance odims[i] = nextpow2(tdim_i); psdims[i] = nextpow2(tdim_i); pfdims[i] = nextpow2(tdim_i); // The normalization factor count *= odims[i]; // Get the indexing params for output if (expand) { index[i].begin = 0; index[i].end = tdim_i - 1; } else { index[i].begin = fdims[i] / 2; index[i].end = index[i].begin + sdims[i] - 1; } index[i].step = 1; } for (int i = baseDim; i < 4; i++) { odims[i] = std::max(sdims[i], fdims[i]); psdims[i] = sdims[i]; pfdims[i] = fdims[i]; index[i] = af_span; } // fft(signal) Array T1 = fft(S, 1.0, baseDim, psdims.get()); // fft(filter) Array T2 = fft(F, 1.0, baseDim, pfdims.get()); // fft(signal) * fft(filter) T1 = arithOp(T1, T2, odims); // ifft(ffit(signal) * fft(filter)) T1 = fft(T1, 1.0 / (double)count, baseDim, odims.get()); // Index to proper offsets T1 = createSubArray(T1, index); if (getInfo(signal).isComplex() || getInfo(filter).isComplex()) { return getHandle(cast(T1)); } else { return getHandle(cast(real(T1))); } } template inline static af_array fftconvolve(const af_array &s, const af_array &f, const bool expand, AF_BATCH_KIND kind) { if (kind == AF_BATCH_DIFF) return fftconvolve_fallback(s, f, expand); else return getHandle(fftconvolve( getArray(s), castArray(f), expand, kind)); } template AF_BATCH_KIND identifyBatchKind(const dim4 &sDims, const dim4 &fDims) { dim_t sn = sDims.ndims(); dim_t fn = fDims.ndims(); if (sn == baseDim && fn == baseDim) return AF_BATCH_NONE; else if (sn == baseDim && (fn > baseDim && fn <= 4)) return AF_BATCH_RHS; else if ((sn > baseDim && sn <= 4) && fn == baseDim) return AF_BATCH_LHS; else if ((sn > baseDim && sn <= 4) && (fn > baseDim && fn <= 4)) { bool doesDimensionsMatch = true; bool isInterleaved = true; for (dim_t i = baseDim; i < 4; i++) { doesDimensionsMatch &= (sDims[i] == fDims[i]); isInterleaved &= (sDims[i] == 1 || fDims[i] == 1 || sDims[i] == fDims[i]); } if (doesDimensionsMatch) return AF_BATCH_SAME; return (isInterleaved ? AF_BATCH_DIFF : AF_BATCH_UNSUPPORTED); } else return AF_BATCH_UNSUPPORTED; } template af_err fft_convolve(af_array *out, const af_array signal, const af_array filter, const bool expand) { try { const ArrayInfo &sInfo = getInfo(signal); const ArrayInfo &fInfo = getInfo(filter); af_dtype stype = sInfo.getType(); dim4 sdims = sInfo.dims(); dim4 fdims = fInfo.dims(); AF_BATCH_KIND convBT = identifyBatchKind(sdims, fdims); ARG_ASSERT(1, (convBT != AF_BATCH_UNSUPPORTED)); af_array output; switch (stype) { case f64: output = fftconvolve( signal, filter, expand, convBT); break; case f32: output = fftconvolve( signal, filter, expand, convBT); break; case u32: output = fftconvolve( signal, filter, expand, convBT); break; case s32: output = fftconvolve( signal, filter, expand, convBT); break; case u64: output = fftconvolve( signal, filter, expand, convBT); break; case s64: output = fftconvolve( signal, filter, expand, convBT); break; case u16: output = fftconvolve( signal, filter, expand, convBT); break; case s16: output = fftconvolve( signal, filter, expand, convBT); break; case u8: output = fftconvolve( signal, filter, expand, convBT); break; case b8: output = fftconvolve( signal, filter, expand, convBT); break; case c32: output = fftconvolve_fallback( signal, filter, expand); break; case c64: output = fftconvolve_fallback( signal, filter, expand); break; default: TYPE_ERROR(1, stype); } std::swap(*out, output); } CATCHALL; return AF_SUCCESS; } af_err af_fft_convolve1(af_array *out, const af_array signal, const af_array filter, const af_conv_mode mode) { return fft_convolve<1>(out, signal, filter, mode == AF_CONV_EXPAND); } af_err af_fft_convolve2(af_array *out, const af_array signal, const af_array filter, const af_conv_mode mode) { if (getInfo(signal).dims().ndims() < 2 && getInfo(filter).dims().ndims() < 2) { return fft_convolve<1>(out, signal, filter, mode == AF_CONV_EXPAND); } else { return fft_convolve<2>(out, signal, filter, mode == AF_CONV_EXPAND); } } af_err af_fft_convolve3(af_array *out, const af_array signal, const af_array filter, const af_conv_mode mode) { if (getInfo(signal).dims().ndims() < 3 && getInfo(filter).dims().ndims() < 3) { return fft_convolve<2>(out, signal, filter, mode == AF_CONV_EXPAND); } else { return fft_convolve<3>(out, signal, filter, mode == AF_CONV_EXPAND); } }