Skip to content

Commit 321ddfc

Browse files
committed
Fix Model.build.
1 parent 0f7bf4d commit 321ddfc

15 files changed

Lines changed: 104 additions & 48 deletions

File tree

src/TensorFlowNET.Console/SimpleRnnTest.cs

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,16 @@ public class SimpleRnnTest
1212
{
1313
public void Run()
1414
{
15-
tf.keras = new KerasInterface();
16-
var inputs = np.random.random((32, 10, 8)).astype(np.float32);
17-
var simple_rnn = tf.keras.layers.SimpleRNN(4);
18-
var output = simple_rnn.Apply(inputs); // The output has shape `[32, 4]`.
19-
if (output.shape == (32, 4))
20-
{
15+
tf.UseKeras<KerasInterface>();
16+
var inputs = np.random.random((6, 10, 8)).astype(np.float32);
17+
//var simple_rnn = tf.keras.layers.SimpleRNN(4);
18+
//var output = simple_rnn.Apply(inputs); // The output has shape `[32, 4]`.
2119

22-
}
23-
/*simple_rnn = tf.keras.layers.SimpleRNN(
24-
4, return_sequences = True, return_state = True)
20+
var simple_rnn = tf.keras.layers.SimpleRNN(4, return_sequences: true, return_state: true);
2521

26-
# whole_sequence_output has shape `[32, 10, 4]`.
27-
# final_state has shape `[32, 4]`.
28-
whole_sequence_output, final_state = simple_rnn(inputs)*/
22+
// whole_sequence_output has shape `[32, 10, 4]`.
23+
// final_state has shape `[32, 4]`.
24+
var (whole_sequence_output, final_state) = simple_rnn.Apply(inputs);
2925
}
3026
}
3127
}

src/TensorFlowNET.Core/Keras/Layers/ILayer.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ public interface ILayer
99
string Name { get; }
1010
bool Trainable { get; }
1111
bool Built { get; }
12+
void build(Shape input_shape);
1213
List<ILayer> Layers { get; }
1314
List<INode> InboundNodes { get; }
1415
List<INode> OutboundNodes { get; }

src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,9 @@ public ILayer SimpleRNN(int units,
163163
string activation = "tanh",
164164
string kernel_initializer = "glorot_uniform",
165165
string recurrent_initializer = "orthogonal",
166-
string bias_initializer = "zeros");
166+
string bias_initializer = "zeros",
167+
bool return_sequences = false,
168+
bool return_state = false);
167169

168170
public ILayer Subtract();
169171
}
Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,32 @@
11
using System;
2+
using System.Linq;
3+
using static Tensorflow.TensorShapeProto.Types;
24

35
namespace Tensorflow.Operations.Initializers
46
{
57
public class Orthogonal : IInitializer
68
{
9+
float _gain = 0f;
10+
11+
public Orthogonal(float gain = 1.0f, int? seed = null)
12+
{
13+
14+
}
15+
716
public Tensor Apply(InitializerArgs args)
817
{
9-
throw new NotImplementedException();
18+
return _generate_init_val(args.Shape, args.DType);
19+
}
20+
21+
private Tensor _generate_init_val(Shape shape, TF_DataType dtype)
22+
{
23+
var num_rows = 1L;
24+
foreach (var dim in shape.dims.Take(shape.ndim - 1))
25+
num_rows *= dim;
26+
var num_cols = shape.dims.Last();
27+
var flat_shape = (Math.Max(num_cols, num_rows), Math.Min(num_cols, num_rows));
28+
29+
throw new NotImplementedException("");
1030
}
1131
}
1232
}

src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,5 +147,10 @@ public LayerArgs get_config()
147147
{
148148
throw new NotImplementedException();
149149
}
150+
151+
public void build(Shape input_shape)
152+
{
153+
throw new NotImplementedException();
154+
}
150155
}
151156
}

src/TensorFlowNET.Core/tensorflow.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ public tensorflow()
6565
InitGradientEnvironment();
6666
}
6767

68+
public void UseKeras<T>() where T : IKerasApi, new()
69+
{
70+
keras = new T();
71+
}
72+
6873
public string VERSION => c_api.StringPiece(c_api.TF_Version());
6974

7075
private void InitGradientEnvironment()

src/TensorFlowNET.Keras/Engine/Functional.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,12 @@ protected void _init_graph_network(Tensors inputs, Tensors outputs)
6565
}
6666

6767
// Keep track of the network's nodes and layers.
68-
(NetworkNodes, NodesByDepth, _self_tracked_trackables, _) = MapGraphNetwork(inputs, outputs);
68+
(NetworkNodes, NodesByDepth, var layers, _) = MapGraphNetwork(inputs, outputs);
69+
70+
if (!_self_tracked_trackables.Any())
71+
{
72+
_self_tracked_trackables = layers;
73+
}
6974

7075
// Build self.input_names and self.output_names.
7176
_set_output_names();

src/TensorFlowNET.Keras/Engine/Model.Build.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
using System;
22
using System.Linq;
33
using Tensorflow.Graphs;
4-
using Tensorflow.Keras.ArgsDefinition;
5-
using Tensorflow.Keras.Losses;
6-
using Tensorflow.Keras.Optimizers;
74
using static Tensorflow.Binding;
85
using static Tensorflow.KerasApi;
96

@@ -13,6 +10,12 @@ public partial class Model
1310
{
1411
public override void build(Shape input_shape)
1512
{
13+
if (this is Functional || this is Sequential)
14+
{
15+
base.build(input_shape);
16+
return;
17+
}
18+
1619
var graph = tf.executing_eagerly() ? new FuncGraph("build_graph") : keras.backend.get_graph();
1720

1821
graph.as_default();

src/TensorFlowNET.Keras/Engine/Sequential.cs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -122,15 +122,9 @@ public void add(ILayer layer)
122122
else
123123
{
124124
_self_tracked_trackables.add(layer);
125-
_handle_deferred_layer_dependencies(layer);
126125
}
127126
}
128127

129-
void _handle_deferred_layer_dependencies(params ILayer[] layers)
130-
{
131-
_self_tracked_trackables.AddRange(layers);
132-
}
133-
134128
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
135129
{
136130
if (!_has_explicit_input_shape)
@@ -156,7 +150,7 @@ void _build_graph_network_for_inferred_shape(Shape input_shape, TF_DataType inpu
156150
ops.init_scope();
157151
var inputs = keras.Input(batch_input_shape: input_shape,
158152
dtype: input_dtype,
159-
name: $"{_self_tracked_trackables[0].Name}_input");
153+
name: _self_tracked_trackables[0].Name.EndsWith("_input") ? _self_tracked_trackables[0].Name : $"{_self_tracked_trackables[0].Name}_input");
160154
Tensors layer_input = inputs;
161155
Tensors layer_output = null;
162156
Tensors outputs = null;

src/TensorFlowNET.Keras/Layers/LayersApi.cs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -658,14 +658,18 @@ public ILayer SimpleRNN(int units,
658658
string activation = "tanh",
659659
string kernel_initializer = "glorot_uniform",
660660
string recurrent_initializer = "orthogonal",
661-
string bias_initializer = "zeros")
661+
string bias_initializer = "zeros",
662+
bool return_sequences = false,
663+
bool return_state = false)
662664
=> new SimpleRNN(new SimpleRNNArgs
663665
{
664666
Units = units,
665667
Activation = GetActivationByName(activation),
666668
KernelInitializer = GetInitializerByName(kernel_initializer),
667-
RecurrentInitializer= GetInitializerByName(recurrent_initializer),
668-
BiasInitializer= GetInitializerByName(bias_initializer)
669+
RecurrentInitializer = GetInitializerByName(recurrent_initializer),
670+
BiasInitializer = GetInitializerByName(bias_initializer),
671+
ReturnSequences = return_sequences,
672+
ReturnState = return_state
669673
});
670674

671675
/// <summary>

0 commit comments

Comments
 (0)