Skip to content

Commit afee297

Browse files
committed
add ResourceVarible operate functions
1 parent d6e0400 commit afee297

2 files changed

Lines changed: 45 additions & 16 deletions

File tree

src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,26 @@ namespace Tensorflow
2222
{
2323
public partial class ResourceVariable
2424
{
25-
public static Tensor operator +(ResourceVariable x, int y) => op_helper("add", x, y);
26-
public static Tensor operator +(ResourceVariable x, float y) => op_helper("add", x, y);
27-
public static Tensor operator +(ResourceVariable x, double y) => op_helper("add", x, y);
28-
29-
public static Tensor operator -(ResourceVariable x, int y) => op_helper("sub", x, y);
30-
public static Tensor operator -(ResourceVariable x, float y) => op_helper("sub", x, y);
31-
public static Tensor operator -(ResourceVariable x, double y) => op_helper("sub", x, y);
32-
public static Tensor operator -(ResourceVariable x, Tensor y) => op_helper("sub", x, y);
25+
public static OpDefLibrary _op_def_lib = new OpDefLibrary();
3326

34-
public static Tensor operator *(ResourceVariable x, ResourceVariable y) => op_helper("mul", x, y);
35-
public static Tensor operator *(ResourceVariable x, NDArray y) => op_helper("mul", x, y);
27+
public static ResourceVariable operator +(ResourceVariable x, int y) => op_helper("add", x, y);
28+
public static ResourceVariable operator +(ResourceVariable x, float y) => op_helper("add", x, y);
29+
public static ResourceVariable operator +(ResourceVariable x, double y) => op_helper("add", x, y);
30+
public static ResourceVariable operator +(ResourceVariable x, ResourceVariable y) => op_helper("add", x, y);
31+
public static ResourceVariable operator -(ResourceVariable x, int y) => op_helper("sub", x, y);
32+
public static ResourceVariable operator -(ResourceVariable x, float y) => op_helper("sub", x, y);
33+
public static ResourceVariable operator -(ResourceVariable x, double y) => op_helper("sub", x, y);
34+
public static ResourceVariable operator -(ResourceVariable x, Tensor y) => op_helper("sub", x, y);
35+
public static ResourceVariable operator -(ResourceVariable x, ResourceVariable y) => op_helper("sub", x, y);
3636

37-
public static Tensor operator <(ResourceVariable x, Tensor y) => gen_math_ops.less(x.value(), y);
37+
public static ResourceVariable operator *(ResourceVariable x, ResourceVariable y) => op_helper("mul", x, y);
38+
public static ResourceVariable operator *(ResourceVariable x, NDArray y) => op_helper("mul", x, y);
3839

39-
public static Tensor operator >(ResourceVariable x, Tensor y) => gen_math_ops.greater(x.value(), y);
40+
public static ResourceVariable operator <(ResourceVariable x, Tensor y) => less(x.value(), y);
4041

41-
private static Tensor op_helper<T>(string default_name, ResourceVariable x, T y)
42+
public static ResourceVariable operator >(ResourceVariable x, Tensor y) => greater(x.value(), y);
43+
44+
private static ResourceVariable op_helper<T>(string default_name, ResourceVariable x, T y)
4245
=> tf_with(ops.name_scope(null, default_name, new { x, y }), scope =>
4346
{
4447
string name = scope;
@@ -64,7 +67,33 @@ private static Tensor op_helper<T>(string default_name, ResourceVariable x, T y)
6467

6568
// x.assign(result);
6669
// result.ResourceVar = x;
67-
return result;
70+
return tf.Variable(result);
6871
});
72+
73+
private static ResourceVariable less<Tx, Ty>(Tx x, Ty y, string name = null)
74+
{
75+
if (tf.context.executing_eagerly())
76+
{
77+
var results = EagerTensorPass.Create();
78+
var inputs = EagerTensorPass.From(x, y);
79+
Status status = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
80+
"Less", name,
81+
inputs.Points, inputs.Length,
82+
null, null,
83+
results.Points, results.Length);
84+
status.Check(true);
85+
return tf.Variable(results[0].Resolve());
86+
}
87+
88+
var _op = _op_def_lib._apply_op_helper("Less", name: name, args: new { x, y });
89+
90+
return tf.Variable(_op.outputs[0]);
91+
}
92+
private static ResourceVariable greater<Tx, Ty>(Tx x, Ty y, string name = null)
93+
{
94+
var _op = _op_def_lib._apply_op_helper("Greater", name: name, args: new { x, y });
95+
96+
return tf.Variable(_op.outputs[0]);
97+
}
6998
}
7099
}

test/TensorFlowNET.UnitTest/Basics/VariableTest.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,10 @@ public void Assign2()
5656
public void Accumulation()
5757
{
5858
var x = tf.Variable(10, name: "x");
59-
/*for (int i = 0; i < 5; i++)
59+
for (int i = 0; i < 5; i++)
6060
x = x + 1;
6161

62-
Assert.AreEqual(15, (int)x.numpy());*/
62+
Assert.AreEqual(15, (int)x.numpy());
6363
}
6464

6565
[TestMethod]

0 commit comments

Comments
 (0)