Skip to content

Commit 9c86cbb

Browse files
dhruvrajankarllessard
authored andcommitted
Add functionality to map over dataset elements in both graph and eager mode.
1 parent eeb6778 commit 9c86cbb

14 files changed

Lines changed: 434 additions & 64 deletions

File tree

tensorflow-framework/src/main/java/org/tensorflow/framework/data/Dataset.java

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2020 The TensorFlow Authors. All rights reserved.
2+
* Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -17,10 +17,8 @@
1717

1818
import org.tensorflow.DataType;
1919
import org.tensorflow.Operand;
20-
import org.tensorflow.framework.data.impl.BatchDataset;
21-
import org.tensorflow.framework.data.impl.SkipDataset;
22-
import org.tensorflow.framework.data.impl.TakeDataset;
23-
import org.tensorflow.framework.data.impl.TensorSliceDataset;
20+
import org.tensorflow.Output;
21+
import org.tensorflow.framework.data.impl.*;
2422
import org.tensorflow.op.Op;
2523
import org.tensorflow.op.Ops;
2624
import org.tensorflow.tools.Shape;
@@ -118,29 +116,7 @@ public final Dataset take(long count) {
118116
*/
119117
@Override
120118
public Iterator<List<Operand<?>>> iterator() {
121-
122-
if (!tf.scope().env().isEager()) {
123-
throw new UnsupportedOperationException(
124-
"Cannot iterate through a " + "dataset in graph mode.");
125-
}
126-
127-
DatasetIterator iterator = makeOneShotIterator();
128-
129-
return new Iterator<List<Operand<?>>>() {
130-
private DatasetOptional nextOptional = iterator.getNextAsOptional();
131-
132-
@Override
133-
public boolean hasNext() {
134-
return nextOptional.hasValue().data().getBoolean();
135-
}
136-
137-
@Override
138-
public List<Operand<?>> next() {
139-
List<Operand<?>> result = nextOptional.getValue();
140-
nextOptional = iterator.getNextAsOptional();
141-
return result;
142-
}
143-
};
119+
return makeOneShotIterator().iterator();
144120
}
145121

146122
/**
@@ -152,7 +128,9 @@ public List<Operand<?>> next() {
152128
* @return A new `DatasetIterator` based on this dataset's structure.
153129
*/
154130
public DatasetIterator makeInitializeableIterator() {
155-
return DatasetIterator.fromStructure(tf, outputTypes, outputShapes);
131+
DatasetIterator iterator = DatasetIterator.fromStructure(tf, outputTypes, outputShapes);
132+
iterator.makeInitializer(this);
133+
return iterator;
156134
}
157135

158136
/**
@@ -198,6 +176,18 @@ public static Dataset fromTensorSlices(
198176
return new TensorSliceDataset(tf, tensors, outputTypes);
199177
}
200178

179+
public static Dataset tfRecordDataset(
180+
Ops tf, String filename, String compressionType, long bufferSize) {
181+
return new TFRecordDataset(
182+
tf, tf.constant(filename), tf.constant(compressionType), tf.constant(bufferSize));
183+
}
184+
185+
public static Dataset textLineDataset(
186+
Ops tf, String filename, String compressionType, long bufferSize) {
187+
return new TextLineDataset(
188+
tf, tf.constant(filename), tf.constant(compressionType), tf.constant(bufferSize));
189+
}
190+
201191
/** Get the variant tensor representing this dataset. */
202192
public abstract Operand<?> getVariant();
203193

@@ -210,4 +200,8 @@ public List<DataType<?>> getOutputTypes() {
210200
public List<Shape> getOutputShapes() {
211201
return this.outputShapes;
212202
}
203+
204+
public Ops getOpsInstance() {
205+
return this.tf;
206+
}
213207
}

tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetIterator.java

Lines changed: 131 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2020 The TensorFlow Authors. All rights reserved.
2+
* Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -18,12 +18,18 @@
1818
import org.tensorflow.DataType;
1919
import org.tensorflow.Graph;
2020
import org.tensorflow.Operand;
21+
import org.tensorflow.Output;
22+
import org.tensorflow.framework.data.impl.MapIterator;
2123
import org.tensorflow.op.Op;
2224
import org.tensorflow.op.Ops;
2325
import org.tensorflow.tools.Shape;
26+
import org.tensorflow.types.family.TType;
2427

2528
import java.util.ArrayList;
29+
import java.util.Iterator;
2630
import java.util.List;
31+
import java.util.function.BiFunction;
32+
import java.util.stream.Collectors;
2733

2834
/**
2935
* Represents the state of an iteration through a tf.data Datset. DatasetIterator is not a
@@ -34,8 +40,8 @@
3440
*
3541
* <pre>{@code
3642
* // Create input tensors
37-
* Operand<?> features = tf.constant( ... );
38-
* Operand<?> labels = tf.constant( ... );
43+
* Operand<?> XTensor = tf.constant( ... );
44+
* Operand<?> yTensor = tf.constant( ... );
3945
*
4046
*
4147
* Dataset dataset = Dataset
@@ -44,11 +50,11 @@
4450
*
4551
* DatasetIterator iterator = dataset.makeInitializeableIterator();
4652
* List<Operand<?>> components = iterator.getNext();
47-
* Operand<?> featureBatch = components.get(0);
48-
* Operand<?> labelBatch = components.get(1);
53+
* Operand<?> XBatch = components.get(0);
54+
* Operand<?> yBatch = components.get(1);
4955
*
5056
* // Build a TensorFlow graph that does something on each element.
51-
* loss = computeModelLoss(featureBatch, labelBatch);
57+
* loss = computeModelLoss(XBatch, yBatch);
5258
*
5359
* optimizer = ... // create an optimizer
5460
* trainOp = optimizer.minimize(loss);
@@ -59,7 +65,7 @@
5965
* try {
6066
* session
6167
* .addTarget(trainOp)
62-
* .fetch(loss)
68+
* .fetch( ... )
6369
* .run();
6470
*
6571
* ...
@@ -76,37 +82,37 @@
7682
*
7783
* <pre>{@code
7884
* // Create input tensors
79-
* Operand<?> features = tf.constant( ... );
80-
* Operand<?> labels = tf.constant( ... );
85+
* Operand<?> XTensor = tf.constant( ... );
86+
* Operand<?> yTensor = tf.constant( ... );
8187
*
8288
* int BATCH_SIZE = ...
8389
*
8490
* Dataset dataset = Dataset
85-
* .fromTensorSlices(features, labels)
91+
* .fromTensorSlices(XTensor, yTensor)
8692
* .batch(BATCH_SIZE);
8793
* DatasetIterator iterator = dataset.makeIterator();
8894
*
8995
* Optimizer optimizer = ... // create an optimizer
9096
*
9197
* for (List<Operand<?>> components : dataset) {
92-
* Operand<?> featureBatch = components.get(0);
93-
* Operand<?> labelBatch = components.get(1);
98+
* Operand<?> XBatch = components.get(0);
99+
* Operand<?> yBatch = components.get(1);
94100
*
95-
* loss = computeModelLoss(featureBatch, labelBatch);
101+
* loss = computeModelLoss(X, y);
96102
* trainOp = optimizer.minimize(loss);
97103
* }
98104
* }</pre>
99105
*/
100-
public class DatasetIterator {
106+
public class DatasetIterator implements Iterable<List<Operand<?>>> {
101107
public static final String EMPTY_SHARED_NAME = "";
102108

103-
private Ops tf;
109+
protected Ops tf;
104110

105111
private Operand<?> iteratorResource;
106112
private Op initializer;
107113

108-
private List<DataType<?>> outputTypes;
109-
private List<Shape> outputShapes;
114+
protected List<DataType<?>> outputTypes;
115+
protected List<Shape> outputShapes;
110116

111117
/**
112118
* @param tf Ops accessor corresponding to the same `ExecutionEnvironment` as the
@@ -119,7 +125,7 @@ public class DatasetIterator {
119125
* @param outputShapes A list of `Shape` objects corresponding to the shapes of each componenet of
120126
* a dataset element.
121127
*/
122-
private DatasetIterator(
128+
protected DatasetIterator(
123129
Ops tf,
124130
Operand<?> iteratorResource,
125131
Op initializer,
@@ -133,7 +139,7 @@ private DatasetIterator(
133139
this.outputShapes = outputShapes;
134140
}
135141

136-
private DatasetIterator(
142+
protected DatasetIterator(
137143
Ops tf,
138144
Operand<?> iteratorResource,
139145
List<DataType<?>> outputTypes,
@@ -144,6 +150,14 @@ private DatasetIterator(
144150
this.outputShapes = outputShapes;
145151
}
146152

153+
protected DatasetIterator(DatasetIterator other) {
154+
this.tf = other.tf;
155+
this.iteratorResource = other.iteratorResource;
156+
this.initializer = other.initializer;
157+
this.outputTypes = other.outputTypes;
158+
this.outputShapes = other.outputShapes;
159+
}
160+
147161
/**
148162
* Returns a list of `Operand<?>` representing the components of the next dataset element.
149163
*
@@ -159,7 +173,7 @@ private DatasetIterator(
159173
public List<Operand<?>> getNext() {
160174
List<Operand<?>> components = new ArrayList<>();
161175
tf.data
162-
.iteratorGetNext(getIteratorResource(), getOutputTypes(), getOutputShapes())
176+
.iteratorGetNext(getIteratorResource(), outputTypes, outputShapes)
163177
.iterator()
164178
.forEachRemaining(components::add);
165179
return components;
@@ -179,7 +193,7 @@ public List<Operand<?>> getNext() {
179193
public DatasetOptional getNextAsOptional() {
180194
Operand<?> optionalVariant =
181195
tf.data
182-
.iteratorGetNextAsOptional(getIteratorResource(), getOutputTypes(), getOutputShapes())
196+
.iteratorGetNextAsOptional(getIteratorResource(), outputTypes, outputShapes)
183197
.optional();
184198
return new DatasetOptional(tf, optionalVariant, outputTypes, outputShapes);
185199
}
@@ -205,8 +219,9 @@ public Op makeInitializer(Dataset dataset) {
205219
"Dataset must share the same" + "ExecutionEnvironment as this iterator.");
206220
}
207221

208-
if (!dataset.getOutputShapes().equals(getOutputShapes())
209-
|| !dataset.getOutputTypes().equals(getOutputTypes())) {
222+
if (!dataset.getOutputShapes().equals(outputShapes)
223+
|| !dataset.getOutputTypes().equals(outputTypes)) {
224+
210225
throw new IllegalArgumentException(
211226
"Dataset structure (types, " + "output shapes) must match this iterator.");
212227
}
@@ -235,6 +250,72 @@ public static DatasetIterator fromStructure(
235250
return new DatasetIterator(tf, iteratorResource, outputTypes, outputShapes);
236251
}
237252

253+
/**
254+
* Returns a new DatasetIterator which maps a function across all elements from this iterator, on
255+
* a single component of each element.
256+
*
257+
* <p>For example, suppose each element is a `List<Operand<?>>` with 2 components: (features,
258+
* labels).
259+
*
260+
* <p>Calling `iterator.mapOneComponent(0, (tf, features) -> tf.math.mul(features,
261+
* tf.constant(2)))` will map the function over the `features` component of each element,
262+
* multiplying each by 2.
263+
*
264+
* @param index The index of the component to transform.
265+
* @param mapper The function to apply to the target component.
266+
* @return A new DatasetIterator applying `mapper` to the component at the chosen index.
267+
*/
268+
public DatasetIterator mapOneComponent(
269+
int index, BiFunction<Ops, Operand<?>, Operand<?>> mapper) {
270+
return map(
271+
(tf, outputs) -> {
272+
List<Operand<?>> newComponents = new ArrayList<>(outputs);
273+
newComponents.set(index, mapper.apply(tf, outputs.get(index)));
274+
return newComponents;
275+
});
276+
}
277+
278+
/**
279+
* Returns a new DatasetIterator which maps a function across all elements from this iterator, on
280+
* all components of each element.
281+
*
282+
* <p>For example, suppose each element is a `List<Operand<?>>` with 2 components: (features,
283+
* labels).
284+
*
285+
* <p>Calling `iterator.mapAllComponents((tf, component) -> tf.math.mul(component,
286+
* tf.constant(2)))` will map the function over the both the `features` and `labels` components of
287+
* each element, multiplying them all by 2
288+
*
289+
* @param mapper The function to apply to each component
290+
* @return A new DatasetIterator applying `mapper` to all components of each element.
291+
*/
292+
public DatasetIterator mapAllComponents(BiFunction<Ops, Operand<?>, Operand<?>> mapper) {
293+
return map(
294+
(tf, outputs) ->
295+
outputs.stream().map(op -> mapper.apply(tf, op)).collect(Collectors.toList()));
296+
}
297+
298+
/**
299+
* Returns a new DatasetIterator which maps a function over all elements returned by this
300+
* iterator.
301+
*
302+
* <p>For example, suppose each element is a `List<Operand<?>>` with 2 components: (features,
303+
* labels).
304+
*
305+
* <p>Calling ``` iterator.map((tf, components) -> { Operand<?> features = components.get(0);
306+
* Operand<?> labels = components.get(1);
307+
*
308+
* <p>return Arrays.asList( tf.math.mul(features, tf.constant(2)), tf.math.mul(labels,
309+
* tf.constant(5)) ); }) ``` will map the function over the `features` and `labels` components,
310+
* multiplying features by 2, and multiplying the labels by 5.
311+
*
312+
* @param mapper The function to apply to each element of this iterator.
313+
* @return A new DatasetIterator applying `mapper` to each element of this iterator.
314+
*/
315+
public DatasetIterator map(BiFunction<Ops, List<Operand<?>>, List<Operand<?>>> mapper) {
316+
return new MapIterator(this, mapper);
317+
}
318+
238319
public Operand<?> getIteratorResource() {
239320
return iteratorResource;
240321
}
@@ -243,11 +324,34 @@ public Op getInitializer() {
243324
return initializer;
244325
}
245326

246-
public List<DataType<?>> getOutputTypes() {
247-
return outputTypes;
327+
public Ops getOpsInstance() {
328+
return tf;
248329
}
249330

250-
public List<Shape> getOutputShapes() {
251-
return outputShapes;
331+
@Override
332+
public Iterator<List<Operand<?>>> iterator() {
333+
334+
if (!tf.scope().env().isEager()) {
335+
throw new UnsupportedOperationException(
336+
"Cannot use foreach iteration through a dataset in graph mode.");
337+
}
338+
339+
DatasetIterator iterator = this;
340+
341+
return new Iterator<List<Operand<?>>>() {
342+
private DatasetOptional nextOptional = iterator.getNextAsOptional();
343+
344+
@Override
345+
public boolean hasNext() {
346+
return nextOptional.hasValue().data().getBoolean();
347+
}
348+
349+
@Override
350+
public List<Operand<?>> next() {
351+
List<Operand<?>> result = nextOptional.getValue();
352+
nextOptional = iterator.getNextAsOptional();
353+
return result;
354+
}
355+
};
252356
}
253357
}

0 commit comments

Comments
 (0)