Skip to content

Commit dff4f51

Browse files
committed
model.get_config for Keras.
1 parent d59db72 commit dff4f51

32 files changed

Lines changed: 555 additions & 116 deletions

src/TensorFlowNET.Console/TensorFlowNET.Console.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
<PropertyGroup>
44
<OutputType>Exe</OutputType>
5-
<TargetFramework>netcoreapp3.1</TargetFramework>
5+
<TargetFramework>net5.0</TargetFramework>
66
<RootNamespace>Tensorflow</RootNamespace>
77
<AssemblyName>Tensorflow</AssemblyName>
88
</PropertyGroup>

src/TensorFlowNET.Core/Keras/Engine/INode.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
using System.Collections.Generic;
1+
using System;
2+
using System.Collections.Generic;
3+
using Tensorflow.Keras.Saving;
24

35
namespace Tensorflow.Keras.Engine
46
{
@@ -10,5 +12,7 @@ public interface INode
1012
List<Tensor> KerasInputs { get; set; }
1113
INode[] ParentNodes { get; }
1214
IEnumerable<(ILayer, int, int, Tensor)> iterate_inbound();
15+
bool is_input { get; }
16+
NodeConfig serialize(Func<string, int, string> make_node_key, Dictionary<string, int> node_conversion_map);
1317
}
1418
}

src/TensorFlowNET.Core/Keras/Layers/ILayer.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System.Collections.Generic;
2+
using Tensorflow.Keras.ArgsDefinition;
23
using Tensorflow.Keras.Engine;
34

45
namespace Tensorflow.Keras
@@ -14,5 +15,6 @@ public interface ILayer
1415
List<IVariableV1> trainable_variables { get; }
1516
TensorShape output_shape { get; }
1617
int count_params();
18+
LayerArgs get_config();
1719
}
1820
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Keras.ArgsDefinition;
5+
using Tensorflow.Keras.Engine;
6+
7+
namespace Tensorflow.Keras.Saving
8+
{
9+
public class LayerConfig
10+
{
11+
public string Name { get; set; }
12+
public string ClassName { get; set; }
13+
public LayerArgs Config { get; set; }
14+
public List<INode> InboundNodes { get; set; }
15+
}
16+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Keras.Engine;
5+
6+
namespace Tensorflow.Keras.Saving
7+
{
8+
public class ModelConfig
9+
{
10+
public string Name { get; set; }
11+
public List<LayerConfig> Layers { get; set; }
12+
public List<ILayer> InputLayers { get; set; }
13+
public List<ILayer> OutputLayers { get; set; }
14+
}
15+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.Saving
6+
{
7+
public class NodeConfig
8+
{
9+
public string Name { get; set; }
10+
public int NodeIndex { get; set; }
11+
public int TensorIndex { get; set; }
12+
}
13+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Keras.Engine;
5+
using Tensorflow.Train;
6+
7+
namespace Tensorflow.ModelSaving
8+
{
9+
public class ModelSaver
10+
{
11+
public void save(Trackable obj, string export_dir, SaveOptions options = null)
12+
{
13+
var saved_model = new SavedModel();
14+
var meta_graph_def = new MetaGraphDef();
15+
saved_model.MetaGraphs.Add(meta_graph_def);
16+
_build_meta_graph(obj, export_dir, options, meta_graph_def);
17+
}
18+
19+
void _build_meta_graph(Trackable obj, string export_dir, SaveOptions options,
20+
MetaGraphDef meta_graph_def = null)
21+
{
22+
23+
}
24+
}
25+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.ModelSaving
6+
{
7+
/// <summary>
8+
/// Options for saving to SavedModel.
9+
/// </summary>
10+
public class SaveOptions
11+
{
12+
bool save_debug_info;
13+
public SaveOptions(bool save_debug_info = false)
14+
{
15+
this.save_debug_info = save_debug_info;
16+
}
17+
}
18+
}

src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
using System;
1818
using System.Collections.Generic;
1919
using Tensorflow.Keras;
20+
using Tensorflow.Keras.ArgsDefinition;
2021
using Tensorflow.Keras.Engine;
2122
using Tensorflow.Operations;
2223
using Tensorflow.Util;
@@ -132,5 +133,10 @@ public int count_params()
132133
{
133134
throw new NotImplementedException();
134135
}
136+
137+
public LayerArgs get_config()
138+
{
139+
throw new NotImplementedException();
140+
}
135141
}
136142
}

src/TensorFlowNET.Core/Protobuf/Gen.bat

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/summary.pro
2727
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/op_def.proto
2828
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/saver.proto
2929
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/saved_object_graph.proto
30+
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/saved_model.proto
3031
ECHO Download `any.proto` from https://github.com/protocolbuffers/protobuf/tree/master/src/google/protobuf
3132
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/meta_graph.proto
3233
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/cluster.proto

0 commit comments

Comments
 (0)