Skip to content

Commit 8b0e5cf

Browse files
committed
added RandomShuffleQueue and docs updated.
1 parent a596dbe commit 8b0e5cf

9 files changed

Lines changed: 202 additions & 7 deletions

File tree

docs/source/Queue.md

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,32 @@ Creates a queue that dequeues elements in a first-in first-out order. A `FIFOQue
5858

5959
A FIFOQueue that supports batching variable-sized tensors by padding. A `PaddingFIFOQueue` may contain components with dynamic shape, while also supporting `dequeue_many`. A `PaddingFIFOQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `dtypes`, and whose shapes are described by the `shapes` argument.
6060

61+
```chsarp
62+
[TestMethod]
63+
public void PaddingFIFOQueue()
64+
{
65+
var numbers = tf.placeholder(tf.int32);
66+
var queue = tf.PaddingFIFOQueue(10, tf.int32, new TensorShape(-1));
67+
var enqueue = queue.enqueue(numbers);
68+
var dequeue_many = queue.dequeue_many(n: 3);
69+
70+
using(var sess = tf.Session())
71+
{
72+
sess.run(enqueue, (numbers, new[] { 1 }));
73+
sess.run(enqueue, (numbers, new[] { 2, 3 }));
74+
sess.run(enqueue, (numbers, new[] { 3, 4, 5 }));
75+
76+
var result = sess.run(dequeue_many[0]);
77+
78+
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0 }, result[0].ToArray<int>()));
79+
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 3, 0 }, result[1].ToArray<int>()));
80+
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 4, 5 }, result[2].ToArray<int>()));
81+
}
82+
}
83+
```
84+
85+
86+
6187
#### PriorityQueue
6288

6389
A queue implementation that dequeues elements in prioritized order. A `PriorityQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `PriorityQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `types`, and whose shapes are optionally described by the `shapes` argument.
@@ -93,6 +119,28 @@ public void PriorityQueue()
93119

94120
A queue implementation that dequeues elements in a random order. A `RandomShuffleQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `RandomShuffleQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `dtypes`, and whose shapes are optionally described by the `shapes` argument.
95121

122+
```csharp
123+
[TestMethod]
124+
public void RandomShuffleQueue()
125+
{
126+
var queue = tf.RandomShuffleQueue(10, min_after_dequeue: 1, dtype: tf.int32);
127+
var init = queue.enqueue_many(new[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 });
128+
var x = queue.dequeue();
129+
130+
string results = "";
131+
using (var sess = tf.Session())
132+
{
133+
init.run();
134+
135+
foreach(var i in range(9))
136+
results += (int)sess.run(x) + ".";
137+
138+
// output in random order
139+
// 1.2.3.4.5.6.7.8.9.
140+
}
141+
}
142+
```
143+
96144

97145

98146
Queue methods must run on the same device as the queue. `FIFOQueue` and `RandomShuffleQueue` are important TensorFlow objects for computing tensor asynchronously in a graph. For example, a typical input architecture is to use a `RandomShuffleQueue` to prepare inputs for training a model:

src/TensorFlowNET.Core/APIs/tf.queue.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,5 +108,20 @@ public PriorityQueue PriorityQueue(int capacity,
108108
new[] { shape ?? new TensorShape() },
109109
shared_name: shared_name,
110110
name: name);
111+
112+
public RandomShuffleQueue RandomShuffleQueue(int capacity,
113+
int min_after_dequeue,
114+
TF_DataType dtype,
115+
TensorShape shape = null,
116+
int? seed = null,
117+
string shared_name = null,
118+
string name = "random_shuffle_queue")
119+
=> new RandomShuffleQueue(capacity,
120+
min_after_dequeue: min_after_dequeue,
121+
new[] { dtype },
122+
new[] { shape ?? new TensorShape() },
123+
seed: seed,
124+
shared_name: shared_name,
125+
name: name);
111126
}
112127
}

src/TensorFlowNET.Core/Operations/Queues/FIFOQueue.cs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,20 @@
1-
using System;
1+
/*****************************************************************************
2+
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using System;
218
using System.Collections.Generic;
319
using System.Linq;
420
using System.Text;

src/TensorFlowNET.Core/Operations/Queues/PaddingFIFOQueue.cs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,20 @@
1-
using System;
1+
/*****************************************************************************
2+
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using System;
218
using System.Collections.Generic;
319
using System.Linq;
420
using System.Text;

src/TensorFlowNET.Core/Operations/Queues/PriorityQueue.cs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,20 @@
1-
using System;
1+
/*****************************************************************************
2+
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using System;
218
using System.Collections.Generic;
319
using System.Linq;
420
using System.Text;

src/TensorFlowNET.Core/Operations/Queues/QueueBase.cs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,20 @@
1-
using System;
1+
/*****************************************************************************
2+
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using System;
218
using System.Collections.Generic;
319
using System.Linq;
420
using System.Text;

src/TensorFlowNET.Core/Operations/Queues/RandomShuffleQueue.cs

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,53 @@
1-
using System;
1+
/*****************************************************************************
2+
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using System;
218
using System.Collections.Generic;
319
using System.Linq;
420
using System.Text;
521

622
namespace Tensorflow.Queues
723
{
24+
/// <summary>
25+
/// Create a queue that dequeues elements in a random order.
26+
/// </summary>
827
public class RandomShuffleQueue : QueueBase
928
{
1029
public RandomShuffleQueue(int capacity,
30+
int min_after_dequeue,
1131
TF_DataType[] dtypes,
1232
TensorShape[] shapes,
1333
string[] names = null,
34+
int? seed = null,
1435
string shared_name = null,
15-
string name = "randomshuffle_fifo_queue")
36+
string name = "random_shuffle_queue")
1637
: base(dtypes: dtypes, shapes: shapes, names: names)
1738
{
18-
_queue_ref = gen_data_flow_ops.padding_fifo_queue_v2(
39+
var(seed1, seed2) = random_seed.get_seed(seed);
40+
if (!seed1.HasValue && !seed2.HasValue)
41+
(seed1, seed2) = (0, 0);
42+
43+
44+
_queue_ref = gen_data_flow_ops.random_shuffle_queue_v2(
1945
component_types: dtypes,
2046
shapes: shapes,
2147
capacity: capacity,
48+
min_after_dequeue: min_after_dequeue,
49+
seed: seed1.Value,
50+
seed2: seed2.Value,
2251
shared_name: shared_name,
2352
name: name);
2453

src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,25 @@ public static Tensor priority_queue_v2(TF_DataType[] component_types, TensorShap
9393
return _op.output;
9494
}
9595

96+
public static Tensor random_shuffle_queue_v2(TF_DataType[] component_types, TensorShape[] shapes,
97+
int capacity = -1, int min_after_dequeue = 0, int seed = 0, int seed2 = 0,
98+
string container = "", string shared_name = "", string name = null)
99+
{
100+
var _op = _op_def_lib._apply_op_helper("RandomShuffleQueueV2", name, new
101+
{
102+
component_types,
103+
shapes,
104+
capacity,
105+
min_after_dequeue,
106+
seed,
107+
seed2,
108+
container,
109+
shared_name
110+
});
111+
112+
return _op.output;
113+
}
114+
96115
public static Operation queue_enqueue(Tensor handle, Tensor[] components, int timeout_ms = -1, string name = null)
97116
{
98117
var _op = _op_def_lib._apply_op_helper("QueueEnqueue", name, new

test/TensorFlowNET.UnitTest/QueueTest.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,5 +92,25 @@ public void PriorityQueue()
9292
Assert.AreEqual(result[0].GetInt64(), 4L);
9393
}
9494
}
95+
96+
[TestMethod]
97+
public void RandomShuffleQueue()
98+
{
99+
var queue = tf.RandomShuffleQueue(10, min_after_dequeue: 1, dtype: tf.int32);
100+
var init = queue.enqueue_many(new[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 });
101+
var x = queue.dequeue();
102+
103+
string results = "";
104+
using (var sess = tf.Session())
105+
{
106+
init.run();
107+
108+
foreach(var i in range(9))
109+
results += (int)sess.run(x) + ".";
110+
111+
// output in random order
112+
Assert.IsFalse(results == "1.2.3.4.5.6.7.8.9.");
113+
}
114+
}
95115
}
96116
}

0 commit comments

Comments
 (0)