Skip to content

Commit 8aa30b9

Browse files
committed
RefVairable and Variable
1 parent b084ae6 commit 8aa30b9

5 files changed

Lines changed: 73 additions & 0 deletions

File tree

src/TensorFlowNET.Core/Operations/ops.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,18 @@ public static Graph get_default_graph()
1616
return tf.Graph();
1717
}
1818

19+
public static Tensor convert_to_tensor()
20+
{
21+
return internal_convert_to_tensor();
22+
}
23+
24+
private static Tensor internal_convert_to_tensor()
25+
{
26+
return null;
27+
}
28+
29+
30+
1931
public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List<Tensor> inputs)
2032
{
2133
var op_desc = c_api.TF_NewOperation(graph.Handle, node_def.Op, node_def.Name);
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public class RefVariable : Variable
8+
{
9+
public bool _in_graph_mode = true;
10+
11+
public RefVariable(object initial_value,
12+
TF_DataType trainable,
13+
bool validate_shape = true) :
14+
base(initial_value, trainable, validate_shape)
15+
{
16+
17+
}
18+
19+
private void _init_from_args()
20+
{
21+
22+
}
23+
}
24+
}

src/TensorFlowNET.Core/Tensors/Variable.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,21 @@
44

55
namespace Tensorflow
66
{
7+
/// <summary>
8+
/// A variable maintains state in the graph across calls to `run()`. You add a
9+
/// variable to the graph by constructing an instance of the class `Variable`.
10+
///
11+
/// The `Variable()` constructor requires an initial value for the variable,
12+
/// which can be a `Tensor` of any type and shape. The initial value defines the
13+
/// type and shape of the variable. After construction, the type and shape of
14+
/// the variable are fixed. The value can be changed using one of the assign methods.
15+
/// https://tensorflow.org/guide/variables
16+
/// </summary>
717
public class Variable
818
{
19+
public Variable(object initial_value, TF_DataType trainable, bool validate_shape = true)
20+
{
21+
22+
}
923
}
1024
}

src/TensorFlowNET.Core/tf.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,17 @@ namespace Tensorflow
1111
public static class tf
1212
{
1313
public static TF_DataType float32 = TF_DataType.TF_FLOAT;
14+
public static TF_DataType chars = TF_DataType.TF_STRING;
1415

1516
public static Context context = new Context();
1617

1718
public static Graph g = new Graph(c_api.TF_NewGraph());
1819

20+
public static object Variable<T>(T data, TF_DataType dtype)
21+
{
22+
return new Variable(null, TF_DataType.DtInvalid);
23+
}
24+
1925
public static unsafe Tensor add(Tensor a, Tensor b)
2026
{
2127
return gen_math_ops.add(a, b);
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Text;
5+
using Tensorflow;
6+
7+
namespace TensorFlowNET.UnitTest
8+
{
9+
[TestClass]
10+
public class VariableTest
11+
{
12+
public void Creating()
13+
{
14+
var mammal = tf.Variable("Elephant", tf.chars);
15+
}
16+
}
17+
}

0 commit comments

Comments
 (0)