Skip to content

Commit bf45277

Browse files
committed
add VariableScope and _VariableStore
1 parent 7c420df commit bf45277

11 files changed

Lines changed: 265 additions & 12 deletions

File tree

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public interface IInitializer
8+
{
9+
Tensor call(TensorShape shape, TF_DataType dtype);
10+
object get_config();
11+
}
12+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public static partial class tf
8+
{
9+
public static IInitializer zeros_initializer => new Zeros();
10+
11+
public class Zeros : IInitializer
12+
{
13+
private TF_DataType dtype;
14+
15+
public Zeros(TF_DataType dtype = TF_DataType.TF_FLOAT)
16+
{
17+
this.dtype = dtype;
18+
}
19+
20+
public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid)
21+
{
22+
if (dtype == TF_DataType.DtInvalid)
23+
dtype = this.dtype;
24+
25+
return array_ops.zeros(shape, dtype);
26+
}
27+
28+
public object get_config()
29+
{
30+
return new { dtype = dtype.name() };
31+
}
32+
}
33+
}
34+
}

src/TensorFlowNET.Core/Tensors/dtypes.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ public static TF_DataType as_base_dtype(this TF_DataType type)
7171
type;
7272
}
7373

74+
public static int name(this TF_DataType type)
75+
{
76+
return (int)type;
77+
}
78+
7479
public static DataType as_base_dtype(this DataType type)
7580
{
7681
return (int)type > 100 ?
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public enum VariableAggregation
8+
{
9+
NONE = 0,
10+
SUM = 1,
11+
MEAN = 2,
12+
ONLY_FIRST_REPLICA = 3 // ONLY_FIRST_TOWER
13+
}
14+
}

src/TensorFlowNET.Core/Variables/VariableScope.cs

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,35 @@ namespace Tensorflow
66
{
77
public class VariableScope
88
{
9-
public bool? use_resource { get; set; }
9+
public bool use_resource { get; set; }
10+
private _ReuseMode _reuse { get; set; }
11+
12+
private object _regularizer;
13+
private TF_DataType _dtype;
14+
public string name { get; set; }
15+
16+
public VariableScope()
17+
{
18+
_reuse = _ReuseMode.AUTO_REUSE;
19+
}
20+
21+
public RefVariable get_variable(_VariableStore var_store,
22+
string name,
23+
TensorShape shape = null,
24+
TF_DataType dtype = TF_DataType.DtInvalid,
25+
VariableSynchronization synchronization = VariableSynchronization.AUTO,
26+
VariableAggregation aggregation= VariableAggregation.NONE)
27+
{
28+
string full_name = !string.IsNullOrEmpty(this.name) ? this.name + "/" + name : name;
29+
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(""), scope =>
30+
{
31+
if (dtype == TF_DataType.DtInvalid)
32+
dtype = _dtype;
33+
34+
return var_store.get_variable(full_name);
35+
36+
});
37+
38+
}
1039
}
1140
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
/// <summary>
8+
/// Mode for variable access within a variable scope.
9+
/// </summary>
10+
public enum _ReuseMode
11+
{
12+
// Indicates that variables are to be fetched if they already exist or
13+
// otherwise created.
14+
AUTO_REUSE = 1
15+
}
16+
}
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
/// <summary>
8+
/// Variable store that carries a number of named Variables.
9+
/// </summary>
10+
public class _VariableStore
11+
{
12+
private Dictionary<string, object> _vars;
13+
private Dictionary<string, object> _partitioned_vars;
14+
private bool _store_eager_variables;
15+
16+
public _VariableStore()
17+
{
18+
_vars = new Dictionary<string, object>();
19+
_partitioned_vars = new Dictionary<string, object>();
20+
_store_eager_variables = false;
21+
}
22+
23+
public RefVariable get_variable(string name,
24+
TensorShape shape = null,
25+
TF_DataType dtype = TF_DataType.TF_FLOAT,
26+
IInitializer initializer = null,
27+
bool trainable = false,
28+
bool validate_shape = true,
29+
VariableSynchronization synchronization = VariableSynchronization.AUTO,
30+
VariableAggregation aggregation = VariableAggregation.NONE)
31+
{
32+
dtype = dtype.as_base_dtype();
33+
trainable = variable_scope._get_trainable_value(synchronization, trainable);
34+
35+
return _true_getter(name,
36+
shape: shape,
37+
dtype: dtype,
38+
initializer: initializer,
39+
trainable: trainable,
40+
validate_shape: validate_shape,
41+
synchronization: synchronization,
42+
aggregation: aggregation);
43+
}
44+
45+
private RefVariable _true_getter(string name,
46+
TensorShape shape = null,
47+
TF_DataType dtype = TF_DataType.DtInvalid,
48+
IInitializer initializer = null,
49+
bool trainable = false,
50+
bool validate_shape = true,
51+
VariableSynchronization synchronization = VariableSynchronization.AUTO,
52+
VariableAggregation aggregation = VariableAggregation.NONE)
53+
{
54+
return _get_single_variable(name: name);
55+
}
56+
57+
private RefVariable _get_single_variable(string name,
58+
TensorShape shape = null,
59+
TF_DataType dtype = TF_DataType.DtInvalid,
60+
IInitializer initializer = null,
61+
bool reuse = false,
62+
bool trainable = false,
63+
bool validate_shape = false,
64+
VariableSynchronization synchronization = VariableSynchronization.AUTO,
65+
VariableAggregation aggregation = VariableAggregation.NONE)
66+
{
67+
if (_vars.ContainsKey(name))
68+
{
69+
if (!reuse)
70+
{
71+
var var = _vars[name];
72+
73+
}
74+
throw new NotImplementedException("_get_single_variable");
75+
}
76+
77+
throw new NotImplementedException("_get_single_variable");
78+
}
79+
}
80+
}

src/TensorFlowNET.Core/Variables/tf.variable.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,15 @@ public static Operation global_variables_initializer()
1111
var g = variables.global_variables();
1212
return variables.variables_initializer(g.ToArray());
1313
}
14+
15+
public static RefVariable get_variable(string name,
16+
TensorShape shape = null,
17+
IInitializer initializer = null,
18+
VariableSynchronization synchronization = VariableSynchronization.AUTO,
19+
VariableAggregation aggregation = VariableAggregation.NONE)
20+
{
21+
var store = variable_scope._get_default_variable_store();
22+
return variable_scope.get_variable_scope().get_variable(store, name, shape: shape);
23+
}
1424
}
1525
}

src/TensorFlowNET.Core/Variables/variable_scope.py.cs

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ namespace Tensorflow
66
{
77
public class variable_scope
88
{
9+
public static string _VARSTORE_KEY = "__variable_store";
910
public static string _VARSCOPESTORE_KEY = "__varscope";
1011
public static bool _DEFAULT_USE_RESOURCE = false;
1112

@@ -32,6 +33,17 @@ public static RefVariable default_variable_creator(object initial_value, string
3233
}
3334
}
3435

36+
public static _VariableStore _get_default_variable_store()
37+
{
38+
var store = ops.get_collection(_VARSTORE_KEY);
39+
if (store != null)
40+
return (store as List<_VariableStore>)[0];
41+
42+
var store1 = new _VariableStore();
43+
ops.add_to_collection(_VARSTORE_KEY, store1);
44+
return store1;
45+
}
46+
3547
public static VariableScope get_variable_scope()
3648
{
3749
return get_variable_scope_store().current_scope;
@@ -65,24 +77,18 @@ public static _VariableScopeStore get_variable_scope_store()
6577
return ret;
6678
}
6779

68-
public static bool _get_trainable_value(VariableSynchronization synchronization, bool? trainable = null)
80+
public static bool _get_trainable_value(VariableSynchronization synchronization, bool trainable = true)
6981
{
70-
if(synchronization == VariableSynchronization.ON_READ)
82+
if (synchronization == VariableSynchronization.ON_READ)
7183
{
72-
if (trainable.Value)
84+
if (trainable)
7385
throw new ValueError("Synchronization value can be set to " +
7486
"VariableSynchronization.ON_READ only for non-trainable variables. " +
7587
"You have specified trainable=True and " +
7688
"synchronization=VariableSynchronization.ON_READ.");
77-
else
78-
trainable = false;
7989
}
80-
else if (!trainable.HasValue)
81-
{
82-
trainable = true;
83-
}
84-
85-
return trainable.Value;
90+
91+
return trainable;
8692
}
8793
}
8894
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Text;
5+
using Tensorflow;
6+
7+
namespace TensorFlowNET.UnitTest
8+
{
9+
[TestClass]
10+
public class TrainSaverTest
11+
{
12+
[TestMethod]
13+
public void Save()
14+
{
15+
var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer);
16+
var v2 = tf.get_variable("v2", shape: new TensorShape(5), initializer: tf.zeros_initializer);
17+
18+
19+
}
20+
}
21+
}

0 commit comments

Comments
 (0)