Skip to content

Commit f0030ca

Browse files
committed
reset_metrics for every epoch.
1 parent f55650d commit f0030ca

5 files changed

Lines changed: 19 additions & 7 deletions

File tree

src/TensorFlowNET.Core/tensorflow.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public partial class tensorflow : ITensorFlowObject
4848
public tensorflow()
4949
{
5050
Logger = new LoggerConfiguration()
51-
.MinimumLevel.Debug()
51+
.MinimumLevel.Error()
5252
.WriteTo.Console()
5353
.CreateLogger();
5454

src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ public TensorLikeDataAdapter(DataAdapterArgs args)
2323
num_samples = args.X.shape[0];
2424
var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize;
2525
_batch_size = batch_size;
26-
_size = num_samples < batch_size ? num_samples % batch_size : num_samples / batch_size;
26+
_size = Convert.ToInt32(Math.Ceiling(num_samples / (batch_size + 0.0f)));
2727
num_full_batches = num_samples / batch_size;
2828
_partial_batch_size = num_samples % batch_size;
2929

src/TensorFlowNET.Keras/Engine/Model.Fit.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ void FitInternal(int epochs, int verbose)
8989
_train_counter.assign(0);
9090
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
9191
{
92-
// reset_metrics();
92+
reset_metrics();
9393
// callbacks.on_epoch_begin(epoch)
9494
// data_handler.catch_stop_iteration();
9595
foreach (var step in data_handler.steps())

src/TensorFlowNET.Keras/Engine/Model.Metrics.cs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ public IEnumerable<Metric> metrics
1010
get
1111
{
1212
var _metrics = new List<Metric>();
13+
1314
if (_is_compiled)
1415
{
1516
if (compiled_loss != null)
@@ -18,13 +19,17 @@ public IEnumerable<Metric> metrics
1819
_metrics.add(compiled_metrics.metrics);
1920
}
2021

21-
foreach (var layer in _flatten_layers())
22-
{
23-
// _metrics.extend(layer.metrics);
24-
}
22+
/*foreach (var layer in _flatten_layers())
23+
_metrics.extend(layer.metrics);*/
2524

2625
return _metrics;
2726
}
2827
}
28+
29+
void reset_metrics()
30+
{
31+
foreach (var metric in metrics)
32+
metric.reset_states();
33+
}
2934
}
3035
}

src/TensorFlowNET.Keras/Metrics/Metric.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using Tensorflow.Keras.ArgsDefinition;
33
using Tensorflow.Keras.Engine;
44
using static Tensorflow.Binding;
5+
using static Tensorflow.KerasApi;
56

67
namespace Tensorflow.Keras.Metrics
78
{
@@ -53,6 +54,12 @@ protected override IVariableV1 add_weight(string name,
5354
public virtual Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
5455
=> throw new NotImplementedException("");
5556

57+
public virtual void reset_states()
58+
{
59+
foreach (var v in weights)
60+
v.assign(0);
61+
}
62+
5663
public virtual Tensor result()
5764
=> throw new NotImplementedException("");
5865

0 commit comments

Comments
 (0)