Skip to content

Commit a4e52ab

Browse files
committed
add tf.keras
1 parent 209c3b9 commit a4e52ab

22 files changed

Lines changed: 414 additions & 153 deletions

TensorFlow.NET.sln

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\T
1111
EndProject
1212
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Visualization", "TensorFlowNET.Visualization\TensorFlowNET.Visualization.csproj", "{4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}"
1313
EndProject
14-
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{E8340C61-12C1-4BEE-A340-403E7C1ACD82}"
15-
EndProject
16-
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "scikit-learn", "..\scikit-learn.net\src\scikit-learn\scikit-learn.csproj", "{199DDAD8-4A6F-43B3-A560-C0393619E304}"
14+
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{1F4C5683-48B4-4328-8171-E9F7ABAA2E72}"
1715
EndProject
1816
Global
1917
GlobalSection(SolutionConfigurationPlatforms) = preSolution
@@ -37,14 +35,10 @@ Global
3735
{4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}.Debug|Any CPU.Build.0 = Debug|Any CPU
3836
{4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}.Release|Any CPU.ActiveCfg = Release|Any CPU
3937
{4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}.Release|Any CPU.Build.0 = Release|Any CPU
40-
{E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
41-
{E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Debug|Any CPU.Build.0 = Debug|Any CPU
42-
{E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Release|Any CPU.ActiveCfg = Release|Any CPU
43-
{E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Release|Any CPU.Build.0 = Release|Any CPU
44-
{199DDAD8-4A6F-43B3-A560-C0393619E304}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
45-
{199DDAD8-4A6F-43B3-A560-C0393619E304}.Debug|Any CPU.Build.0 = Debug|Any CPU
46-
{199DDAD8-4A6F-43B3-A560-C0393619E304}.Release|Any CPU.ActiveCfg = Release|Any CPU
47-
{199DDAD8-4A6F-43B3-A560-C0393619E304}.Release|Any CPU.Build.0 = Release|Any CPU
38+
{1F4C5683-48B4-4328-8171-E9F7ABAA2E72}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
39+
{1F4C5683-48B4-4328-8171-E9F7ABAA2E72}.Debug|Any CPU.Build.0 = Debug|Any CPU
40+
{1F4C5683-48B4-4328-8171-E9F7ABAA2E72}.Release|Any CPU.ActiveCfg = Release|Any CPU
41+
{1F4C5683-48B4-4328-8171-E9F7ABAA2E72}.Release|Any CPU.Build.0 = Release|Any CPU
4842
EndGlobalSection
4943
GlobalSection(SolutionProperties) = preSolution
5044
HideSolutionNode = FALSE
Lines changed: 1 addition & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Text;
4+
using Tensorflow.Operations.Initializers;
45

56
namespace Tensorflow
67
{
@@ -24,128 +25,5 @@ public static variable_scope variable_scope(VariableScope scope,
2425
default_name,
2526
values,
2627
auxiliary_name_scope);
27-
28-
public class Zeros : IInitializer
29-
{
30-
private TF_DataType dtype;
31-
32-
public Zeros(TF_DataType dtype = TF_DataType.TF_FLOAT)
33-
{
34-
this.dtype = dtype;
35-
}
36-
37-
public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid)
38-
{
39-
if (dtype == TF_DataType.DtInvalid)
40-
dtype = this.dtype;
41-
42-
return array_ops.zeros(shape, dtype);
43-
}
44-
45-
public object get_config()
46-
{
47-
return new { dtype = dtype.name() };
48-
}
49-
}
50-
51-
/// <summary>
52-
/// Initializer capable of adapting its scale to the shape of weights tensors.
53-
/// </summary>
54-
public class VarianceScaling : IInitializer
55-
{
56-
protected float _scale;
57-
protected string _mode;
58-
protected string _distribution;
59-
protected int? _seed;
60-
protected TF_DataType _dtype;
61-
62-
public VarianceScaling(float scale = 1.0f,
63-
string mode = "fan_in",
64-
string distribution= "truncated_normal",
65-
int? seed = null,
66-
TF_DataType dtype = TF_DataType.TF_FLOAT)
67-
{
68-
if (scale < 0)
69-
throw new ValueError("`scale` must be positive float.");
70-
_scale = scale;
71-
_mode = mode;
72-
_distribution = distribution;
73-
_seed = seed;
74-
_dtype = dtype;
75-
}
76-
77-
public Tensor call(TensorShape shape, TF_DataType dtype)
78-
{
79-
var (fan_in, fan_out) = _compute_fans(shape);
80-
if (_mode == "fan_in")
81-
_scale /= Math.Max(1, fan_in);
82-
else if (_mode == "fan_out")
83-
_scale /= Math.Max(1, fan_out);
84-
else
85-
_scale /= Math.Max(1, (fan_in + fan_out) / 2);
86-
87-
if (_distribution == "normal" || _distribution == "truncated_normal")
88-
{
89-
throw new NotImplementedException("truncated_normal");
90-
}
91-
else if(_distribution == "untruncated_normal")
92-
{
93-
throw new NotImplementedException("truncated_normal");
94-
}
95-
else
96-
{
97-
var limit = Math.Sqrt(3.0f * _scale);
98-
return random_ops.random_uniform(shape, (float)-limit, (float)limit, dtype, seed: _seed);
99-
}
100-
}
101-
102-
private (int, int) _compute_fans(int[] shape)
103-
{
104-
if (shape.Length < 1)
105-
return (1, 1);
106-
if (shape.Length == 1)
107-
return (shape[0], shape[0]);
108-
if (shape.Length == 2)
109-
return (shape[0], shape[1]);
110-
else
111-
throw new NotImplementedException("VarianceScaling._compute_fans");
112-
}
113-
114-
public virtual object get_config()
115-
{
116-
return new
117-
{
118-
scale = _scale,
119-
mode = _mode,
120-
distribution = _distribution,
121-
seed = _seed,
122-
dtype = _dtype
123-
};
124-
}
125-
}
126-
127-
public class GlorotUniform : VarianceScaling
128-
{
129-
public GlorotUniform(float scale = 1.0f,
130-
string mode = "fan_avg",
131-
string distribution = "uniform",
132-
int? seed = null,
133-
TF_DataType dtype = TF_DataType.TF_FLOAT) : base(scale, mode, distribution, seed, dtype)
134-
{
135-
136-
}
137-
138-
public object get_config()
139-
{
140-
return new
141-
{
142-
scale = _scale,
143-
mode = _mode,
144-
distribution = _distribution,
145-
seed = _seed,
146-
dtype = _dtype
147-
};
148-
}
149-
}
15028
}
15129
}

src/TensorFlowNET.Core/APIs/tf.random.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,12 @@ public static Tensor random_normal(int[] shape,
2222
TF_DataType dtype = TF_DataType.TF_FLOAT,
2323
int? seed = null,
2424
string name = null) => random_ops.random_normal(shape, mean, stddev, dtype, seed, name);
25+
26+
public static Tensor random_uniform(int[] shape,
27+
float minval = 0,
28+
float? maxval = null,
29+
TF_DataType dtype = TF_DataType.TF_FLOAT,
30+
int? seed = null,
31+
string name = null) => random_ops.random_uniform(shape, minval, maxval, dtype, seed, name);
2532
}
2633
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Operations.Initializers;
5+
6+
namespace Tensorflow.Keras
7+
{
8+
public class Initializers
9+
{
10+
/// <summary>
11+
/// He normal initializer.
12+
/// </summary>
13+
/// <param name="seed"></param>
14+
/// <returns></returns>
15+
public IInitializer he_normal(int? seed = null)
16+
{
17+
return new VarianceScaling(scale: 20f, mode: "fan_in", distribution: "truncated_normal", seed: seed);
18+
}
19+
}
20+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Keras;
5+
6+
namespace Tensorflow
7+
{
8+
public static partial class tf
9+
{
10+
public static class keras
11+
{
12+
public static Initializers initializers => new Initializers();
13+
}
14+
}
15+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Operations.Initializers
6+
{
7+
public class GlorotUniform : VarianceScaling
8+
{
9+
public GlorotUniform(float scale = 1.0f,
10+
string mode = "fan_avg",
11+
string distribution = "uniform",
12+
int? seed = null,
13+
TF_DataType dtype = TF_DataType.TF_FLOAT) : base(scale, mode, distribution, seed, dtype)
14+
{
15+
16+
}
17+
18+
public object get_config()
19+
{
20+
return new
21+
{
22+
scale = _scale,
23+
mode = _mode,
24+
distribution = _distribution,
25+
seed = _seed,
26+
dtype = _dtype
27+
};
28+
}
29+
}
30+
}

src/TensorFlowNET.Core/Operations/IInitializer.cs renamed to src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ namespace Tensorflow
66
{
77
public interface IInitializer
88
{
9-
Tensor call(TensorShape shape, TF_DataType dtype);
9+
Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid);
1010
object get_config();
1111
}
1212
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Operations.Initializers
6+
{
7+
public class TruncatedNormal : IInitializer
8+
{
9+
private float mean;
10+
private float stddev;
11+
private int? seed;
12+
private TF_DataType dtype;
13+
14+
public TruncatedNormal(float mean = 0.0f,
15+
float stddev = 1.0f,
16+
int? seed = null,
17+
TF_DataType dtype = TF_DataType.TF_FLOAT)
18+
{
19+
this.mean = mean;
20+
this.stddev = stddev;
21+
this.seed = seed;
22+
this.dtype = dtype;
23+
}
24+
25+
public Tensor call(TensorShape shape, TF_DataType dtype)
26+
{
27+
throw new NotImplementedException("");
28+
}
29+
30+
public object get_config()
31+
{
32+
return new
33+
{
34+
mean = mean,
35+
stddev = stddev,
36+
seed = seed,
37+
dtype = dtype.name()
38+
};
39+
}
40+
}
41+
}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Operations.Initializers
6+
{
7+
/// <summary>
8+
/// Initializer capable of adapting its scale to the shape of weights tensors.
9+
/// </summary>
10+
public class VarianceScaling : IInitializer
11+
{
12+
protected float _scale;
13+
protected string _mode;
14+
protected string _distribution;
15+
protected int? _seed;
16+
protected TF_DataType _dtype;
17+
18+
public VarianceScaling(float scale = 1.0f,
19+
string mode = "fan_in",
20+
string distribution = "truncated_normal",
21+
int? seed = null,
22+
TF_DataType dtype = TF_DataType.TF_FLOAT)
23+
{
24+
if (scale < 0)
25+
throw new ValueError("`scale` must be positive float.");
26+
_scale = scale;
27+
_mode = mode;
28+
_distribution = distribution;
29+
_seed = seed;
30+
_dtype = dtype;
31+
}
32+
33+
public Tensor call(TensorShape shape, TF_DataType dtype)
34+
{
35+
var (fan_in, fan_out) = _compute_fans(shape);
36+
if (_mode == "fan_in")
37+
_scale /= Math.Max(1, fan_in);
38+
else if (_mode == "fan_out")
39+
_scale /= Math.Max(1, fan_out);
40+
else
41+
_scale /= Math.Max(1, (fan_in + fan_out) / 2);
42+
43+
if (_distribution == "normal" || _distribution == "truncated_normal")
44+
{
45+
throw new NotImplementedException("truncated_normal");
46+
}
47+
else if (_distribution == "untruncated_normal")
48+
{
49+
throw new NotImplementedException("truncated_normal");
50+
}
51+
else
52+
{
53+
var limit = Math.Sqrt(3.0f * _scale);
54+
return random_ops.random_uniform(shape, (float)-limit, (float)limit, dtype, seed: _seed);
55+
}
56+
}
57+
58+
private (int, int) _compute_fans(int[] shape)
59+
{
60+
if (shape.Length < 1)
61+
return (1, 1);
62+
if (shape.Length == 1)
63+
return (shape[0], shape[0]);
64+
if (shape.Length == 2)
65+
return (shape[0], shape[1]);
66+
else
67+
throw new NotImplementedException("VarianceScaling._compute_fans");
68+
}
69+
70+
public virtual object get_config()
71+
{
72+
return new
73+
{
74+
scale = _scale,
75+
mode = _mode,
76+
distribution = _distribution,
77+
seed = _seed,
78+
dtype = _dtype
79+
};
80+
}
81+
}
82+
}

0 commit comments

Comments
 (0)