Skip to content

Commit fe06a29

Browse files
committed
graph_io.write_graph
1 parent 596afe2 commit fe06a29

5 files changed

Lines changed: 46 additions & 1 deletion

File tree

src/TensorFlowNET.Core/Graphs/Graph.Export.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,18 @@ public Buffer ToGraphDef(Status s)
1717

1818
return buffer;
1919
}
20+
21+
public GraphDef _as_graph_def()
22+
{
23+
var buffer = ToGraphDef(Status);
24+
Status.Check();
25+
var def = GraphDef.Parser.ParseFrom(buffer);
26+
buffer.Dispose();
27+
28+
// Strip the experimental library field iff it's empty.
29+
// if(def.Library.Function.Count == 0)
30+
31+
return def;
32+
}
2033
}
2134
}

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ private Tensor _as_graph_element(object obj)
5353

5454
return null;
5555
}
56-
56+
5757
private ITensorOrOperation _as_graph_element_locked(object obj, bool allow_tensor = true, bool allow_operation = true)
5858
{
5959
string types_str = "";
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.IO;
4+
using System.Text;
5+
6+
namespace Tensorflow
7+
{
8+
public class graph_io
9+
{
10+
public static string write_graph(Graph graph, string logdir, string name, bool as_text = true)
11+
{
12+
var def = graph._as_graph_def();
13+
string path = Path.Combine(logdir, name);
14+
string text = def.ToString();
15+
if (as_text)
16+
File.WriteAllText(path, text);
17+
18+
return path;
19+
}
20+
}
21+
}

src/TensorFlowNET.Core/Train/tf.optimizers.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.IO;
34
using System.Text;
45

56
namespace Tensorflow
@@ -11,6 +12,8 @@ public static class train
1112
public static Optimizer GradientDescentOptimizer(double learning_rate) => new GradientDescentOptimizer(learning_rate);
1213

1314
public static Saver Saver() => new Saver();
15+
16+
public static string write_graph(Graph graph, string logdir, string name, bool as_text = true) => graph_io.write_graph(graph, logdir, name, as_text);
1417
}
1518
}
1619
}

test/TensorFlowNET.UnitTest/TrainSaverTest.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@ namespace TensorFlowNET.UnitTest
99
[TestClass]
1010
public class TrainSaverTest : Python
1111
{
12+
[TestMethod]
13+
public void WriteGraph()
14+
{
15+
var v = tf.Variable(0, name: "my_variable");
16+
var sess = tf.Session();
17+
tf.train.write_graph(sess.graph, "/tmp/my-model", "train.pbtxt");
18+
}
19+
1220
[TestMethod]
1321
public void Save()
1422
{

0 commit comments

Comments
 (0)