Skip to content

Commit 1d2fe5b

Browse files
committed
Scalar and String constant creation.
1 parent 8aa30b9 commit 1d2fe5b

17 files changed

Lines changed: 214 additions & 65 deletions

src/TensorFlowNET.Core/Operations/Operation.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataT
3939
_outputs = new Tensor[num_outputs];
4040
for (int i = 0; i < num_outputs; i++)
4141
{
42-
_outputs[i] = new Tensor(this, i, TF_DataType.TF_FLOAT);
42+
_outputs[i] = new Tensor(this, i, output_types[i]);
4343
}
4444

4545
_graph._add_op(this);

src/TensorFlowNET.Core/Operations/gen_math_ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public static Tensor add(Tensor a, Tensor b)
2525
private static OpDefLibrary _InitOpDefLibrary()
2626
{
2727
// c_api.TF_GraphGetOpDef(g.Handle, op_type_name, buffer.Handle, status.Handle);
28-
var bytes = File.ReadAllBytes("Tensorflow/op_list_proto_math.bin");
28+
var bytes = File.ReadAllBytes("Operations/op_list_proto_math.bin");
2929
var op_list = OpList.Parser.ParseFrom(bytes);
3030
var op_def_lib = new OpDefLibrary();
3131
op_def_lib.add_op_list(op_list);

src/TensorFlowNET.Core/Protobuf/op_list_proto_array.bin renamed to src/TensorFlowNET.Core/Operations/op_list_proto_array.bin

File renamed without changes.

src/TensorFlowNET.Core/Protobuf/op_list_proto_math.bin renamed to src/TensorFlowNET.Core/Operations/op_list_proto_math.bin

File renamed without changes.

src/TensorFlowNET.Core/Operations/ops.cs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,16 @@ public static Graph get_default_graph()
1616
return tf.Graph();
1717
}
1818

19-
public static Tensor convert_to_tensor()
19+
public static Tensor convert_to_tensor(object value, string name = "")
2020
{
21-
return internal_convert_to_tensor();
21+
return internal_convert_to_tensor(value, name);
2222
}
2323

24-
private static Tensor internal_convert_to_tensor()
24+
private static Tensor internal_convert_to_tensor(object value, string name = "")
2525
{
26-
return null;
26+
return tf.constant(value);
2727
}
2828

29-
30-
3129
public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List<Tensor> inputs)
3230
{
3331
var op_desc = c_api.TF_NewOperation(graph.Handle, node_def.Op, node_def.Name);

src/TensorFlowNET.Core/TensorFlowNET.Core.csproj

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@
2727
<None Update="tensorflow.dll">
2828
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
2929
</None>
30-
<None Update="Protobuf\op_list_proto_array.bin">
30+
<None Update="Operations\op_list_proto_array.bin">
3131
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
3232
</None>
33-
<None Update="Protobuf\op_list_proto_math.bin">
33+
<None Update="Operations\op_list_proto_math.bin">
3434
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
3535
</None>
3636
</ItemGroup>

src/TensorFlowNET.Core/Tensors/RefVariable.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ namespace Tensorflow
77
public class RefVariable : Variable
88
{
99
public bool _in_graph_mode = true;
10+
public Tensor _initial_value;
1011

1112
public RefVariable(object initial_value,
1213
TF_DataType trainable,
@@ -16,9 +17,10 @@ public RefVariable(object initial_value,
1617

1718
}
1819

19-
private void _init_from_args()
20+
private void _init_from_args(object initial_value,
21+
TF_DataType trainable)
2022
{
21-
23+
_initial_value = ops.convert_to_tensor(initial_value, name: "initial_value");
2224
}
2325
}
2426
}

src/TensorFlowNET.Core/Tensors/Tensor.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ namespace Tensorflow
1313
/// </summary>
1414
public class Tensor
1515
{
16-
public Operation op { get; }
17-
public int value_index { get; }
18-
1916
public Graph graph => op.graph;
17+
public Operation op { get; }
2018

2119
public string name;
20+
public object value;
21+
public int value_index { get; }
2222

2323
public TF_DataType dtype { get; }
2424
public IntPtr handle { get; }

src/TensorFlowNET.Core/Tensors/TensorShape.cs

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,14 @@
66

77
namespace Tensorflow
88
{
9+
/// <summary>
10+
/// Represents the shape of a `Tensor`.
11+
/// </summary>
912
public class TensorShape : Shape
1013
{
11-
public TensorShape(params int[] shape) : base(shape)
14+
public TensorShape(params int[] dims) : base(dims)
1215
{
1316

1417
}
15-
16-
public TensorShape as_shape()
17-
{
18-
return this;
19-
}
20-
21-
public TensorShapeProto as_proto()
22-
{
23-
TensorShapeProto dim = new TensorShapeProto();
24-
25-
return new TensorShapeProto(dim);
26-
}
2718
}
2819
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public class constant_op
8+
{
9+
/// <summary>
10+
/// Creates a constant tensor.
11+
///
12+
/// The resulting tensor is populated with values of type `dtype`, as
13+
/// specified by arguments `value` and (optionally) `shape`
14+
/// </summary>
15+
/// <param name="value">A constant value (or list) of output type `dtype`.</param>
16+
/// <param name="dtype">The type of the elements of the resulting tensor.</param>
17+
/// <param name="shape">Optional dimensions of resulting tensor.</param>
18+
/// <param name="name">Optional name for the tensor.</param>
19+
/// <param name="verify_shape">Boolean that enables verification of a shape of values.</param>
20+
/// <returns></returns>
21+
public static Tensor Create(object value, TF_DataType dtype = TF_DataType.DtInvalid, TensorShape shape = null, string name = "Const", bool verify_shape = false)
22+
{
23+
Graph g = ops.get_default_graph();
24+
var tensor_value = new AttrValue();
25+
var tensor_pb = tensor_util.make_tensor_proto(value, dtype, shape, verify_shape);
26+
tensor_value.Tensor = tensor_pb;
27+
var dtype_value = new AttrValue
28+
{
29+
Type = tensor_value.Tensor.Dtype,
30+
};
31+
32+
var attrs = new Dictionary<string, AttrValue>();
33+
attrs["dtype"] = dtype_value;
34+
attrs["value"] = tensor_value;
35+
var const_tensor = g.create_op("Const", null, new TF_DataType[] { (TF_DataType)dtype_value.Type }, attrs: attrs).outputs[0];
36+
const_tensor.value = value;
37+
38+
return const_tensor;
39+
}
40+
}
41+
}

0 commit comments

Comments
 (0)