forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathModel.Fit.cs
More file actions
144 lines (131 loc) · 4.94 KB
/
Model.Fit.cs
File metadata and controls
144 lines (131 loc) · 4.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
using Tensorflow.NumPy;
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine.DataAdapters;
using System.Diagnostics;
namespace Tensorflow.Keras.Engine
{
public partial class Model
{
/// <summary>
/// Trains the model for a fixed number of epochs (iterations on a dataset).
/// </summary>
/// <param name="x"></param>
/// <param name="y"></param>
/// <param name="batch_size"></param>
/// <param name="epochs"></param>
/// <param name="verbose"></param>
/// <param name="validation_split"></param>
/// <param name="shuffle"></param>
public void fit(NDArray x, NDArray y,
int batch_size = -1,
int epochs = 1,
int verbose = 1,
float validation_split = 0f,
bool shuffle = true,
int initial_epoch = 0,
int max_queue_size = 10,
int workers = 1,
bool use_multiprocessing = false)
{
int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split));
var train_x = x[new Slice(0, train_count)];
var train_y = y[new Slice(0, train_count)];
var val_x = x[new Slice(train_count)];
var val_y = y[new Slice(train_count)];
data_handler = new DataHandler(new DataHandlerArgs
{
X = train_x,
Y = train_y,
BatchSize = batch_size,
InitialEpoch = initial_epoch,
Epochs = epochs,
Shuffle = shuffle,
MaxQueueSize = max_queue_size,
Workers = workers,
UseMultiprocessing = use_multiprocessing,
Model = this,
StepsPerExecution = _steps_per_execution
});
FitInternal(epochs, verbose);
}
public void fit(IDatasetV2 dataset,
IDatasetV2 validation_data = null,
int batch_size = -1,
int epochs = 1,
int verbose = 1,
float validation_split = 0f,
bool shuffle = true,
int initial_epoch = 0,
int max_queue_size = 10,
int workers = 1,
bool use_multiprocessing = false)
{
data_handler = new DataHandler(new DataHandlerArgs
{
Dataset = dataset,
BatchSize = batch_size,
InitialEpoch = initial_epoch,
Epochs = epochs,
Shuffle = shuffle,
MaxQueueSize = max_queue_size,
Workers = workers,
UseMultiprocessing = use_multiprocessing,
Model = this,
StepsPerExecution = _steps_per_execution
});
FitInternal(epochs, verbose);
}
void FitInternal(int epochs, int verbose)
{
stop_training = false;
_train_counter.assign(0);
Stopwatch sw = new Stopwatch();
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
reset_metrics();
on_epoch_begin(epoch, epochs);
// data_handler.catch_stop_iteration();
foreach (var step in data_handler.steps())
{
sw.Start();
var results = train_step_function(iterator);
sw.Stop();
on_train_batch_begin(verbose, step, sw.ElapsedMilliseconds, results);
// recycle memory more frequency
if (sw.ElapsedMilliseconds > 100)
{
GC.Collect();
}
sw.Reset();
}
Console.WriteLine();
GC.Collect();
GC.WaitForPendingFinalizers();
}
}
void on_epoch_begin(int epoch, int epochs)
{
Binding.tf_output_redirect.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}");
}
void on_train_batch_begin(int verbose, long step, long elapse, IEnumerable<(string, Tensor)> results)
{
if (verbose == 1)
{
var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}"));
var progress = "";
for (int i = 0; i < step + 1; i++)
for (int j = 0; j < 30 / data_handler.Inferredsteps; j++)
progress += "=";
progress += ">";
var remaining = "";
for (int i = 1; i < 30 - progress.Length; i++)
remaining += ".";
Binding.tf_output_redirect.Write($"{step + 1:D4}/{data_handler.Inferredsteps:D4} [{progress}{remaining}] - {elapse}ms/step {result_pairs}");
Console.CursorLeft = 0;
}
}
}
}