forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathEstimator.cs
More file actions
138 lines (115 loc) · 4.16 KB
/
Estimator.cs
File metadata and controls
138 lines (115 loc) · 4.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Text;
using Tensorflow.Data;
using Tensorflow.Train;
using static Tensorflow.Binding;
namespace Tensorflow.Estimators
{
/// <summary>
/// Estimator class to train and evaluate TensorFlow models.
/// </summary>
public class Estimator : IObjectLife
{
RunConfig _config;
public RunConfig config => _config;
ConfigProto _session_config;
public ConfigProto session_config => _session_config;
string _model_dir;
Action _model_fn;
public Estimator(Action model_fn, RunConfig config)
{
_config = config;
_model_dir = _config.model_dir;
_session_config = _config.session_config;
_model_fn = model_fn;
}
public Estimator train(Func<DatasetV1Adapter> input_fn, int max_steps = 1, Action[] hooks = null,
_NewCheckpointListenerForEvaluate[] saving_listeners = null)
{
if(max_steps > 0)
{
var start_step = _load_global_step_from_checkpoint_dir(_model_dir);
if (max_steps <= start_step)
{
Console.WriteLine("Skipping training since max_steps has already saved.");
return this;
}
}
var loss = _train_model(input_fn);
print($"Loss for final step: {loss}.");
return this;
}
private int _load_global_step_from_checkpoint_dir(string checkpoint_dir)
{
// var cp = tf.train.latest_checkpoint(checkpoint_dir);
// should use NewCheckpointReader (not implemented)
var cp = tf.train.get_checkpoint_state(checkpoint_dir);
return cp.AllModelCheckpointPaths.Count - 1;
}
private Tensor _train_model(Func<DatasetV1Adapter> input_fn)
{
return _train_model_default(input_fn);
}
private Tensor _train_model_default(Func<DatasetV1Adapter> input_fn)
{
using (var g = tf.Graph().as_default())
{
var global_step_tensor = _create_and_assert_global_step(g);
// Skip creating a read variable if _create_and_assert_global_step
// returns None (e.g. tf.contrib.estimator.SavedModelEstimator).
if (global_step_tensor != null)
TrainingUtil._get_or_create_global_step_read(g);
var (features, labels) = _get_features_and_labels_from_input_fn(input_fn, "train");
}
throw new NotImplementedException("");
}
private (Dictionary<string, Tensor>, Dictionary<string, Tensor>) _get_features_and_labels_from_input_fn(Func<DatasetV1Adapter> input_fn, string mode)
{
var result = _call_input_fn(input_fn, mode);
return EstimatorUtil.parse_input_fn_result(result);
}
/// <summary>
/// Calls the input function.
/// </summary>
/// <param name="input_fn"></param>
/// <param name="mode"></param>
private DatasetV1Adapter _call_input_fn(Func<DatasetV1Adapter> input_fn, string mode)
{
return input_fn();
}
private Tensor _create_and_assert_global_step(Graph graph)
{
var step = _create_global_step(graph);
Debug.Assert(step == tf.train.get_global_step(graph));
Debug.Assert(step.dtype.is_integer());
return step;
}
private RefVariable _create_global_step(Graph graph)
{
return tf.train.create_global_step(graph);
}
public void __init__()
{
throw new NotImplementedException();
}
public void __enter__()
{
throw new NotImplementedException();
}
public void __del__()
{
throw new NotImplementedException();
}
public void __exit__()
{
throw new NotImplementedException();
}
public void Dispose()
{
throw new NotImplementedException();
}
}
}