Skip to content

Commit d7e04d5

Browse files
committed
test string const
1 parent 33449cd commit d7e04d5

11 files changed

Lines changed: 73 additions & 15 deletions

File tree

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,36 @@ public partial class Graph : IDisposable
2424
private List<String> _unfetchable_ops = new List<string>();
2525

2626
private string _name_stack;
27+
public Status Status { get; }
2728

2829
public Graph()
2930
{
3031
_handle = c_api.TF_NewGraph();
32+
Status = new Status();
3133
}
3234

3335
public Graph(IntPtr graph)
3436
{
3537
_handle = graph;
38+
Status = new Status();
3639
_nodes_by_id = new Dictionary<int, Operation>();
3740
_nodes_by_name = new Dictionary<string, Operation>();
3841
_names_in_use = new Dictionary<string, int>();
3942
}
4043

41-
public OperationDescription NewOperation(string opType, string opName)
44+
public Operation NewOperation(string opType, string opName, Tensor t)
4245
{
43-
return c_api.TF_NewOperation(_handle, opType, opName);
46+
var desc = c_api.TF_NewOperation(_handle, opType, opName);
47+
48+
c_api.TF_SetAttrTensor(desc, "value", t, Status);
49+
Status.Check();
50+
51+
c_api.TF_SetAttrType(desc, "dtype", t.dtype);
52+
53+
var op = c_api.TF_FinishOperation(desc, Status);
54+
Status.Check();
55+
56+
return op;
4457
}
4558

4659
public T as_graph_element<T>(T obj, bool allow_tensor = true, bool allow_operation = true)

src/TensorFlowNET.Core/Operations/c_api.ops.cs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,18 @@ namespace Tensorflow
77
{
88
public static partial class c_api
99
{
10+
/// <summary>
11+
/// Request that `desc` be co-located on the device where `op`
12+
/// is placed.
13+
///
14+
/// Use of this is discouraged since the implementation of device placement is
15+
/// subject to change. Primarily intended for internal libraries
16+
/// </summary>
17+
/// <param name="desc"></param>
18+
/// <param name="op"></param>
19+
[DllImport(TensorFlowLibName)]
20+
public static extern void TF_ColocateWith(IntPtr desc, IntPtr op);
21+
1022
/// <summary>
1123
/// Get the OpList of all OpDefs defined in this address space.
1224
/// </summary>
@@ -209,7 +221,7 @@ public static partial class c_api
209221
/// <param name="value">const void*</param>
210222
/// <param name="length">size_t</param>
211223
[DllImport(TensorFlowLibName)]
212-
public static extern void TF_SetAttrString(IntPtr desc, string attr_name, string value, uint length);
224+
public static extern void TF_SetAttrString(IntPtr desc, string attr_name, IntPtr value, uint length);
213225

214226
/// <summary>
215227
///

src/TensorFlowNET.Core/Sessions/BaseSession.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ private unsafe object[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] fe
119119
.Select(x => (object)*(float*)x)
120120
.ToArray();
121121

122+
var op = new Operation(fetch_list[0].oper);
123+
//var metadata = c_api.TF_OperationGetAttrMetadata(fetch_list[0].oper, "dtype", status);
124+
122125
return result;
123126
}
124127

src/TensorFlowNET.Core/Sessions/Session.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,19 @@ namespace Tensorflow
77
public class Session : BaseSession
88
{
99
private IntPtr _handle;
10+
public Status Status { get; }
11+
public SessionOptions Options { get; }
1012

1113
public Session(string target = "", Graph graph = null)
1214
{
15+
Status = new Status();
16+
if(graph == null)
17+
{
18+
graph = tf.get_default_graph();
19+
}
20+
Options = new SessionOptions();
21+
_handle = c_api.TF_NewSession(graph, Options, Status);
22+
Status.Check();
1323
}
1424

1525
public Session(IntPtr handle)

src/TensorFlowNET.Core/Status/Status.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,15 @@ public void SetStatus(TF_Code code, string msg)
3636
/// Check status
3737
/// Throw exception with error message if code != TF_OK
3838
/// </summary>
39-
public void Check()
39+
public void Check(bool throwException = false)
4040
{
4141
if(Code != TF_Code.TF_OK)
4242
{
4343
Console.WriteLine(Message);
44-
// throw new Exception(Message);
44+
if (throwException)
45+
{
46+
throw new Exception(Message);
47+
}
4548
}
4649
}
4750

src/TensorFlowNET.Core/Tensors/Tensor.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ public Tensor(NDArray nd)
6969
private IntPtr Allocate(NDArray nd)
7070
{
7171
var dotHandle = Marshal.AllocHGlobal(nd.dtypesize * nd.size);
72+
ulong size = (ulong)(nd.size * nd.dtypesize);
7273

7374
switch (nd.dtype.Name)
7475
{
@@ -81,16 +82,21 @@ private IntPtr Allocate(NDArray nd)
8182
case "Double":
8283
Marshal.Copy(nd.Data<double>(), 0, dotHandle, nd.size);
8384
break;
85+
case "String":
86+
dotHandle = Marshal.StringToHGlobalAuto(nd.Data<string>()[0]);
87+
size = (ulong)nd.Data<string>()[0].Length;
88+
break;
8489
default:
8590
throw new NotImplementedException("Marshal.Copy failed.");
8691
}
8792

8893
var dataType = ToTFDataType(nd.dtype);
94+
8995
var tfHandle = c_api.TF_NewTensor(dataType,
9096
nd.shape.Select(x => (long)x).ToArray(), // shape
9197
nd.ndim,
9298
dotHandle,
93-
(ulong)(nd.size * nd.dtypesize),
99+
size,
94100
(IntPtr values, IntPtr len, ref bool closure) =>
95101
{
96102
// Free the original buffer and set flag
@@ -154,6 +160,8 @@ public TF_DataType ToTFDataType(Type type)
154160
return TF_DataType.TF_FLOAT;
155161
case "Double":
156162
return TF_DataType.TF_DOUBLE;
163+
case "String":
164+
return TF_DataType.TF_STRING;
157165
}
158166

159167
return TF_DataType.DtInvalid;

src/TensorFlowNET.Core/Tensors/constant_op.cs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,13 @@ public static Tensor Create(NDArray nd, string name = "Const", bool verify_shape
3434
attrs["dtype"] = dtype_value;
3535
attrs["value"] = tensor_value;
3636

37-
var const_tensor = g.create_op("Const",
38-
null,
39-
new TF_DataType[] { (TF_DataType)dtype_value.Type },
37+
var op = g.create_op("Const",
38+
null,
39+
new TF_DataType[] { (TF_DataType)dtype_value.Type },
4040
attrs: attrs,
41-
name: name).outputs[0];
41+
name: name);
4242

43+
var const_tensor = op.outputs[0];
4344
const_tensor.value = nd.Data();
4445

4546
return const_tensor;

src/TensorFlowNET.Core/Tensors/tf.constant.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ namespace Tensorflow
77
{
88
public static partial class tf
99
{
10-
public static Tensor constant(NDArray value, string name = "Const", bool verify_shape = false)
10+
public static Tensor constant(NDArray nd, string name = "Const", bool verify_shape = false)
1111
{
12-
return constant_op.Create(value, name, verify_shape);
12+
return constant_op.Create(nd, name, verify_shape);
1313
}
1414
}
1515
}

test/TensorFlowNET.Examples/HelloWorld.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ of the Constant op.*/
2424
var sess = tf.Session();
2525

2626
// Run the op
27-
sess.run(hello);
27+
Console.WriteLine(sess.run(hello));
2828
}
2929
}
3030
}

test/TensorFlowNET.Examples/Program.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ static void Main(string[] args)
2323
Console.ReadLine();
2424
}
2525
}
26+
27+
Console.ReadLine();
2628
}
2729
}
2830
}

0 commit comments

Comments
 (0)