2323import org .tensorflow .framework .data .impl .TakeDataset ;
2424import org .tensorflow .framework .data .impl .TensorSliceDataset ;
2525import org .tensorflow .framework .data .impl .TextLineDataset ;
26+ import org .tensorflow .ndarray .Shape ;
2627import org .tensorflow .op .Op ;
2728import org .tensorflow .op .Ops ;
28- import org .tensorflow .ndarray . Shape ;
29+ import org .tensorflow .types . family . TType ;
2930
3031import java .util .ArrayList ;
3132import java .util .Arrays ;
3233import java .util .Iterator ;
3334import java .util .List ;
3435import java .util .function .Function ;
35- import org .tensorflow .types .family .TType ;
3636
3737/**
3838 * Represents a potentially large list of independent elements (samples), and allows iteration and
3939 * transformations to be performed across these elements.
4040 */
4141public abstract class Dataset implements Iterable <List <Operand <?>>> {
4242 protected Ops tf ;
43- private Operand <?> variant ;
44- private List <Class <? extends TType >> outputTypes ;
45- private List <Shape > outputShapes ;
43+ private final Operand <?> variant ;
44+ private final List <Class <? extends TType >> outputTypes ;
45+ private final List <Shape > outputShapes ;
4646
47+ /**
48+ * Creates a Dataset
49+ *
50+ * @param tf The TensorFlow Ops
51+ * @param variant the Operand that represents the dataset.
52+ * @param outputTypes A list of classes corresponding to the tensor type of each component of a
53+ * dataset element.
54+ * @param outputShapes A list of `Shape` objects corresponding to the shapes of each component of
55+ * a dataset element.
56+ */
4757 public Dataset (
48- Ops tf , Operand <?> variant , List <Class <? extends TType >> outputTypes , List <Shape > outputShapes ) {
58+ Ops tf ,
59+ Operand <?> variant ,
60+ List <Class <? extends TType >> outputTypes ,
61+ List <Shape > outputShapes ) {
4962 if (tf == null ) {
5063 throw new IllegalArgumentException ("Ops accessor cannot be null." );
5164 }
@@ -61,13 +74,65 @@ public Dataset(
6174 this .outputShapes = outputShapes ;
6275 }
6376
77+ /**
78+ * Creates a dataset from another dataset
79+ *
80+ * @param other the other dataset
81+ */
6482 protected Dataset (Dataset other ) {
6583 this .tf = other .tf ;
6684 this .variant = other .variant ;
6785 this .outputTypes = other .outputTypes ;
6886 this .outputShapes = other .outputShapes ;
6987 }
7088
89+ /**
90+ * Creates an in-memory `Dataset` whose elements are slices of the given tensors. Each element of
91+ * this dataset will be a {@code List<Operand<?>>}, representing slices (e.g. batches) of the
92+ * provided tensors.
93+ *
94+ * @param tf Ops Accessor
95+ * @param tensors A list of {@code Operand<?>} representing components of this dataset (e.g.
96+ * features, labels)
97+ * @param outputTypes A list of tensor type classes representing the data type of each component
98+ * of this dataset.
99+ * @return A new `Dataset`
100+ */
101+ public static Dataset fromTensorSlices (
102+ Ops tf , List <Operand <?>> tensors , List <Class <? extends TType >> outputTypes ) {
103+ return new TensorSliceDataset (tf , tensors , outputTypes );
104+ }
105+
106+ /**
107+ * Creates a Dataset comprising records from one or more TFRecord files.
108+ *
109+ * @param tf the TensorFlow Ops
110+ * @param filename the name of the file containing the TFRecords
111+ * @param compressionType the compression type, either "" (no compression), "ZLIB", or "GZIP"
112+ * @param bufferSize the number of bytes in the read buffer
113+ * @return A Dataset comprising records from a TFRecord file.
114+ */
115+ public static Dataset tfRecordDataset (
116+ Ops tf , String filename , String compressionType , long bufferSize ) {
117+ return new TFRecordDataset (
118+ tf , tf .constant (filename ), tf .constant (compressionType ), tf .constant (bufferSize ));
119+ }
120+
121+ /**
122+ * Creates a Dataset comprising lines from one or more text files.
123+ *
124+ * @param tf the TensorFlow Ops
125+ * @param filename the name of the file containing the text linea
126+ * @param compressionType the compression type, either "" (no compression), "ZLIB", or "GZIP"
127+ * @param bufferSize the number of bytes in the read buffer
128+ * @return A Dataset comprising lines from a text file.
129+ */
130+ public static Dataset textLineDataset (
131+ Ops tf , String filename , String compressionType , long bufferSize ) {
132+ return new TextLineDataset (
133+ tf , tf .constant (filename ), tf .constant (compressionType ), tf .constant (bufferSize ));
134+ }
135+
71136 /**
72137 * Groups elements of this dataset into batches.
73138 *
@@ -127,11 +192,12 @@ public final Dataset take(long count) {
127192 * Returns a new Dataset which maps a function across all elements from this dataset, on a single
128193 * component of each element.
129194 *
130- * <p>For example, suppose each element is a {@code List<Operand<?>>} with 2 components: (features,
131- * labels).
195+ * <p>For example, suppose each element is a {@code List<Operand<?>>} with 2 components:
196+ * (features, labels).
132197 *
133- * <p>Calling {@code dataset.mapOneComponent(0, features -> tf.math.mul(features, tf.constant(2)))} will
134- * map the function over the `features` component of each element, multiplying each by 2.
198+ * <p>Calling {@code dataset.mapOneComponent(0, features -> tf.math.mul(features,
199+ * tf.constant(2)))} will map the function over the `features` component of each element,
200+ * multiplying each by 2.
135201 *
136202 * @param index The index of the component to transform.
137203 * @param mapper The function to apply to the target component.
@@ -150,8 +216,8 @@ public Dataset mapOneComponent(int index, Function<Operand<?>, Operand<?>> mappe
150216 * Returns a new Dataset which maps a function across all elements from this dataset, on all
151217 * components of each element.
152218 *
153- * <p>For example, suppose each element is a {@code List<Operand<?>>} with 2 components: (features,
154- * labels).
219+ * <p>For example, suppose each element is a {@code List<Operand<?>>} with 2 components:
220+ * (features, labels).
155221 *
156222 * <p>Calling {@code dataset.mapAllComponents(component -> tf.math.mul(component,
157223 * tf.constant(2)))} will map the function over the both the `features` and `labels` components of
@@ -172,8 +238,8 @@ public Dataset mapAllComponents(Function<Operand<?>, Operand<?>> mapper) {
172238 /**
173239 * Returns a new Dataset which maps a function over all elements returned by this dataset.
174240 *
175- * <p>For example, suppose each element is a {@code List<Operand<?>>} with 2 components: (features,
176- * labels).
241+ * <p>For example, suppose each element is a {@code List<Operand<?>>} with 2 components:
242+ * (features, labels).
177243 *
178244 * <p>Calling
179245 *
@@ -254,53 +320,42 @@ public DatasetIterator makeOneShotIterator() {
254320 }
255321
256322 /**
257- * Creates an in-memory `Dataset` whose elements are slices of the given tensors. Each element of
258- * this dataset will be a {@code List<Operand<?>>}, representing slices (e.g. batches) of the
259- * provided tensors.
323+ * Gets the variant tensor representing this dataset.
260324 *
261- * @param tf Ops Accessor
262- * @param tensors A list of {@code Operand<?>} representing components of this dataset (e.g.
263- * features, labels)
264- * @param outputTypes A list of tensor type classes representing the data type of each component of
265- * this dataset.
266- * @return A new `Dataset`
325+ * @return the variant tensor representing this dataset.
267326 */
268- public static Dataset fromTensorSlices (
269- Ops tf , List <Operand <?>> tensors , List <Class <? extends TType >> outputTypes ) {
270- return new TensorSliceDataset (tf , tensors , outputTypes );
271- }
272-
273- public static Dataset tfRecordDataset (
274- Ops tf , String filename , String compressionType , long bufferSize ) {
275- return new TFRecordDataset (
276- tf , tf .constant (filename ), tf .constant (compressionType ), tf .constant (bufferSize ));
277- }
278-
279- public static Dataset textLineDataset (
280- Ops tf , String filename , String compressionType , long bufferSize ) {
281- return new TextLineDataset (
282- tf , tf .constant (filename ), tf .constant (compressionType ), tf .constant (bufferSize ));
283- }
284-
285- /** Get the variant tensor representing this dataset. */
286327 public Operand <?> getVariant () {
287328 return variant ;
288329 }
289330
290- /** Get a list of output types for each component of this dataset. */
331+ /**
332+ * Gets a list of output types for each component of this dataset.
333+ *
334+ * @return the list of output types for each component of this dataset.
335+ */
291336 public List <Class <? extends TType >> getOutputTypes () {
292337 return this .outputTypes ;
293338 }
294339
295- /** Get a list of shapes for each component of this dataset. */
340+ /**
341+ * Gets a list of shapes for each component of this dataset.
342+ *
343+ * @return the list of shapes for each component of this dataset.
344+ */
296345 public List <Shape > getOutputShapes () {
297346 return this .outputShapes ;
298347 }
299348
349+ /**
350+ * Gets the TensorFlow Ops Instance
351+ *
352+ * @return the TensorFlow Ops Instance
353+ */
300354 public Ops getOpsInstance () {
301355 return this .tf ;
302356 }
303357
358+ /** {@inheritDoc} */
304359 @ Override
305360 public String toString () {
306361 return "Dataset{"
0 commit comments