@@ -14,9 +14,11 @@ You may obtain a copy of the License at
1414 limitations under the License.
1515******************************************************************************/
1616
17+ using System . Linq ;
1718using Tensorflow . Keras . ArgsDefinition ;
1819using Tensorflow . Keras . Engine ;
1920using Tensorflow . Keras . Utils ;
21+ using static Tensorflow . Binding ;
2022
2123namespace Tensorflow . Keras . Layers
2224{
@@ -36,27 +38,31 @@ public Pooling1D(Pooling1DArgs args)
3638
3739 protected override Tensors Call ( Tensors inputs , Tensor state = null , bool ? training = null )
3840 {
39- int [ ] pool_shape ;
40- int [ ] strides ;
41+ int pad_axis = args . DataFormat == "channels_first" ? 2 : 3 ;
42+ inputs = tf . expand_dims ( inputs , pad_axis ) ;
43+ int [ ] pool_shape = new int [ ] { args . PoolSize , 1 } ;
44+ int [ ] strides = new int [ ] { args . Strides , 1 } ;
45+ var ndim = inputs [ 0 ] . ndim ;
46+
4147 if ( args . DataFormat == "channels_last" )
4248 {
43- pool_shape = new int [ ] { 1 , args . PoolSize , 1 } ;
44- strides = new int [ ] { 1 , args . Strides , 1 } ;
49+ pool_shape = new int [ ] { 1 } . Concat ( pool_shape ) . Concat ( new int [ ] { 1 } ) . ToArray ( ) ;
50+ strides = new int [ ] { 1 } . Concat ( strides ) . Concat ( new int [ ] { 1 } ) . ToArray ( ) ;
4551 }
4652 else
4753 {
48- pool_shape = new int [ ] { 1 , 1 , args . PoolSize } ;
49- strides = new int [ ] { 1 , 1 , args . Strides } ;
54+ pool_shape = new int [ ] { 1 , 1 } . Concat ( pool_shape ) . ToArray ( ) ;
55+ strides = new int [ ] { 1 , 1 } . Concat ( strides ) . ToArray ( ) ;
5056 }
5157
5258 var outputs = args . PoolFunction . Apply (
5359 inputs ,
5460 ksize : pool_shape ,
5561 strides : strides ,
5662 padding : args . Padding . ToUpper ( ) ,
57- data_format : conv_utils . convert_data_format ( args . DataFormat , 3 ) ) ;
63+ data_format : conv_utils . convert_data_format ( args . DataFormat , ndim ) ) ;
5864
59- return outputs ;
65+ return tf . squeeze ( outputs , pad_axis ) ;
6066 }
6167 }
6268}
0 commit comments