Skip to content

Commit 1c5731f

Browse files
committed
fix internal_convert_n_to_tensor return type.
1 parent 5ccef1b commit 1c5731f

2 files changed

Lines changed: 14 additions & 21 deletions

File tree

src/TensorFlowNET.Core/Operations/OpDefLibrary.cs

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -84,32 +84,25 @@ public Operation _apply_op_helper(string op_type_name, string name = "", dynamic
8484
dtype = dtype.as_base_dtype();
8585

8686
values = ops.internal_convert_n_to_tensor(values, name: input_arg.Name, dtype: dtype, preferred_dtype: default_dtype, as_ref: input_arg.IsRef);
87-
88-
inputs.AddRange(values as Tensor[]);
8987
}
9088
else
9189
{
92-
if (!(values is Tensor))
90+
if (keywords[input_name] is Tensor)
9391
{
94-
keywords[input_name] = constant_op.constant(values, input_name);
9592
}
96-
97-
if (keywords[input_name] is Tensor value)
93+
else
9894
{
99-
if (keywords.ContainsKey(input_name))
100-
{
101-
inputs.Add(value);
102-
}
103-
104-
if (!String.IsNullOrEmpty(input_arg.TypeAttr))
105-
{
106-
attrs[input_arg.TypeAttr] = value.dtype;
107-
}
95+
keywords[input_name] = ops.internal_convert_to_tensor(values, name: input_name);
96+
}
10897

109-
values = new Tensor[] { value };
98+
if (!String.IsNullOrEmpty(input_arg.TypeAttr))
99+
{
100+
attrs[input_arg.TypeAttr] = (keywords[input_name] as Tensor).dtype;
110101
}
102+
values = new Tensor[] { keywords[input_name] as Tensor };
111103
}
112104

105+
inputs.AddRange(values as Tensor[]);
113106
base_types.AddRange((values as Tensor[]).Select(x => x.dtype.as_base_dtype()));
114107
input_types.AddRange(base_types);
115108
}

src/TensorFlowNET.Core/ops.py.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -310,11 +310,11 @@ public static Session get_default_session()
310310
};
311311
}
312312

313-
public static T[] internal_convert_n_to_tensor<T>(T[] values, DataType dtype = DataType.DtInvalid,
313+
public static Tensor[] internal_convert_n_to_tensor<T>(T[] values, DataType dtype = DataType.DtInvalid,
314314
string name = "", DataType preferred_dtype = DataType.DtInvalid,
315315
bool as_ref = false)
316316
{
317-
var ret = new List<T>();
317+
var ret = new List<Tensor>();
318318

319319
foreach((int i, T value) in Python.enumerate(values))
320320
{
@@ -325,16 +325,16 @@ public static T[] internal_convert_n_to_tensor<T>(T[] values, DataType dtype = D
325325
return ret.ToArray();
326326
}
327327

328-
public static T internal_convert_to_tensor<T>(T value, DataType dtype = DataType.DtInvalid,
328+
public static Tensor internal_convert_to_tensor<T>(T value, DataType dtype = DataType.DtInvalid,
329329
string name = "", DataType preferred_dtype = DataType.DtInvalid,
330330
bool as_ref = false)
331331
{
332332
switch (typeof(T).Name)
333333
{
334334
case "Tensor":
335-
return value;
335+
return value as Tensor;
336336
default:
337-
throw new NotImplementedException("internal_convert_to_tensor");
337+
return constant_op.constant(np.array(value), name);
338338
}
339339
}
340340
}

0 commit comments

Comments
 (0)