Skip to content

Commit 3fc14e4

Browse files
committed
Add JavaDoc to Dataset classes, fix existing Java Doc
Fix Javadoc in Constraint
1 parent 11748ae commit 3fc14e4

13 files changed

Lines changed: 344 additions & 108 deletions

File tree

tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ public Constraint(Ops tf) {
4242
*
4343
* @param weights the weights
4444
* @return the constrained weights
45+
* @param <T> the date type for the weights and return value
4546
*/
4647
public abstract <T extends TNumber> Operand<T> call(Operand<T> weights);
4748

tensorflow-framework/src/main/java/org/tensorflow/framework/data/Dataset.java

Lines changed: 98 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -23,29 +23,42 @@
2323
import org.tensorflow.framework.data.impl.TakeDataset;
2424
import org.tensorflow.framework.data.impl.TensorSliceDataset;
2525
import org.tensorflow.framework.data.impl.TextLineDataset;
26+
import org.tensorflow.ndarray.Shape;
2627
import org.tensorflow.op.Op;
2728
import org.tensorflow.op.Ops;
28-
import org.tensorflow.ndarray.Shape;
29+
import org.tensorflow.types.family.TType;
2930

3031
import java.util.ArrayList;
3132
import java.util.Arrays;
3233
import java.util.Iterator;
3334
import java.util.List;
3435
import java.util.function.Function;
35-
import org.tensorflow.types.family.TType;
3636

3737
/**
3838
* Represents a potentially large list of independent elements (samples), and allows iteration and
3939
* transformations to be performed across these elements.
4040
*/
4141
public abstract class Dataset implements Iterable<List<Operand<?>>> {
4242
protected Ops tf;
43-
private Operand<?> variant;
44-
private List<Class<? extends TType>> outputTypes;
45-
private List<Shape> outputShapes;
43+
private final Operand<?> variant;
44+
private final List<Class<? extends TType>> outputTypes;
45+
private final List<Shape> outputShapes;
4646

47+
/**
48+
* Creates a Dataset
49+
*
50+
* @param tf The TensorFlow Ops
51+
* @param variant the Operand that represents the dataset.
52+
* @param outputTypes A list of classes corresponding to the tensor type of each component of a
53+
* dataset element.
54+
* @param outputShapes A list of `Shape` objects corresponding to the shapes of each component of
55+
* a dataset element.
56+
*/
4757
public Dataset(
48-
Ops tf, Operand<?> variant, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
58+
Ops tf,
59+
Operand<?> variant,
60+
List<Class<? extends TType>> outputTypes,
61+
List<Shape> outputShapes) {
4962
if (tf == null) {
5063
throw new IllegalArgumentException("Ops accessor cannot be null.");
5164
}
@@ -61,13 +74,65 @@ public Dataset(
6174
this.outputShapes = outputShapes;
6275
}
6376

77+
/**
78+
* Creates a dataset from another dataset
79+
*
80+
* @param other the other dataset
81+
*/
6482
protected Dataset(Dataset other) {
6583
this.tf = other.tf;
6684
this.variant = other.variant;
6785
this.outputTypes = other.outputTypes;
6886
this.outputShapes = other.outputShapes;
6987
}
7088

89+
/**
90+
* Creates an in-memory `Dataset` whose elements are slices of the given tensors. Each element of
91+
* this dataset will be a {@code List<Operand<?>>}, representing slices (e.g. batches) of the
92+
* provided tensors.
93+
*
94+
* @param tf Ops Accessor
95+
* @param tensors A list of {@code Operand<?>} representing components of this dataset (e.g.
96+
* features, labels)
97+
* @param outputTypes A list of tensor type classes representing the data type of each component
98+
* of this dataset.
99+
* @return A new `Dataset`
100+
*/
101+
public static Dataset fromTensorSlices(
102+
Ops tf, List<Operand<?>> tensors, List<Class<? extends TType>> outputTypes) {
103+
return new TensorSliceDataset(tf, tensors, outputTypes);
104+
}
105+
106+
/**
107+
* Creates a Dataset comprising records from one or more TFRecord files.
108+
*
109+
* @param tf the TensorFlow Ops
110+
* @param filename the name of the file containing the TFRecords
111+
* @param compressionType the compression type, either "" (no compression), "ZLIB", or "GZIP"
112+
* @param bufferSize the number of bytes in the read buffer
113+
* @return A Dataset comprising records from a TFRecord file.
114+
*/
115+
public static Dataset tfRecordDataset(
116+
Ops tf, String filename, String compressionType, long bufferSize) {
117+
return new TFRecordDataset(
118+
tf, tf.constant(filename), tf.constant(compressionType), tf.constant(bufferSize));
119+
}
120+
121+
/**
122+
* Creates a Dataset comprising lines from one or more text files.
123+
*
124+
* @param tf the TensorFlow Ops
125+
* @param filename the name of the file containing the text linea
126+
* @param compressionType the compression type, either "" (no compression), "ZLIB", or "GZIP"
127+
* @param bufferSize the number of bytes in the read buffer
128+
* @return A Dataset comprising lines from a text file.
129+
*/
130+
public static Dataset textLineDataset(
131+
Ops tf, String filename, String compressionType, long bufferSize) {
132+
return new TextLineDataset(
133+
tf, tf.constant(filename), tf.constant(compressionType), tf.constant(bufferSize));
134+
}
135+
71136
/**
72137
* Groups elements of this dataset into batches.
73138
*
@@ -127,11 +192,12 @@ public final Dataset take(long count) {
127192
* Returns a new Dataset which maps a function across all elements from this dataset, on a single
128193
* component of each element.
129194
*
130-
* <p>For example, suppose each element is a {@code List<Operand<?>>} with 2 components: (features,
131-
* labels).
195+
* <p>For example, suppose each element is a {@code List<Operand<?>>} with 2 components:
196+
* (features, labels).
132197
*
133-
* <p>Calling {@code dataset.mapOneComponent(0, features -> tf.math.mul(features, tf.constant(2)))} will
134-
* map the function over the `features` component of each element, multiplying each by 2.
198+
* <p>Calling {@code dataset.mapOneComponent(0, features -> tf.math.mul(features,
199+
* tf.constant(2)))} will map the function over the `features` component of each element,
200+
* multiplying each by 2.
135201
*
136202
* @param index The index of the component to transform.
137203
* @param mapper The function to apply to the target component.
@@ -150,8 +216,8 @@ public Dataset mapOneComponent(int index, Function<Operand<?>, Operand<?>> mappe
150216
* Returns a new Dataset which maps a function across all elements from this dataset, on all
151217
* components of each element.
152218
*
153-
* <p>For example, suppose each element is a {@code List<Operand<?>>} with 2 components: (features,
154-
* labels).
219+
* <p>For example, suppose each element is a {@code List<Operand<?>>} with 2 components:
220+
* (features, labels).
155221
*
156222
* <p>Calling {@code dataset.mapAllComponents(component -> tf.math.mul(component,
157223
* tf.constant(2)))} will map the function over the both the `features` and `labels` components of
@@ -172,8 +238,8 @@ public Dataset mapAllComponents(Function<Operand<?>, Operand<?>> mapper) {
172238
/**
173239
* Returns a new Dataset which maps a function over all elements returned by this dataset.
174240
*
175-
* <p>For example, suppose each element is a {@code List<Operand<?>>} with 2 components: (features,
176-
* labels).
241+
* <p>For example, suppose each element is a {@code List<Operand<?>>} with 2 components:
242+
* (features, labels).
177243
*
178244
* <p>Calling
179245
*
@@ -254,53 +320,42 @@ public DatasetIterator makeOneShotIterator() {
254320
}
255321

256322
/**
257-
* Creates an in-memory `Dataset` whose elements are slices of the given tensors. Each element of
258-
* this dataset will be a {@code List<Operand<?>>}, representing slices (e.g. batches) of the
259-
* provided tensors.
323+
* Gets the variant tensor representing this dataset.
260324
*
261-
* @param tf Ops Accessor
262-
* @param tensors A list of {@code Operand<?>} representing components of this dataset (e.g.
263-
* features, labels)
264-
* @param outputTypes A list of tensor type classes representing the data type of each component of
265-
* this dataset.
266-
* @return A new `Dataset`
325+
* @return the variant tensor representing this dataset.
267326
*/
268-
public static Dataset fromTensorSlices(
269-
Ops tf, List<Operand<?>> tensors, List<Class<? extends TType>> outputTypes) {
270-
return new TensorSliceDataset(tf, tensors, outputTypes);
271-
}
272-
273-
public static Dataset tfRecordDataset(
274-
Ops tf, String filename, String compressionType, long bufferSize) {
275-
return new TFRecordDataset(
276-
tf, tf.constant(filename), tf.constant(compressionType), tf.constant(bufferSize));
277-
}
278-
279-
public static Dataset textLineDataset(
280-
Ops tf, String filename, String compressionType, long bufferSize) {
281-
return new TextLineDataset(
282-
tf, tf.constant(filename), tf.constant(compressionType), tf.constant(bufferSize));
283-
}
284-
285-
/** Get the variant tensor representing this dataset. */
286327
public Operand<?> getVariant() {
287328
return variant;
288329
}
289330

290-
/** Get a list of output types for each component of this dataset. */
331+
/**
332+
* Gets a list of output types for each component of this dataset.
333+
*
334+
* @return the list of output types for each component of this dataset.
335+
*/
291336
public List<Class<? extends TType>> getOutputTypes() {
292337
return this.outputTypes;
293338
}
294339

295-
/** Get a list of shapes for each component of this dataset. */
340+
/**
341+
* Gets a list of shapes for each component of this dataset.
342+
*
343+
* @return the list of shapes for each component of this dataset.
344+
*/
296345
public List<Shape> getOutputShapes() {
297346
return this.outputShapes;
298347
}
299348

349+
/**
350+
* Gets the TensorFlow Ops Instance
351+
*
352+
* @return the TensorFlow Ops Instance
353+
*/
300354
public Ops getOpsInstance() {
301355
return this.tf;
302356
}
303357

358+
/** {@inheritDoc} */
304359
@Override
305360
public String toString() {
306361
return "Dataset{"

tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetIterator.java

Lines changed: 58 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717

1818
import org.tensorflow.Graph;
1919
import org.tensorflow.Operand;
20+
import org.tensorflow.ndarray.Shape;
2021
import org.tensorflow.op.Op;
2122
import org.tensorflow.op.Ops;
22-
import org.tensorflow.ndarray.Shape;
23+
import org.tensorflow.types.family.TType;
2324

2425
import java.util.ArrayList;
2526
import java.util.Iterator;
2627
import java.util.List;
27-
import org.tensorflow.types.family.TType;
2828

2929
/**
3030
* Represents the state of an iteration through a tf.data Datset. DatasetIterator is not a
@@ -102,21 +102,21 @@ public class DatasetIterator implements Iterable<List<Operand<?>>> {
102102
public static final String EMPTY_SHARED_NAME = "";
103103

104104
protected Ops tf;
105-
106-
private Operand<?> iteratorResource;
107-
private Op initializer;
108-
109105
protected List<Class<? extends TType>> outputTypes;
110106
protected List<Shape> outputShapes;
107+
private final Operand<?> iteratorResource;
108+
private Op initializer;
111109

112110
/**
111+
* Creates a DatasetIterator
112+
*
113113
* @param tf Ops accessor corresponding to the same `ExecutionEnvironment` as the
114114
* `iteratorResource`.
115115
* @param iteratorResource An Operand representing the iterator (e.g. constructed from
116116
* `tf.data.iterator` or `tf.data.anonymousIterator`)
117117
* @param initializer An `Op` that should be run to initialize this iterator
118-
* @param outputTypes A list of classes corresponding to the tensor type of each component of
119-
* a dataset element.
118+
* @param outputTypes A list of classes corresponding to the tensor type of each component of a
119+
* dataset element.
120120
* @param outputShapes A list of `Shape` objects corresponding to the shapes of each component of
121121
* a dataset element.
122122
*/
@@ -134,6 +134,18 @@ public DatasetIterator(
134134
this.outputShapes = outputShapes;
135135
}
136136

137+
/**
138+
* Creates a DatasetIterator
139+
*
140+
* @param tf Ops accessor corresponding to the same `ExecutionEnvironment` as the
141+
* `iteratorResource`.
142+
* @param iteratorResource An Operand representing the iterator (e.g. constructed from
143+
* `tf.data.iterator` or `tf.data.anonymousIterator`)
144+
* @param outputTypes A list of classes corresponding to the tensor type of each component of a
145+
* dataset element.
146+
* @param outputShapes A list of `Shape` objects corresponding to the shapes of each component of
147+
* a dataset element.
148+
*/
137149
public DatasetIterator(
138150
Ops tf,
139151
Operand<?> iteratorResource,
@@ -145,6 +157,11 @@ public DatasetIterator(
145157
this.outputShapes = outputShapes;
146158
}
147159

160+
/**
161+
* Creates a DatasetIterator from another DatasetIterator
162+
*
163+
* @param other the other DatasetIterator
164+
*/
148165
protected DatasetIterator(DatasetIterator other) {
149166
this.tf = other.tf;
150167
this.iteratorResource = other.iteratorResource;
@@ -153,6 +170,26 @@ protected DatasetIterator(DatasetIterator other) {
153170
this.outputShapes = other.outputShapes;
154171
}
155172

173+
/**
174+
* Creates a new iterator from a "structure" defined by `outputShapes` and `outputTypes`.
175+
*
176+
* @param tf Ops accessor
177+
* @param outputTypes A list of classes repesenting the tensor type of each component of a dataset
178+
* element.
179+
* @param outputShapes A list of Shape objects representing the shape of each component of a
180+
* dataset element.
181+
* @return A new DatasetIterator
182+
*/
183+
public static DatasetIterator fromStructure(
184+
Ops tf, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
185+
Operand<?> iteratorResource =
186+
tf.scope().env() instanceof Graph
187+
? tf.data.iterator(EMPTY_SHARED_NAME, "", outputTypes, outputShapes)
188+
: tf.data.anonymousIterator(outputTypes, outputShapes).handle();
189+
190+
return new DatasetIterator(tf, iteratorResource, outputTypes, outputShapes);
191+
}
192+
156193
/**
157194
* Returns a list of {@code Operand<?>} representing the components of the next dataset element.
158195
*
@@ -226,37 +263,33 @@ public Op makeInitializer(Dataset dataset) {
226263
}
227264

228265
/**
229-
* Creates a new iterator from a "structure" defined by `outputShapes` and `outputTypes`.
266+
* Gets the iteratorResource
230267
*
231-
* @param tf Ops accessor
232-
* @param outputTypes A list of classes repesenting the tensor type of each component of a
233-
* dataset element.
234-
* @param outputShapes A list of Shape objects representing the shape of each component of a
235-
* dataset element.
236-
* @return A new DatasetIterator
268+
* @return the iteratorResource
237269
*/
238-
public static DatasetIterator fromStructure(
239-
Ops tf, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
240-
Operand<?> iteratorResource =
241-
tf.scope().env() instanceof Graph
242-
? tf.data.iterator(EMPTY_SHARED_NAME, "", outputTypes, outputShapes)
243-
: tf.data.anonymousIterator(outputTypes, outputShapes).handle();
244-
245-
return new DatasetIterator(tf, iteratorResource, outputTypes, outputShapes);
246-
}
247-
248270
public Operand<?> getIteratorResource() {
249271
return iteratorResource;
250272
}
251273

274+
/**
275+
* Gets the initializer
276+
*
277+
* @return the initializer
278+
*/
252279
public Op getInitializer() {
253280
return initializer;
254281
}
255282

283+
/**
284+
* Gets the TensorFlow Ops Instance
285+
*
286+
* @return the TensorFlow Ops Instance
287+
*/
256288
public Ops getOpsInstance() {
257289
return tf;
258290
}
259291

292+
/** {@inheritDoc} */
260293
@Override
261294
public Iterator<List<Operand<?>>> iterator() {
262295

0 commit comments

Comments
 (0)