forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathGraph.Import.cs
More file actions
46 lines (41 loc) · 1.61 KB
/
Graph.Import.cs
File metadata and controls
46 lines (41 loc) · 1.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
using Google.Protobuf;
using System;
using System.Collections.Generic;
using System.IO;
using System.Runtime.InteropServices;
using System.Text;
namespace Tensorflow
{
public partial class Graph
{
public unsafe TF_Output[] ImportGraphDefWithReturnOutputs(Buffer graph_def, ImportGraphDefOptions opts, Status s)
{
var num_return_outputs = opts.NumReturnOutputs;
var return_outputs = new TF_Output[num_return_outputs];
int size = Marshal.SizeOf<TF_Output>();
var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs);
c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s);
for (int i = 0; i < num_return_outputs; i++)
{
var handle = return_output_handle + i * size;
return_outputs[i] = Marshal.PtrToStructure<TF_Output>(handle);
}
return return_outputs;
}
public Status Import(string file_path)
{
var bytes = File.ReadAllBytes(file_path);
var graph_def = new Tensorflow.Buffer(bytes);
var opts = c_api.TF_NewImportGraphDefOptions();
c_api.TF_GraphImportGraphDef(_handle, graph_def, opts, Status);
return Status;
}
public static Graph ImportFromPB(string file_path)
{
var graph = tf.Graph().as_default();
var graph_def = GraphDef.Parser.ParseFrom(File.ReadAllBytes(file_path));
importer.import_graph_def(graph_def);
return graph;
}
}
}