forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathModel.cs
More file actions
102 lines (89 loc) · 3.4 KB
/
Model.cs
File metadata and controls
102 lines (89 loc) · 3.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine.DataAdapters;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Optimizers;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
namespace Tensorflow.Keras.Engine
{
/// <summary>
/// `Model` groups layers into an object with training and inference features.
/// </summary>
public partial class Model : Layer, IModel
{
#pragma warning disable CS0169 // The field 'Model._cloning' is never used
bool _cloning;
#pragma warning restore CS0169 // The field 'Model._cloning' is never used
#pragma warning disable CS0108 // Member hides inherited member; missing new keyword
#pragma warning disable CS0414 // The field 'Model._is_compiled' is assigned but its value is never used
bool _is_compiled;
#pragma warning restore CS0414 // The field 'Model._is_compiled' is assigned but its value is never used
#pragma warning restore CS0108 // Member hides inherited member; missing new keyword
ILossFunc loss;
OptimizerV2 optimizer;
IVariableV1 _steps_per_execution;
protected bool _is_graph_network;
protected Tensors inputs;
protected Tensors outputs;
public string[] output_names;
IVariableV1 _train_counter;
IVariableV1 _test_counter;
IVariableV1 _predict_counter;
bool _base_model_initialized;
bool stop_training;
DataHandler data_handler;
public Model(ModelArgs args)
: base(args)
{
_init_batch_counters();
}
void _configure_steps_per_execution(int steps_per_execution)
{
_steps_per_execution = tf.Variable(steps_per_execution,
dtype: TF_DataType.TF_INT64,
aggregation: VariableAggregation.OnlyFirstReplica);
}
void _reset_compile_cache()
{
// Used to cache `trainable` attr of `Layer`s for `fit`.
_compiled_trainable_state = _get_trainable_state();
keras.backend._GRAPH = null;
}
void _init_batch_counters()
{
_train_counter = tf.Variable(0L,
dtype: TF_DataType.TF_INT64,
aggregation: VariableAggregation.OnlyFirstReplica);
_test_counter = tf.Variable(0L,
dtype: TF_DataType.TF_INT64,
aggregation: VariableAggregation.OnlyFirstReplica);
_predict_counter = tf.Variable(0L,
dtype: TF_DataType.TF_INT64,
aggregation: VariableAggregation.OnlyFirstReplica);
}
public override List<IVariableV1> trainable_variables
{
get
{
var variables = new List<IVariableV1>();
if (!Trainable)
{
return variables;
}
foreach (var trackable_obj in _self_tracked_trackables)
{
if (trackable_obj.Trainable)
variables.AddRange(trackable_obj.trainable_variables);
}
foreach (var layer in _layers)
{
if (layer.Trainable)
variables.AddRange(layer.trainable_variables);
}
// variables.AddRange(_trainable_weights);
return variables;
}
}
}
}