Skip to content

Commit afaf0c8

Browse files
committed
fix input list into op
1 parent 71fef9a commit afaf0c8

1 file changed

Lines changed: 36 additions & 13 deletions

File tree

src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ namespace Tensorflow.Eager
1010
/// </summary>
1111
public class pywrap_tfe_src
1212
{
13+
static int kFastPathExecuteInputStartIndex = 0;
1314
public static EagerTensor TFE_Py_FastPathExecute(Context ctx,
1415
string device_name,
1516
string opName,
@@ -28,7 +29,7 @@ public static EagerTensor TFE_Py_FastPathExecute(Context ctx,
2829

2930
// Set non-inferred attrs, including setting defaults if the attr is passed in
3031
// as None.
31-
for (int i = op_def.InputArg.Count; i < args_size; i += 2)
32+
for (int i = kFastPathExecuteInputStartIndex + op_def.InputArg.Count; i < args_size; i += 2)
3233
{
3334
var attr_name = args[i].ToString();
3435
var attr_value = args[i + 1];
@@ -38,20 +39,39 @@ public static EagerTensor TFE_Py_FastPathExecute(Context ctx,
3839
if(attr_name == attr.Name)
3940
{
4041
SetOpAttrWithDefaults(ctx, op, attr, attr_name, attr_value, attr_list_sizes, status);
42+
status.Check(true);
4143
break;
4244
}
4345
}
4446
}
4547

4648
c_api.TFE_OpSetDevice(op, device_name, status);
49+
status.Check(true);
4750

51+
// Add inferred attrs and inputs.
4852
for (int i = 0; i < op_def.InputArg.Count; i++)
4953
{
5054
var input_arg = op_def.InputArg[i];
55+
int len = (args[kFastPathExecuteInputStartIndex + i] as object[]).Length;
5156
if (!string.IsNullOrEmpty(input_arg.NumberAttr))
5257
{
53-
c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, 0);
54-
attr_list_sizes[input_arg.NumberAttr] = 0;
58+
c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, len);
59+
attr_list_sizes[input_arg.NumberAttr] = len;
60+
61+
if (len > 0)
62+
{
63+
var fast_input_array = (object[])args[i];
64+
// First item adds the type attr.
65+
if (!AddInputToOp(fast_input_array[i], true, input_arg, op, status))
66+
return null;
67+
68+
for (var j = 1; j < len; j++)
69+
{
70+
// Since the list is homogeneous, we don't need to re-add the attr.
71+
if (!AddInputToOp(fast_input_array[j], false, input_arg, op, status))
72+
return null;
73+
}
74+
}
5575
}
5676
else if (!string.IsNullOrEmpty(input_arg.TypeListAttr))
5777
{
@@ -60,14 +80,7 @@ public static EagerTensor TFE_Py_FastPathExecute(Context ctx,
6080
else
6181
{
6282
// The item is a single item.
63-
switch (args[i])
64-
{
65-
case Tensor inputTensor:
66-
AddInputToOp(inputTensor, true, input_arg, op, status);
67-
break;
68-
default:
69-
throw new NotImplementedException("");
70-
}
83+
AddInputToOp(args[i], true, input_arg, op, status);
7184
}
7285
}
7386

@@ -106,13 +119,23 @@ public static EagerTensor TFE_Py_FastPathExecute(Context ctx,
106119
/// <param name="op"></param>
107120
/// <param name="status"></param>
108121
/// <returns></returns>
109-
private static bool AddInputToOp(Tensor input,
122+
private static bool AddInputToOp(object inputs,
110123
bool add_type_attr,
111124
ArgDef input_arg,
112125
IntPtr op,
113126
Status status)
114127
{
115-
var input_handle = c_api.TFE_NewTensorHandle(input, status);
128+
IntPtr input_handle = IntPtr.Zero;
129+
130+
switch (inputs)
131+
{
132+
case Tensor input:
133+
input_handle = c_api.TFE_NewTensorHandle(input, status);
134+
break;
135+
default:
136+
throw new NotImplementedException("");
137+
}
138+
116139

117140
if(add_type_attr && !string.IsNullOrEmpty(input_arg.TypeAttr))
118141
{

0 commit comments

Comments
 (0)