Skip to content

Commit 78bd4c7

Browse files
Add api set_weights and get_weights
1 parent a075bba commit 78bd4c7

3 files changed

Lines changed: 32 additions & 3 deletions

File tree

src/TensorFlowNET.Core/Keras/Layers/ILayer.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Tensorflow.Keras.Engine;
22
using Tensorflow.Keras.Saving;
3+
using Tensorflow.NumPy;
34
using Tensorflow.Training;
45

56
namespace Tensorflow.Keras
@@ -18,6 +19,8 @@ public interface ILayer: IWithTrackable, IKerasConfigable
1819
List<IVariableV1> TrainableWeights { get; }
1920
List<IVariableV1> NonTrainableWeights { get; }
2021
List<IVariableV1> Weights { get; set; }
22+
void set_weights(List<NDArray> weights);
23+
List<NDArray> get_weights();
2124
Shape OutputShape { get; }
2225
Shape BatchInputShape { get; }
2326
TensorShapeConfig BuildInputShape { get; }

src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License.
2121
using Tensorflow.Keras.ArgsDefinition.Rnn;
2222
using Tensorflow.Keras.Engine;
2323
using Tensorflow.Keras.Saving;
24+
using Tensorflow.NumPy;
2425
using Tensorflow.Operations;
2526
using Tensorflow.Train;
2627
using Tensorflow.Util;
@@ -71,7 +72,10 @@ public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell
7172

7273
public List<IVariableV1> TrainableVariables => throw new NotImplementedException();
7374
public List<IVariableV1> TrainableWeights => throw new NotImplementedException();
74-
public List<IVariableV1> Weights => throw new NotImplementedException();
75+
public List<IVariableV1> Weights { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }
76+
77+
public List<NDArray> get_weights() => throw new NotImplementedException();
78+
public void set_weights(List<NDArray> weights) => throw new NotImplementedException();
7579
public List<IVariableV1> NonTrainableWeights => throw new NotImplementedException();
7680

7781
public Shape OutputShape => throw new NotImplementedException();
@@ -84,8 +88,6 @@ public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell
8488
protected bool built = false;
8589
public bool Built => built;
8690

87-
List<IVariableV1> ILayer.Weights { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }
88-
8991
public RnnCell(bool trainable = true,
9092
string name = null,
9193
TF_DataType dtype = TF_DataType.DtInvalid,

src/TensorFlowNET.Keras/Engine/Layer.cs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,30 @@ public virtual List<IVariableV1> Weights
120120
}
121121
}
122122

123+
public virtual void set_weights(List<NDArray> weights)
124+
{
125+
if (Weights.Count() != weights.Count()) throw new ValueError(
126+
$"You called `set_weights` on layer \"{this.name}\"" +
127+
$"with a weight list of length {len(weights)}, but the layer was " +
128+
$"expecting {len(Weights)} weights.");
129+
for (int i = 0; i < weights.Count(); i++)
130+
{
131+
if (weights[i].shape != Weights[i].shape)
132+
{
133+
throw new ValueError($"Layer weight shape {weights[i].shape} not compatible with provided weight shape {Weights[i].shape}");
134+
}
135+
}
136+
foreach (var (this_w, v_w) in zip(Weights, weights))
137+
this_w.assign(v_w, read_value: true);
138+
}
139+
140+
public List<NDArray> get_weights()
141+
{
142+
List<NDArray > weights = new List<NDArray>();
143+
weights.AddRange(Weights.ConvertAll(x => x.numpy()));
144+
return weights;
145+
}
146+
123147
protected int id;
124148
public int Id => id;
125149
protected string name;

0 commit comments

Comments
 (0)