Skip to content

Commit bffb9cd

Browse files
committed
latest_checkpoint
1 parent c7e511a commit bffb9cd

9 files changed

Lines changed: 105 additions & 13 deletions

File tree

src/TensorFlowNET.Core/APIs/tf.train.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ public Saver import_meta_graph(string meta_graph_or_file,
5757
clear_devices: clear_devices,
5858
clear_extraneous_savers: clear_extraneous_savers,
5959
strip_default_attrs: strip_default_attrs);
60+
61+
public string latest_checkpoint(string checkpoint_dir, string latest_filename = null)
62+
=> checkpoint_management.latest_checkpoint(checkpoint_dir, latest_filename: latest_filename);
63+
64+
public CheckpointState get_checkpoint_state(string checkpoint_dir, string latest_filename = null)
65+
=> checkpoint_management.get_checkpoint_state(checkpoint_dir, latest_filename: latest_filename);
6066
}
6167
}
6268
}

src/TensorFlowNET.Core/Estimators/Estimator.cs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,35 @@ public class Estimator : IObjectLife
1818

1919
string _model_dir;
2020

21+
Action _model_fn;
22+
2123
public Estimator(Action model_fn, RunConfig config)
2224
{
2325
_config = config;
2426
_model_dir = _config.model_dir;
2527
_session_config = _config.session_config;
28+
_model_fn = model_fn;
2629
}
2730

28-
public Estimator train(Action input_fn, int max_steps = 1,
31+
public Estimator train(Action input_fn, int max_steps = 1, Action[] hooks = null,
2932
_NewCheckpointListenerForEvaluate[] saving_listeners = null)
3033
{
34+
if(max_steps > 0)
35+
{
36+
var start_step = _load_global_step_from_checkpoint_dir(_model_dir);
37+
}
38+
3139
_train_model();
3240
throw new NotImplementedException("");
3341
}
3442

43+
private int _load_global_step_from_checkpoint_dir(string checkpoint_dir)
44+
{
45+
var cp = tf.train.latest_checkpoint(checkpoint_dir);
46+
47+
return 0;
48+
}
49+
3550
private void _train_model()
3651
{
3752
_train_model_default();

src/TensorFlowNET.Core/Estimators/EvalSpec.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@ namespace Tensorflow.Estimators
66
{
77
public class EvalSpec
88
{
9+
string _name;
10+
911
public EvalSpec(string name, Action input_fn, FinalExporter exporters)
1012
{
11-
13+
_name = name;
1214
}
1315
}
1416
}

src/TensorFlowNET.Core/Estimators/TrainSpec.cs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,16 @@ namespace Tensorflow.Estimators
66
{
77
public class TrainSpec
88
{
9-
public int max_steps { get; set; }
9+
int _max_steps;
10+
public int max_steps => _max_steps;
11+
12+
Action _input_fn;
13+
public Action input_fn => _input_fn;
1014

1115
public TrainSpec(Action input_fn, int max_steps)
1216
{
13-
this.max_steps = max_steps;
17+
_max_steps = max_steps;
18+
_input_fn = input_fn;
1419
}
1520
}
1621
}

src/TensorFlowNET.Core/Estimators/_NewCheckpointListenerForEvaluate.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,11 @@ namespace Tensorflow.Estimators
66
{
77
public class _NewCheckpointListenerForEvaluate
88
{
9+
_Evaluator _evaluator;
10+
11+
public _NewCheckpointListenerForEvaluate(_Evaluator evaluator, int eval_throttle_secs)
12+
{
13+
_evaluator = evaluator;
14+
}
915
}
1016
}

src/TensorFlowNET.Core/Estimators/_TrainingExecutor.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,17 @@ public void run()
3232
/// </summary>
3333
private void run_local()
3434
{
35+
var train_hooks = new Action[0];
3536
Console.WriteLine("Start train and evaluate loop. The evaluate will happen " +
3637
"after every checkpoint. Checkpoint frequency is determined " +
3738
$"based on RunConfig arguments: save_checkpoints_steps {_estimator.config.save_checkpoints_steps} or " +
3839
$"save_checkpoints_secs {_estimator.config.save_checkpoints_secs}.");
3940
var evaluator = new _Evaluator(_estimator, _eval_spec, _train_spec.max_steps);
40-
/*_estimator.train(input_fn: _train_spec.input_fn,
41+
var saving_listeners = new _NewCheckpointListenerForEvaluate[0];
42+
_estimator.train(input_fn: _train_spec.input_fn,
4143
max_steps: _train_spec.max_steps,
4244
hooks: train_hooks,
43-
saving_listeners: saving_listeners);*/
45+
saving_listeners: saving_listeners);
4446
}
4547
}
4648
}

src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
using System.IO;
2020
using System.Linq;
2121
using static Tensorflow.SaverDef.Types;
22+
using static Tensorflow.Binding;
2223

2324
namespace Tensorflow
2425
{
@@ -144,5 +145,54 @@ private static string _prefix_to_checkpoint_path(string prefix, CheckpointFormat
144145
return prefix + ".index";
145146
return prefix;
146147
}
148+
149+
/// <summary>
150+
/// Finds the filename of latest saved checkpoint file.
151+
/// </summary>
152+
/// <param name="checkpoint_dir"></param>
153+
/// <param name="latest_filename"></param>
154+
/// <returns></returns>
155+
public static string latest_checkpoint(string checkpoint_dir, string latest_filename = null)
156+
{
157+
// Pick the latest checkpoint based on checkpoint state.
158+
var ckpt = get_checkpoint_state(checkpoint_dir, latest_filename);
159+
if(ckpt != null && !string.IsNullOrEmpty(ckpt.ModelCheckpointPath))
160+
{
161+
// Look for either a V2 path or a V1 path, with priority for V2.
162+
var v2_path = _prefix_to_checkpoint_path(ckpt.ModelCheckpointPath, CheckpointFormatVersion.V2);
163+
var v1_path = _prefix_to_checkpoint_path(ckpt.ModelCheckpointPath, CheckpointFormatVersion.V1);
164+
if (File.Exists(v2_path) || File.Exists(v1_path))
165+
return ckpt.ModelCheckpointPath;
166+
else
167+
throw new ValueError($"Couldn't match files for checkpoint {ckpt.ModelCheckpointPath}");
168+
}
169+
return null;
170+
}
171+
172+
public static CheckpointState get_checkpoint_state(string checkpoint_dir, string latest_filename = null)
173+
{
174+
var coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir, latest_filename);
175+
if (File.Exists(coord_checkpoint_filename))
176+
{
177+
var file_content = File.ReadAllBytes(coord_checkpoint_filename);
178+
var ckpt = CheckpointState.Parser.ParseFrom(file_content);
179+
if (string.IsNullOrEmpty(ckpt.ModelCheckpointPath))
180+
throw new ValueError($"Invalid checkpoint state loaded from {checkpoint_dir}");
181+
// For relative model_checkpoint_path and all_model_checkpoint_paths,
182+
// prepend checkpoint_dir.
183+
if (!Path.IsPathRooted(ckpt.ModelCheckpointPath))
184+
ckpt.ModelCheckpointPath = Path.Combine(checkpoint_dir, ckpt.ModelCheckpointPath);
185+
foreach(var i in range(len(ckpt.AllModelCheckpointPaths)))
186+
{
187+
var p = ckpt.AllModelCheckpointPaths[i];
188+
if (!Path.IsPathRooted(p))
189+
ckpt.AllModelCheckpointPaths[i] = Path.Combine(checkpoint_dir, p);
190+
}
191+
192+
return ckpt;
193+
}
194+
195+
return null;
196+
}
147197
}
148198
}

src/TensorFlowNET.Models/ObjectDetection/ModelLib.cs

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,28 @@ public TrainAndEvalDict create_estimator_and_inputs(RunConfig run_config,
1919
int sample_1_of_n_eval_on_train_examples = 1)
2020
{
2121
var config = ConfigUtil.get_configs_from_pipeline_file(pipeline_config_path);
22+
23+
// Create the input functions for TRAIN/EVAL/PREDICT.
24+
Action train_input_fn = () => { };
25+
2226
var eval_input_configs = config.EvalInputReader;
2327

2428
var eval_input_fns = new Action[eval_input_configs.Count];
2529
var eval_input_names = eval_input_configs.Select(eval_input_config => eval_input_config.Name).ToArray();
30+
Action eval_on_train_input_fn = () => { };
31+
Action predict_input_fn = () => { };
2632
Action model_fn = () => { };
2733
var estimator = tf.estimator.Estimator(model_fn: model_fn, config: run_config);
2834

2935
return new TrainAndEvalDict
3036
{
3137
estimator = estimator,
32-
train_steps = train_steps,
38+
train_input_fn = train_input_fn,
3339
eval_input_fns = eval_input_fns,
34-
eval_input_names = eval_input_names
40+
eval_input_names = eval_input_names,
41+
eval_on_train_input_fn = eval_on_train_input_fn,
42+
predict_input_fn = predict_input_fn,
43+
train_steps = train_steps
3544
};
3645
}
3746

@@ -46,10 +55,7 @@ public TrainAndEvalDict create_estimator_and_inputs(RunConfig run_config,
4655
.Select(x => x.ToString())
4756
.ToArray();
4857

49-
var eval_specs = new List<EvalSpec>()
50-
{
51-
new EvalSpec("", null, null) // for test.
52-
};
58+
var eval_specs = new List<EvalSpec>();
5359
foreach (var (index, (eval_spec_name, eval_input_fn)) in enumerate(zip(eval_spec_names, eval_input_fns).ToList()))
5460
{
5561
var exporter_name = index == 0 ? final_exporter_name : $"{final_exporter_name}_{eval_spec_name}";

test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection/Main.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ public class Main : IExample
2121

2222
string model_dir = "D:/Projects/PythonLab/tf-models/research/object_detection/models/model";
2323
string pipeline_config_path = "ObjectDetection/Models/faster_rcnn_resnet101_voc07.config";
24-
int num_train_steps = 1;
24+
int num_train_steps = 50;
2525
int sample_1_of_n_eval_examples = 1;
2626
int sample_1_of_n_eval_on_train_examples = 5;
2727

0 commit comments

Comments
 (0)