Skip to content
This repository was archived by the owner on Jul 15, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Skeleton for the NdArray hydration API
  • Loading branch information
karllessard committed Dec 5, 2021
commit 07e95bf17d990590384512393622f888e065ad64
26 changes: 26 additions & 0 deletions ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.tensorflow.ndarray;

import java.util.function.Consumer;
import org.tensorflow.ndarray.buffer.BooleanDataBuffer;
import org.tensorflow.ndarray.buffer.ByteDataBuffer;
import org.tensorflow.ndarray.buffer.DataBuffer;
Expand All @@ -25,6 +26,7 @@
import org.tensorflow.ndarray.buffer.IntDataBuffer;
import org.tensorflow.ndarray.buffer.LongDataBuffer;
import org.tensorflow.ndarray.buffer.ShortDataBuffer;
import org.tensorflow.ndarray.hydrator.DoubleNdArrayHydrator;
import org.tensorflow.ndarray.impl.dense.BooleanDenseNdArray;
import org.tensorflow.ndarray.impl.dense.ByteDenseNdArray;
import org.tensorflow.ndarray.impl.dense.DenseNdArray;
Expand All @@ -33,6 +35,7 @@
import org.tensorflow.ndarray.impl.dense.IntDenseNdArray;
import org.tensorflow.ndarray.impl.dense.LongDenseNdArray;
import org.tensorflow.ndarray.impl.dense.ShortDenseNdArray;
import org.tensorflow.ndarray.impl.dense.hydrator.DoubleDenseNdArrayHydrator;
import org.tensorflow.ndarray.impl.dimension.DimensionalSpace;
import org.tensorflow.ndarray.impl.sparse.BooleanSparseNdArray;
import org.tensorflow.ndarray.impl.sparse.ByteSparseNdArray;
Expand All @@ -41,6 +44,7 @@
import org.tensorflow.ndarray.impl.sparse.IntSparseNdArray;
import org.tensorflow.ndarray.impl.sparse.LongSparseNdArray;
import org.tensorflow.ndarray.impl.sparse.ShortSparseNdArray;
import org.tensorflow.ndarray.impl.sparse.hydrator.DoubleSparseNdArrayHydrator;

/** Utility class for instantiating {@link NdArray} objects. */
public final class NdArrays {
Expand Down Expand Up @@ -555,6 +559,20 @@ public static DoubleNdArray ofDoubles(Shape shape) {
return wrap(shape, DataBuffers.ofDoubles(shape.size()));
}

/**
* Creates an N-dimensional array of doubles of the given shape, with data hydration
*
* @param shape shape of the array
* @param hydrate initialize the data of the created array, using a hydrator
* @return new double N-dimensional array
* @throws IllegalArgumentException if shape is null or has unknown dimensions
*/
public static DoubleNdArray ofDoubles(Shape shape, Consumer<DoubleNdArrayHydrator> hydrate) {
Comment thread
karllessard marked this conversation as resolved.
DoubleDenseNdArray array = (DoubleDenseNdArray)ofDoubles(shape);
hydrate.accept(new DoubleDenseNdArrayHydrator(array));
return array;
}

/**
* Wraps a buffer in a double N-dimensional array of a given shape.
*
Expand All @@ -568,6 +586,14 @@ public static DoubleNdArray wrap(Shape shape, DoubleDataBuffer buffer) {
return DoubleDenseNdArray.create(buffer, shape);
}

public static DoubleSparseNdArray sparseOfDoubles(long numValues, Shape shape, Consumer<DoubleNdArrayHydrator> hydrate) {
LongNdArray indices = ofLongs(Shape.of(numValues, shape.numDimensions()));
DoubleNdArray values = ofDoubles(Shape.of(numValues));
DoubleSparseNdArray array = DoubleSparseNdArray.create(indices, values, DimensionalSpace.create(shape));
hydrate.accept(new DoubleSparseNdArrayHydrator(array));
return array;
}

/**
* Creates a Sparse array of double values with a default value of zero
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package org.tensorflow.ndarray.hydrator;
Comment thread
karllessard marked this conversation as resolved.

import org.tensorflow.ndarray.DoubleNdArray;
import org.tensorflow.ndarray.NdArray;
import org.tensorflow.ndarray.Shape;

public interface DoubleNdArrayHydrator extends NdArrayHydrator<Double> {

interface Scalars extends NdArrayHydrator.Scalars<Double> {

@Override
Scalars at(long... coordinates);

Scalars put(double scalar);
}

interface Vectors extends NdArrayHydrator.Vectors<Double> {

@Override
Vectors at(long... coordinates);

Vectors put(double... vector);
}

@Override
Scalars byScalars(long... coordinates);

@Override
Vectors byVectors(long... coordinates);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package org.tensorflow.ndarray.hydrator;
Comment thread
karllessard marked this conversation as resolved.

import org.tensorflow.ndarray.NdArray;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.ByteDataBuffer;

public interface NdArrayHydrator<T> {

interface Scalars<T> {

<U extends Scalars<T>> U at(long... coordinates);

<U extends Scalars<T>> U putObject(T scalar);
}

interface Vectors<T> {

<U extends Vectors<T>> U at(long... coordinates);

<U extends Vectors<T>> U putObjects(T... vector);
}

interface Elements<T> {

<U extends Elements<T>> U at(long... coordinates);

<U extends Elements<T>> U put(NdArray<T> vector);
}

<U extends Scalars<T>> U byScalars(long... coordinates);

<U extends Vectors<T>> U byVectors(long... coordinates);

<U extends Elements<T>> U byElements(long... coordinates);
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
@SuppressWarnings("unchecked")
public abstract class AbstractDenseNdArray<T, U extends NdArray<T>> extends AbstractNdArray<T, U> {

abstract public DataBuffer<T> buffer();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copyright 2022 should probably be added to this file, e.g. "Copyright 2019, 2022, The TensorFlow Authors. All Rights Reserved.".

Also javadoc for the new public method? And why does it need to be public now?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I access directly the array buffer in the initializer as I know that at this point it is safe to do it. Now, I could also pass a reference to that buffer at the initializer construction I suppose, if we prefer to keep it hidden from malicious usage. Though in some cases I missed the fact that I cannot access the buffer of a dense array... so I'm a bit undecided on this, your thoughts?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I would default to keeping it more private unless there is a strong reason not to. It provides us more flexibility in the future, and making things public is easier than making them private again.


@Override
public NdArraySequence<U> elements(int dimensionIdx) {
if (dimensionIdx >= shape().numDimensions()) {
Expand Down Expand Up @@ -136,8 +138,6 @@ protected AbstractDenseNdArray(DimensionalSpace dimensions) {
super(dimensions);
}

abstract protected DataBuffer<T> buffer();

abstract U instantiate(DataBuffer<T> buffer, DimensionalSpace dimensions);

long positionOf(long[] coords, boolean isValue) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ public static BooleanNdArray create(BooleanDataBuffer buffer, Shape shape) {
return new BooleanDenseNdArray(buffer, shape);
}

@Override
public BooleanDataBuffer buffer() {
return buffer;
}

@Override
public boolean getBoolean(long... indices) {
return buffer.getBoolean(positionOf(indices, true));
Expand Down Expand Up @@ -77,11 +82,6 @@ BooleanDenseNdArray instantiate(DataBuffer<Boolean> buffer, DimensionalSpace dim
return new BooleanDenseNdArray((BooleanDataBuffer)buffer, dimensions);
}

@Override
protected BooleanDataBuffer buffer() {
return buffer;
}

private final BooleanDataBuffer buffer;

private BooleanDenseNdArray(BooleanDataBuffer buffer, DimensionalSpace dimensions) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ public static ByteNdArray create(ByteDataBuffer buffer, Shape shape) {
return new ByteDenseNdArray(buffer, shape);
}

@Override
public ByteDataBuffer buffer() {
return buffer;
}

@Override
public byte getByte(long... indices) {
return buffer.getByte(positionOf(indices, true));
Expand Down Expand Up @@ -77,11 +82,6 @@ ByteDenseNdArray instantiate(DataBuffer<Byte> buffer, DimensionalSpace dimension
return new ByteDenseNdArray((ByteDataBuffer)buffer, dimensions);
}

@Override
protected ByteDataBuffer buffer() {
return buffer;
}

private final ByteDataBuffer buffer;

private ByteDenseNdArray(ByteDataBuffer buffer, DimensionalSpace dimensions) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ DenseNdArray<T> instantiate(DataBuffer<T> buffer, DimensionalSpace dimensions) {
}

@Override
protected DataBuffer<T> buffer() {
public DataBuffer<T> buffer() {
return buffer;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ public static DoubleNdArray create(DoubleDataBuffer buffer, Shape shape) {
return new DoubleDenseNdArray(buffer, shape);
}

@Override
public DoubleDataBuffer buffer() {
return buffer;
}

@Override
public double getDouble(long... indices) {
return buffer.getDouble(positionOf(indices, true));
Expand Down Expand Up @@ -77,11 +82,6 @@ DoubleDenseNdArray instantiate(DataBuffer<Double> buffer, DimensionalSpace dimen
return new DoubleDenseNdArray((DoubleDataBuffer)buffer, dimensions);
}

@Override
protected DoubleDataBuffer buffer() {
return buffer;
}

private final DoubleDataBuffer buffer;

private DoubleDenseNdArray(DoubleDataBuffer buffer, DimensionalSpace dimensions) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ public static FloatNdArray create(FloatDataBuffer buffer, Shape shape) {
return new FloatDenseNdArray(buffer, shape);
}

@Override
public FloatDataBuffer buffer() {
return buffer;
}

@Override
public float getFloat(long... indices) {
return buffer.getFloat(positionOf(indices, true));
Expand Down Expand Up @@ -77,11 +82,6 @@ FloatDenseNdArray instantiate(DataBuffer<Float> buffer, DimensionalSpace dimensi
return new FloatDenseNdArray((FloatDataBuffer) buffer, dimensions);
}

@Override
public FloatDataBuffer buffer() {
return buffer;
}

private final FloatDataBuffer buffer;

private FloatDenseNdArray(FloatDataBuffer buffer, DimensionalSpace dimensions) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ public static IntNdArray create(IntDataBuffer buffer, Shape shape) {
return new IntDenseNdArray(buffer, shape);
}

@Override
public IntDataBuffer buffer() {
return buffer;
}

@Override
public int getInt(long... indices) {
return buffer.getInt(positionOf(indices, true));
Expand Down Expand Up @@ -77,11 +82,6 @@ IntDenseNdArray instantiate(DataBuffer<Integer> buffer, DimensionalSpace dimensi
return new IntDenseNdArray((IntDataBuffer)buffer, dimensions);
}

@Override
protected IntDataBuffer buffer() {
return buffer;
}

private final IntDataBuffer buffer;

private IntDenseNdArray(IntDataBuffer buffer, DimensionalSpace dimensions) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ public static LongNdArray create(LongDataBuffer buffer, Shape shape) {
return new LongDenseNdArray(buffer, shape);
}

@Override
public LongDataBuffer buffer() {
return buffer;
}

@Override
public long getLong(long... indices) {
return buffer.getLong(positionOf(indices, true));
Expand Down Expand Up @@ -77,11 +82,6 @@ LongDenseNdArray instantiate(DataBuffer<Long> buffer, DimensionalSpace dimension
return new LongDenseNdArray((LongDataBuffer)buffer, dimensions);
}

@Override
protected LongDataBuffer buffer() {
return buffer;
}

private final LongDataBuffer buffer;

private LongDenseNdArray(LongDataBuffer buffer, DimensionalSpace dimensions) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ public static ShortNdArray create(ShortDataBuffer buffer, Shape shape) {
return new ShortDenseNdArray(buffer, shape);
}

@Override
public ShortDataBuffer buffer() {
return buffer;
}

@Override
public short getShort(long... indices) {
return buffer.getShort(positionOf(indices, true));
Expand Down Expand Up @@ -77,11 +82,6 @@ ShortDenseNdArray instantiate(DataBuffer<Short> buffer, DimensionalSpace dimensi
return new ShortDenseNdArray((ShortDataBuffer)buffer, dimensions);
}

@Override
protected ShortDataBuffer buffer() {
return buffer;
}

private final ShortDataBuffer buffer;

private ShortDenseNdArray(ShortDataBuffer buffer, DimensionalSpace dimensions) {
Expand Down
Loading