Skip to content

Commit 43d061a

Browse files
committed
Fixed wrong shapes with BS=1
1 parent 596d11a commit 43d061a

1 file changed

Lines changed: 2 additions & 5 deletions

File tree

  • TensorFlow/Classification/ConvNets/model

TensorFlow/Classification/ConvNets/model/resnet.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,12 +187,9 @@ def __call__(self, features, labels, mode, params):
187187
reuse=False,
188188
use_final_conv=params['use_final_conv']
189189
)
190-
191-
if mode!=tf.estimator.ModeKeys.PREDICT:
192-
logits = tf.squeeze(logits)
193190

194-
if mode!=tf.estimator.ModeKeys.PREDICT:
195-
logits = tf.squeeze(logits)
191+
if params['use_final_conv']:
192+
logits = tf.squeeze(logits, axis=[-2, -1])
196193

197194
y_preds = tf.argmax(logits, axis=1, output_type=tf.int32)
198195

0 commit comments

Comments
 (0)