Skip to content

Commit 50ef868

Browse files
committed
_name_stack isn't correct before entering random_uniform
1 parent 37f98b5 commit 50ef868

12 files changed

Lines changed: 309 additions & 16 deletions

File tree

src/TensorFlowNET.Core/APIs/tf.init.cs

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@ namespace Tensorflow
77
public static partial class tf
88
{
99
public static IInitializer zeros_initializer => new Zeros();
10+
public static IInitializer glorot_uniform => new GlorotUniform();
1011

12+
public static variable_scope variable_scope(string name_or_scope,
13+
string default_name = null,
14+
object values = null) => new variable_scope(name_or_scope, default_name, values);
15+
1116
public class Zeros : IInitializer
1217
{
1318
private TF_DataType dtype;
@@ -30,5 +35,105 @@ public object get_config()
3035
return new { dtype = dtype.name() };
3136
}
3237
}
38+
39+
/// <summary>
40+
/// Initializer capable of adapting its scale to the shape of weights tensors.
41+
/// </summary>
42+
public class VarianceScaling : IInitializer
43+
{
44+
protected float _scale;
45+
protected string _mode;
46+
protected string _distribution;
47+
protected int? _seed;
48+
protected TF_DataType _dtype;
49+
50+
public VarianceScaling(float scale = 1.0f,
51+
string mode = "fan_in",
52+
string distribution= "truncated_normal",
53+
int? seed = null,
54+
TF_DataType dtype = TF_DataType.TF_FLOAT)
55+
{
56+
if (scale < 0)
57+
throw new ValueError("`scale` must be positive float.");
58+
_scale = scale;
59+
_mode = mode;
60+
_distribution = distribution;
61+
_seed = seed;
62+
_dtype = dtype;
63+
}
64+
65+
public Tensor call(TensorShape shape, TF_DataType dtype)
66+
{
67+
var (fan_in, fan_out) = _compute_fans(shape);
68+
if (_mode == "fan_in")
69+
_scale /= Math.Max(1, fan_in);
70+
else if (_mode == "fan_out")
71+
_scale /= Math.Max(1, fan_out);
72+
else
73+
_scale /= Math.Max(1, (fan_in + fan_out) / 2);
74+
75+
if (_distribution == "normal" || _distribution == "truncated_normal")
76+
{
77+
throw new NotImplementedException("truncated_normal");
78+
}
79+
else if(_distribution == "untruncated_normal")
80+
{
81+
throw new NotImplementedException("truncated_normal");
82+
}
83+
else
84+
{
85+
var limit = Math.Sqrt(3.0f * _scale);
86+
return random_ops.random_uniform(shape, (float)-limit, (float)limit, dtype, seed: _seed);
87+
}
88+
}
89+
90+
private (int, int) _compute_fans(int[] shape)
91+
{
92+
if (shape.Length < 1)
93+
return (1, 1);
94+
if (shape.Length == 1)
95+
return (shape[0], shape[0]);
96+
if (shape.Length == 2)
97+
return (shape[0], shape[1]);
98+
else
99+
throw new NotImplementedException("VarianceScaling._compute_fans");
100+
}
101+
102+
public virtual object get_config()
103+
{
104+
return new
105+
{
106+
scale = _scale,
107+
mode = _mode,
108+
distribution = _distribution,
109+
seed = _seed,
110+
dtype = _dtype
111+
};
112+
}
113+
}
114+
115+
public class GlorotUniform : VarianceScaling
116+
{
117+
public GlorotUniform(float scale = 1.0f,
118+
string mode = "fan_avg",
119+
string distribution = "uniform",
120+
int? seed = null,
121+
TF_DataType dtype = TF_DataType.TF_FLOAT) : base(scale, mode, distribution, seed, dtype)
122+
{
123+
124+
}
125+
126+
public object get_config()
127+
{
128+
return new
129+
{
130+
scale = _scale,
131+
mode = _mode,
132+
distribution = _distribution,
133+
seed = _seed,
134+
dtype = _dtype
135+
};
136+
}
137+
}
33138
}
34139
}

src/TensorFlowNET.Core/Operations/random_ops.py.cs

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,18 @@
44

55
namespace Tensorflow
66
{
7-
public class random_ops
7+
public class random_ops : Python
88
{
9+
/// <summary>
10+
///
11+
/// </summary>
12+
/// <param name="shape"></param>
13+
/// <param name="mean"></param>
14+
/// <param name="stddev"></param>
15+
/// <param name="dtype"></param>
16+
/// <param name="seed"></param>
17+
/// <param name="name"></param>
18+
/// <returns></returns>
919
public static Tensor random_normal(int[] shape,
1020
float mean = 0.0f,
1121
float stddev = 1.0f,
@@ -26,6 +36,30 @@ public static Tensor random_normal(int[] shape,
2636
});
2737
}
2838

39+
/// <summary>
40+
/// Outputs random values from a uniform distribution.
41+
/// </summary>
42+
/// <param name="shape"></param>
43+
/// <param name="minval"></param>
44+
/// <param name="maxval"></param>
45+
/// <param name="dtype">The type of the output</param>
46+
/// <param name="seed">Used to create a random seed for the distribution.</param>
47+
/// <param name="name">A name for the operation</param>
48+
/// <returns>A tensor of the specified shape filled with random uniform values.</returns>
49+
public static Tensor random_uniform(int[] shape,
50+
float minval = 0,
51+
float? maxval = null,
52+
TF_DataType dtype = TF_DataType.TF_FLOAT,
53+
int? seed = null,
54+
string name = null)
55+
{
56+
return with<ops.name_scope, Tensor>(new ops.name_scope(name, "random_uniform", new { shape, minval, maxval }), scope =>
57+
{
58+
name = scope;
59+
return null;
60+
});
61+
}
62+
2963
private static Tensor _ShapeTensor(int[] shape)
3064
{
3165
return ops.convert_to_tensor(shape, name: "shape");

src/TensorFlowNET.Core/TensorFlowNET.Core.csproj

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
<TargetFramework>netstandard2.0</TargetFramework>
55
<AssemblyName>TensorFlow.NET</AssemblyName>
66
<RootNamespace>Tensorflow</RootNamespace>
7-
<Version>0.4.0</Version>
7+
<Version>0.4.1</Version>
88
<Authors>Haiping Chen</Authors>
99
<Company>SciSharp STACK</Company>
1010
<GeneratePackageOnBuild>true</GeneratePackageOnBuild>
@@ -16,11 +16,11 @@
1616
<PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET</PackageTags>
1717
<Description>Google's TensorFlow binding in .NET Standard.
1818
Docs: https://tensorflownet.readthedocs.io</Description>
19-
<AssemblyVersion>0.4.0.0</AssemblyVersion>
20-
<PackageReleaseNotes>Added Linear Regression example.
21-
</PackageReleaseNotes>
19+
<AssemblyVersion>0.4.1.0</AssemblyVersion>
20+
<PackageReleaseNotes>Added ConfigProto to control CPU and GPU resource.
21+
Fixed import name scope issue.</PackageReleaseNotes>
2222
<LangVersion>7.2</LangVersion>
23-
<FileVersion>0.4.0.0</FileVersion>
23+
<FileVersion>0.4.1.0</FileVersion>
2424
</PropertyGroup>
2525

2626
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public class PureVariableScope : IPython
8+
{
9+
private string _name_or_scope;
10+
private string _new_name;
11+
private string _old_name_scope;
12+
private bool _reuse;
13+
private _VariableStore _var_store;
14+
private VariableScope _old;
15+
private _VariableScopeStore _var_scope_store;
16+
private VariableScope variable_scope_object;
17+
18+
public PureVariableScope(string name_or_scope,
19+
string old_name_scope = null,
20+
TF_DataType dtype = TF_DataType.DtInvalid)
21+
{
22+
_name_or_scope = name_or_scope;
23+
_old_name_scope = old_name_scope;
24+
_var_store = variable_scope._get_default_variable_store();
25+
_var_scope_store = variable_scope.get_variable_scope_store();
26+
}
27+
28+
public void __enter__()
29+
{
30+
_old = _var_scope_store.current_scope;
31+
_new_name = string.IsNullOrEmpty(_old.name) ? _name_or_scope : _old.name + "/" + _name_or_scope;
32+
_reuse = _reuse || _old.resue;
33+
string name_scope = _old_name_scope == null ? _name_or_scope : _old_name_scope;
34+
35+
variable_scope_object = new VariableScope(_reuse,
36+
name: _new_name,
37+
name_scope: name_scope);
38+
39+
_var_scope_store.open_variable_scope(_new_name);
40+
_var_scope_store.current_scope = variable_scope_object;
41+
}
42+
43+
public void Dispose()
44+
{
45+
46+
}
47+
48+
public void __exit__()
49+
{
50+
51+
}
52+
53+
public static implicit operator VariableScope(PureVariableScope scope)
54+
{
55+
return scope.variable_scope_object;
56+
}
57+
}
58+
}

src/TensorFlowNET.Core/Variables/VariableScope.cs

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,26 @@
44

55
namespace Tensorflow
66
{
7+
/// <summary>
8+
/// Variable scope object to carry defaults to provide to `get_variable`
9+
/// </summary>
710
public class VariableScope
811
{
912
public bool use_resource { get; set; }
10-
private _ReuseMode _reuse { get; set; }
13+
private _ReuseMode _reuse;
14+
public bool resue;
1115

12-
private object _regularizer;
1316
private TF_DataType _dtype;
1417
public string name { get; set; }
18+
public string name_scope { get; set; }
1519

16-
public VariableScope(TF_DataType dtype = TF_DataType.TF_FLOAT)
20+
public VariableScope(bool reuse,
21+
string name = "",
22+
string name_scope = "",
23+
TF_DataType dtype = TF_DataType.TF_FLOAT)
1724
{
25+
this.name = name;
26+
this.name_scope = name_scope;
1827
_reuse = _ReuseMode.AUTO_REUSE;
1928
_dtype = dtype;
2029
}
@@ -29,7 +38,7 @@ public RefVariable get_variable(_VariableStore var_store,
2938
VariableAggregation aggregation= VariableAggregation.NONE)
3039
{
3140
string full_name = !string.IsNullOrEmpty(this.name) ? this.name + "/" + name : name;
32-
return Python.with<ops.name_scope, RefVariable>(new ops.name_scope(""), scope =>
41+
return Python.with<ops.name_scope, RefVariable>(new ops.name_scope(null), scope =>
3342
{
3443
if (dtype == TF_DataType.DtInvalid)
3544
dtype = _dtype;

src/TensorFlowNET.Core/Variables/_VariableScopeStore.cs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,20 @@ namespace Tensorflow
77
public class _VariableScopeStore
88
{
99
public VariableScope current_scope { get; set; }
10+
private Dictionary<string, int> variable_scopes_count;
1011

1112
public _VariableScopeStore()
1213
{
13-
current_scope = new VariableScope();
14+
current_scope = new VariableScope(false);
15+
variable_scopes_count = new Dictionary<string, int>();
16+
}
17+
18+
public void open_variable_scope(string scope_name)
19+
{
20+
if (variable_scopes_count.ContainsKey(scope_name))
21+
variable_scopes_count[scope_name] += 1;
22+
else
23+
variable_scopes_count[scope_name] = 1;
1424
}
1525
}
1626
}

src/TensorFlowNET.Core/Variables/_VariableStore.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,14 @@ private RefVariable _get_single_variable(string name,
8787
}
8888

8989
Tensor init_val = null;
90+
91+
// Create the tensor to initialize the variable with default value.
92+
if (initializer == null)
93+
{
94+
if (dtype.is_floating())
95+
initializer = tf.glorot_uniform;
96+
}
97+
9098
ops.init_scope();
9199
{
92100
if (initializing_from_value)

src/TensorFlowNET.Core/Variables/tf.variable.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ public static RefVariable get_variable(string name,
2020
VariableSynchronization synchronization = VariableSynchronization.AUTO,
2121
VariableAggregation aggregation = VariableAggregation.NONE)
2222
{
23-
var scope = variable_scope.get_variable_scope();
24-
var store = variable_scope._get_default_variable_store();
23+
var scope = Tensorflow.variable_scope.get_variable_scope();
24+
var store = Tensorflow.variable_scope._get_default_variable_store();
2525
return scope.get_variable(store,
2626
name,
2727
shape: shape,

0 commit comments

Comments
 (0)