@@ -7,7 +7,12 @@ namespace Tensorflow
77 public static partial class tf
88 {
99 public static IInitializer zeros_initializer => new Zeros ( ) ;
10+ public static IInitializer glorot_uniform => new GlorotUniform ( ) ;
1011
12+ public static variable_scope variable_scope ( string name_or_scope ,
13+ string default_name = null ,
14+ object values = null ) => new variable_scope ( name_or_scope , default_name , values ) ;
15+
1116 public class Zeros : IInitializer
1217 {
1318 private TF_DataType dtype ;
@@ -30,5 +35,105 @@ public object get_config()
3035 return new { dtype = dtype . name ( ) } ;
3136 }
3237 }
38+
39+ /// <summary>
40+ /// Initializer capable of adapting its scale to the shape of weights tensors.
41+ /// </summary>
42+ public class VarianceScaling : IInitializer
43+ {
44+ protected float _scale ;
45+ protected string _mode ;
46+ protected string _distribution ;
47+ protected int ? _seed ;
48+ protected TF_DataType _dtype ;
49+
50+ public VarianceScaling ( float scale = 1.0f ,
51+ string mode = "fan_in" ,
52+ string distribution = "truncated_normal" ,
53+ int ? seed = null ,
54+ TF_DataType dtype = TF_DataType . TF_FLOAT )
55+ {
56+ if ( scale < 0 )
57+ throw new ValueError ( "`scale` must be positive float." ) ;
58+ _scale = scale ;
59+ _mode = mode ;
60+ _distribution = distribution ;
61+ _seed = seed ;
62+ _dtype = dtype ;
63+ }
64+
65+ public Tensor call ( TensorShape shape , TF_DataType dtype )
66+ {
67+ var ( fan_in , fan_out ) = _compute_fans ( shape ) ;
68+ if ( _mode == "fan_in" )
69+ _scale /= Math . Max ( 1 , fan_in ) ;
70+ else if ( _mode == "fan_out" )
71+ _scale /= Math . Max ( 1 , fan_out ) ;
72+ else
73+ _scale /= Math . Max ( 1 , ( fan_in + fan_out ) / 2 ) ;
74+
75+ if ( _distribution == "normal" || _distribution == "truncated_normal" )
76+ {
77+ throw new NotImplementedException ( "truncated_normal" ) ;
78+ }
79+ else if ( _distribution == "untruncated_normal" )
80+ {
81+ throw new NotImplementedException ( "truncated_normal" ) ;
82+ }
83+ else
84+ {
85+ var limit = Math . Sqrt ( 3.0f * _scale ) ;
86+ return random_ops . random_uniform ( shape , ( float ) - limit , ( float ) limit , dtype , seed : _seed ) ;
87+ }
88+ }
89+
90+ private ( int , int ) _compute_fans ( int [ ] shape )
91+ {
92+ if ( shape . Length < 1 )
93+ return ( 1 , 1 ) ;
94+ if ( shape . Length == 1 )
95+ return ( shape [ 0 ] , shape [ 0 ] ) ;
96+ if ( shape . Length == 2 )
97+ return ( shape [ 0 ] , shape [ 1 ] ) ;
98+ else
99+ throw new NotImplementedException ( "VarianceScaling._compute_fans" ) ;
100+ }
101+
102+ public virtual object get_config ( )
103+ {
104+ return new
105+ {
106+ scale = _scale ,
107+ mode = _mode ,
108+ distribution = _distribution ,
109+ seed = _seed ,
110+ dtype = _dtype
111+ } ;
112+ }
113+ }
114+
115+ public class GlorotUniform : VarianceScaling
116+ {
117+ public GlorotUniform ( float scale = 1.0f ,
118+ string mode = "fan_avg" ,
119+ string distribution = "uniform" ,
120+ int ? seed = null ,
121+ TF_DataType dtype = TF_DataType . TF_FLOAT ) : base ( scale , mode , distribution , seed , dtype )
122+ {
123+
124+ }
125+
126+ public object get_config ( )
127+ {
128+ return new
129+ {
130+ scale = _scale ,
131+ mode = _mode ,
132+ distribution = _distribution ,
133+ seed = _seed ,
134+ dtype = _dtype
135+ } ;
136+ }
137+ }
33138 }
34139}
0 commit comments