forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathModel.Train.cs
More file actions
53 lines (47 loc) · 2.11 KB
/
Model.Train.cs
File metadata and controls
53 lines (47 loc) · 2.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Gradients;
using Tensorflow.Keras.Optimizers;
using static Tensorflow.Binding;
namespace Tensorflow.Keras.Engine
{
public partial class Model
{
IEnumerable<(string, Tensor)> train_step_function(OwnedIterator iterator)
{
var data = iterator.next();
var outputs = train_step(data[0], data[1]);
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1));
return outputs;
}
/// <summary>
/// The logic for one training step.
/// </summary>
/// <param name="data"></param>
/// <returns></returns>
List<(string, Tensor)> train_step(Tensor x, Tensor y)
{
(x, y) = data_handler.DataAdapter.Expand1d(x, y);
using var tape = tf.GradientTape();
var y_pred = Apply(x, training: true);
var loss = compiled_loss.Call(y, y_pred);
// For custom training steps, users can just write:
// trainable_variables = self.trainable_variables
// gradients = tape.gradient(loss, trainable_variables)
// self.optimizer.apply_gradients(zip(gradients, trainable_variables))
// The _minimize call does a few extra steps unnecessary in most cases,
// such as loss scaling and gradient clipping.
_minimize(tape, optimizer, loss, trainable_variables);
compiled_metrics.update_state(y, y_pred);
return metrics.Select(x => (x.Name, x.result())).ToList();
}
void _minimize(GradientTape tape, OptimizerV2 optimizer, Tensor loss, List<IVariableV1> trainable_variables)
{
var gradients = tape.gradient(loss, trainable_variables);
gradients = optimizer._aggregate_gradients(zip(gradients, trainable_variables));
gradients = optimizer._clip_gradients(gradients);
optimizer.apply_gradients(zip(gradients, trainable_variables.Select(x => x as ResourceVariable)),
experimental_aggregate_gradients: false);
}
}
}