Skip to content

Commit caae2db

Browse files
sharwellOceania2018
authored andcommitted
Implement SafeBufferHandle as a wrapper for TF_Buffer
1 parent 33e17df commit caae2db

18 files changed

Lines changed: 121 additions & 89 deletions

File tree

src/TensorFlowNET.Core/Attributes/c_api.ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ public partial class c_api
5555
/// <param name="oper"></param>
5656
/// <returns></returns>
5757
[DllImport(TensorFlowLibName)]
58-
public static extern int TF_OperationGetAttrValueProto(IntPtr oper, string attr_name, IntPtr output_attr_value, SafeStatusHandle status);
58+
public static extern int TF_OperationGetAttrValueProto(IntPtr oper, string attr_name, SafeBufferHandle output_attr_value, SafeStatusHandle status);
5959

6060
[DllImport(TensorFlowLibName)]
6161
public static extern void TF_SetAttrBool(IntPtr desc, string attr_name, bool value);

src/TensorFlowNET.Core/Buffers/Buffer.cs

Lines changed: 35 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -14,45 +14,42 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using NumSharp.Backends.Unmanaged;
1718
using System;
1819
using System.Runtime.CompilerServices;
1920
using System.Runtime.InteropServices;
20-
using NumSharp.Backends.Unmanaged;
21+
using Tensorflow.Util;
2122
using static Tensorflow.c_api;
2223

2324
namespace Tensorflow
2425
{
2526
/// <summary>
2627
/// Represents a TF_Buffer that can be passed to Tensorflow.
2728
/// </summary>
28-
public class Buffer : DisposableObject
29+
public sealed class Buffer : IDisposable
2930
{
30-
private unsafe TF_Buffer buffer
31-
{
32-
[MethodImpl(MethodImplOptions.AggressiveInlining)]
33-
get => *bufferptr;
34-
}
31+
public SafeBufferHandle Handle { get; }
3532

36-
private unsafe TF_Buffer* bufferptr
37-
{
38-
[MethodImpl(MethodImplOptions.AggressiveInlining)]
39-
get => (TF_Buffer*) _handle;
40-
}
33+
/// <remarks>
34+
/// <inheritdoc cref="SafeHandleLease" path="/devdoc/usage"/>
35+
/// </remarks>
36+
private unsafe ref readonly TF_Buffer DangerousBuffer
37+
=> ref Unsafe.AsRef<TF_Buffer>(Handle.DangerousGetHandle().ToPointer());
4138

4239
/// <summary>
4340
/// The memory block representing this buffer.
4441
/// </summary>
45-
/// <remarks>The deallocator is set to null.</remarks>
46-
public UnmanagedMemoryBlock<byte> MemoryBlock
42+
/// <remarks>
43+
/// <para>The deallocator is set to null.</para>
44+
///
45+
/// <inheritdoc cref="SafeHandleLease" path="/devdoc/usage"/>
46+
/// </remarks>
47+
public unsafe UnmanagedMemoryBlock<byte> DangerousMemoryBlock
4748
{
4849
get
4950
{
50-
unsafe
51-
{
52-
EnsureNotDisposed();
53-
var buff = (TF_Buffer*) _handle;
54-
return new UnmanagedMemoryBlock<byte>((byte*) buff->data.ToPointer(), (long) buff->length);
55-
}
51+
ref readonly TF_Buffer buffer = ref DangerousBuffer;
52+
return new UnmanagedMemoryBlock<byte>((byte*)buffer.data.ToPointer(), (long)buffer.length);
5653
}
5754
}
5855

@@ -63,25 +60,23 @@ public ulong Length
6360
{
6461
get
6562
{
66-
EnsureNotDisposed();
67-
return buffer.length;
63+
using (Handle.Lease())
64+
{
65+
return DangerousBuffer.length;
66+
}
6867
}
6968
}
7069

71-
public Buffer() => _handle = TF_NewBuffer();
72-
73-
public Buffer(IntPtr handle)
74-
{
75-
if (handle == IntPtr.Zero)
76-
throw new ArgumentException("Handle (IntPtr) can't be zero.", nameof(handle));
70+
public Buffer()
71+
=> Handle = TF_NewBuffer();
7772

78-
_handle = handle;
79-
}
73+
public Buffer(SafeBufferHandle handle)
74+
=> Handle = handle;
8075

81-
public Buffer(byte[] data) : this(_toBuffer(data))
82-
{ }
76+
public Buffer(byte[] data)
77+
=> Handle = _toBuffer(data);
8378

84-
private static IntPtr _toBuffer(byte[] data)
79+
private static SafeBufferHandle _toBuffer(byte[] data)
8580
{
8681
if (data == null)
8782
throw new ArgumentNullException(nameof(data));
@@ -93,38 +88,25 @@ private static IntPtr _toBuffer(byte[] data)
9388
}
9489
}
9590

96-
public static implicit operator IntPtr(Buffer buffer)
97-
{
98-
buffer.EnsureNotDisposed();
99-
return buffer._handle;
100-
}
101-
102-
public static explicit operator byte[](Buffer buffer) => buffer.ToArray(); //has to be explicit, developer will assume it doesn't cost.
103-
10491
/// <summary>
10592
/// Copies this buffer's contents onto a <see cref="byte"/> array.
10693
/// </summary>
10794
public byte[] ToArray()
10895
{
109-
EnsureNotDisposed();
110-
111-
unsafe
96+
using (Handle.Lease())
11297
{
113-
var len = buffer.length;
98+
var block = DangerousMemoryBlock;
99+
var len = block.Count;
114100
if (len == 0)
115101
return Array.Empty<byte>();
116102

117-
byte[] data = new byte[len];
118-
fixed (byte* dst = data)
119-
System.Buffer.MemoryCopy((void*) bufferptr->data, dst, len, len);
120-
103+
var data = new byte[len];
104+
block.CopyTo(data, 0);
121105
return data;
122106
}
123107
}
124108

125-
protected override void DisposeUnmanagedResources(IntPtr handle)
126-
{
127-
TF_DeleteBuffer(handle);
128-
}
109+
public void Dispose()
110+
=> Handle.Dispose();
129111
}
130112
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*****************************************************************************
2+
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using System;
18+
using Tensorflow.Util;
19+
20+
namespace Tensorflow
21+
{
22+
public sealed class SafeBufferHandle : SafeTensorflowHandle
23+
{
24+
private SafeBufferHandle()
25+
{
26+
}
27+
28+
public SafeBufferHandle(IntPtr handle)
29+
: base(handle)
30+
{
31+
}
32+
33+
protected override bool ReleaseHandle()
34+
{
35+
c_api.TF_DeleteBuffer(handle);
36+
SetHandle(IntPtr.Zero);
37+
return true;
38+
}
39+
}
40+
}

src/TensorFlowNET.Core/Buffers/c_api.buffer.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ public partial class c_api
2929
/// </summary>
3030
/// <returns></returns>
3131
[DllImport(TensorFlowLibName)]
32-
public static extern IntPtr TF_NewBuffer();
32+
public static extern SafeBufferHandle TF_NewBuffer();
3333

3434
[DllImport(TensorFlowLibName)]
3535
public static extern IntPtr TF_GetBuffer(TF_Buffer buffer);
@@ -42,6 +42,6 @@ public partial class c_api
4242
/// <param name="proto_len">size_t</param>
4343
/// <returns></returns>
4444
[DllImport(TensorFlowLibName)]
45-
public static extern IntPtr TF_NewBufferFromString(IntPtr proto, ulong proto_len);
45+
public static extern SafeBufferHandle TF_NewBufferFromString(IntPtr proto, ulong proto_len);
4646
}
4747
}

src/TensorFlowNET.Core/Framework/importer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ public static ITensorOrOperation[] import_graph_def(GraphDef graph_def,
6262
{
6363
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements);
6464
// need to create a class ImportGraphDefWithResults with IDisposal
65-
results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options.Handle, status.Handle);
65+
results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer.Handle, scoped_options.Handle, status.Handle);
6666
status.Check(true);
6767
}
6868

src/TensorFlowNET.Core/Framework/op_def_registry.py.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public static Dictionary<string, OpDef> get_registered_ops()
3030
{
3131
_registered_ops = new Dictionary<string, OpDef>();
3232
using var buffer = new Buffer(c_api.TF_GetAllOpList());
33-
using var stream = buffer.MemoryBlock.Stream();
33+
using var stream = buffer.DangerousMemoryBlock.Stream();
3434
var op_list = OpList.Parser.ParseFrom(stream);
3535
foreach (var op_def in op_list.Op)
3636
_registered_ops[op_def.Name] = op_def;

src/TensorFlowNET.Core/GraphTransformation/GraphTransformer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ public GraphDef TransformGraph(GraphDef input_graph_def,
3434
inputs_string,
3535
outputs_string,
3636
transforms_string,
37-
buffer,
37+
buffer.Handle,
3838
status.Handle);
3939

4040
status.Check(false);

src/TensorFlowNET.Core/GraphTransformation/c_api.transform_graph.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ public static extern int TransformGraphWithStringInputs(byte[] graph_def_string,
2727
string inputs_string,
2828
string outputs_string,
2929
string transforms_string,
30-
IntPtr output_buffer,
30+
SafeBufferHandle output_buffer,
3131
SafeStatusHandle status);
3232
}
3333
}

src/TensorFlowNET.Core/Graphs/Graph.Export.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public partial class Graph
2525
public Buffer ToGraphDef(Status s)
2626
{
2727
var buffer = new Buffer();
28-
c_api.TF_GraphToGraphDef(_handle, buffer, s.Handle);
28+
c_api.TF_GraphToGraphDef(_handle, buffer.Handle, s.Handle);
2929
s.Check(true);
3030

3131
return buffer;
@@ -39,7 +39,7 @@ private GraphDef _as_graph_def(bool add_shapes = false)
3939
{
4040
status.Check(true);
4141
// limit size to 250M, recursion to max 100
42-
var inputStream = CodedInputStream.CreateWithLimits(buffer.MemoryBlock.Stream(), 250 * 1024 * 1024, 100);
42+
var inputStream = CodedInputStream.CreateWithLimits(buffer.DangerousMemoryBlock.Stream(), 250 * 1024 * 1024, 100);
4343
def = GraphDef.Parser.ParseFrom(inputStream);
4444
}
4545

src/TensorFlowNET.Core/Graphs/Graph.Import.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ public unsafe TF_Output[] ImportGraphDefWithReturnOutputs(Buffer graph_def, Impo
2929
int size = Marshal.SizeOf<TF_Output>();
3030
var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs);
3131

32-
c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts.Handle, return_output_handle, num_return_outputs, s.Handle);
32+
c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def.Handle, opts.Handle, return_output_handle, num_return_outputs, s.Handle);
3333

3434
var tf_output_ptr = (TF_Output*) return_output_handle;
3535
for (int i = 0; i < num_return_outputs; i++)
@@ -54,7 +54,7 @@ public bool Import(byte[] bytes, string prefix = "")
5454
{
5555
as_default();
5656
c_api.TF_ImportGraphDefOptionsSetPrefix(opts.Handle, prefix);
57-
c_api.TF_GraphImportGraphDef(_handle, graph_def, opts.Handle, status.Handle);
57+
c_api.TF_GraphImportGraphDef(_handle, graph_def.Handle, opts.Handle, status.Handle);
5858
status.Check(true);
5959
return status.Code == TF_Code.TF_OK;
6060
}

0 commit comments

Comments
 (0)