Skip to content

Commit dcfaa77

Browse files
committed
Create EagerTensor from NDArray.
1 parent 855eba9 commit dcfaa77

5 files changed

Lines changed: 44 additions & 6 deletions

File tree

src/TensorFlowNET.Core/APIs/tf.linalg.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ public Tensor diag(Tensor diagonal, string name = null)
2222
=> gen_array_ops.diag(diagonal, name: name);
2323

2424
public Tensor matmul(Tensor a, Tensor b)
25-
=> gen_math_ops.mat_mul(a, b);
25+
=> math_ops.matmul(a, b);
2626

2727
public Tensor batch_matmul(Tensor x, Tensor y)
2828
=> gen_math_ops.batch_mat_mul(x, y);

src/TensorFlowNET.Core/Eager/EagerTensor.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using System;
1+
using NumSharp;
2+
using System;
23
using System.Collections.Generic;
34
using System.Text;
45

@@ -18,6 +19,10 @@ public EagerTensor(int value, string device_name) : base(value)
1819
{
1920
}
2021

22+
public EagerTensor(NDArray value, string device_name) : base(value)
23+
{
24+
}
25+
2126
public override string ToString()
2227
{
2328
switch (rank)

src/TensorFlowNET.Core/Operations/gen_math_ops.cs

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,14 @@ public static Tensor floor_div(Tensor x, Tensor y, string name = null)
638638
/// <returns></returns>
639639
public static Tensor mat_mul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false, string name = null)
640640
{
641+
if (tf.context.executing_eagerly())
642+
{
643+
var _result = wrap_tfe_src.TFE_Py_FastPathExecute(tf.context, tf.context.device_name,
644+
"MatMul", name, null,
645+
a, b, "transpose_a", transpose_a, "transpose_b", transpose_b);
646+
return _result;
647+
}
648+
641649
var _op = _op_def_lib._apply_op_helper("MatMul", name, args: new { a, b, transpose_a, transpose_b });
642650

643651
return _op.output;
@@ -738,17 +746,37 @@ public static Tensor _sum<Tx, Ty>(Tx input, Ty axis = default, bool keep_dims =
738746
{
739747
if (tf.context.executing_eagerly())
740748
{
741-
var _result = wrap_tfe_src.TFE_Py_FastPathExecute(tf.context, tf.context.device_name,
742-
"Sum", name, null,
743-
input, axis, "keep_dims", keep_dims);
744-
return _result;
749+
try
750+
{
751+
var _result = wrap_tfe_src.TFE_Py_FastPathExecute(tf.context, tf.context.device_name,
752+
"Sum", name, null,
753+
input, axis, "keep_dims", keep_dims);
754+
return _result;
755+
}
756+
catch (Exception)
757+
{
758+
return _sum_eager_fallback(input as Tensor[], axis as Tensor,
759+
keep_dims: keep_dims, name: name, ctx: tf.context);
760+
}
745761
}
746762

747763
var _op = _op_def_lib._apply_op_helper("Sum", name, args: new { input, reduction_indices = axis, keep_dims });
748764

749765
return _op.outputs[0];
750766
}
751767

768+
private static Tensor _sum_eager_fallback(Tensor[] inputs, Tensor axis, bool keep_dims = false, string name = null, Context ctx = null)
769+
{
770+
var (_attr_T, input) = _execute.args_to_matching_eager(inputs, ctx);
771+
var (_attr_Tidx, axis1) = _execute.args_to_matching_eager(new[] { axis }, ctx, TF_DataType.TF_INT32);
772+
var _inputs_flat = new Tensor[] { input, axis1 };
773+
774+
var _attrs = new object[] { "keep_dims", keep_dims, "T", _attr_T, "Tidx", _attr_Tidx };
775+
776+
var _result = _execute.execute(ctx, "Sum", _inputs_flat, _attrs, name: name);
777+
return _result;
778+
}
779+
752780
/// <summary>
753781
/// Creates a sequence of numbers.
754782
/// </summary>

src/TensorFlowNET.Core/Tensors/Tensor.Value.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,8 @@ public NDArray numpy()
163163
return StringData();
164164
case TF_DataType.TF_INT32:
165165
return ToArray<int>();
166+
case TF_DataType.TF_FLOAT:
167+
return ToArray<float>();
166168
default:
167169
return BufferToArray();
168170
}

src/TensorFlowNET.Core/Tensors/constant_op.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using NumSharp;
1718
using System;
1819
using System.Collections.Generic;
1920
using Tensorflow.Eager;
@@ -84,6 +85,8 @@ private static EagerTensor convert_to_eager_tensor(object value, Context ctx, TF
8485
{
8586
switch (value)
8687
{
88+
case NDArray nd:
89+
return new EagerTensor(nd, ctx.device_name);
8790
case string str:
8891
return new EagerTensor(str, ctx.device_name);
8992
case int int32:

0 commit comments

Comments
 (0)