11/*
2- * Copyright 2020 The TensorFlow Authors . All rights reserved.
2+ * Copyright (c) 2020, Oracle and/or its affiliates . All rights reserved.
33 *
44 * Licensed under the Apache License, Version 2.0 (the "License");
55 * you may not use this file except in compliance with the License.
1818import org .tensorflow .DataType ;
1919import org .tensorflow .Graph ;
2020import org .tensorflow .Operand ;
21+ import org .tensorflow .Output ;
22+ import org .tensorflow .framework .data .impl .MapIterator ;
2123import org .tensorflow .op .Op ;
2224import org .tensorflow .op .Ops ;
2325import org .tensorflow .tools .Shape ;
26+ import org .tensorflow .types .family .TType ;
2427
2528import java .util .ArrayList ;
29+ import java .util .Iterator ;
2630import java .util .List ;
31+ import java .util .function .BiFunction ;
32+ import java .util .stream .Collectors ;
2733
2834/**
2935 * Represents the state of an iteration through a tf.data Datset. DatasetIterator is not a
3440 *
3541 * <pre>{@code
3642 * // Create input tensors
37- * Operand<?> features = tf.constant( ... );
38- * Operand<?> labels = tf.constant( ... );
43+ * Operand<?> XTensor = tf.constant( ... );
44+ * Operand<?> yTensor = tf.constant( ... );
3945 *
4046 *
4147 * Dataset dataset = Dataset
4450 *
4551 * DatasetIterator iterator = dataset.makeInitializeableIterator();
4652 * List<Operand<?>> components = iterator.getNext();
47- * Operand<?> featureBatch = components.get(0);
48- * Operand<?> labelBatch = components.get(1);
53+ * Operand<?> XBatch = components.get(0);
54+ * Operand<?> yBatch = components.get(1);
4955 *
5056 * // Build a TensorFlow graph that does something on each element.
51- * loss = computeModelLoss(featureBatch, labelBatch );
57+ * loss = computeModelLoss(XBatch, yBatch );
5258 *
5359 * optimizer = ... // create an optimizer
5460 * trainOp = optimizer.minimize(loss);
5965 * try {
6066 * session
6167 * .addTarget(trainOp)
62- * .fetch(loss )
68+ * .fetch( ... )
6369 * .run();
6470 *
6571 * ...
7682 *
7783 * <pre>{@code
7884 * // Create input tensors
79- * Operand<?> features = tf.constant( ... );
80- * Operand<?> labels = tf.constant( ... );
85+ * Operand<?> XTensor = tf.constant( ... );
86+ * Operand<?> yTensor = tf.constant( ... );
8187 *
8288 * int BATCH_SIZE = ...
8389 *
8490 * Dataset dataset = Dataset
85- * .fromTensorSlices(features, labels )
91+ * .fromTensorSlices(XTensor, yTensor )
8692 * .batch(BATCH_SIZE);
8793 * DatasetIterator iterator = dataset.makeIterator();
8894 *
8995 * Optimizer optimizer = ... // create an optimizer
9096 *
9197 * for (List<Operand<?>> components : dataset) {
92- * Operand<?> featureBatch = components.get(0);
93- * Operand<?> labelBatch = components.get(1);
98+ * Operand<?> XBatch = components.get(0);
99+ * Operand<?> yBatch = components.get(1);
94100 *
95- * loss = computeModelLoss(featureBatch, labelBatch );
101+ * loss = computeModelLoss(X, y );
96102 * trainOp = optimizer.minimize(loss);
97103 * }
98104 * }</pre>
99105 */
100- public class DatasetIterator {
106+ public class DatasetIterator implements Iterable < List < Operand <?>>> {
101107 public static final String EMPTY_SHARED_NAME = "" ;
102108
103- private Ops tf ;
109+ protected Ops tf ;
104110
105111 private Operand <?> iteratorResource ;
106112 private Op initializer ;
107113
108- private List <DataType <?>> outputTypes ;
109- private List <Shape > outputShapes ;
114+ protected List <DataType <?>> outputTypes ;
115+ protected List <Shape > outputShapes ;
110116
111117 /**
112118 * @param tf Ops accessor corresponding to the same `ExecutionEnvironment` as the
@@ -119,7 +125,7 @@ public class DatasetIterator {
119125 * @param outputShapes A list of `Shape` objects corresponding to the shapes of each componenet of
120126 * a dataset element.
121127 */
122- private DatasetIterator (
128+ protected DatasetIterator (
123129 Ops tf ,
124130 Operand <?> iteratorResource ,
125131 Op initializer ,
@@ -133,7 +139,7 @@ private DatasetIterator(
133139 this .outputShapes = outputShapes ;
134140 }
135141
136- private DatasetIterator (
142+ protected DatasetIterator (
137143 Ops tf ,
138144 Operand <?> iteratorResource ,
139145 List <DataType <?>> outputTypes ,
@@ -144,6 +150,14 @@ private DatasetIterator(
144150 this .outputShapes = outputShapes ;
145151 }
146152
153+ protected DatasetIterator (DatasetIterator other ) {
154+ this .tf = other .tf ;
155+ this .iteratorResource = other .iteratorResource ;
156+ this .initializer = other .initializer ;
157+ this .outputTypes = other .outputTypes ;
158+ this .outputShapes = other .outputShapes ;
159+ }
160+
147161 /**
148162 * Returns a list of `Operand<?>` representing the components of the next dataset element.
149163 *
@@ -159,7 +173,7 @@ private DatasetIterator(
159173 public List <Operand <?>> getNext () {
160174 List <Operand <?>> components = new ArrayList <>();
161175 tf .data
162- .iteratorGetNext (getIteratorResource (), getOutputTypes (), getOutputShapes () )
176+ .iteratorGetNext (getIteratorResource (), outputTypes , outputShapes )
163177 .iterator ()
164178 .forEachRemaining (components ::add );
165179 return components ;
@@ -179,7 +193,7 @@ public List<Operand<?>> getNext() {
179193 public DatasetOptional getNextAsOptional () {
180194 Operand <?> optionalVariant =
181195 tf .data
182- .iteratorGetNextAsOptional (getIteratorResource (), getOutputTypes (), getOutputShapes () )
196+ .iteratorGetNextAsOptional (getIteratorResource (), outputTypes , outputShapes )
183197 .optional ();
184198 return new DatasetOptional (tf , optionalVariant , outputTypes , outputShapes );
185199 }
@@ -205,8 +219,9 @@ public Op makeInitializer(Dataset dataset) {
205219 "Dataset must share the same" + "ExecutionEnvironment as this iterator." );
206220 }
207221
208- if (!dataset .getOutputShapes ().equals (getOutputShapes ())
209- || !dataset .getOutputTypes ().equals (getOutputTypes ())) {
222+ if (!dataset .getOutputShapes ().equals (outputShapes )
223+ || !dataset .getOutputTypes ().equals (outputTypes )) {
224+
210225 throw new IllegalArgumentException (
211226 "Dataset structure (types, " + "output shapes) must match this iterator." );
212227 }
@@ -235,6 +250,72 @@ public static DatasetIterator fromStructure(
235250 return new DatasetIterator (tf , iteratorResource , outputTypes , outputShapes );
236251 }
237252
253+ /**
254+ * Returns a new DatasetIterator which maps a function across all elements from this iterator, on
255+ * a single component of each element.
256+ *
257+ * <p>For example, suppose each element is a `List<Operand<?>>` with 2 components: (features,
258+ * labels).
259+ *
260+ * <p>Calling `iterator.mapOneComponent(0, (tf, features) -> tf.math.mul(features,
261+ * tf.constant(2)))` will map the function over the `features` component of each element,
262+ * multiplying each by 2.
263+ *
264+ * @param index The index of the component to transform.
265+ * @param mapper The function to apply to the target component.
266+ * @return A new DatasetIterator applying `mapper` to the component at the chosen index.
267+ */
268+ public DatasetIterator mapOneComponent (
269+ int index , BiFunction <Ops , Operand <?>, Operand <?>> mapper ) {
270+ return map (
271+ (tf , outputs ) -> {
272+ List <Operand <?>> newComponents = new ArrayList <>(outputs );
273+ newComponents .set (index , mapper .apply (tf , outputs .get (index )));
274+ return newComponents ;
275+ });
276+ }
277+
278+ /**
279+ * Returns a new DatasetIterator which maps a function across all elements from this iterator, on
280+ * all components of each element.
281+ *
282+ * <p>For example, suppose each element is a `List<Operand<?>>` with 2 components: (features,
283+ * labels).
284+ *
285+ * <p>Calling `iterator.mapAllComponents((tf, component) -> tf.math.mul(component,
286+ * tf.constant(2)))` will map the function over the both the `features` and `labels` components of
287+ * each element, multiplying them all by 2
288+ *
289+ * @param mapper The function to apply to each component
290+ * @return A new DatasetIterator applying `mapper` to all components of each element.
291+ */
292+ public DatasetIterator mapAllComponents (BiFunction <Ops , Operand <?>, Operand <?>> mapper ) {
293+ return map (
294+ (tf , outputs ) ->
295+ outputs .stream ().map (op -> mapper .apply (tf , op )).collect (Collectors .toList ()));
296+ }
297+
298+ /**
299+ * Returns a new DatasetIterator which maps a function over all elements returned by this
300+ * iterator.
301+ *
302+ * <p>For example, suppose each element is a `List<Operand<?>>` with 2 components: (features,
303+ * labels).
304+ *
305+ * <p>Calling ``` iterator.map((tf, components) -> { Operand<?> features = components.get(0);
306+ * Operand<?> labels = components.get(1);
307+ *
308+ * <p>return Arrays.asList( tf.math.mul(features, tf.constant(2)), tf.math.mul(labels,
309+ * tf.constant(5)) ); }) ``` will map the function over the `features` and `labels` components,
310+ * multiplying features by 2, and multiplying the labels by 5.
311+ *
312+ * @param mapper The function to apply to each element of this iterator.
313+ * @return A new DatasetIterator applying `mapper` to each element of this iterator.
314+ */
315+ public DatasetIterator map (BiFunction <Ops , List <Operand <?>>, List <Operand <?>>> mapper ) {
316+ return new MapIterator (this , mapper );
317+ }
318+
238319 public Operand <?> getIteratorResource () {
239320 return iteratorResource ;
240321 }
@@ -243,11 +324,34 @@ public Op getInitializer() {
243324 return initializer ;
244325 }
245326
246- public List < DataType <?>> getOutputTypes () {
247- return outputTypes ;
327+ public Ops getOpsInstance () {
328+ return tf ;
248329 }
249330
250- public List <Shape > getOutputShapes () {
251- return outputShapes ;
331+ @ Override
332+ public Iterator <List <Operand <?>>> iterator () {
333+
334+ if (!tf .scope ().env ().isEager ()) {
335+ throw new UnsupportedOperationException (
336+ "Cannot use foreach iteration through a dataset in graph mode." );
337+ }
338+
339+ DatasetIterator iterator = this ;
340+
341+ return new Iterator <List <Operand <?>>>() {
342+ private DatasetOptional nextOptional = iterator .getNextAsOptional ();
343+
344+ @ Override
345+ public boolean hasNext () {
346+ return nextOptional .hasValue ().data ().getBoolean ();
347+ }
348+
349+ @ Override
350+ public List <Operand <?>> next () {
351+ List <Operand <?>> result = nextOptional .getValue ();
352+ nextOptional = iterator .getNextAsOptional ();
353+ return result ;
354+ }
355+ };
252356 }
253357}
0 commit comments