Skip to content

Commit 0ece291

Browse files
committed
tf.PaddingFIFOQueue SciSharp#396
1 parent c2138b2 commit 0ece291

6 files changed

Lines changed: 244 additions & 4 deletions

File tree

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*****************************************************************************
2+
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using System;
18+
using Tensorflow.Queues;
19+
20+
namespace Tensorflow
21+
{
22+
public partial class tensorflow
23+
{
24+
/// <summary>
25+
/// A FIFOQueue that supports batching variable-sized tensors by padding.
26+
/// </summary>
27+
/// <param name="capacity"></param>
28+
/// <param name="dtypes"></param>
29+
/// <param name="shapes"></param>
30+
/// <param name="names"></param>
31+
/// <param name="shared_name"></param>
32+
/// <param name="name"></param>
33+
/// <returns></returns>
34+
public PaddingFIFOQueue PaddingFIFOQueue(int capacity,
35+
TF_DataType[] dtypes,
36+
TensorShape[] shapes,
37+
string[] names = null,
38+
string shared_name = null,
39+
string name = "padding_fifo_queue")
40+
=> new PaddingFIFOQueue(capacity,
41+
dtypes,
42+
shapes,
43+
names,
44+
shared_name: shared_name,
45+
name: name);
46+
}
47+
}

src/TensorFlowNET.Core/Operations/OpDefLibrary.cs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
using System.Linq;
2020
using static Tensorflow.OpDef.Types;
2121
using static Tensorflow.Binding;
22+
using Google.Protobuf;
2223

2324
namespace Tensorflow
2425
{
@@ -194,7 +195,9 @@ public Operation _apply_op_helper(string op_type_name, string name = null, Dicti
194195
if (attrs.ContainsKey(key))
195196
{
196197
attr_protos[key] = SetAttrValue(op_def, attr_def, attrs[key]);
197-
} else {
198+
}
199+
else
200+
{
198201
if (attr_def.DefaultValue == null)
199202
{
200203
throw new TypeError("Missing required positional argument " + key);
@@ -311,6 +314,16 @@ private void SetAttrs(string op_type_name,
311314
input_types.AddRange(base_types);
312315
}
313316

317+
public ByteString _MakeStr(string value, AttrDef attr_def)
318+
{
319+
return ByteString.CopyFromUtf8(value ?? string.Empty);
320+
}
321+
322+
public TensorShapeProto _MakeShape(TensorShape shape, AttrDef attr_def)
323+
{
324+
return shape.as_proto();
325+
}
326+
314327
public DataType _MakeType(TF_DataType v, AttrDef attr_def)
315328
{
316329
return v.as_base_dtype().as_datatype_enum();
@@ -330,7 +343,7 @@ private AttrValue SetAttrValue(OpDef op_def, AttrDef attr_def, object value)
330343
switch (attr_def.Type)
331344
{
332345
case "string":
333-
attr_value.S = Google.Protobuf.ByteString.CopyFromUtf8((string)value);
346+
attr_value.S = _MakeStr((string)value, attr_def);
334347
break;
335348
case "type":
336349
attr_value.Type = _MakeType((TF_DataType)value, attr_def);
@@ -363,6 +376,9 @@ private AttrValue SetAttrValue(OpDef op_def, AttrDef attr_def, object value)
363376
else if (value is int[] val3)
364377
attr_value.Shape = tensor_util.as_shape(val3);
365378

379+
break;
380+
case "list(shape)":
381+
attr_value.List.Shape.AddRange((value as TensorShape[]).Select(x => _MakeShape(x, attr_def)));
366382
break;
367383
default:
368384
throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos.");
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using Tensorflow.Framework;
6+
using static Tensorflow.Binding;
7+
8+
namespace Tensorflow.Queues
9+
{
10+
/// <summary>
11+
/// A FIFOQueue that supports batching variable-sized tensors by padding.
12+
/// </summary>
13+
public class PaddingFIFOQueue : QueueBase
14+
{
15+
public PaddingFIFOQueue(int capacity,
16+
TF_DataType[] dtypes,
17+
TensorShape[] shapes,
18+
string[] names = null,
19+
string shared_name = null,
20+
string name = "padding_fifo_queue")
21+
: base(dtypes: dtypes, shapes: shapes, names: names)
22+
{
23+
_queue_ref = gen_data_flow_ops.padding_fifo_queue_v2(
24+
component_types: dtypes,
25+
shapes: shapes,
26+
capacity: capacity,
27+
shared_name: shared_name,
28+
name: name);
29+
30+
_name = _queue_ref.op.name.Split('/').Last();
31+
}
32+
}
33+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using static Tensorflow.Binding;
6+
7+
namespace Tensorflow.Queues
8+
{
9+
public class QueueBase
10+
{
11+
protected TF_DataType[] _dtypes;
12+
protected TensorShape[] _shapes;
13+
protected string[] _names;
14+
protected Tensor _queue_ref;
15+
protected string _name;
16+
17+
public QueueBase(TF_DataType[] dtypes, TensorShape[] shapes, string[] names)
18+
{
19+
_dtypes = dtypes;
20+
_shapes = shapes;
21+
_names = names;
22+
}
23+
24+
public Operation enqueue(Tensor val, string name = null)
25+
{
26+
return tf_with(ops.name_scope(name, $"{_name}_enqueue", val), scope =>
27+
{
28+
var vals = new[] { val };
29+
if (_queue_ref.dtype == TF_DataType.TF_RESOURCE)
30+
return gen_data_flow_ops.queue_enqueue_v2(_queue_ref, vals, name: scope);
31+
else
32+
return gen_data_flow_ops.queue_enqueue(_queue_ref, vals, name: scope);
33+
});
34+
}
35+
36+
public Tensor[] dequeue_many(int n, string name = null)
37+
{
38+
if (name == null)
39+
name = $"{_name}_DequeueMany";
40+
41+
var ret = gen_data_flow_ops.queue_dequeue_many_v2(_queue_ref, n: n, component_types: _dtypes, name: name);
42+
//var op = ret[0].op;
43+
//var cv = tensor_util.constant_value(op.inputs[1]);
44+
//var batch_dim = new Dimension(cv);
45+
46+
return _dequeue_return_value(ret);
47+
}
48+
49+
public Tensor[] _dequeue_return_value(Tensor[] tensors)
50+
{
51+
if (_names != null)
52+
throw new NotImplementedException("");
53+
return tensors;
54+
}
55+
}
56+
}

src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,9 @@ public class gen_data_flow_ops
2222

2323
public static Tensor dynamic_stitch(Tensor[] indices, Tensor[] data, string name = null)
2424
{
25-
var _attr_N = indices.Length;
2625
var _op = _op_def_lib._apply_op_helper("DynamicStitch", name, new { indices, data });
2726

28-
return _op.outputs[0];
27+
return _op.output;
2928
}
3029

3130
public static (Tensor, Tensor) tensor_array_v3(Tensor size, TF_DataType dtype = TF_DataType.DtInvalid,
@@ -45,5 +44,58 @@ public static (Tensor, Tensor) tensor_array_v3(Tensor size, TF_DataType dtype =
4544

4645
return (null, null);
4746
}
47+
48+
public static Tensor padding_fifo_queue_v2(TF_DataType[] component_types, TensorShape[] shapes,
49+
int capacity = -1, string container = "", string shared_name = "",
50+
string name = null)
51+
{
52+
var _op = _op_def_lib._apply_op_helper("PaddingFIFOQueueV2", name, new
53+
{
54+
component_types,
55+
shapes,
56+
capacity,
57+
container,
58+
shared_name
59+
});
60+
61+
return _op.output;
62+
}
63+
64+
public static Operation queue_enqueue(Tensor handle, Tensor[] components, int timeout_ms = -1, string name = null)
65+
{
66+
var _op = _op_def_lib._apply_op_helper("QueueEnqueue", name, new
67+
{
68+
handle,
69+
components,
70+
timeout_ms
71+
});
72+
73+
return _op;
74+
}
75+
76+
public static Operation queue_enqueue_v2(Tensor handle, Tensor[] components, int timeout_ms = -1, string name = null)
77+
{
78+
var _op = _op_def_lib._apply_op_helper("QueueEnqueueV2", name, new
79+
{
80+
handle,
81+
components,
82+
timeout_ms
83+
});
84+
85+
return _op;
86+
}
87+
88+
public static Tensor[] queue_dequeue_many_v2(Tensor handle, int n, TF_DataType[] component_types, int timeout_ms = -1, string name = null)
89+
{
90+
var _op = _op_def_lib._apply_op_helper("QueueDequeueManyV2", name, new
91+
{
92+
handle,
93+
n,
94+
component_types,
95+
timeout_ms
96+
});
97+
98+
return _op.outputs;
99+
}
48100
}
49101
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Linq;
5+
using System.Text;
6+
using Tensorflow;
7+
using static Tensorflow.Binding;
8+
9+
namespace TensorFlowNET.UnitTest
10+
{
11+
[TestClass]
12+
public class QueueTest
13+
{
14+
[TestMethod]
15+
public void PaddingFIFOQueue()
16+
{
17+
var numbers = tf.placeholder(tf.int32);
18+
var queue = tf.PaddingFIFOQueue(capacity: 10, dtypes: new[] { tf.int32 }, shapes: new[] { new TensorShape(-1) });
19+
var enqueue = queue.enqueue(numbers);
20+
var dequeue_many = queue.dequeue_many(n: 3);
21+
22+
using(var sess = tf.Session())
23+
{
24+
sess.run(enqueue, (numbers, new[] { 1 }));
25+
sess.run(enqueue, (numbers, new[] { 2, 3 }));
26+
sess.run(enqueue, (numbers, new[] { 3, 4, 5 }));
27+
28+
var result = sess.run(dequeue_many[0]);
29+
30+
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0 }, result[0].ToArray<int>()));
31+
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 3, 0 }, result[1].ToArray<int>()));
32+
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 4, 5 }, result[2].ToArray<int>()));
33+
}
34+
}
35+
}
36+
}

0 commit comments

Comments
 (0)