Skip to content

Commit 5ccef1b

Browse files
committed
Allow list parameter in OpDefLibrary
1 parent 4a80846 commit 5ccef1b

8 files changed

Lines changed: 149 additions & 18 deletions

File tree

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public class TypeError : Exception
8+
{
9+
public TypeError() : base()
10+
{
11+
12+
}
13+
14+
public TypeError(string message) : base(message)
15+
{
16+
17+
}
18+
}
19+
}

src/TensorFlowNET.Core/Operations/OpDefLibrary.cs

Lines changed: 71 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System.ComponentModel;
44
using System.Dynamic;
55
using System.IO;
6+
using System.Linq;
67
using System.Runtime.InteropServices;
78
using System.Text;
89
using static Tensorflow.OpDef.Types;
@@ -41,6 +42,7 @@ public Operation _apply_op_helper(string op_type_name, string name = "", dynamic
4142
var attrs = new Dictionary<string, object>();
4243
var inputs = new List<Tensor>();
4344
var input_types = new List<TF_DataType>();
45+
var base_types = new List<TF_DataType>();
4446

4547
Operation op = null;
4648
Python.with<ops.name_scope>(new ops.name_scope(name), scope =>
@@ -49,34 +51,67 @@ public Operation _apply_op_helper(string op_type_name, string name = "", dynamic
4951
foreach (var input_arg in op_def.InputArg)
5052
{
5153
var input_name = input_arg.Name;
52-
if (keywords[input_name] is double int_value)
54+
var values = keywords[input_name];
55+
// Goals:
56+
// * Convert values to Tensors if it contains constants.
57+
// * Verify that values is a list if that matches the input_arg's
58+
// type.
59+
// * If the input_arg's type is determined by attrs, either set
60+
// those attrs and validate those attr values are legal (if
61+
// they have not yet been set) or validate the input matches
62+
// the type indicated by the attrs (if they have already been
63+
// inferred via an earlier input).
64+
// * If the input_arg has an explicit type, make sure the input
65+
// conforms.
66+
67+
if (_IsListParameter(input_arg))
5368
{
54-
keywords[input_name] = constant_op.constant(int_value, input_name);
55-
}
69+
DataType dtype = DataType.DtInvalid;
70+
DataType default_dtype = DataType.DtInvalid;
5671

57-
if (keywords[input_name] is Tensor value)
58-
{
59-
if (keywords.ContainsKey(input_name))
72+
if (!_IsListValue(values))
73+
throw new TypeError($"Expected list for '{input_name}' argument to '{op_type_name}' Op, not {values}.");
74+
if(input_arg.Type != DataType.DtInvalid)
6075
{
61-
inputs.Add(value);
76+
dtype = input_arg.Type;
6277
}
63-
64-
if (!String.IsNullOrEmpty(input_arg.TypeAttr))
78+
else if (!String.IsNullOrEmpty(input_arg.NumberAttr))
6579
{
66-
attrs[input_arg.TypeAttr] = value.dtype;
80+
6781
}
6882

69-
if (input_arg.IsRef)
70-
{
83+
if(input_arg.IsRef && dtype != DataType.DtInvalid)
84+
dtype = dtype.as_base_dtype();
7185

86+
values = ops.internal_convert_n_to_tensor(values, name: input_arg.Name, dtype: dtype, preferred_dtype: default_dtype, as_ref: input_arg.IsRef);
87+
88+
inputs.AddRange(values as Tensor[]);
89+
}
90+
else
91+
{
92+
if (!(values is Tensor))
93+
{
94+
keywords[input_name] = constant_op.constant(values, input_name);
7295
}
73-
else
96+
97+
if (keywords[input_name] is Tensor value)
7498
{
75-
var base_type = value.dtype.as_base_dtype();
99+
if (keywords.ContainsKey(input_name))
100+
{
101+
inputs.Add(value);
102+
}
103+
104+
if (!String.IsNullOrEmpty(input_arg.TypeAttr))
105+
{
106+
attrs[input_arg.TypeAttr] = value.dtype;
107+
}
76108

77-
input_types.Add(base_type);
109+
values = new Tensor[] { value };
78110
}
79111
}
112+
113+
base_types.AddRange((values as Tensor[]).Select(x => x.dtype.as_base_dtype()));
114+
input_types.AddRange(base_types);
80115
}
81116

82117
// Process remaining attrs
@@ -152,6 +187,27 @@ public DataType _MakeType(TF_DataType v, AttrDef attr_def)
152187
return v.as_base_dtype().as_datatype_enum();
153188
}
154189

190+
private bool _IsListParameter(ArgDef arg)
191+
{
192+
if (!String.IsNullOrEmpty(arg.NumberAttr))
193+
return true;
194+
else if (!String.IsNullOrEmpty(arg.TypeListAttr))
195+
return true;
196+
else
197+
return false;
198+
}
199+
200+
private bool _IsListValue(object v)
201+
{
202+
switch (v)
203+
{
204+
case Tensor[] val:
205+
return true;
206+
default:
207+
return false;
208+
}
209+
}
210+
155211
private Dictionary<string, object> ConvertToDict(dynamic dyn)
156212
{
157213
var dictionary = new Dictionary<string, object>();

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ public static Tensor rank(Tensor input, string name = "")
5050
/// Creates a tensor filled with a scalar value.
5151
/// </summary>
5252
/// <param name="dims">A `Tensor`.</param>
53-
/// <param name="value">A `Tensor`.</param>
53+
/// <param name="value">A `Tensor`. 0-D (scalar). Value to fill the returned tensor.</param>
5454
/// <param name="name">A name for the operation (optional).</param>
5555
/// <returns>A `Tensor`. Has the same type as `value`.</returns>
56-
public static Tensor fill(Tensor dims, Tensor value, string name = "")
56+
public static Tensor fill<T>(Tensor dims, T value, string name = "")
5757
{
5858
var _op = _op_def_lib._apply_op_helper("Fill", name, new { dims, value });
5959

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public class gen_data_flow_ops
8+
{
9+
public static OpDefLibrary _op_def_lib = new OpDefLibrary();
10+
11+
public static Tensor dynamic_stitch(Tensor[] indices, Tensor[] data, string name = "")
12+
{
13+
var _attr_N = indices.Length;
14+
var _op = _op_def_lib._apply_op_helper("DynamicStitch", name, new { indices, data });
15+
16+
return _op.outputs[0];
17+
}
18+
}
19+
}

src/TensorFlowNET.Core/Operations/math_ops.py.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ public static Tensor reduced_shape(Tensor input_shape, Tensor axes)
2020
var input_rank = array_ops.size(input_shape);
2121
axes = (axes + input_rank) % input_rank;
2222
var axes_shape = array_ops.shape(axes);
23+
var a1 = new Tensor[] { input_rank, axes };
24+
var a2 = new Tensor[] { input_shape, gen_array_ops.fill(axes_shape, 1) };
2325

24-
return null;
26+
return gen_data_flow_ops.dynamic_stitch(a1, a2);
2527
}
2628

2729
/// <summary>

src/TensorFlowNET.Core/Python.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ public static TOut with<TIn, TOut>(IPython py, Func<TIn, TOut> action) where TIn
8383
for (int i = 0; i < t1.Count; i++)
8484
yield return (t1[i], t2[i]);
8585
}
86+
87+
public static IEnumerable<(int, T)> enumerate<T>(IList<T> values)
88+
{
89+
for (int i = 0; i < values.Count; i++)
90+
yield return (i, values[i]);
91+
}
8692
}
8793

8894
public interface IPython : IDisposable

src/TensorFlowNET.Core/ops.name_scope.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ public void __enter__()
4646
public void Dispose()
4747
{
4848
var g = get_default_graph();
49+
Console.WriteLine($"name_scope: {g._name_stack} -> {old_stack}");
4950
g._name_stack = old_stack;
5051
}
5152

src/TensorFlowNET.Core/ops.py.cs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,5 +309,33 @@ public static Session get_default_session()
309309
return (p1.GetValue(result, null) as Tensor, p2.GetValue(result, null) as Tensor);*/
310310
};
311311
}
312+
313+
public static T[] internal_convert_n_to_tensor<T>(T[] values, DataType dtype = DataType.DtInvalid,
314+
string name = "", DataType preferred_dtype = DataType.DtInvalid,
315+
bool as_ref = false)
316+
{
317+
var ret = new List<T>();
318+
319+
foreach((int i, T value) in Python.enumerate(values))
320+
{
321+
string n = string.IsNullOrEmpty(name) ? "" : $"{name}_{i}";
322+
ret.Add(internal_convert_to_tensor(value, dtype: dtype, name: n, as_ref: as_ref, preferred_dtype: preferred_dtype));
323+
}
324+
325+
return ret.ToArray();
326+
}
327+
328+
public static T internal_convert_to_tensor<T>(T value, DataType dtype = DataType.DtInvalid,
329+
string name = "", DataType preferred_dtype = DataType.DtInvalid,
330+
bool as_ref = false)
331+
{
332+
switch (typeof(T).Name)
333+
{
334+
case "Tensor":
335+
return value;
336+
default:
337+
throw new NotImplementedException("internal_convert_to_tensor");
338+
}
339+
}
312340
}
313341
}

0 commit comments

Comments
 (0)