Skip to content

Commit 11377d6

Browse files
committed
add NewTF_Tensor test
1 parent fca0368 commit 11377d6

17 files changed

Lines changed: 171 additions & 32 deletions

TensorFlow.NET.sln

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\T
99
EndProject
1010
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "test\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj", "{1FE60088-157C-4140-91AB-E96B915E4BAE}"
1111
EndProject
12+
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{6ACED8FF-F08E-40E6-A75D-D01BAAA41072}"
13+
EndProject
1214
Global
1315
GlobalSection(SolutionConfigurationPlatforms) = preSolution
1416
Debug|Any CPU = Debug|Any CPU
@@ -27,6 +29,10 @@ Global
2729
{1FE60088-157C-4140-91AB-E96B915E4BAE}.Debug|Any CPU.Build.0 = Debug|Any CPU
2830
{1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.ActiveCfg = Release|Any CPU
2931
{1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.Build.0 = Release|Any CPU
32+
{6ACED8FF-F08E-40E6-A75D-D01BAAA41072}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
33+
{6ACED8FF-F08E-40E6-A75D-D01BAAA41072}.Debug|Any CPU.Build.0 = Debug|Any CPU
34+
{6ACED8FF-F08E-40E6-A75D-D01BAAA41072}.Release|Any CPU.ActiveCfg = Release|Any CPU
35+
{6ACED8FF-F08E-40E6-A75D-D01BAAA41072}.Release|Any CPU.Build.0 = Release|Any CPU
3036
EndGlobalSection
3137
GlobalSection(SolutionProperties) = preSolution
3238
HideSolutionNode = FALSE

src/TensorFlowNET.Core/OpDefLibrary.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ public unsafe Operation _apply_op_helper(string op_type_name, string name = "",
5050

5151
// Perform input type inference
5252
var inputs = new List<Tensor>();
53-
var input_types = new List<DataType>();
53+
var input_types = new List<TF_DataType>();
5454

5555
foreach (var input_arg in op_def.InputArg)
5656
{
@@ -106,7 +106,7 @@ public unsafe Operation _apply_op_helper(string op_type_name, string name = "",
106106
}
107107

108108
// Determine output types (possibly using attrs)
109-
var output_types = new List<DataType>();
109+
var output_types = new List<TF_DataType>();
110110

111111
foreach (var arg in op_def.OutputArg)
112112
{
@@ -116,7 +116,7 @@ public unsafe Operation _apply_op_helper(string op_type_name, string name = "",
116116
}
117117
else if (!String.IsNullOrEmpty(arg.TypeAttr))
118118
{
119-
output_types.Add(attr_protos[arg.TypeAttr].Type);
119+
output_types.Add((TF_DataType)attr_protos[arg.TypeAttr].Type);
120120
}
121121
}
122122

src/TensorFlowNET.Core/Operation.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ public Operation(Graph g, string opType, string oper_name)
2424
var status = new Status();
2525

2626
var desc = c_api.TF_NewOperation(g.Handle, opType, oper_name);
27-
c_api.TF_SetAttrType(desc, "dtype", DataType.DtInt32);
27+
c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_INT32);
2828
c_api.TF_FinishOperation(desc, status.Handle);
2929
}
3030

@@ -39,7 +39,7 @@ public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataT
3939
_outputs = new Tensor[num_outputs];
4040
for (int i = 0; i < num_outputs; i++)
4141
{
42-
_outputs[i] = new Tensor(this, i, TF_DataType.DtFloat);
42+
_outputs[i] = new Tensor(this, i, TF_DataType.TF_FLOAT);
4343
}
4444

4545
_graph._add_op(this);

src/TensorFlowNET.Core/Session/BaseSession.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using System;
33
using System.Collections.Generic;
44
using System.Linq;
5+
using System.Runtime.InteropServices;
56
using System.Text;
67

78
namespace Tensorflow
@@ -113,10 +114,9 @@ private unsafe object[] _call_tf_sessionrun(TF_Output[] fetch_list)
113114
run_metadata: IntPtr.Zero,
114115
status: status.Handle);
115116

116-
var result = output_values.Select(x => new Tensor(x).buffer).Select(x =>
117-
{
118-
return (object)*(float*)x;
119-
}).ToArray();
117+
var result = output_values.Select(x => c_api.TF_TensorData(x))
118+
.Select(x => (object)*(float*)x)
119+
.ToArray();
120120

121121
return result;
122122
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public enum TF_DataType
8+
{
9+
TF_FLOAT = 1,
10+
TF_DOUBLE = 2,
11+
TF_INT32 = 3, // Int32 tensors are always in 'host' memory.
12+
TF_UINT8 = 4,
13+
TF_INT16 = 5,
14+
TF_INT8 = 6,
15+
TF_STRING = 7,
16+
TF_COMPLEX64 = 8, // Single-precision complex
17+
TF_COMPLEX = 8, // Old identifier kept for API backwards compatibility
18+
TF_INT64 = 9,
19+
TF_BOOL = 10,
20+
TF_QINT8 = 11, // Quantized int8
21+
TF_QUINT8 = 12, // Quantized uint8
22+
TF_QINT32 = 13, // Quantized int32
23+
TF_BFLOAT16 = 14, // Float32 truncated to 16 bits. Only for cast ops.
24+
TF_QINT16 = 15, // Quantized int16
25+
TF_QUINT16 = 16, // Quantized uint16
26+
TF_UINT16 = 17,
27+
TF_COMPLEX128 = 18, // Double-precision complex
28+
TF_HALF = 19,
29+
TF_RESOURCE = 20,
30+
TF_VARIANT = 21,
31+
TF_UINT32 = 22,
32+
TF_UINT64 = 23
33+
}
34+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Runtime.InteropServices;
4+
using System.Text;
5+
6+
namespace Tensorflow
7+
{
8+
[StructLayout(LayoutKind.Sequential)]
9+
public struct TF_Tensor
10+
{
11+
public TF_DataType dtype;
12+
public IntPtr shape;
13+
public IntPtr buffer;
14+
}
15+
}
Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Runtime.InteropServices;
34
using System.Text;
45

56
namespace Tensorflow
@@ -10,23 +11,28 @@ public class Tensor
1011
public Operation op => _op;
1112
private readonly int _value_index;
1213
public int value_index => _value_index;
13-
private DataType _dtype;
14-
public DataType dtype => _dtype;
14+
private TF_DataType _dtype;
15+
public TF_DataType dtype => _dtype;
1516

1617
public Graph graph => _op.graph;
1718

1819
public string name;
1920

2021
private readonly IntPtr _handle;
2122
public IntPtr handle => _handle;
22-
public IntPtr buffer => c_api.TF_TensorData(_handle);
23+
24+
private TF_Tensor tensor;
25+
26+
public IntPtr buffer => c_api.TF_TensorData(tensor.buffer);
2327

2428
public Tensor(IntPtr handle)
2529
{
2630
_handle = handle;
31+
tensor = Marshal.PtrToStructure<TF_Tensor>(handle);
32+
_dtype = tensor.dtype;
2733
}
2834

29-
public Tensor(Operation op, int value_index, DataType dtype)
35+
public Tensor(Operation op, int value_index, TF_DataType dtype)
3036
{
3137
_op = op;
3238
_value_index = value_index;
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public class TensorBuffer
8+
{
9+
}
10+
}
File renamed without changes.

src/TensorFlowNET.Core/TensorFlowNET.Core.csproj

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717

1818
<ItemGroup>
1919
<PackageReference Include="Google.Protobuf" Version="3.6.1" />
20-
<PackageReference Include="NumSharp" Version="0.6.0" />
20+
</ItemGroup>
21+
22+
<ItemGroup>
23+
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" />
2124
</ItemGroup>
2225

2326
<ItemGroup>

0 commit comments

Comments
 (0)