Skip to content

Commit 1f6ea31

Browse files
committed
c_test_util.GetAttrValue
1 parent a23be5a commit 1f6ea31

7 files changed

Lines changed: 80 additions & 13 deletions

File tree

src/TensorFlowNET.Core/Buffers/Buffer.cs

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,43 @@
55

66
namespace Tensorflow
77
{
8-
public class Buffer
8+
public class Buffer : IDisposable
99
{
1010
private IntPtr _handle;
1111

12-
private TF_Buffer buffer;
12+
private TF_Buffer buffer => Marshal.PtrToStructure<TF_Buffer>(_handle);
1313

14-
public byte[] Data;
14+
public byte[] Data
15+
{
16+
get
17+
{
18+
var data = new byte[buffer.length];
19+
if (buffer.length > 0)
20+
Marshal.Copy(buffer.data, data, 0, (int)buffer.length);
21+
return data;
22+
}
23+
}
1524

1625
public int Length => (int)buffer.length;
1726

18-
public unsafe Buffer(IntPtr handle)
27+
public Buffer()
28+
{
29+
_handle = c_api.TF_NewBuffer();
30+
}
31+
32+
public Buffer(IntPtr handle)
1933
{
2034
_handle = handle;
21-
buffer = Marshal.PtrToStructure<TF_Buffer>(_handle);
22-
Data = new byte[buffer.length];
23-
if (buffer.length > 0)
24-
Marshal.Copy(buffer.data, Data, 0, (int)buffer.length);
35+
}
36+
37+
public static implicit operator IntPtr(Buffer buffer)
38+
{
39+
return buffer._handle;
40+
}
41+
42+
public void Dispose()
43+
{
44+
c_api.TF_DeleteBuffer(_handle);
2545
}
2646
}
2747
}

src/TensorFlowNET.Core/Buffers/c_api.buffer.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ namespace Tensorflow
77
{
88
public static partial class c_api
99
{
10+
[DllImport(TensorFlowLibName)]
11+
public static extern void TF_DeleteBuffer(IntPtr buffer);
12+
1013
/// <summary>
1114
/// Useful for passing *out* a protobuf.
1215
/// </summary>
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public class Function
8+
{
9+
10+
}
11+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Runtime.InteropServices;
4+
using System.Text;
5+
6+
namespace Tensorflow
7+
{
8+
public static partial class c_api
9+
{
10+
/// <summary>
11+
/// Write out a serialized representation of `func` (as a FunctionDef protocol
12+
/// message) to `output_func_def` (allocated by TF_NewBuffer()).
13+
/// `output_func_def`'s underlying buffer will be freed when TF_DeleteBuffer()
14+
/// is called.
15+
/// </summary>
16+
/// <param name="func"></param>
17+
/// <param name="output_func_def"></param>
18+
/// <param name="status"></param>
19+
[DllImport(TensorFlowLibName)]
20+
public static extern void TF_FunctionToFunctionDef(IntPtr func, IntPtr output_func_def, IntPtr status);
21+
}
22+
}

src/TensorFlowNET.Core/Operations/c_api.ops.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,15 @@ public static partial class c_api
3131
[DllImport(TensorFlowLibName)]
3232
public static extern string TF_OperationDevice(IntPtr oper);
3333

34+
/// <summary>
35+
/// Sets `output_attr_value` to the binary-serialized AttrValue proto
36+
/// representation of the value of the `attr_name` attr of `oper`.
37+
/// </summary>
38+
/// <param name="oper"></param>
39+
/// <returns></returns>
40+
[DllImport(TensorFlowLibName)]
41+
public static extern int TF_OperationGetAttrValueProto(IntPtr oper, string attr_name, IntPtr output_attr_value, IntPtr status);
42+
3443
[DllImport(TensorFlowLibName)]
3544
public static extern string TF_OperationName(IntPtr oper);
3645

test/TensorFlowNET.UnitTest/GraphTest.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ public void Graph()
2828
Assert.AreEqual(0, feed.NumControlInputs);
2929
Assert.AreEqual(0, feed.NumControlOutputs);
3030

31-
var attr_value = new AttrValue();
32-
c_test_util.GetAttrValue(feed, "dtype", attr_value, s);
31+
AttrValue attr_value = null;
32+
c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s);
3333
}
3434
}
3535
}

test/TensorFlowNET.UnitTest/c_test_util.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@ namespace TensorFlowNET.UnitTest
99
{
1010
public static class c_test_util
1111
{
12-
public static bool GetAttrValue(Operation oper, string attr_name, AttrValue attr_value, Status s)
12+
public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s)
1313
{
14-
var buffer = c_api.TF_NewBuffer();
15-
14+
var buffer = new Buffer();
15+
c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s);
16+
attr_value = AttrValue.Parser.ParseFrom(buffer.Data);
17+
buffer.Dispose();
1618
return s.Code == TF_Code.TF_OK;
1719
}
1820

0 commit comments

Comments
 (0)