forked from tensorflow/tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathnn_ops.py
More file actions
1345 lines (1099 loc) · 55.8 KB
/
nn_ops.py
File metadata and controls
1345 lines (1099 loc) · 55.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Wrappers for primitive Neural Net (NN) Operations."""
# pylint: disable=invalid-name
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_nn_ops import *
# pylint: enable=wildcard-import
# Aliases for some automatically-generated names.
local_response_normalization = gen_nn_ops.lrn
def atrous_conv2d(value, filters, rate, padding, name=None):
"""Atrous convolution (a.k.a. convolution with holes or dilated convolution).
Computes a 2-D atrous convolution, also known as convolution with holes or
dilated convolution, given 4-D `value` and `filters` tensors. If the `rate`
parameter is equal to one, it performs regular 2-D convolution. If the `rate`
parameter is greater than one, it performs convolution with holes, sampling
the input values every `rate` pixels in the `height` and `width` dimensions.
This is equivalent to convolving the input with a set of upsampled filters,
produced by inserting `rate - 1` zeros between two consecutive values of the
filters along the `height` and `width` dimensions, hence the name atrous
convolution or convolution with holes (the French word trous means holes in
English).
More specifically:
output[b, i, j, k] = sum_{di, dj, q} filters[di, dj, q, k] *
value[b, i + rate * di, j + rate * dj, q]
Atrous convolution allows us to explicitly control how densely to compute
feature responses in fully convolutional networks. Used in conjunction with
bilinear interpolation, it offers an alternative to `conv2d_transpose` in
dense prediction tasks such as semantic image segmentation, optical flow
computation, or depth estimation. It also allows us to effectively enlarge
the field of view of filters without increasing the number of parameters or
the amount of computation.
For a description of atrous convolution and how it can be used for dense
feature extraction, please see: [Semantic Image Segmentation with Deep
Convolutional Nets and Fully Connected CRFs](http://arxiv.org/abs/1412.7062).
The same operation is investigated further in [Multi-Scale Context Aggregation
by Dilated Convolutions](http://arxiv.org/abs/1511.07122). Previous works
that effectively use atrous convolution in different ways are, among others,
[OverFeat: Integrated Recognition, Localization and Detection using
Convolutional Networks](http://arxiv.org/abs/1312.6229) and [Fast Image
Scanning with Deep Max-Pooling Convolutional Neural Networks]
(http://arxiv.org/abs/1302.1700). Atrous convolution is also closely related
to the so-called noble identities in multi-rate signal processing.
There are many different ways to implement atrous convolution (see the refs
above). The implementation here reduces
atrous_conv2d(value, filters, rate, padding=padding)
to the following three operations:
paddings = ...
net = space_to_batch(value, paddings, block_size=rate)
net = conv2d(net, filters, strides=[1, 1, 1, 1], padding="VALID")
crops = ...
net = batch_to_space(net, crops, block_size=rate)
Advanced usage. Note the following optimization: A sequence of `atrous_conv2d`
operations with identical `rate` parameters, 'SAME' `padding`, and filters
with odd heights/ widths:
net = atrous_conv2d(net, filters1, rate, padding="SAME")
net = atrous_conv2d(net, filters2, rate, padding="SAME")
...
net = atrous_conv2d(net, filtersK, rate, padding="SAME")
can be equivalently performed cheaper in terms of computation and memory as:
pad = ... # padding so that the input dims are multiples of rate
net = space_to_batch(net, paddings=pad, block_size=rate)
net = conv2d(net, filters1, strides=[1, 1, 1, 1], padding="SAME")
net = conv2d(net, filters2, strides=[1, 1, 1, 1], padding="SAME")
...
net = conv2d(net, filtersK, strides=[1, 1, 1, 1], padding="SAME")
net = batch_to_space(net, crops=pad, block_size=rate)
because a pair of consecutive `space_to_batch` and `batch_to_space` ops with
the same `block_size` cancel out when their respective `paddings` and `crops`
inputs are identical.
Args:
value: A 4-D `Tensor` of type `float`. It needs to be in the default "NHWC"
format. Its shape is `[batch, in_height, in_width, in_channels]`.
filters: A 4-D `Tensor` with the same type as `value` and shape
`[filter_height, filter_width, in_channels, out_channels]`. `filters`'
`in_channels` dimension must match that of `value`. Atrous convolution is
equivalent to standard convolution with upsampled filters with effective
height `filter_height + (filter_height - 1) * (rate - 1)` and effective
width `filter_width + (filter_width - 1) * (rate - 1)`, produced by
inserting `rate - 1` zeros along consecutive elements across the
`filters`' spatial dimensions.
rate: A positive int32. The stride with which we sample input values across
the `height` and `width` dimensions. Equivalently, the rate by which we
upsample the filter values by inserting zeros across the `height` and
`width` dimensions. In the literature, the same parameter is sometimes
called `input stride` or `dilation`.
padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
name: Optional name for the returned tensor.
Returns:
A `Tensor` with the same type as `value`.
Raises:
ValueError: If input/output depth does not match `filters`' shape, or if
padding is other than `'VALID'` or `'SAME'`.
"""
with ops.op_scope([value, filters], name, "atrous_conv2d") as name:
value = ops.convert_to_tensor(value, name="value")
filters = ops.convert_to_tensor(filters, name="filters")
if not value.get_shape()[3].is_compatible_with(filters.get_shape()[2]):
raise ValueError(
"value's input channels does not match filters' input channels, "
"{} != {}".format(value.get_shape()[3], filters.get_shape()[2]))
if rate < 1:
raise ValueError("rate {} cannot be less than one".format(rate))
if rate == 1:
value = gen_nn_ops.conv2d(input=value,
filter=filters,
strides=[1, 1, 1, 1],
padding=padding)
return value
# We have two padding contributions. The first is used for converting "SAME"
# to "VALID". The second is required so that the height and width of the
# zero-padded value tensor are multiples of rate.
# Padding required to reduce to "VALID" convolution
if padding == "SAME":
# Handle filters whose shape is unknown during graph creation.
if filters.get_shape().is_fully_defined():
filter_shape = filters.get_shape().as_list()
else:
filter_shape = array_ops.shape(filters)
filter_height, filter_width = filter_shape[0], filter_shape[1]
# Spatial dimensions of the filters and the upsampled filters in which we
# introduce (rate - 1) zeros between consecutive filter values.
filter_height_up = filter_height + (filter_height - 1) * (rate - 1)
filter_width_up = filter_width + (filter_width - 1) * (rate - 1)
pad_height = filter_height_up - 1
pad_width = filter_width_up - 1
# When pad_height (pad_width) is odd, we pad more to bottom (right),
# following the same convention as conv2d().
pad_top = pad_height // 2
pad_bottom = pad_height - pad_top
pad_left = pad_width // 2
pad_right = pad_width - pad_left
elif padding == "VALID":
pad_top = 0
pad_bottom = 0
pad_left = 0
pad_right = 0
else:
raise ValueError("Invalid padding")
# Handle input whose shape is unknown during graph creation.
if value.get_shape().is_fully_defined():
value_shape = value.get_shape().as_list()
else:
value_shape = array_ops.shape(value)
in_height = value_shape[1] + pad_top + pad_bottom
in_width = value_shape[2] + pad_left + pad_right
# More padding so that rate divides the height and width of the input.
pad_bottom_extra = (rate - in_height % rate) % rate
pad_right_extra = (rate - in_width % rate) % rate
# The paddings argument to space_to_batch includes both padding components.
space_to_batch_pad = [[pad_top, pad_bottom + pad_bottom_extra],
[pad_left, pad_right + pad_right_extra]]
value = array_ops.space_to_batch(input=value,
paddings=space_to_batch_pad,
block_size=rate)
value = gen_nn_ops.conv2d(input=value,
filter=filters,
strides=[1, 1, 1, 1],
padding="VALID",
name=name)
# The crops argument to batch_to_space is just the extra padding component.
batch_to_space_crop = [[0, pad_bottom_extra], [0, pad_right_extra]]
value = array_ops.batch_to_space(input=value,
crops=batch_to_space_crop,
block_size=rate)
return value
def conv2d_transpose(value,
filter,
output_shape,
strides,
padding="SAME",
name=None):
"""The transpose of `conv2d`.
This operation is sometimes called "deconvolution" after [Deconvolutional
Networks](http://www.matthewzeiler.com/pubs/cvpr2010/cvpr2010.pdf), but is
actually the transpose (gradient) of `conv2d` rather than an actual
deconvolution.
Args:
value: A 4-D `Tensor` of type `float` and shape
`[batch, height, width, in_channels]`.
filter: A 4-D `Tensor` with the same type as `value` and shape
`[height, width, output_channels, in_channels]`. `filter`'s
`in_channels` dimension must match that of `value`.
output_shape: A 1-D `Tensor` representing the output shape of the
deconvolution op.
strides: A list of ints. The stride of the sliding window for each
dimension of the input tensor.
padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
See the [comment here](https://www.tensorflow.org/api_docs/python/nn.html#convolution)
name: Optional name for the returned tensor.
Returns:
A `Tensor` with the same type as `value`.
Raises:
ValueError: If input/output depth does not match `filter`'s shape, or if
padding is other than `'VALID'` or `'SAME'`.
"""
with ops.op_scope([value, filter, output_shape], name,
"conv2d_transpose") as name:
value = ops.convert_to_tensor(value, name="value")
filter = ops.convert_to_tensor(filter, name="filter")
if not value.get_shape()[3].is_compatible_with(filter.get_shape()[3]):
raise ValueError("input channels does not match filter's input channels, "
"{} != {}".format(value.get_shape()[3], filter.get_shape(
)[3]))
output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape")
if not output_shape_.get_shape().is_compatible_with(tensor_shape.vector(4)):
raise ValueError("output_shape must have shape (4,), got {}"
.format(output_shape_.get_shape()))
if isinstance(output_shape, (list, np.ndarray)):
# output_shape's shape should be == [4] if reached this point.
if not filter.get_shape()[2].is_compatible_with(output_shape[3]):
raise ValueError(
"output_shape does not match filter's output channels, "
"{} != {}".format(output_shape[3], filter.get_shape()[2]))
if padding != "VALID" and padding != "SAME":
raise ValueError("padding must be either VALID or SAME:"
" {}".format(padding))
return gen_nn_ops.conv2d_backprop_input(input_sizes=output_shape_,
filter=filter,
out_backprop=value,
strides=strides,
padding=padding,
name=name)
def conv3d_transpose(value,
filter,
output_shape,
strides,
padding="SAME",
name=None):
"""The transpose of `conv3d`.
This operation is sometimes called "deconvolution" after [Deconvolutional
Networks](http://www.matthewzeiler.com/pubs/cvpr2010/cvpr2010.pdf), but is
actually the transpose (gradient) of `conv3d` rather than an actual
deconvolution.
Args:
value: A 5-D `Tensor` of type `float` and shape
`[batch, depth, height, width, in_channels]`.
filter: A 5-D `Tensor` with the same type as `value` and shape
`[depth, height, width, output_channels, in_channels]`. `filter`'s
`in_channels` dimension must match that of `value`.
output_shape: A 1-D `Tensor` representing the output shape of the
deconvolution op.
strides: A list of ints. The stride of the sliding window for each
dimension of the input tensor.
padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
See the [comment here](https://www.tensorflow.org/api_docs/python/nn.html#convolution)
name: Optional name for the returned tensor.
Returns:
A `Tensor` with the same type as `value`.
Raises:
ValueError: If input/output depth does not match `filter`'s shape, or if
padding is other than `'VALID'` or `'SAME'`.
"""
with ops.op_scope([value, filter, output_shape], name,
"conv3d_transpose") as name:
value = ops.convert_to_tensor(value, name="value")
filter = ops.convert_to_tensor(filter, name="filter")
if not value.get_shape()[4].is_compatible_with(filter.get_shape()[4]):
raise ValueError("input channels does not match filter's input channels, "
"{} != {}".format(value.get_shape()[4], filter.get_shape(
)[4]))
output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape")
if not output_shape_.get_shape().is_compatible_with(tensor_shape.vector(5)):
raise ValueError("output_shape must have shape (5,), got {}"
.format(output_shape_.get_shape()))
if isinstance(output_shape, (list, np.ndarray)):
# output_shape's shape should be == [5] if reached this point.
if not filter.get_shape()[3].is_compatible_with(output_shape[4]):
raise ValueError(
"output_shape does not match filter's output channels, "
"{} != {}".format(output_shape[4], filter.get_shape()[3]))
if padding != "VALID" and padding != "SAME":
raise ValueError("padding must be either VALID or SAME:"
" {}".format(padding))
return gen_nn_ops.conv3d_backprop_input_v2(input_sizes=output_shape_,
filter=filter,
out_backprop=value,
strides=strides,
padding=padding,
name=name)
# pylint: disable=protected-access
def bias_add(value, bias, data_format=None, name=None):
"""Adds `bias` to `value`.
This is (mostly) a special case of `tf.add` where `bias` is restricted to 1-D.
Broadcasting is supported, so `value` may have any number of dimensions.
Unlike `tf.add`, the type of `bias` is allowed to differ from `value` in the
case where both types are quantized.
Args:
value: A `Tensor` with type `float`, `double`, `int64`, `int32`, `uint8`,
`int16`, `int8`, `complex64`, or `complex128`.
bias: A 1-D `Tensor` with size matching the last dimension of `value`.
Must be the same type as `value` unless `value` is a quantized type,
in which case a different quantized type may be used.
data_format: A string. 'NHWC' and 'NCHW' are supported.
name: A name for the operation (optional).
Returns:
A `Tensor` with the same type as `value`.
"""
with ops.op_scope([value, bias], name, "BiasAdd") as name:
value = ops.convert_to_tensor(value, name="input")
bias = ops.convert_to_tensor(bias, dtype=value.dtype, name="bias")
return gen_nn_ops._bias_add(value, bias, data_format=data_format, name=name)
ops.RegisterShape("BiasAdd")(common_shapes.bias_add_shape)
ops.RegisterShape("BiasAddGrad")(common_shapes.bias_add_grad_shape)
# pylint: disable=protected-access
def bias_add_v1(value, bias, name=None):
"""Adds `bias` to `value`.
This is a deprecated version of bias_add and will soon to be removed.
This is (mostly) a special case of `tf.add` where `bias` is restricted to 1-D.
Broadcasting is supported, so `value` may have any number of dimensions.
Unlike `tf.add`, the type of `bias` is allowed to differ from `value` in the
case where both types are quantized.
Args:
value: A `Tensor` with type `float`, `double`, `int64`, `int32`, `uint8`,
`int16`, `int8`, `complex64`, or `complex128`.
bias: A 1-D `Tensor` with size matching the last dimension of `value`.
Must be the same type as `value` unless `value` is a quantized type,
in which case a different quantized type may be used.
name: A name for the operation (optional).
Returns:
A `Tensor` with the same type as `value`.
"""
with ops.op_scope([value, bias], name, "BiasAddV1") as name:
value = ops.convert_to_tensor(value, name="input")
bias = ops.convert_to_tensor(bias, dtype=value.dtype, name="bias")
return gen_nn_ops._bias_add_v1(value, bias, name=name)
ops.RegisterShape("BiasAddV1")(common_shapes.bias_add_shape)
ops.RegisterShape("BiasAddGradV1")(common_shapes.bias_add_grad_shape)
def relu6(features, name=None):
"""Computes Rectified Linear 6: `min(max(features, 0), 6)`.
Args:
features: A `Tensor` with type `float`, `double`, `int32`, `int64`, `uint8`,
`int16`, or `int8`.
name: A name for the operation (optional).
Returns:
A `Tensor` with the same type as `features`.
"""
with ops.op_scope([features], name, "Relu6") as name:
features = ops.convert_to_tensor(features, name="features")
return gen_nn_ops._relu6(features, name=name)
def softmax_cross_entropy_with_logits(logits, labels, name=None):
"""Computes softmax cross entropy between `logits` and `labels`.
Measures the probability error in discrete classification tasks in which the
classes are mutually exclusive (each entry is in exactly one class). For
example, each CIFAR-10 image is labeled with one and only one label: an image
can be a dog or a truck, but not both.
**NOTE:** While the classes are mutually exclusive, their probabilities
need not be. All that is required is that each row of `labels` is
a valid probability distribution. If they are not, the computation of the
gradient will be incorrect.
If using exclusive `labels` (wherein one and only
one class is true at a time), see `sparse_softmax_cross_entropy_with_logits`.
**WARNING:** This op expects unscaled logits, since it performs a `softmax`
on `logits` internally for efficiency. Do not call this op with the
output of `softmax`, as it will produce incorrect results.
`logits` and `labels` must have the same shape `[batch_size, num_classes]`
and the same dtype (either `float16`, `float32`, or `float64`).
Args:
logits: Unscaled log probabilities.
labels: Each row `labels[i]` must be a valid probability distribution.
name: A name for the operation (optional).
Returns:
A 1-D `Tensor` of length `batch_size` of the same type as `logits` with the
softmax cross entropy loss.
"""
# TODO(pcmurray) Raise an error when the labels do not sum to 1. Note: This
# could break users who call this with bad labels, but disregard the bad
# results.
logits = ops.convert_to_tensor(logits)
precise_logits = math_ops.cast(logits, dtypes.float32) if (
logits.dtype == dtypes.float16) else logits
# The second output tensor contains the gradients. We use it in
# _CrossEntropyGrad() in nn_grad but not here.
cost, unused_backprop = gen_nn_ops._softmax_cross_entropy_with_logits(
precise_logits, labels, name=name)
if logits.dtype == dtypes.float16:
return math_ops.cast(cost, dtypes.float16)
else:
return cost
def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None):
"""Computes sparse softmax cross entropy between `logits` and `labels`.
Measures the probability error in discrete classification tasks in which the
classes are mutually exclusive (each entry is in exactly one class). For
example, each CIFAR-10 image is labeled with one and only one label: an image
can be a dog or a truck, but not both.
**NOTE:** For this operation, the probability of a given label is considered
exclusive. That is, soft classes are not allowed, and the `labels` vector
must provide a single specific index for the true class for each row of
`logits` (each minibatch entry). For soft softmax classification with
a probability distribution for each entry, see
`softmax_cross_entropy_with_logits`.
**WARNING:** This op expects unscaled logits, since it performs a softmax
on `logits` internally for efficiency. Do not call this op with the
output of `softmax`, as it will produce incorrect results.
A common use case is to have logits of shape `[batch_size, num_classes]` and
labels of shape `[batch_size]`. But higher dimensions are supported.
Args:
logits: Unscaled log probabilities of rank `r` and shape
`[d_0, d_1, ..., d_{r-2}, num_classes]` and dtype `float32` or `float64`.
labels: `Tensor` of shape `[d_0, d_1, ..., d_{r-2}]` and dtype `int32` or
`int64`. Each entry in `labels` must be an index in `[0, num_classes)`.
Other values will result in a loss of 0, but incorrect gradient
computations.
name: A name for the operation (optional).
Returns:
A `Tensor` of the same shape as `labels` and of the same type as `logits`
with the softmax cross entropy loss.
Raises:
ValueError: If logits are scalars (need to have rank >= 1) or if the rank
of the labels is not equal to the rank of the labels minus one.
"""
# TODO(pcmurray) Raise an error when the label is not an index in
# [0, num_classes). Note: This could break users who call this with bad
# labels, but disregard the bad results.
# Reshape logits and labels to rank 2.
with ops.op_scope([labels, logits], name,
"SparseSoftmaxCrossEntropyWithLogits"):
labels = ops.convert_to_tensor(labels)
logits = ops.convert_to_tensor(logits)
precise_logits = math_ops.cast(logits, dtypes.float32) if (
dtypes.as_dtype(logits.dtype) == dtypes.float16) else logits
# Store label shape for result later.
labels_static_shape = labels.get_shape()
labels_shape = array_ops.shape(labels)
if logits.get_shape().ndims is not None and logits.get_shape().ndims == 0:
raise ValueError("Logits cannot be scalars - received shape %s.",
logits.get_shape())
if logits.get_shape().ndims is not None and (
labels_static_shape.ndims is not None and
labels_static_shape.ndims != logits.get_shape().ndims - 1):
raise ValueError("Rank mismatch: Labels rank (received %s) should equal "
"logits rank (received %s) - 1.",
labels_static_shape.ndims, logits.get_shape().ndims)
# Check if no reshapes are required.
if logits.get_shape().ndims == 2:
cost, _ = gen_nn_ops._sparse_softmax_cross_entropy_with_logits(
precise_logits, labels, name=name)
if logits.dtype == dtypes.float16:
return math_ops.cast(cost, dtypes.float16)
else:
return cost
# Reshape logits to 2 dim, labels to 1 dim.
num_classes = array_ops.gather(array_ops.shape(logits),
array_ops.rank(logits) - 1)
precise_logits = array_ops.reshape(precise_logits, [-1, num_classes])
labels = array_ops.reshape(labels, [-1])
# The second output tensor contains the gradients. We use it in
# _CrossEntropyGrad() in nn_grad but not here.
cost, _ = gen_nn_ops._sparse_softmax_cross_entropy_with_logits(
precise_logits, labels, name=name)
cost = array_ops.reshape(cost, labels_shape)
cost.set_shape(labels_static_shape)
if logits.dtype == dtypes.float16:
return math_ops.cast(cost, dtypes.float16)
else:
return cost
@ops.RegisterShape("SparseSoftmaxCrossEntropyWithLogits")
def _SparseSoftmaxCrossEntropyWithLogitsShape(op):
"""Shape function for SparseSoftmaxCrossEntropyWithLogits op."""
logits_shape = op.inputs[0].get_shape()
input_shape = logits_shape.with_rank(2)
batch_size = input_shape[0]
# labels_shape
op.inputs[1].get_shape().merge_with(tensor_shape.vector(batch_size))
return [tensor_shape.vector(batch_size.value), input_shape]
@ops.RegisterShape("SoftmaxCrossEntropyWithLogits")
def _SoftmaxCrossEntropyWithLogitsShape(op):
"""Shape function for SoftmaxCrossEntropyWithLogits op."""
logits_shape = op.inputs[0].get_shape()
labels_shape = op.inputs[1].get_shape()
input_shape = logits_shape.merge_with(labels_shape).with_rank(2)
batch_size = input_shape[0]
return [tensor_shape.vector(batch_size.value), input_shape]
def avg_pool(value, ksize, strides, padding, data_format="NHWC", name=None):
"""Performs the average pooling on the input.
Each entry in `output` is the mean of the corresponding size `ksize`
window in `value`.
Args:
value: A 4-D `Tensor` of shape `[batch, height, width, channels]` and type
`float32`, `float64`, `qint8`, `quint8`, or `qint32`.
ksize: A list of ints that has length >= 4.
The size of the window for each dimension of the input tensor.
strides: A list of ints that has length >= 4.
The stride of the sliding window for each dimension of the
input tensor.
padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
See the [comment here](https://www.tensorflow.org/api_docs/python/nn.html#convolution)
data_format: A string. 'NHWC' and 'NCHW' are supported.
name: Optional name for the operation.
Returns:
A `Tensor` with the same type as `value`. The average pooled output tensor.
"""
with ops.op_scope([value], name, "AvgPool") as name:
value = ops.convert_to_tensor(value, name="input")
return gen_nn_ops._avg_pool(value,
ksize=ksize,
strides=strides,
padding=padding,
data_format=data_format,
name=name)
def max_pool(value, ksize, strides, padding, data_format="NHWC", name=None):
"""Performs the max pooling on the input.
Args:
value: A 4-D `Tensor` with shape `[batch, height, width, channels]` and
type `tf.float32`.
ksize: A list of ints that has length >= 4. The size of the window for
each dimension of the input tensor.
strides: A list of ints that has length >= 4. The stride of the sliding
window for each dimension of the input tensor.
padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
See the [comment here](https://www.tensorflow.org/api_docs/python/nn.html#convolution)
data_format: A string. 'NHWC' and 'NCHW' are supported.
name: Optional name for the operation.
Returns:
A `Tensor` with type `tf.float32`. The max pooled output tensor.
"""
with ops.op_scope([value], name, "MaxPool") as name:
value = ops.convert_to_tensor(value, name="input")
return gen_nn_ops._max_pool(value,
ksize=ksize,
strides=strides,
padding=padding,
data_format=data_format,
name=name)
ops.RegisterShape("Relu")(common_shapes.unchanged_shape)
ops.RegisterShape("Relu6")(common_shapes.unchanged_shape)
ops.RegisterShape("Elu")(common_shapes.unchanged_shape)
ops.RegisterShape("Softplus")(common_shapes.unchanged_shape)
ops.RegisterShape("Softsign")(common_shapes.unchanged_shape)
@ops.RegisterShape("ReluGrad")
@ops.RegisterShape("Relu6Grad")
@ops.RegisterShape("EluGrad")
@ops.RegisterShape("SoftplusGrad")
@ops.RegisterShape("SoftsignGrad")
def _BinaryElementwiseShape(op):
"""Returns same shape as both inputs to op.
Args:
op: Input operation.
Returns:
Shape of both inputs to `op`.
"""
return [op.inputs[0].get_shape().merge_with(op.inputs[1].get_shape())]
ops.RegisterShape("L2Loss")(common_shapes.scalar_shape)
ops.RegisterShape("LRN")(common_shapes.unchanged_shape_with_rank(4))
@ops.RegisterShape("LRNGrad")
def _LRNGradShape(op):
"""Shape function for LRNGrad op."""
in_grads_shape = op.inputs[0].get_shape().with_rank(4)
in_image_shape = op.inputs[1].get_shape().with_rank(4)
out_image_shape = op.inputs[2].get_shape().with_rank(4)
return [in_grads_shape.merge_with(in_image_shape).merge_with(out_image_shape)]
ops.RegisterShape("Softmax")(common_shapes.unchanged_shape_with_rank(2))
ops.RegisterShape("LogSoftmax")(common_shapes.unchanged_shape_with_rank(2))
@ops.RegisterShape("InTopK")
def _InTopKShape(op):
"""Shape function for InTopK op."""
predictions_shape = op.inputs[0].get_shape().with_rank(2)
targets_shape = op.inputs[1].get_shape().with_rank(1)
batch_size = predictions_shape[0].merge_with(targets_shape[0])
return [tensor_shape.vector(batch_size.value)]
@ops.RegisterShape("TopK")
@ops.RegisterShape("TopKV2")
def _TopKShape(op):
"""Shape function for TopK and TopKV2 ops."""
input_shape = op.inputs[0].get_shape().with_rank_at_least(1)
if len(op.inputs) >= 2:
k = tensor_util.constant_value(op.inputs[1])
else:
k = op.get_attr("k")
last = input_shape[-1].value
if last is not None and k is not None and last < k:
raise ValueError("input.shape %s must have last dimension >= k = %d" %
(input_shape, k))
output_shape = input_shape[:-1].concatenate([k])
return [output_shape, output_shape]
@ops.RegisterShape("BatchNormWithGlobalNormalization")
def _BatchNormShape(op):
"""Shape function for BatchNormWithGlobalNormalization op."""
input_shape = op.inputs[0].get_shape().with_rank(4)
mean_shape = op.inputs[1].get_shape().with_rank(1)
var_shape = op.inputs[2].get_shape().with_rank(1)
beta_shape = op.inputs[3].get_shape().with_rank(1)
gamma_shape = op.inputs[4].get_shape().with_rank(1)
mean_shape[0].merge_with(input_shape[3])
var_shape[0].merge_with(input_shape[3])
beta_shape[0].merge_with(input_shape[3])
gamma_shape[0].merge_with(input_shape[3])
return [input_shape]
@ops.RegisterShape("BatchNormWithGlobalNormalizationGrad")
def _BatchNormGradShape(op):
"""Shape function for BatchNormWithGlobalNormalizationGrad op."""
input_shape = op.inputs[0].get_shape().with_rank(4)
mean_shape = op.inputs[1].get_shape().with_rank(1)
var_shape = op.inputs[2].get_shape().with_rank(1)
beta_shape = op.inputs[3].get_shape().with_rank(1)
out_backprop_shape = op.inputs[4].get_shape().with_rank(4)
input_shape = input_shape.merge_with(out_backprop_shape)
vector_dim = input_shape[3]
vector_dim = vector_dim.merge_with(mean_shape[0])
vector_dim = vector_dim.merge_with(var_shape[0])
vector_dim = vector_dim.merge_with(beta_shape[0])
return [input_shape] + ([tensor_shape.vector(vector_dim)] * 4)
ops.RegisterShape("Conv2D")(common_shapes.conv2d_shape)
ops.RegisterShape("DepthwiseConv2dNative")(
common_shapes.depthwise_conv2d_native_shape)
ops.RegisterShape("AvgPool")(common_shapes.avg_pool_shape)
ops.RegisterShape("MaxPool")(common_shapes.max_pool_shape)
@ops.RegisterShape("MaxPoolWithArgmax")
def _MaxPoolWithArgMaxShape(op):
"""Shape function for MaxPoolWithArgmax op."""
return common_shapes.max_pool_shape(op) * 2
@ops.RegisterShape("AvgPoolGrad")
def _AvgPoolGradShape(op):
"""Shape function for the AvgPoolGrad op."""
orig_input_shape = tensor_util.constant_value(op.inputs[0])
if orig_input_shape is not None:
return [tensor_shape.TensorShape(orig_input_shape.tolist())]
else:
# NOTE(mrry): We could in principle work out the shape from the
# gradients and the attrs, but if we do not know orig_input_shape
# statically, then we are unlikely to know the shape of the
# gradients either.
return [tensor_shape.unknown_shape(ndims=4)]
@ops.RegisterShape("Conv2DBackpropFilter")
def _Conv2DBackpropFilterShape(op):
"""Shape function for the Conv2DBackpropFilter op."""
filter_shape = tensor_util.constant_value(op.inputs[1])
if filter_shape is not None:
return [tensor_shape.TensorShape(filter_shape.tolist())]
else:
# NOTE(mrry): We could in principle work out the shape from the
# gradients and the attrs, but if we do not know filter_shape
# statically, then we are unlikely to know the shape of the
# gradients either.
return [tensor_shape.unknown_shape(ndims=4)]
@ops.RegisterShape("Conv2DBackpropInput")
def _Conv2DBackpropInputShape(op):
"""Shape function for the Conv2DBackpropInput op."""
input_shape = tensor_util.constant_value(op.inputs[0])
if input_shape is not None:
return [tensor_shape.TensorShape(input_shape.tolist())]
else:
# NOTE(mrry): We could in principle work out the shape from the
# gradients and the attrs, but if we do not know input_shape
# statically, then we are unlikely to know the shape of the
# gradients either.
return [tensor_shape.unknown_shape(ndims=4)]
@ops.RegisterShape("DepthwiseConv2dNativeBackpropFilter")
def _DepthwiseConv2dNativeBackpropFilterShape(op):
"""Shape function for the DepthwiseConv2dNativeBackpropFilter op."""
filter_shape = tensor_util.constant_value(op.inputs[1])
if filter_shape is not None:
return [tensor_shape.TensorShape(filter_shape.tolist())]
else:
return [tensor_shape.unknown_shape(ndims=4)]
@ops.RegisterShape("DepthwiseConv2dNativeBackpropInput")
def _DepthwiseConv2dNativeBackpropInputShape(op):
"""Shape function for the DepthwiseConv2dNativeBackpropInput op."""
input_shape = tensor_util.constant_value(op.inputs[0])
if input_shape is not None:
return [tensor_shape.TensorShape(input_shape.tolist())]
else:
return [tensor_shape.unknown_shape(ndims=4)]
@ops.RegisterShape("MaxPoolGrad")
@ops.RegisterShape("MaxPoolGradWithArgmax")
def _MaxPoolGradShape(op):
"""Shape function for the MaxPoolGrad op."""
orig_input_shape = op.inputs[0].get_shape().with_rank(4)
return [orig_input_shape]
@ops.RegisterStatistics("Conv2D", "flops")
def _calc_conv_flops(graph, node):
"""Calculates the compute resources needed for Conv2D."""
input_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
input_shape.assert_is_fully_defined()
filter_shape = graph_util.tensor_shape_from_node_def_name(graph,
node.input[1])
filter_shape.assert_is_fully_defined()
output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
output_shape.assert_is_fully_defined()
filter_height = int(filter_shape[0])
filter_width = int(filter_shape[1])
filter_in_depth = int(filter_shape[2])
output_count = np.prod(output_shape.as_list())
return ops.OpStats("flops", (output_count * filter_in_depth * filter_height *
filter_width * 2))
@ops.RegisterStatistics("Conv2D", "weight_parameters")
def _calc_conv_weight_params(graph, node):
"""Calculates the on-disk size of the weights for Conv2D."""
input_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
input_shape.assert_is_fully_defined()
filter_shape = graph_util.tensor_shape_from_node_def_name(graph,
node.input[1])
filter_shape.assert_is_fully_defined()
output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
output_shape.assert_is_fully_defined()
filter_height = int(filter_shape[0])
filter_width = int(filter_shape[1])
filter_in_depth = int(filter_shape[2])
filter_out_depth = int(filter_shape[3])
return ops.OpStats("weight_parameters", (filter_height * filter_width *
filter_in_depth * filter_out_depth))
@ops.RegisterStatistics("DepthwiseConv2dNative", "flops")
def _calc_depthwise_conv_flops(graph, node):
"""Calculates the compute resources needed for DepthwiseConv2dNative."""
input_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
input_shape.assert_is_fully_defined()
filter_shape = graph_util.tensor_shape_from_node_def_name(graph,
node.input[1])
filter_shape.assert_is_fully_defined()
output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
output_shape.assert_is_fully_defined()
filter_height = int(filter_shape[0])
filter_width = int(filter_shape[1])
output_count = np.prod(output_shape.as_list())
return ops.OpStats("flops", (output_count * filter_height * filter_width * 2))
@ops.RegisterStatistics("DepthwiseConv2dNative", "weight_parameters")
def _calc_depthwise_conv_weight_params(graph, node):
"""Calculates the on-disk size of the weights for DepthwiseConv2dNative."""
input_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
input_shape.assert_is_fully_defined()
filter_shape = graph_util.tensor_shape_from_node_def_name(graph,
node.input[1])
filter_shape.assert_is_fully_defined()
output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
output_shape.assert_is_fully_defined()
filter_height = int(filter_shape[0])
filter_width = int(filter_shape[1])
filter_in_depth = int(filter_shape[2])
filter_channel_multiplier = int(filter_shape[3])
return ops.OpStats("weight_parameters", (filter_height * filter_width *
filter_in_depth *
filter_channel_multiplier))
@ops.RegisterShape("Conv3D")
def _Conv3DShape(op):
"""Shape function for Conv3D."""
input_shape = op.inputs[0].get_shape().with_rank(5)
filter_shape = op.inputs[1].get_shape().with_rank(5)
batch_size = input_shape[0]
out_channels = filter_shape[4]
# Check that the input number of channels is compatible between
# input data and filter size.
input_shape[4].assert_is_compatible_with(filter_shape[3])
stride_b, stride_p, stride_r, stride_c, stride_d = op.get_attr("strides")
assert stride_b == 1
assert stride_d == 1
padding_type = op.get_attr("padding")
out_planes, out_rows, out_cols = common_shapes.get_conv_output_size(
input_shape[1:4], filter_shape[0:3], (stride_p, stride_r, stride_c),
padding_type)
return [tensor_shape.TensorShape([batch_size, out_planes, out_rows, out_cols,
out_channels])]
@ops.RegisterShape("MaxPool3D")
@ops.RegisterShape("AvgPool3D")
def _Pool3DShape(op):
"""Shape function for Max/AvgPool3D."""
input_shape = op.inputs[0].get_shape().with_rank(5)
ksize_b, ksize_p, ksize_r, ksize_c, ksize_d = op.get_attr("ksize")
assert ksize_b == 1
assert ksize_d == 1
stride_b, stride_p, stride_r, stride_c, stride_d = op.get_attr("strides")
assert stride_b == 1
assert stride_d == 1
batch_size = input_shape[0]
channels = input_shape[4]
padding = op.get_attr("padding")
out_planes, out_rows, out_cols = common_shapes.get_conv_output_size(
input_shape[1:4], (ksize_p, ksize_r, ksize_c),
(stride_p, stride_r, stride_c), padding)
return [tensor_shape.TensorShape([batch_size, out_planes, out_rows, out_cols,
channels])]
@ops.RegisterShape("Conv3DBackpropFilter")
def _Conv3DBackpropFilterShape(op):
"""Shape function for the Conv3DBackpropFilter op."""
filter_shape = op.inputs[1].get_shape()
return [filter_shape.with_rank(5)]
@ops.RegisterShape("Conv3DBackpropInput")
def _Conv3DBackpropInputShape(op):
"""Shape function for the Conv3DBackpropInput op."""
input_shape = op.inputs[0].get_shape()
return [input_shape.with_rank(5)]
@ops.RegisterShape("Conv3DBackpropFilterV2")
def _Conv3DBackpropFilterShapeV2(op):
"""Shape function for the Conv3DBackpropFilterV2 op."""
filter_shape = tensor_util.constant_value(op.inputs[1])
return [tensor_shape.TensorShape(filter_shape).with_rank(5)]
@ops.RegisterShape("Conv3DBackpropInputV2")
def _Conv3DBackpropInputShapeV2(op):
"""Shape function for the Conv3DBackpropInputV2 op."""
input_shape = tensor_util.constant_value(op.inputs[0])
return [tensor_shape.TensorShape(input_shape).with_rank(5)]
@ops.RegisterShape("AvgPool3DGrad")
def _AvgPool3DGradShape(op):