Skip to content

Commit c4a7bfb

Browse files
committed
Moved isCompatibleWith to Shape
1 parent 505d0d6 commit c4a7bfb

2 files changed

Lines changed: 124 additions & 136 deletions

File tree

  • ndarray/src/main/java/org/tensorflow/ndarray
  • tensorflow-framework/src/main/java/org/tensorflow/framework/utils

ndarray/src/main/java/org/tensorflow/ndarray/Shape.java

Lines changed: 124 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@ public static Shape scalar() {
5252
/**
5353
* Create a Shape representing a scalar or an N-dimensional value.
5454
*
55-
* <p>Creates a Shape representing a scalar or an N-dimensional value (N being at least 1),
56-
* with the provided size for each dimension. A -1 indicates that the size of the corresponding
57-
* dimension is unknown. If no sizes are provided, a Shape representing a scalar is created.
58-
* For example:
55+
* <p>Creates a Shape representing a scalar or an N-dimensional value (N being at least 1), with
56+
* the provided size for each dimension. A -1 indicates that the size of the corresponding
57+
* dimension is unknown. If no sizes are provided, a Shape representing a scalar is created. For
58+
* example:
5959
*
6060
* <pre>{@code
6161
* // A 2-element vector.
@@ -84,11 +84,11 @@ public static Shape of(long... dimensionSizes) {
8484
/**
8585
* Returns the total number of elements a Tensor with this Shape would have.
8686
*
87-
* <p>If {@link Shape#isUnknown()} is true or {@link Shape#hasUnknownDimension()} is true,
88-
* {@link Shape#UNKNOWN_SIZE} is returned.
87+
* <p>If {@link Shape#isUnknown()} is true or {@link Shape#hasUnknownDimension()} is true, {@link
88+
* Shape#UNKNOWN_SIZE} is returned.
8989
*
9090
* @return The total number of elements a Tensor with this shape would have if it can be
91-
* calculated, else {@link Shape#UNKNOWN_SIZE}.
91+
* calculated, else {@link Shape#UNKNOWN_SIZE}.
9292
*/
9393
public long size() {
9494
if (size == null) {
@@ -104,12 +104,11 @@ public long size() {
104104
* an unknown size, {@link Shape#UNKNOWN_SIZE} is returned.
105105
*
106106
* @param i the index of the dimension to get the size for. If this Shape has a known number of
107-
* dimensions, it must be < {@link Shape#numDimensions()}. The index may be negative,
108-
* in which case the position is counted from the end of the shape. E.g.:
109-
* {@code size(-1)} returns the size of the last dimension, {@code size(-2)} the size of
110-
* the second to last dimension etc.
107+
* dimensions, it must be < {@link Shape#numDimensions()}. The index may be negative, in which
108+
* case the position is counted from the end of the shape. E.g.: {@code size(-1)} returns the
109+
* size of the last dimension, {@code size(-2)} the size of the second to last dimension etc.
111110
* @return The size of the dimension with the given index if known, {@link Shape#UNKNOWN_SIZE}
112-
* otherwise.
111+
* otherwise.
113112
*/
114113
public long size(int i) {
115114
if (dimensionSizes == null) {
@@ -163,8 +162,8 @@ public boolean isUnknown() {
163162
}
164163

165164
/**
166-
* Returns a defensive copy of the this Shape's axes. Changes to the returned array to not
167-
* change this Shape's state. Returns null if {@link Shape#isUnknown()} is true.
165+
* Returns a defensive copy of the this Shape's axes. Changes to the returned array to not change
166+
* this Shape's state. Returns null if {@link Shape#isUnknown()} is true.
168167
*/
169168
public long[] asArray() {
170169
if (this.dimensionSizes == null) {
@@ -182,15 +181,17 @@ public int hashCode() {
182181
/**
183182
* Equals implementation for Shapes. Two Shapes are considered equal iff:
184183
*
185-
* <p><ul>
186-
* <li>the number of dimensions is defined and equal for both
187-
* <li>the size of each dimension is defined and equal for both
184+
* <p>
185+
*
186+
* <ul>
187+
* <li>the number of dimensions is defined and equal for both
188+
* <li>the size of each dimension is defined and equal for both
188189
* </ul>
189190
*
190191
* <p>If either Shape has unknown dimensions (even if they are the same in both) or if either
191-
* shape has an unknown number of dimensions (even if both return {@code true} for
192-
* {@link Shape#isUnknown()}), they are not considered equal! However, a shape will always
193-
* equal itself, even if it is unknown or contains unknown dimensions.
192+
* shape has an unknown number of dimensions (even if both return {@code true} for {@link
193+
* Shape#isUnknown()}), they are not considered equal! However, a shape will always equal itself,
194+
* even if it is unknown or contains unknown dimensions.
194195
*/
195196
@Override
196197
public boolean equals(Object obj) {
@@ -229,17 +230,17 @@ public Shape head() {
229230
}
230231

231232
/**
232-
* Returns an n-dimensional Shape with the dimensions matching the first n dimensions
233-
* of this shape
233+
* Returns an n-dimensional Shape with the dimensions matching the first n dimensions of this
234+
* shape
234235
*
235236
* @param n the number of leading dimensions to get, must be <= than {@link Shape#numDimensions()}
236-
* @return an n-dimensional Shape with the first n dimensions matching the first n dimensions
237-
* of this Shape
237+
* @return an n-dimensional Shape with the first n dimensions matching the first n dimensions of
238+
* this Shape
238239
*/
239240
public Shape take(int n) {
240241
if (n > numDimensions()) {
241-
throw new ArrayIndexOutOfBoundsException("Cannot take " + n +
242-
" dimensions, shape has only " + numDimensions() + ".");
242+
throw new ArrayIndexOutOfBoundsException(
243+
"Cannot take " + n + " dimensions, shape has only " + numDimensions() + ".");
243244
}
244245
long[] newDimensions = new long[n];
245246
System.arraycopy(dimensionSizes, 0, newDimensions, 0, n);
@@ -253,18 +254,18 @@ public Shape tail() {
253254
}
254255

255256
/**
256-
* Returns an n-dimensional Shape with the dimensions matching the last n dimensions
257-
* of this Shape.
257+
* Returns an n-dimensional Shape with the dimensions matching the last n dimensions of this
258+
* Shape.
258259
*
259-
* @param n the number of trailing dimensions to get, must be <= than
260-
* {@link Shape#numDimensions()}
260+
* @param n the number of trailing dimensions to get, must be <= than {@link
261+
* Shape#numDimensions()}
261262
* @return an n-dimensional shape with the dimensions matching the last n dimensions of this
262-
* Shape, never null
263+
* Shape, never null
263264
*/
264265
public Shape takeLast(int n) {
265266
if (n > numDimensions()) {
266-
throw new ArrayIndexOutOfBoundsException("Cannot take last " + n +
267-
" dimensions, shape has only " + numDimensions() + ".");
267+
throw new ArrayIndexOutOfBoundsException(
268+
"Cannot take last " + n + " dimensions, shape has only " + numDimensions() + ".");
268269
}
269270
long[] newDimensions = new long[n];
270271
System.arraycopy(dimensionSizes, numDimensions() - n, newDimensions, 0, n);
@@ -276,8 +277,8 @@ public Shape takeLast(int n) {
276277
* {@link Shape#isUnknown()} must be {@code false}.
277278
*
278279
* @param firstDimension the dimension to prepend
279-
* @return a new shape with the given dimension first, followed by this Shape's dimensions,
280-
* never null
280+
* @return a new shape with the given dimension first, followed by this Shape's dimensions, never
281+
* null
281282
*/
282283
public Shape prepend(long firstDimension) {
283284
long[] newDimensions = new long[dimensionSizes.length + 1];
@@ -288,8 +289,8 @@ public Shape prepend(long firstDimension) {
288289
}
289290

290291
/**
291-
* Returns a new Shape, with a new last dimension added. In order for this call to succeed,
292-
* {@link Shape#isUnknown()} must be {@code false}.
292+
* Returns a new Shape, with a new last dimension added. In order for this call to succeed, {@link
293+
* Shape#isUnknown()} must be {@code false}.
293294
*
294295
* @param lastDimension the dimension to append
295296
* @return a new Shape with this Shape's dimensions followed by the given dimension, never null
@@ -303,38 +304,36 @@ public Shape append(long lastDimension) {
303304
}
304305

305306
/**
306-
* Returns a new Shape, with another Shape's dimensions prepended.
307-
* For both this Shape and the other Shape, {@link Shape#isUnknown()} must return false.
308-
* E.g. {@code Shape.of(3,4).prepend(Shape.of(1,2)) => Shape.of(1,2,3,4) }
307+
* Returns a new Shape, with another Shape's dimensions prepended. For both this Shape and the
308+
* other Shape, {@link Shape#isUnknown()} must return false. E.g. {@code
309+
* Shape.of(3,4).prepend(Shape.of(1,2)) => Shape.of(1,2,3,4) }
309310
*
310311
* @param other another Shape, must not be {@code null}, must not be unknown
311-
* @return A new Shape consisting of the given Shapes's dimensions followed by this Shape's
312-
* dimensions, never null
312+
* @return A new Shape consisting of the given Shape's dimensions followed by this Shape's
313+
* dimensions, never null
313314
*/
314315
public Shape prepend(Shape other) {
315316
long[] newDimensions = new long[other.dimensionSizes.length + dimensionSizes.length];
316-
System.arraycopy(other.dimensionSizes, 0,
317-
newDimensions, 0, other.dimensionSizes.length);
318-
System.arraycopy(dimensionSizes, 0,
319-
newDimensions, other.dimensionSizes.length, dimensionSizes.length);
317+
System.arraycopy(other.dimensionSizes, 0, newDimensions, 0, other.dimensionSizes.length);
318+
System.arraycopy(
319+
dimensionSizes, 0, newDimensions, other.dimensionSizes.length, dimensionSizes.length);
320320
return Shape.of(newDimensions);
321321
}
322322

323323
/**
324-
* Returns a new Shape, with another Shapes' dimensions appended.
325-
* For both this Shape and the other Shape, {@link Shape#isUnknown()} must return false.
326-
* E.g. @code Shape.of(3,4).append(Shape.of(1,2)) => Shape.of(3,4,1,2) }
324+
* Returns a new Shape, with another Shapes' dimensions appended. For both this Shape and the
325+
* other Shape, {@link Shape#isUnknown()} must return false. E.g. @code
326+
* Shape.of(3,4).append(Shape.of(1,2)) => Shape.of(3,4,1,2) }
327327
*
328328
* @param other another Shape, must not be {@code null}, must not be unknown
329-
* @return A new Shape consisting of this Shapes's dimensions followed by the given Shape's
330-
* dimensions
329+
* @return A new Shape consisting of this Shape's dimensions followed by the given Shape's
330+
* dimensions
331331
*/
332332
public Shape append(Shape other) {
333333
long[] newDimensions = new long[dimensionSizes.length + other.dimensionSizes.length];
334-
System.arraycopy(dimensionSizes, 0,
335-
newDimensions, 0, dimensionSizes.length);
336-
System.arraycopy(other.dimensionSizes, 0,
337-
newDimensions, dimensionSizes.length, other.dimensionSizes.length);
334+
System.arraycopy(dimensionSizes, 0, newDimensions, 0, dimensionSizes.length);
335+
System.arraycopy(
336+
other.dimensionSizes, 0, newDimensions, dimensionSizes.length, other.dimensionSizes.length);
338337
return Shape.of(newDimensions);
339338
}
340339

@@ -351,4 +350,74 @@ private static long computeSize(long[] dimensionSizes) {
351350
}
352351
return computedSize;
353352
}
353+
354+
/**
355+
* Determines whether another shape is compatible with this one.
356+
*
357+
* <p>
358+
*
359+
* <p>Two possibly-partially-defined shapes are compatible if there exists a fully-defined shape
360+
* that both shapes can represent. Thus, compatibility allows the shape inference code to reason
361+
* about partially-defined shapes. For example:
362+
*
363+
* <ul>
364+
* <li><code>Shape.unknown()</code> is compatible with all shapes.
365+
* <li><code>Shape(UNKNOWN_SIZE, UNKNOWN_SIZE)</code> is compatible with all two-dimensional
366+
* shapes, such as <code>Shape(32, 784)</code>, and also <code>Shape.unknown()</code>. It is
367+
* not compatible with, for example, <code>Shape(UNKNOWN_SIZE)</code> or <code>
368+
* Shape(UNKNOWN_SIZE, UNKNOWN_SIZE, UNKNOWN_SIZE)</code>.
369+
* <li><code>Shape(32, UNKNOWN_SIZE)</code> is compatible with all two-dimensional shapes with
370+
* size 32 in the 0th dimension, and also <code>Shape(UNKNOWN_SIZE, UNKNOWN_SIZE)</code> and
371+
* <code>Shape.unknown()</code>. It is not compatible with, for example, <code>Shape(32)
372+
* </code>, <code>Shape(32, UNKNOWN_SIZE, 1)</code> or <code>Shape(64, UNKNOWN_SIZE)</code>.
373+
* <li><code>Shape(32, 784)</code> is compatible with itself, and also <code>
374+
* Shape(32, UNKNOWN_SIZE)</code>, <code>Shape(UNKNOWN_SIZE, 784)</code>, <code>
375+
* Shape(UNKNOWN_SIZE, UNKNOWN_SIZE)</code> and <code>Shape.unknown()</code>. It is not
376+
* compatible with, for example, <code>Shape(32, 1, 784)</code> or <code>Shape(UNKNOWN_SIZE)
377+
* </code>.
378+
* </ul>
379+
*
380+
* <p>The compatibility relation is reflexive and symmetric, but not transitive. For example,
381+
* <code>Shape(32, 784)</code> is compatible with <code>Shape.unknown()</code>, and <code>
382+
* Shape.unknown()</code> is compatible with <code>Shape(4, 4)</code>, but <code>Shape(32, 784)
383+
* </code> is not compatible with <code>Shape(4, 4)</code>.
384+
*
385+
* <p>Compatibility is not the same as broadcasting. Compatible shapes must have the same number
386+
* of dimensions and for each dimension pair, one dimension has to equal the other dimensions or
387+
* at least one of the dimensions in the pair has to be UNKNOWN_SIZE.
388+
*
389+
* <p>Broadcasting allows different dimensions, but paired dimensions have to either be equal, or
390+
* one dimension must be 1. If one shape has less dimensions than another shape, the smaller shape
391+
* is "stretched" with dimensions of 1.
392+
*
393+
* @param shape The other shape
394+
* @return true, if the two shapes are compatible.
395+
*/
396+
public boolean isCompatibleWith(Shape shape) {
397+
if (!this.isUnknown() && !shape.isUnknown()) {
398+
if (numDimensions() != shape.numDimensions()) {
399+
return false;
400+
}
401+
for (int i = 0; i < numDimensions(); i++) {
402+
if (!isCompatible(size(i), shape.size(i))) {
403+
return false;
404+
}
405+
}
406+
}
407+
return true;
408+
}
409+
410+
/**
411+
* Test to see if two shape dimensions are compatible.
412+
*
413+
* <p>The dimensions are compatible if either dimension is <code>Shape.UNKNOWN_SIZE</code> or both
414+
* dimensions are equal
415+
*
416+
* @param dim the first dimension
417+
* @param otherDim the second dimension
418+
* @return true, if both dimensions are compatible
419+
*/
420+
public static boolean isCompatible(long dim, long otherDim) {
421+
return dim == Shape.UNKNOWN_SIZE || otherDim == Shape.UNKNOWN_SIZE || dim == otherDim;
422+
}
354423
}

tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java

Lines changed: 0 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -123,73 +123,6 @@ public static <T extends TNumber> Shape getShape(Tensor<T> tensor) {
123123
return data.shape();
124124
}
125125

126-
/**
127-
* Determines whether two shapes are compatible.
128-
*
129-
* <p>
130-
*
131-
* <p>Two possibly-partially-defined shapes are compatible if there exists a fully-defined shape
132-
* that both shapes can represent. Thus, compatibility allows the shape inference code to reason
133-
* about partially-defined shapes. For example:
134-
*
135-
* <ul>
136-
* <li><code>Shape.unknown()</code> is compatible with all shapes.
137-
* <li><code>Shape(UNKNOWN_SIZE, UNKNOWN_SIZE)</code> is compatible with all two-dimensional
138-
* shapes, such as <code>Shape(32, 784)</code>, and also <code>Shape.unknown()</code>. It is
139-
* not compatible with, for example, <code>Shape(UNKNOWN_SIZE)</code> or <code>
140-
* Shape(UNKNOWN_SIZE, UNKNOWN_SIZE, UNKNOWN_SIZE)</code>.
141-
* <li><code>Shape(32, UNKNOWN_SIZE)</code> is compatible with all two-dimensional shapes with
142-
* size 32 in the 0th dimension, and also <code>Shape(UNKNOWN_SIZE, UNKNOWN_SIZE)</code> and
143-
* <code>Shape.unknown()</code>. It is not compatible with, for example, <code>Shape(32)
144-
* </code>, <code>Shape(32, UNKNOWN_SIZE, 1)</code> or <code>Shape(64, UNKNOWN_SIZE)</code>.
145-
* <li><code>Shape(32, 784)</code> is compatible with itself, and also <code>
146-
* Shape(32, UNKNOWN_SIZE)</code>, <code>Shape(UNKNOWN_SIZE, 784)</code>, <code>
147-
* Shape(UNKNOWN_SIZE, UNKNOWN_SIZE)</code> and <code>Shape.unknown()</code>. It is not
148-
* compatible with, for example, <code>Shape(32, 1, 784)</code> or <code>Shape(UNKNOWN_SIZE)
149-
* </code>.
150-
* </ul>
151-
*
152-
* <p>The compatibility relation is reflexive and symmetric, but not transitive. For example,
153-
* <code>Shape(32, 784)</code> is compatible with <code>Shape.unknown()</code>, and <code>
154-
* Shape.unknown()</code> is compatible with <code>Shape(4, 4)</code>, but <code>Shape(32, 784)
155-
* </code> is not compatible with <code>Shape(4, 4)</code>.
156-
*
157-
* <p>Compatibility is not the same as broadcasting. Compatible shapes must have the same number
158-
* of dimensions and for each dimension pair, one dimension has to equal the other dimensions or
159-
* at least one of the dimensions in the pair has to be UNKNOWN_SIZE.
160-
*
161-
* <p>Broadcasting allows different dimensions, but paired dimensions have to either be equal, or
162-
* one dimension must be 1. If one shape has less dimensions than another shape, the smaller shape
163-
* is "stretched" with dimensions of 1. See {@link org.tensorflow.op.Ops#broadcastTo}.
164-
*
165-
* @param a The first shape
166-
* @param b The second shape
167-
* @return true, if the two shapes are compatible.
168-
*/
169-
public static boolean isCompatibleWith(Shape a, Shape b) {
170-
if (isUnknownShape(a) && isUnknownShape(b)) {
171-
if (a.numDimensions() != b.numDimensions()) {
172-
return false;
173-
}
174-
for (int i = 0; i < a.numDimensions(); i++) {
175-
if (!isCompatible(a.size(i), b.size(i))) {
176-
return false;
177-
}
178-
}
179-
}
180-
return true;
181-
}
182-
183-
/**
184-
* Determines if a shape is an unknown shape as provided in <code>Shape.unknown()</code>.
185-
*
186-
* @param a the shape to test.
187-
* @return true if the shape is an unknown shape
188-
*/
189-
public static boolean isUnknownShape(Shape a) {
190-
return a.equals(Shape.unknown());
191-
}
192-
193126
/**
194127
* Reduces the shape by eliminating trailing Dimensions.
195128
*
@@ -215,18 +148,4 @@ public static Shape reduce(Shape shape, int axis) {
215148
newArray[axis - 1] = prod;
216149
return Shape.of(newArray);
217150
}
218-
219-
/**
220-
* Test to see if two shape dimensions are compatible.
221-
*
222-
* <p>The dimensions are compatible if either dimension is <code>Shape.UNKNOWN_SIZE</code> or both
223-
* dimensions are equal
224-
*
225-
* @param dim the first dimension
226-
* @param otherDim the second dimension
227-
* @return true, if both dimensions are compatible
228-
*/
229-
public static boolean isCompatible(long dim, long otherDim) {
230-
return dim == Shape.UNKNOWN_SIZE || otherDim == Shape.UNKNOWN_SIZE || dim == otherDim;
231-
}
232151
}

0 commit comments

Comments
 (0)