Skip to content

Commit 9d6525e

Browse files
committed
Buffer: Revamped and all perf-optted all use-cases.
- Fixed all test cases to use using(Buffer) - Fixed all test cases to explicitly specify session
1 parent 7e46d4f commit 9d6525e

20 files changed

Lines changed: 216 additions & 109 deletions

File tree

src/TensorFlowNET.Core/Buffers/Buffer.cs

Lines changed: 80 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,58 +15,116 @@ limitations under the License.
1515
******************************************************************************/
1616

1717
using System;
18+
using System.Runtime.CompilerServices;
1819
using System.Runtime.InteropServices;
20+
using NumSharp.Backends.Unmanaged;
21+
using static Tensorflow.c_api;
1922

2023
namespace Tensorflow
2124
{
25+
/// <summary>
26+
/// Represents a TF_Buffer that can be passed to Tensorflow.
27+
/// </summary>
2228
public class Buffer : DisposableObject
2329
{
24-
private TF_Buffer buffer => Marshal.PtrToStructure<TF_Buffer>(_handle);
30+
private unsafe TF_Buffer buffer
31+
{
32+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
33+
get => *bufferptr;
34+
}
35+
36+
private unsafe TF_Buffer* bufferptr
37+
{
38+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
39+
get => (TF_Buffer*) _handle;
40+
}
2541

26-
public byte[] Data
42+
/// <summary>
43+
/// The memory block representing this buffer.
44+
/// </summary>
45+
/// <remarks>The deallocator is set to null.</remarks>
46+
public UnmanagedMemoryBlock<byte> MemoryBlock
2747
{
28-
get
48+
get
2949
{
30-
var data = new byte[buffer.length];
31-
if (data.Length > 0)
32-
Marshal.Copy(buffer.data, data, 0, data.Length);
33-
return data;
50+
unsafe
51+
{
52+
EnsureNotDisposed();
53+
var buff = (TF_Buffer*) _handle;
54+
return new UnmanagedMemoryBlock<byte>((byte*) buff->data.ToPointer(), (long) buff->length);
55+
}
3456
}
3557
}
3658

37-
public int Length => (int)buffer.length;
38-
39-
public Buffer()
59+
/// <summary>
60+
/// The bytes length of this buffer.
61+
/// </summary>
62+
public ulong Length
4063
{
41-
_handle = c_api.TF_NewBuffer();
64+
get
65+
{
66+
EnsureNotDisposed();
67+
return buffer.length;
68+
}
4269
}
4370

44-
public Buffer(IntPtr handle)
71+
public Buffer() => _handle = TF_NewBuffer();
72+
73+
internal Buffer(IntPtr handle)
4574
{
75+
if (handle == IntPtr.Zero)
76+
throw new ArgumentException("Handle (IntPtr) can't be zero.", nameof(handle));
77+
4678
_handle = handle;
4779
}
4880

49-
public Buffer(byte[] data)
50-
{
51-
var dst = Marshal.AllocHGlobal(data.Length);
52-
Marshal.Copy(data, 0, dst, data.Length);
81+
public Buffer(byte[] data) : this(_toBuffer(data))
82+
{ }
5383

54-
_handle = c_api.TF_NewBufferFromString(dst, (ulong)data.Length);
84+
private static IntPtr _toBuffer(byte[] data)
85+
{
86+
if (data == null)
87+
throw new ArgumentNullException(nameof(data));
5588

56-
Marshal.FreeHGlobal(dst);
89+
unsafe
90+
{
91+
fixed (byte* src = data)
92+
return TF_NewBufferFromString(new IntPtr(src), (ulong) data.LongLength);
93+
}
5794
}
5895

5996
public static implicit operator IntPtr(Buffer buffer)
6097
{
98+
buffer.EnsureNotDisposed();
6199
return buffer._handle;
62100
}
63101

64-
public static implicit operator byte[](Buffer buffer)
102+
public static explicit operator byte[](Buffer buffer) => buffer.ToArray(); //has to be explicit, developer will assume it doesn't cost.
103+
104+
/// <summary>
105+
/// Copies this buffer's contents onto a <see cref="byte"/> array.
106+
/// </summary>
107+
public byte[] ToArray()
65108
{
66-
return buffer.Data;
109+
EnsureNotDisposed();
110+
111+
unsafe
112+
{
113+
var len = buffer.length;
114+
if (len == 0)
115+
return Array.Empty<byte>();
116+
117+
byte[] data = new byte[len];
118+
fixed (byte* dst = data)
119+
System.Buffer.MemoryCopy((void*) bufferptr->data, dst, len, len);
120+
121+
return data;
122+
}
67123
}
68124

69125
protected override void DisposeUnmanagedResources(IntPtr handle)
70-
=> c_api.TF_DeleteBuffer(handle);
126+
{
127+
TF_DeleteBuffer(handle);
128+
}
71129
}
72-
}
130+
}

src/TensorFlowNET.Core/Framework/op_def_registry.py.cs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ limitations under the License.
1515
******************************************************************************/
1616

1717
using System.Collections.Generic;
18+
using System.IO;
19+
using Tensorflow.Util;
1820

1921
namespace Tensorflow
2022
{
@@ -27,12 +29,12 @@ public static Dictionary<string, OpDef> get_registered_ops()
2729
if(_registered_ops == null)
2830
{
2931
_registered_ops = new Dictionary<string, OpDef>();
30-
var handle = c_api.TF_GetAllOpList();
31-
var buffer = new Buffer(handle);
32-
var op_list = OpList.Parser.ParseFrom(buffer);
33-
34-
foreach (var op_def in op_list.Op)
35-
_registered_ops[op_def.Name] = op_def;
32+
using (var buffer = new Buffer(c_api.TF_GetAllOpList()))
33+
{
34+
var op_list = OpList.Parser.ParseFrom(buffer.MemoryBlock.Stream());
35+
foreach (var op_def in op_list.Op)
36+
_registered_ops[op_def.Name] = op_def;
37+
}
3638
}
3739

3840
return _registered_ops;

src/TensorFlowNET.Core/Graphs/Graph.Control.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515
******************************************************************************/
1616

1717
using System.Collections.Generic;
18+
using System.Diagnostics.CodeAnalysis;
1819
using System.Linq;
1920
using Tensorflow.Operations;
2021

@@ -66,8 +67,9 @@ public ITensorOrOperation[] _control_dependencies_for_inputs(ITensorOrOperation[
6667
/// within the context should have control dependencies on
6768
/// `control_inputs`.
6869
/// </summary>
70+
[SuppressMessage("ReSharper", "CoVariantArrayConversion")]
6971
public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs)
70-
=> control_dependencies(control_inputs == null ? null : control_inputs.OfType<object>().ToArray());
72+
=> control_dependencies((object[])control_inputs);
7173

7274
/// <summary>
7375
/// Returns a context manager that specifies control dependencies.

src/TensorFlowNET.Core/Graphs/Graph.Export.cs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using System.IO;
18+
using Tensorflow.Util;
19+
1720
namespace Tensorflow
1821
{
1922
public partial class Graph
@@ -23,29 +26,27 @@ public Buffer ToGraphDef(Status s)
2326
var buffer = new Buffer();
2427
c_api.TF_GraphToGraphDef(_handle, buffer, s);
2528
s.Check(true);
26-
// var def = GraphDef.Parser.ParseFrom(buffer);
27-
// buffer.Dispose();
2829

2930
return buffer;
3031
}
3132

3233
private GraphDef _as_graph_def(bool add_shapes = false)
3334
{
34-
var status = new Status();
35-
var buffer = ToGraphDef(status);
36-
status.Check(true);
37-
status.Dispose();
38-
39-
var def = GraphDef.Parser.ParseFrom(buffer);
40-
buffer.Dispose();
35+
GraphDef def;
36+
using (var status = new Status())
37+
using (var buffer = ToGraphDef(status))
38+
{
39+
status.Check(true);
40+
def = GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream());
41+
}
4142

4243
// Strip the experimental library field iff it's empty.
4344
// if(def.Library.Function.Count == 0)
4445

4546
return def;
4647
}
4748

48-
public GraphDef as_graph_def(bool add_shapes = false)
49+
public GraphDef as_graph_def(bool add_shapes = false)
4950
=> _as_graph_def(add_shapes);
5051
}
51-
}
52+
}

src/TensorFlowNET.Core/Graphs/Graph.Operation.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License.
1818
using System.Collections.Generic;
1919
using System.Linq;
2020
using System.Runtime.InteropServices;
21+
using Tensorflow.Util;
2122
using static Tensorflow.Binding;
2223

2324
namespace Tensorflow
@@ -30,7 +31,7 @@ public OpDef GetOpDef(string type)
3031
using (var status = new Status())
3132
{
3233
c_api.TF_GraphGetOpDef(_handle, type, buffer, status);
33-
return OpDef.Parser.ParseFrom(buffer.Data);
34+
return OpDef.Parser.ParseFrom(buffer.MemoryBlock.Stream());
3435
}
3536
}
3637

@@ -71,7 +72,7 @@ public Operation OperationByName(string operName)
7172

7273
public ITensorOrOperation[] get_operations()
7374
{
74-
return _nodes_by_name.Values.Select(x => x).ToArray();
75+
return _nodes_by_name.Values.ToArray();
7576
}
7677

7778
/// <summary>
@@ -85,7 +86,7 @@ public Operation get_operation_by_name(string name)
8586

8687
public ITensorOrOperation _get_operation_by_name_unsafe(string name)
8788
{
88-
return _nodes_by_name.ContainsKey(name) ? _nodes_by_name[name] : null;
89+
return _nodes_by_name.TryGetValue(name, out var val) ? val : null;
8990
}
9091

9192
public ITensorOrOperation _get_operation_by_tf_operation(IntPtr tf_oper)

src/TensorFlowNET.Core/Operations/Operation.cs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ limitations under the License.
1717
using Google.Protobuf.Collections;
1818
using System;
1919
using System.Collections.Generic;
20+
using System.IO;
2021
using System.Linq;
22+
using Tensorflow.Util;
2123

2224
namespace Tensorflow
2325
{
@@ -226,9 +228,12 @@ public object get_attr(string name)
226228
using (var status = new Status())
227229
using (var buf = new Buffer())
228230
{
229-
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status);
230-
status.Check(true);
231-
x = AttrValue.Parser.ParseFrom(buf);
231+
unsafe
232+
{
233+
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status);
234+
status.Check(true);
235+
x = AttrValue.Parser.ParseFrom(buf.MemoryBlock.Stream());
236+
}
232237
}
233238

234239
string oneof_value = x.ValueCase.ToString();
@@ -259,7 +264,7 @@ private NodeDef GetNodeDef()
259264
{
260265
c_api.TF_OperationToNodeDef(_handle, buffer, s);
261266
s.Check();
262-
return NodeDef.Parser.ParseFrom(buffer);
267+
return NodeDef.Parser.ParseFrom(buffer.MemoryBlock.Stream());
263268
}
264269
}
265270

@@ -299,17 +304,15 @@ private void _assert_same_graph(Tensor tensor)
299304
/// </summary>
300305
public TF_Output _tf_output(int output_idx)
301306
{
302-
var tf_output = new TF_Output(op, output_idx);
303-
return tf_output;
307+
return new TF_Output(op, output_idx);
304308
}
305309

306310
/// <summary>
307311
/// Create and return a new TF_Input for input_idx'th input of this op.
308312
/// </summary>
309313
public TF_Input _tf_input(int input_idx)
310314
{
311-
var tf_input = new TF_Input(op, input_idx);
312-
return tf_input;
315+
return new TF_Input(op, input_idx);
313316
}
314317
}
315318
}

src/TensorFlowNET.Core/Sessions/SessionOptions.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ protected override void DisposeUnmanagedResources(IntPtr handle)
3737

3838
public void SetConfig(ConfigProto config)
3939
{
40-
var bytes = config.ToByteArray();
41-
var proto = Marshal.AllocHGlobal(bytes.Length);
40+
var bytes = config.ToByteArray(); //TODO! we can use WriteTo
41+
var proto = Marshal.AllocHGlobal(bytes.Length); //TODO! potential memory leak
4242
Marshal.Copy(bytes, 0, proto, bytes.Length);
4343

4444
using (var status = new Status())

src/TensorFlowNET.Core/ops.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,8 @@ public static (IntPtr, IntPtr) _create_c_op<T>(Graph graph, NodeDef node_def, T[
230230
// Add attrs
231231
foreach (var attr in node_def.Attr)
232232
{
233-
var bytes = attr.Value.ToByteArray();
234-
var proto = Marshal.AllocHGlobal(bytes.Length);
233+
var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream.
234+
var proto = Marshal.AllocHGlobal(bytes.Length); //TODO: potential memory leak
235235
Marshal.Copy(bytes, 0, proto, bytes.Length);
236236
uint len = (uint)bytes.Length;
237237
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status);

src/TensorFlowNET.Core/tensorflow.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ public void enable_eager_execution()
6464

6565
public Session Session()
6666
{
67-
defaultSession = new Session();
68-
return defaultSession;
67+
return new Session();
6968
}
7069

7170
public Session Session(Graph graph)

test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ public bool Run()
102102

103103
// Display logs per epoch step
104104
if ((epoch + 1) % display_step == 0)
105-
print($"Epoch: {(epoch + 1).ToString("D4")} Cost: {avg_cost.ToString("G9")} Elapse: {sw.ElapsedMilliseconds}ms");
105+
print($"Epoch: {(epoch + 1):D4} Cost: {avg_cost:G9} Elapse: {sw.ElapsedMilliseconds}ms");
106106

107107
sw.Reset();
108108
}
@@ -114,8 +114,8 @@ public bool Run()
114114
var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1));
115115
// Calculate accuracy
116116
var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32));
117-
float acc = accuracy.eval((x, mnist.Test.Data), (y, mnist.Test.Labels));
118-
print($"Accuracy: {acc.ToString("F4")}");
117+
float acc = accuracy.eval(sess, (x, mnist.Test.Data), (y, mnist.Test.Labels));
118+
print($"Accuracy: {acc:F4}");
119119

120120
return acc > 0.9;
121121
}

0 commit comments

Comments
 (0)