11using System ;
22using System . Collections . Generic ;
33using System . Text ;
4+ using Tensorflow . Operations . Initializers ;
45
56namespace Tensorflow
67{
@@ -24,128 +25,5 @@ public static variable_scope variable_scope(VariableScope scope,
2425 default_name ,
2526 values ,
2627 auxiliary_name_scope ) ;
27-
28- public class Zeros : IInitializer
29- {
30- private TF_DataType dtype ;
31-
32- public Zeros ( TF_DataType dtype = TF_DataType . TF_FLOAT )
33- {
34- this . dtype = dtype ;
35- }
36-
37- public Tensor call ( TensorShape shape , TF_DataType dtype = TF_DataType . DtInvalid )
38- {
39- if ( dtype == TF_DataType . DtInvalid )
40- dtype = this . dtype ;
41-
42- return array_ops . zeros ( shape , dtype ) ;
43- }
44-
45- public object get_config ( )
46- {
47- return new { dtype = dtype . name ( ) } ;
48- }
49- }
50-
51- /// <summary>
52- /// Initializer capable of adapting its scale to the shape of weights tensors.
53- /// </summary>
54- public class VarianceScaling : IInitializer
55- {
56- protected float _scale ;
57- protected string _mode ;
58- protected string _distribution ;
59- protected int ? _seed ;
60- protected TF_DataType _dtype ;
61-
62- public VarianceScaling ( float scale = 1.0f ,
63- string mode = "fan_in" ,
64- string distribution = "truncated_normal" ,
65- int ? seed = null ,
66- TF_DataType dtype = TF_DataType . TF_FLOAT )
67- {
68- if ( scale < 0 )
69- throw new ValueError ( "`scale` must be positive float." ) ;
70- _scale = scale ;
71- _mode = mode ;
72- _distribution = distribution ;
73- _seed = seed ;
74- _dtype = dtype ;
75- }
76-
77- public Tensor call ( TensorShape shape , TF_DataType dtype )
78- {
79- var ( fan_in , fan_out ) = _compute_fans ( shape ) ;
80- if ( _mode == "fan_in" )
81- _scale /= Math . Max ( 1 , fan_in ) ;
82- else if ( _mode == "fan_out" )
83- _scale /= Math . Max ( 1 , fan_out ) ;
84- else
85- _scale /= Math . Max ( 1 , ( fan_in + fan_out ) / 2 ) ;
86-
87- if ( _distribution == "normal" || _distribution == "truncated_normal" )
88- {
89- throw new NotImplementedException ( "truncated_normal" ) ;
90- }
91- else if ( _distribution == "untruncated_normal" )
92- {
93- throw new NotImplementedException ( "truncated_normal" ) ;
94- }
95- else
96- {
97- var limit = Math . Sqrt ( 3.0f * _scale ) ;
98- return random_ops . random_uniform ( shape , ( float ) - limit , ( float ) limit , dtype , seed : _seed ) ;
99- }
100- }
101-
102- private ( int , int ) _compute_fans ( int [ ] shape )
103- {
104- if ( shape . Length < 1 )
105- return ( 1 , 1 ) ;
106- if ( shape . Length == 1 )
107- return ( shape [ 0 ] , shape [ 0 ] ) ;
108- if ( shape . Length == 2 )
109- return ( shape [ 0 ] , shape [ 1 ] ) ;
110- else
111- throw new NotImplementedException ( "VarianceScaling._compute_fans" ) ;
112- }
113-
114- public virtual object get_config ( )
115- {
116- return new
117- {
118- scale = _scale ,
119- mode = _mode ,
120- distribution = _distribution ,
121- seed = _seed ,
122- dtype = _dtype
123- } ;
124- }
125- }
126-
127- public class GlorotUniform : VarianceScaling
128- {
129- public GlorotUniform ( float scale = 1.0f ,
130- string mode = "fan_avg" ,
131- string distribution = "uniform" ,
132- int ? seed = null ,
133- TF_DataType dtype = TF_DataType . TF_FLOAT ) : base ( scale , mode , distribution , seed , dtype )
134- {
135-
136- }
137-
138- public object get_config ( )
139- {
140- return new
141- {
142- scale = _scale ,
143- mode = _mode ,
144- distribution = _distribution ,
145- seed = _seed ,
146- dtype = _dtype
147- } ;
148- }
149- }
15028 }
15129}
0 commit comments