forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTrackable.cs
More file actions
62 lines (56 loc) · 2.18 KB
/
Trackable.cs
File metadata and controls
62 lines (56 loc) · 2.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
using System;
using System.Collections.Generic;
using System.Text;
namespace Tensorflow.Train
{
public abstract class Trackable
{
protected int _self_update_uid;
/// <summary>
/// Restore-on-create for a variable be saved with this `Checkpointable`.
/// </summary>
/// <returns></returns>
protected virtual RefVariable _add_variable_with_custom_getter(string name,
int[] shape,
TF_DataType dtype = TF_DataType.TF_FLOAT,
IInitializer initializer = null,
Func<string, int[], TF_DataType, IInitializer, bool, RefVariable> getter = null,
bool overwrite = false,
bool trainable = false)
{
var checkpoint_initializer = true;
var new_variable = getter(name, shape, dtype, initializer, trainable);
// If we set an initializer and the variable processed it, tracking will not
// assign again. It will add this variable to our dependencies, and if there
// is a non-trivial restoration queued, it will handle that. This also
// handles slot variables.
if (!overwrite || new_variable is RefVariable)
return _track_checkpointable(new_variable, name: name,
overwrite: overwrite);
else
return new_variable;
}
/// <summary>
/// Pop and load any deferred checkpoint restores into `trackable`.
/// </summary>
/// <param name="name"></param>
/// <param name="trackable"></param>
protected void _handle_deferred_dependencies(string name, RefVariable trackable)
{
_maybe_initialize_trackable();
// TODO
}
protected RefVariable _track_checkpointable(RefVariable checkpointable, string name, bool overwrite = false)
{
return checkpointable;
}
/// <summary>
/// Initialize dependency management.
/// </summary>
protected void _maybe_initialize_trackable()
{
// _self_unconditional_checkpoint_dependencies = []
_self_update_uid = -1;
}
}
}