@@ -49,72 +49,139 @@ class FFTGPUBase : public OpKernel {
4949 void Compute (OpKernelContext* ctx) override {
5050 const Tensor& in = ctx->input (0 );
5151 const TensorShape& shape = in.shape ();
52+ const int fft_rank = Rank ();
5253 OP_REQUIRES (
53- ctx, shape.dims () >= Rank () ,
54- errors::InvalidArgument (" Input must have rank of at least " , Rank () ,
54+ ctx, shape.dims () >= fft_rank ,
55+ errors::InvalidArgument (" Input must have rank of at least " , fft_rank ,
5556 " but got: " , shape.DebugString ()));
57+
5658 Tensor* out;
57- OP_REQUIRES_OK (ctx, ctx->allocate_output (0 , shape, &out));
59+ TensorShape output_shape = shape;
60+ uint64 fft_shape[3 ] = {0 , 0 , 0 };
61+
62+ // In R2C or C2R mode, we use a second input to specify the FFT length
63+ // instead of inferring it from the input shape.
64+ if (IsReal ()) {
65+ const Tensor& fft_length = ctx->input (1 );
66+ OP_REQUIRES (ctx,
67+ fft_length.shape ().dims () == 1 &&
68+ fft_length.shape ().dim_size (0 ) == fft_rank,
69+ errors::InvalidArgument (" fft_length must have shape [" ,
70+ fft_rank, " ]" ));
71+
72+ auto fft_length_as_vec = fft_length.vec <int32>();
73+ for (int i = 0 ; i < fft_rank; ++i) {
74+ fft_shape[i] = fft_length_as_vec (i);
75+ uint64 dim = IsForward () && i == fft_rank - 1 && fft_shape[i] != 0
76+ ? fft_shape[i] / 2 + 1
77+ : fft_shape[i];
78+ output_shape.set_dim (output_shape.dims () - fft_rank + i, dim);
79+ }
80+ } else {
81+ for (int i = 0 ; i < fft_rank; ++i) {
82+ fft_shape[i] =
83+ output_shape.dim_size (output_shape.dims () - fft_rank + i);
84+ }
85+ }
86+
87+ OP_REQUIRES_OK (ctx, ctx->allocate_output (0 , output_shape, &out));
5888 if (shape.num_elements () == 0 ) {
5989 return ;
6090 }
61- DoFFT (ctx, in, out);
91+
92+ DoFFT (ctx, in, fft_shape, out);
6293 }
6394
6495 protected:
6596 virtual int Rank () const = 0;
6697 virtual bool IsForward () const = 0;
98+ virtual bool IsReal () const = 0;
6799
68100 private:
69- void DoFFT (OpKernelContext* ctx, const Tensor& in, Tensor* out) {
101+ void DoFFT (OpKernelContext* ctx, const Tensor& in, uint64* fft_shape,
102+ Tensor* out) {
70103 auto * stream = ctx->op_device_context ()->stream ();
71104 OP_REQUIRES (ctx, stream, errors::Internal (" No GPU stream available." ));
72105
73- const TensorShape& shape = in.shape ();
74- auto src = AsDeviceMemory<complex64>(in.flat <complex64>().data ());
75- auto dst = AsDeviceMemory<complex64>(out->flat <complex64>().data ());
106+ const TensorShape& input_shape = in.shape ();
107+ const TensorShape& output_shape = out->shape ();
76108
77- const int rank = Rank ();
109+ const int fft_rank = Rank ();
78110 int batch_size = 1 ;
79- for (int i = 0 ; i < shape .dims () - rank ; ++i) {
80- batch_size *= shape .dim_size (i);
111+ for (int i = 0 ; i < input_shape .dims () - fft_rank ; ++i) {
112+ batch_size *= input_shape .dim_size (i);
81113 }
82- uint64 data_length = 1 ;
83- uint64 data_dims[3 ];
84- for (int i = 0 ; i < rank; ++i) {
85- auto dim = shape.dim_size (shape.dims () - rank + i);
86- data_length *= dim;
87- data_dims[i] = dim;
114+ uint64 input_embed[3 ];
115+ uint64 input_stride = 1 ;
116+ uint64 input_distance = 1 ;
117+ uint64 output_embed[3 ];
118+ uint64 output_stride = 1 ;
119+ uint64 output_distance = 1 ;
120+
121+ for (int i = 0 ; i < fft_rank; ++i) {
122+ auto dim_offset = input_shape.dims () - fft_rank + i;
123+ input_embed[i] = input_shape.dim_size (dim_offset);
124+ input_distance *= input_shape.dim_size (dim_offset);
125+ output_embed[i] = output_shape.dim_size (dim_offset);
126+ output_distance *= output_shape.dim_size (dim_offset);
88127 }
89128
90- constexpr uint64* kInputEmbed = nullptr ;
91- constexpr uint64 kInputStride = 1 ;
92- constexpr uint64 kInputDistance = 1 ;
93- constexpr uint64* kOutputEmbed = nullptr ;
94- constexpr uint64 kOutputStride = 1 ;
95- constexpr uint64 kOutputDistance = 1 ;
96129 constexpr bool kInPlaceFft = false ;
130+ const auto kFftType =
131+ IsReal () ? (IsForward () ? perftools::gputools::fft::Type::kR2C
132+ : perftools::gputools::fft::Type::kC2R )
133+ : (IsForward () ? perftools::gputools::fft::Type::kC2CForward
134+ : perftools::gputools::fft::Type::kC2CInverse );
97135
98136 auto plan = stream->parent ()->AsFft ()->CreateBatchedPlan (
99- stream, rank, data_dims, kInputEmbed , kInputStride , kInputDistance ,
100- kOutputEmbed , kOutputStride , kOutputDistance ,
101- IsForward () ? perftools::gputools::fft::Type::kC2CForward
102- : perftools::gputools::fft::Type::kC2CInverse ,
103- kInPlaceFft , batch_size);
104-
105- OP_REQUIRES (
106- ctx, stream->ThenFft (plan.get (), src, &dst).ok (),
107- errors::Internal (" c2c fft failed : in.shape=" , shape.DebugString ()));
108- if (!IsForward ()) {
109- auto alpha = complex64 (1 .f / data_length);
137+ stream, fft_rank, fft_shape, input_embed, input_stride, input_distance,
138+ output_embed, output_stride, output_distance, kFftType , kInPlaceFft ,
139+ batch_size);
140+
141+ if (IsReal ()) {
142+ if (IsForward ()) {
143+ auto src = AsDeviceMemory<float >(in.flat <float >().data ());
144+ auto dst = AsDeviceMemory<complex64>(out->flat <complex64>().data ());
145+ OP_REQUIRES (
146+ ctx, stream->ThenFft (plan.get (), src, &dst).ok (),
147+ errors::Internal (" fft failed : type=" , static_cast <int >(kFftType ),
148+ " in.shape=" , input_shape.DebugString ()));
149+ } else {
150+ auto src = AsDeviceMemory<complex64>(in.flat <complex64>().data ());
151+ auto dst = AsDeviceMemory<float >(out->flat <float >().data ());
152+ OP_REQUIRES (
153+ ctx, stream->ThenFft (plan.get (), src, &dst).ok (),
154+ errors::Internal (" fft failed : type=" , static_cast <int >(kFftType ),
155+ " in.shape=" , input_shape.DebugString ()));
156+ auto alpha = 1 .f / output_distance;
157+ OP_REQUIRES (
158+ ctx,
159+ stream->ThenBlasScal (output_shape.num_elements (), alpha, &dst, 1 )
160+ .ok (),
161+ errors::Internal (" BlasScal failed : in.shape=" ,
162+ input_shape.DebugString ()));
163+ }
164+ } else {
165+ auto src = AsDeviceMemory<complex64>(in.flat <complex64>().data ());
166+ auto dst = AsDeviceMemory<complex64>(out->flat <complex64>().data ());
110167 OP_REQUIRES (
111- ctx, stream->ThenBlasScal (shape.num_elements (), alpha, &dst, 1 ).ok (),
112- errors::Internal (" BlasScal failed : in.shape=" , shape.DebugString ()));
168+ ctx, stream->ThenFft (plan.get (), src, &dst).ok (),
169+ errors::Internal (" fft failed : type=" , static_cast <int >(kFftType ),
170+ " in.shape=" , input_shape.DebugString ()));
171+ if (!IsForward ()) {
172+ auto alpha = complex64 (1 .f / output_distance);
173+ OP_REQUIRES (
174+ ctx,
175+ stream->ThenBlasScal (output_shape.num_elements (), alpha, &dst, 1 )
176+ .ok (),
177+ errors::Internal (" BlasScal failed : in.shape=" ,
178+ input_shape.DebugString ()));
179+ }
113180 }
114181 }
115182};
116183
117- template <bool Forward, int FFTRank>
184+ template <bool Forward, bool _Real, int FFTRank>
118185class FFTGPU : public FFTGPUBase {
119186 public:
120187 static_assert (FFTRank >= 1 && FFTRank <= 3 ,
@@ -124,24 +191,53 @@ class FFTGPU : public FFTGPUBase {
124191 protected:
125192 int Rank () const override { return FFTRank; }
126193 bool IsForward () const override { return Forward; }
194+ bool IsReal () const override { return _Real; }
127195};
128196
129- REGISTER_KERNEL_BUILDER (Name(" FFT" ).Device(DEVICE_GPU), FFTGPU<true , 1 >);
130- REGISTER_KERNEL_BUILDER (Name(" IFFT" ).Device(DEVICE_GPU), FFTGPU<false , 1 >);
131- REGISTER_KERNEL_BUILDER (Name(" FFT2D" ).Device(DEVICE_GPU), FFTGPU<true , 2 >);
132- REGISTER_KERNEL_BUILDER (Name(" IFFT2D" ).Device(DEVICE_GPU), FFTGPU<false , 2 >);
133- REGISTER_KERNEL_BUILDER (Name(" FFT3D" ).Device(DEVICE_GPU), FFTGPU<true , 3 >);
134- REGISTER_KERNEL_BUILDER (Name(" IFFT3D" ).Device(DEVICE_GPU), FFTGPU<false , 3 >);
197+ REGISTER_KERNEL_BUILDER (Name(" FFT" ).Device(DEVICE_GPU), FFTGPU<true , false , 1 >);
198+ REGISTER_KERNEL_BUILDER (Name(" IFFT" ).Device(DEVICE_GPU),
199+ FFTGPU<false , false , 1 >);
200+ REGISTER_KERNEL_BUILDER (Name(" FFT2D" ).Device(DEVICE_GPU),
201+ FFTGPU<true , false , 2 >);
202+ REGISTER_KERNEL_BUILDER (Name(" IFFT2D" ).Device(DEVICE_GPU),
203+ FFTGPU<false , false , 2 >);
204+ REGISTER_KERNEL_BUILDER (Name(" FFT3D" ).Device(DEVICE_GPU),
205+ FFTGPU<true , false , 3 >);
206+ REGISTER_KERNEL_BUILDER (Name(" IFFT3D" ).Device(DEVICE_GPU),
207+ FFTGPU<false , false , 3 >);
208+
209+ REGISTER_KERNEL_BUILDER (
210+ Name (" RFFT" ).Device(DEVICE_GPU).HostMemory(" fft_length" ),
211+ FFTGPU<true, true, 1>);
212+ REGISTER_KERNEL_BUILDER (
213+ Name (" IRFFT" ).Device(DEVICE_GPU).HostMemory(" fft_length" ),
214+ FFTGPU<false, true, 1>);
215+ REGISTER_KERNEL_BUILDER (
216+ Name (" RFFT2D" ).Device(DEVICE_GPU).HostMemory(" fft_length" ),
217+ FFTGPU<true, true, 2>);
218+ REGISTER_KERNEL_BUILDER (
219+ Name (" IRFFT2D" ).Device(DEVICE_GPU).HostMemory(" fft_length" ),
220+ FFTGPU<false, true, 2>);
221+ REGISTER_KERNEL_BUILDER (
222+ Name (" RFFT3D" ).Device(DEVICE_GPU).HostMemory(" fft_length" ),
223+ FFTGPU<true, true, 3>);
224+ REGISTER_KERNEL_BUILDER (
225+ Name (" IRFFT3D" ).Device(DEVICE_GPU).HostMemory(" fft_length" ),
226+ FFTGPU<false, true, 3>);
135227
136228// Deprecated kernels.
137- REGISTER_KERNEL_BUILDER (Name(" BatchFFT" ).Device(DEVICE_GPU), FFTGPU<true , 1 >);
138- REGISTER_KERNEL_BUILDER (Name(" BatchIFFT" ).Device(DEVICE_GPU), FFTGPU<false , 1 >);
139- REGISTER_KERNEL_BUILDER (Name(" BatchFFT2D" ).Device(DEVICE_GPU), FFTGPU<true , 2 >);
229+ REGISTER_KERNEL_BUILDER (Name(" BatchFFT" ).Device(DEVICE_GPU),
230+ FFTGPU<true, false, 1>);
231+ REGISTER_KERNEL_BUILDER (Name(" BatchIFFT" ).Device(DEVICE_GPU),
232+ FFTGPU<false, false, 1>);
233+ REGISTER_KERNEL_BUILDER (Name(" BatchFFT2D" ).Device(DEVICE_GPU),
234+ FFTGPU<true, false, 2>);
140235REGISTER_KERNEL_BUILDER (Name(" BatchIFFT2D" ).Device(DEVICE_GPU),
141- FFTGPU<false , 2 >);
142- REGISTER_KERNEL_BUILDER (Name(" BatchFFT3D" ).Device(DEVICE_GPU), FFTGPU<true , 3 >);
236+ FFTGPU<false, false, 2>);
237+ REGISTER_KERNEL_BUILDER (Name(" BatchFFT3D" ).Device(DEVICE_GPU),
238+ FFTGPU<true, false, 3>);
143239REGISTER_KERNEL_BUILDER (Name(" BatchIFFT3D" ).Device(DEVICE_GPU),
144- FFTGPU<false , 3 >);
240+ FFTGPU<false, false, 3>);
145241
146242} // end namespace tensorflow
147243
0 commit comments