forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathRefVariable.cs
More file actions
297 lines (256 loc) · 11.7 KB
/
RefVariable.cs
File metadata and controls
297 lines (256 loc) · 11.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
using Google.Protobuf;
using Google.Protobuf.Collections;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using static Tensorflow.Python;
namespace Tensorflow
{
public partial class RefVariable : VariableV1, IProtoBuf<VariableDef, RefVariable>
{
public bool _in_graph_mode = true;
public Tensor _initial_value;
public string _graph_key;
public bool _trainable;
public Tensor _variable;
public Tensor _snapshot;
public bool _save_slice_info;
private Operation _initializer_op;
public override Operation initializer => _initializer_op;
public override Operation op => _variable.op;
public Graph graph => _variable.graph;
public TF_DataType dtype => _variable.dtype;
public TensorShape shape => tensor_util.to_shape(_variable.shape);
public override string name => _variable.name;
public RefVariable(object initial_value = null,
bool trainable = true,
List<string> collections = null,
bool validate_shape = true,
string caching_device = "",
string name = null,
VariableDef variable_def = null,
TF_DataType dtype = TF_DataType.DtInvalid,
string import_scope = "") : base(initial_value,
trainable,
collections,
validate_shape,
caching_device,
name,
dtype)
{
_in_graph_mode = true;
if (variable_def != null)
{
if (initial_value != null)
throw new ValueError("variable_def and initial_value are mutually exclusive.");
_init_from_proto(variable_def, import_scope: import_scope);
}
else
{
_init_from_args(initial_value, trainable, collections, validate_shape, caching_device, name, dtype);
}
}
private void _init_from_proto(VariableDef variable_def, string import_scope = "")
{
var g = ops.get_default_graph();
_variable = g.as_graph_element(
ops.prepend_name_scope(variable_def.VariableName,
import_scope: import_scope)) as Tensor;
_initializer_op = g.as_graph_element(
ops.prepend_name_scope(variable_def.InitializerName,
import_scope: import_scope)) as Operation;
// Tests whether initial_value_name exists first for backwards compatibility.
if (!string.IsNullOrEmpty(variable_def.InitialValueName))
_initial_value = g.as_graph_element(
ops.prepend_name_scope(variable_def.InitialValueName,
import_scope: import_scope)) as Tensor;
else
_initial_value = null;
_trainable = variable_def.Trainable;
_snapshot = g.as_graph_element(
ops.prepend_name_scope(variable_def.SnapshotName,
import_scope: import_scope)) as Tensor;
if (variable_def.SaveSliceInfoDef != null)
throw new NotImplementedException("save_slice_info_def");
else
;// _save_slice_info = null;
//_caching_device = null;
//_constraint = null;
}
private void _init_from_args(object initial_value,
bool trainable = true,
List<string> collections = null,
bool validate_shape = true,
string caching_device = "",
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid)
{
if (initial_value is null)
throw new ValueError("initial_value must be specified.");
var init_from_fn = initial_value.GetType().Name == "Func`1";
if(collections == null)
{
collections = new List<string> { ops.GraphKeys.GLOBAL_VARIABLES };
}
// Store the graph key so optimizers know how to only retrieve variables from
// this graph.
_graph_key = ops.get_default_graph().graph_key;
_trainable = trainable;
if (trainable && !collections.Contains(ops.GraphKeys.TRAINABLE_VARIABLES))
collections.Add(ops.GraphKeys.TRAINABLE_VARIABLES);
ops.init_scope();
var values = init_from_fn ? new object[0] : new object[] { initial_value };
with(ops.name_scope(name, "Variable", values), scope =>
{
name = scope;
if (init_from_fn)
{
// Use attr_scope and device(None) to simulate the behavior of
// colocate_with when the variable we want to colocate with doesn't
// yet exist.
string true_name = ops._name_from_scope_name(name);
var attr = new AttrValue
{
List = new AttrValue.Types.ListValue()
};
attr.List.S.Add(ByteString.CopyFromUtf8($"loc:{true_name}"));
with(ops.name_scope("Initializer"), scope2 =>
{
_initial_value = (initial_value as Func<Tensor>)();
_initial_value = ops.convert_to_tensor(_initial_value, name: "initial_value", dtype: dtype);
});
_variable = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name);
}
// Or get the initial value from a Tensor or Python object.
else
{
_initial_value = ops.convert_to_tensor(initial_value, name: "initial_value");
var shape = _initial_value.shape;
dtype = _initial_value.dtype;
_variable = gen_state_ops.variable_v2(shape, dtype.as_base_dtype(), scope);
}
// Manually overrides the variable's shape with the initial value's.
if (validate_shape)
{
var initial_value_shape = _initial_value.GetShape();
if (!initial_value_shape.is_fully_defined())
throw new ValueError($"initial_value must have a shape specified: {_initial_value}");
}
// If 'initial_value' makes use of other variables, make sure we don't
// have an issue if these other variables aren't initialized first by
// using their initialized_value() method.
var _initial_value2 = _try_guard_against_uninitialized_dependencies(_initial_value);
_initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op;
if (!String.IsNullOrEmpty(caching_device))
{
}
else
{
ops.colocate_with(_initializer_op);
_snapshot = gen_array_ops.identity(_variable, name = "read");
}
ops.add_to_collections(collections, this as VariableV1);
});
}
public Tensor _ref() => _variable;
public Tensor value() => _snapshot;
public Tensor _AsTensor() => _snapshot;
public Tensor _as_graph_element() => _variable;
public Tensor _TensorConversionFunction(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false)
{
if (as_ref)
return _ref();
else
return value();
}
/// <summary>
/// Attempt to guard against dependencies on uninitialized variables.
/// </summary>
/// <param name="initial_value"></param>
private Tensor _try_guard_against_uninitialized_dependencies(Tensor initial_value)
{
return _safe_initial_value_from_tensor(initial_value, new Dictionary<string, Operation>());
}
/// <summary>
/// Replace dependencies on variables with their initialized values.
/// </summary>
/// <param name="tensor">A `Tensor`. The tensor to replace.</param>
/// <param name="op_cache">A dict mapping operation names to `Operation`s.</param>
/// <returns>A `Tensor` compatible with `tensor`.</returns>
private Tensor _safe_initial_value_from_tensor(Tensor tensor, Dictionary<string, Operation> op_cache)
{
var op = tensor.op;
var new_op = op_cache.ContainsKey(op.name) ? op_cache[op.name] : null;
if(new_op == null)
{
new_op = _safe_initial_value_from_op(op, op_cache);
op_cache[op.name] = new_op;
}
return new_op.outputs[tensor.value_index];
}
private Operation _safe_initial_value_from_op(Operation op, Dictionary<string, Operation> op_cache)
{
var op_type = op.node_def.Op;
switch (op_type)
{
case "IsVariableInitialized":
case "VarIsInitializedOp":
case "ReadVariableOp":
return op;
case "Variable":
case "VariableV2":
case "VarHandleOp":
break;
}
// Recursively build initializer expressions for inputs.
return op;
}
/// <summary>
/// Assigns a new value to the variable.
/// </summary>
/// <param name="value">The new value for this variable.</param>
/// <param name="use_locking">If `True`, use locking during the assignment.</param>
/// <param name="name">The name of the operation to be created</param>
/// <param name="read_value">
/// if True, will return something which evaluates to the
/// new value of the variable; if False will return the assign op.
/// </param>
/// <returns>
/// A `Tensor` that will hold the new value of this variable after
/// the assignment has completed.
/// </returns>
public ITensorOrOperation assign(object value, bool use_locking = false, string name = null, bool read_value = true)
{
var assign = gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name);
if (read_value)
return assign;
return assign.op;
}
public override string ToString()
{
return $"tf.RefVariable '{name}' shape={shape} dtype={dtype}";
}
public VariableDef to_proto(string export_scope)
{
if(string.IsNullOrEmpty(export_scope) || _variable.name.StartsWith(export_scope))
{
var var_def = new VariableDef();
var_def.VariableName = ops.strip_name_scope(_variable.name, export_scope);
if (_initial_value != null)
var_def.InitialValueName = ops.strip_name_scope(_initial_value.name, export_scope);
var_def.Trainable = _trainable;
var_def.InitializerName = ops.strip_name_scope(initializer.name, export_scope);
var_def.SnapshotName = ops.strip_name_scope(_snapshot.name, export_scope);
if (_save_slice_info)
throw new NotImplementedException("to_proto _save_slice_info");
return var_def;
}
throw new NotImplementedException("to_proto RefVariable");
}
public RefVariable from_proto(VariableDef proto, string import_scope)
{
throw new NotImplementedException();
}
}
}