forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathModel.cs
More file actions
148 lines (127 loc) · 5.06 KB
/
Model.cs
File metadata and controls
148 lines (127 loc) · 5.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Train;
namespace Tensorflow.Keras.Engine
{
/// <summary>
/// `Model` groups layers into an object with training and inference features.
/// </summary>
public partial class Model : Layer, IModel
{
#pragma warning disable CS0169 // The field 'Model._cloning' is never used
bool _cloning;
#pragma warning restore CS0169 // The field 'Model._cloning' is never used
#pragma warning disable CS0108 // Member hides inherited member; missing new keyword
#pragma warning disable CS0414 // The field 'Model._is_compiled' is assigned but its value is never used
bool _is_compiled;
#pragma warning restore CS0414 // The field 'Model._is_compiled' is assigned but its value is never used
#pragma warning restore CS0108 // Member hides inherited member; missing new keyword
ILossFunc loss;
IOptimizer optimizer;
IVariableV1 _steps_per_execution;
protected bool _is_graph_network;
protected Tensors inputs;
protected Tensors outputs;
public string[] output_names;
IVariableV1 _train_counter;
IVariableV1 _test_counter;
IVariableV1 _predict_counter;
bool _base_model_initialized;
bool stop_training;
public bool IsGraphNetwork => _is_graph_network;
public IOptimizer Optimizer
{
get => optimizer;
set => optimizer = value;
}
public Model(ModelArgs args)
: base(args)
{
_init_batch_counters();
}
internal override void Initialize(LayerArgs args)
{
_init_batch_counters();
base.Initialize(args);
}
void _configure_steps_per_execution(int steps_per_execution)
{
_steps_per_execution = tf.Variable(steps_per_execution,
dtype: TF_DataType.TF_INT64,
aggregation: VariableAggregation.OnlyFirstReplica);
}
void _reset_compile_cache()
{
// Used to cache `trainable` attr of `Layer`s for `fit`.
_compiled_trainable_state = _get_trainable_state();
keras.backend._GRAPH = null;
}
void _init_batch_counters()
{
_train_counter = tf.Variable(0L,
dtype: TF_DataType.TF_INT64,
aggregation: VariableAggregation.OnlyFirstReplica);
_test_counter = tf.Variable(0L,
dtype: TF_DataType.TF_INT64,
aggregation: VariableAggregation.OnlyFirstReplica);
_predict_counter = tf.Variable(0L,
dtype: TF_DataType.TF_INT64,
aggregation: VariableAggregation.OnlyFirstReplica);
}
public override List<ILayer> Layers
=> _flatten_layers(recursive: false, include_self: false).ToList();
public override List<IVariableV1> TrainableWeights
{
get
{
// skip the assertion of weights created.
var variables = new List<IVariableV1>();
if (!Trainable)
{
return variables;
}
foreach (var trackable_obj in _self_tracked_trackables)
{
if (trackable_obj.Trainable)
variables.AddRange(trackable_obj.TrainableWeights);
}
variables.AddRange(_trainable_weights);
return variables.Distinct().ToList();
}
}
public override List<IVariableV1> NonTrainableWeights
{
get
{
// skip the assertion of weights created.
var variables = new List<IVariableV1>();
foreach (var trackable_obj in _self_tracked_trackables)
{
variables.AddRange(trackable_obj.NonTrainableWeights);
}
if (!Trainable)
{
var trainable_variables = new List<IVariableV1>();
foreach (var trackable_obj in _self_tracked_trackables)
{
variables.AddRange(trackable_obj.TrainableWeights);
}
variables.AddRange(trainable_variables);
variables.AddRange(_trainable_weights);
variables.AddRange(_non_trainable_weights);
}
return variables.Distinct().ToList();
}
}
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{
if(save_type == SaveType.SAVEDMODEL)
{
//TODO: deal with `train_function`, `test_function`, `predict_function`, `train_tf_function`.
}
var children = base._trackable_children(save_type, cache);
return children;
}
}
}