forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathModel.Fit.cs
More file actions
111 lines (104 loc) · 3.87 KB
/
Model.Fit.cs
File metadata and controls
111 lines (104 loc) · 3.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
using NumSharp;
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine.DataAdapters;
namespace Tensorflow.Keras.Engine
{
public partial class Model
{
/// <summary>
/// Trains the model for a fixed number of epochs (iterations on a dataset).
/// </summary>
/// <param name="x"></param>
/// <param name="y"></param>
/// <param name="batch_size"></param>
/// <param name="epochs"></param>
/// <param name="verbose"></param>
/// <param name="validation_split"></param>
/// <param name="shuffle"></param>
public void fit(NDArray x, NDArray y,
int batch_size = -1,
int epochs = 1,
int verbose = 1,
float validation_split = 0f,
bool shuffle = true,
int initial_epoch = 0,
int max_queue_size = 10,
int workers = 1,
bool use_multiprocessing = false)
{
int train_count = Convert.ToInt32(x.shape[0] * (1 - validation_split));
var train_x = x[new Slice(0, train_count)];
var train_y = y[new Slice(0, train_count)];
var val_x = x[new Slice(train_count)];
var val_y = y[new Slice(train_count)];
data_handler = new DataHandler(new DataHandlerArgs
{
X = train_x,
Y = train_y,
BatchSize = batch_size,
InitialEpoch = initial_epoch,
Epochs = epochs,
Shuffle = shuffle,
MaxQueueSize = max_queue_size,
Workers = workers,
UseMultiprocessing = use_multiprocessing,
Model = this,
StepsPerExecution = _steps_per_execution
});
FitInternal(epochs, verbose);
}
public void fit(IDatasetV2 dataset,
IDatasetV2 validation_data = null,
int batch_size = -1,
int epochs = 1,
int verbose = 1,
float validation_split = 0f,
bool shuffle = true,
int initial_epoch = 0,
int max_queue_size = 10,
int workers = 1,
bool use_multiprocessing = false)
{
data_handler = new DataHandler(new DataHandlerArgs
{
Dataset = dataset,
BatchSize = batch_size,
InitialEpoch = initial_epoch,
Epochs = epochs,
Shuffle = shuffle,
MaxQueueSize = max_queue_size,
Workers = workers,
UseMultiprocessing = use_multiprocessing,
Model = this,
StepsPerExecution = _steps_per_execution
});
FitInternal(epochs, verbose);
}
void FitInternal(int epochs, int verbose)
{
stop_training = false;
_train_counter.assign(0);
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
reset_metrics();
// callbacks.on_epoch_begin(epoch)
// data_handler.catch_stop_iteration();
foreach (var step in data_handler.steps())
{
// callbacks.on_train_batch_begin(step)
var results = train_step_function(iterator);
if (verbose == 1)
{
var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}"));
Binding.tf_output_redirect.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}");
}
}
GC.Collect();
GC.WaitForPendingFinalizers();
}
}
}
}