Skip to content

Commit 8ca44dd

Browse files
yashk2810copybara-github
authored andcommitted
Switch division to tf.float32
PiperOrigin-RevId: 253790284
1 parent 3e6c617 commit 8ca44dd

2 files changed

Lines changed: 5 additions & 5 deletions

File tree

tensorflow_examples/models/densenet/distributed_train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,8 @@ def custom_loop(self, train_dist_dataset, test_dist_dataset,
122122
"""
123123

124124
def distributed_train_epoch(ds):
125-
total_loss = 0.
126-
num_train_batches = 0
125+
total_loss = 0.0
126+
num_train_batches = 0.0
127127
for one_batch in ds:
128128
per_replica_loss = strategy.experimental_run_v2(
129129
self.train_step, args=(one_batch,))
@@ -133,7 +133,7 @@ def distributed_train_epoch(ds):
133133
return total_loss, num_train_batches
134134

135135
def distributed_test_epoch(ds):
136-
num_test_batches = 0
136+
num_test_batches = 0.0
137137
for one_batch in ds:
138138
strategy.experimental_run_v2(
139139
self.test_step, args=(one_batch,))

tensorflow_examples/models/nmt_with_attention/distributed_train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ def training_loop(self, train_dist_dataset, test_dist_dataset,
6767
"""
6868

6969
def distributed_train_epoch(ds):
70-
total_loss = 0.
71-
num_train_batches = 0
70+
total_loss = 0.0
71+
num_train_batches = 0.0
7272
for one_batch in ds:
7373
per_replica_loss = strategy.experimental_run_v2(
7474
self.train_step, args=(one_batch,))

0 commit comments

Comments
 (0)