@@ -52,10 +52,10 @@ public static Shape scalar() {
5252 /**
5353 * Create a Shape representing a scalar or an N-dimensional value.
5454 *
55- * <p>Creates a Shape representing a scalar or an N-dimensional value (N being at least 1),
56- * with the provided size for each dimension. A -1 indicates that the size of the corresponding
57- * dimension is unknown. If no sizes are provided, a Shape representing a scalar is created.
58- * For example:
55+ * <p>Creates a Shape representing a scalar or an N-dimensional value (N being at least 1), with
56+ * the provided size for each dimension. A -1 indicates that the size of the corresponding
57+ * dimension is unknown. If no sizes are provided, a Shape representing a scalar is created. For
58+ * example:
5959 *
6060 * <pre>{@code
6161 * // A 2-element vector.
@@ -84,11 +84,11 @@ public static Shape of(long... dimensionSizes) {
8484 /**
8585 * Returns the total number of elements a Tensor with this Shape would have.
8686 *
87- * <p>If {@link Shape#isUnknown()} is true or {@link Shape#hasUnknownDimension()} is true,
88- * {@link Shape#UNKNOWN_SIZE} is returned.
87+ * <p>If {@link Shape#isUnknown()} is true or {@link Shape#hasUnknownDimension()} is true, {@link
88+ * Shape#UNKNOWN_SIZE} is returned.
8989 *
9090 * @return The total number of elements a Tensor with this shape would have if it can be
91- * calculated, else {@link Shape#UNKNOWN_SIZE}.
91+ * calculated, else {@link Shape#UNKNOWN_SIZE}.
9292 */
9393 public long size () {
9494 if (size == null ) {
@@ -104,12 +104,11 @@ public long size() {
104104 * an unknown size, {@link Shape#UNKNOWN_SIZE} is returned.
105105 *
106106 * @param i the index of the dimension to get the size for. If this Shape has a known number of
107- * dimensions, it must be < {@link Shape#numDimensions()}. The index may be negative,
108- * in which case the position is counted from the end of the shape. E.g.:
109- * {@code size(-1)} returns the size of the last dimension, {@code size(-2)} the size of
110- * the second to last dimension etc.
107+ * dimensions, it must be < {@link Shape#numDimensions()}. The index may be negative, in which
108+ * case the position is counted from the end of the shape. E.g.: {@code size(-1)} returns the
109+ * size of the last dimension, {@code size(-2)} the size of the second to last dimension etc.
111110 * @return The size of the dimension with the given index if known, {@link Shape#UNKNOWN_SIZE}
112- * otherwise.
111+ * otherwise.
113112 */
114113 public long size (int i ) {
115114 if (dimensionSizes == null ) {
@@ -163,8 +162,8 @@ public boolean isUnknown() {
163162 }
164163
165164 /**
166- * Returns a defensive copy of the this Shape's axes. Changes to the returned array to not
167- * change this Shape's state. Returns null if {@link Shape#isUnknown()} is true.
165+ * Returns a defensive copy of the this Shape's axes. Changes to the returned array to not change
166+ * this Shape's state. Returns null if {@link Shape#isUnknown()} is true.
168167 */
169168 public long [] asArray () {
170169 if (this .dimensionSizes == null ) {
@@ -182,15 +181,17 @@ public int hashCode() {
182181 /**
183182 * Equals implementation for Shapes. Two Shapes are considered equal iff:
184183 *
185- * <p><ul>
186- * <li>the number of dimensions is defined and equal for both
187- * <li>the size of each dimension is defined and equal for both
184+ * <p>
185+ *
186+ * <ul>
187+ * <li>the number of dimensions is defined and equal for both
188+ * <li>the size of each dimension is defined and equal for both
188189 * </ul>
189190 *
190191 * <p>If either Shape has unknown dimensions (even if they are the same in both) or if either
191- * shape has an unknown number of dimensions (even if both return {@code true} for
192- * {@link Shape#isUnknown()}), they are not considered equal! However, a shape will always
193- * equal itself, even if it is unknown or contains unknown dimensions.
192+ * shape has an unknown number of dimensions (even if both return {@code true} for {@link
193+ * Shape#isUnknown()}), they are not considered equal! However, a shape will always equal itself,
194+ * even if it is unknown or contains unknown dimensions.
194195 */
195196 @ Override
196197 public boolean equals (Object obj ) {
@@ -229,17 +230,17 @@ public Shape head() {
229230 }
230231
231232 /**
232- * Returns an n-dimensional Shape with the dimensions matching the first n dimensions
233- * of this shape
233+ * Returns an n-dimensional Shape with the dimensions matching the first n dimensions of this
234+ * shape
234235 *
235236 * @param n the number of leading dimensions to get, must be <= than {@link Shape#numDimensions()}
236- * @return an n-dimensional Shape with the first n dimensions matching the first n dimensions
237- * of this Shape
237+ * @return an n-dimensional Shape with the first n dimensions matching the first n dimensions of
238+ * this Shape
238239 */
239240 public Shape take (int n ) {
240241 if (n > numDimensions ()) {
241- throw new ArrayIndexOutOfBoundsException ("Cannot take " + n +
242- " dimensions, shape has only " + numDimensions () + "." );
242+ throw new ArrayIndexOutOfBoundsException (
243+ "Cannot take " + n + " dimensions, shape has only " + numDimensions () + "." );
243244 }
244245 long [] newDimensions = new long [n ];
245246 System .arraycopy (dimensionSizes , 0 , newDimensions , 0 , n );
@@ -253,18 +254,18 @@ public Shape tail() {
253254 }
254255
255256 /**
256- * Returns an n-dimensional Shape with the dimensions matching the last n dimensions
257- * of this Shape.
257+ * Returns an n-dimensional Shape with the dimensions matching the last n dimensions of this
258+ * Shape.
258259 *
259- * @param n the number of trailing dimensions to get, must be <= than
260- * {@link Shape#numDimensions()}
260+ * @param n the number of trailing dimensions to get, must be <= than {@link
261+ * Shape#numDimensions()}
261262 * @return an n-dimensional shape with the dimensions matching the last n dimensions of this
262- * Shape, never null
263+ * Shape, never null
263264 */
264265 public Shape takeLast (int n ) {
265266 if (n > numDimensions ()) {
266- throw new ArrayIndexOutOfBoundsException ("Cannot take last " + n +
267- " dimensions, shape has only " + numDimensions () + "." );
267+ throw new ArrayIndexOutOfBoundsException (
268+ "Cannot take last " + n + " dimensions, shape has only " + numDimensions () + "." );
268269 }
269270 long [] newDimensions = new long [n ];
270271 System .arraycopy (dimensionSizes , numDimensions () - n , newDimensions , 0 , n );
@@ -276,8 +277,8 @@ public Shape takeLast(int n) {
276277 * {@link Shape#isUnknown()} must be {@code false}.
277278 *
278279 * @param firstDimension the dimension to prepend
279- * @return a new shape with the given dimension first, followed by this Shape's dimensions,
280- * never null
280+ * @return a new shape with the given dimension first, followed by this Shape's dimensions, never
281+ * null
281282 */
282283 public Shape prepend (long firstDimension ) {
283284 long [] newDimensions = new long [dimensionSizes .length + 1 ];
@@ -288,8 +289,8 @@ public Shape prepend(long firstDimension) {
288289 }
289290
290291 /**
291- * Returns a new Shape, with a new last dimension added. In order for this call to succeed,
292- * {@link Shape#isUnknown()} must be {@code false}.
292+ * Returns a new Shape, with a new last dimension added. In order for this call to succeed, {@link
293+ * Shape#isUnknown()} must be {@code false}.
293294 *
294295 * @param lastDimension the dimension to append
295296 * @return a new Shape with this Shape's dimensions followed by the given dimension, never null
@@ -303,38 +304,36 @@ public Shape append(long lastDimension) {
303304 }
304305
305306 /**
306- * Returns a new Shape, with another Shape's dimensions prepended.
307- * For both this Shape and the other Shape, {@link Shape#isUnknown()} must return false.
308- * E.g. {@code Shape.of(3,4).prepend(Shape.of(1,2)) => Shape.of(1,2,3,4) }
307+ * Returns a new Shape, with another Shape's dimensions prepended. For both this Shape and the
308+ * other Shape, {@link Shape#isUnknown()} must return false. E.g. {@code
309+ * Shape.of(3,4).prepend(Shape.of(1,2)) => Shape.of(1,2,3,4) }
309310 *
310311 * @param other another Shape, must not be {@code null}, must not be unknown
311- * @return A new Shape consisting of the given Shapes 's dimensions followed by this Shape's
312- * dimensions, never null
312+ * @return A new Shape consisting of the given Shape 's dimensions followed by this Shape's
313+ * dimensions, never null
313314 */
314315 public Shape prepend (Shape other ) {
315316 long [] newDimensions = new long [other .dimensionSizes .length + dimensionSizes .length ];
316- System .arraycopy (other .dimensionSizes , 0 ,
317- newDimensions , 0 , other .dimensionSizes .length );
318- System .arraycopy (dimensionSizes , 0 ,
319- newDimensions , other .dimensionSizes .length , dimensionSizes .length );
317+ System .arraycopy (other .dimensionSizes , 0 , newDimensions , 0 , other .dimensionSizes .length );
318+ System .arraycopy (
319+ dimensionSizes , 0 , newDimensions , other .dimensionSizes .length , dimensionSizes .length );
320320 return Shape .of (newDimensions );
321321 }
322322
323323 /**
324- * Returns a new Shape, with another Shapes' dimensions appended.
325- * For both this Shape and the other Shape, {@link Shape#isUnknown()} must return false.
326- * E.g. @code Shape.of(3,4).append(Shape.of(1,2)) => Shape.of(3,4,1,2) }
324+ * Returns a new Shape, with another Shapes' dimensions appended. For both this Shape and the
325+ * other Shape, {@link Shape#isUnknown()} must return false. E.g. @code
326+ * Shape.of(3,4).append(Shape.of(1,2)) => Shape.of(3,4,1,2) }
327327 *
328328 * @param other another Shape, must not be {@code null}, must not be unknown
329- * @return A new Shape consisting of this Shapes 's dimensions followed by the given Shape's
330- * dimensions
329+ * @return A new Shape consisting of this Shape 's dimensions followed by the given Shape's
330+ * dimensions
331331 */
332332 public Shape append (Shape other ) {
333333 long [] newDimensions = new long [dimensionSizes .length + other .dimensionSizes .length ];
334- System .arraycopy (dimensionSizes , 0 ,
335- newDimensions , 0 , dimensionSizes .length );
336- System .arraycopy (other .dimensionSizes , 0 ,
337- newDimensions , dimensionSizes .length , other .dimensionSizes .length );
334+ System .arraycopy (dimensionSizes , 0 , newDimensions , 0 , dimensionSizes .length );
335+ System .arraycopy (
336+ other .dimensionSizes , 0 , newDimensions , dimensionSizes .length , other .dimensionSizes .length );
338337 return Shape .of (newDimensions );
339338 }
340339
@@ -351,4 +350,74 @@ private static long computeSize(long[] dimensionSizes) {
351350 }
352351 return computedSize ;
353352 }
353+
354+ /**
355+ * Determines whether another shape is compatible with this one.
356+ *
357+ * <p>
358+ *
359+ * <p>Two possibly-partially-defined shapes are compatible if there exists a fully-defined shape
360+ * that both shapes can represent. Thus, compatibility allows the shape inference code to reason
361+ * about partially-defined shapes. For example:
362+ *
363+ * <ul>
364+ * <li><code>Shape.unknown()</code> is compatible with all shapes.
365+ * <li><code>Shape(UNKNOWN_SIZE, UNKNOWN_SIZE)</code> is compatible with all two-dimensional
366+ * shapes, such as <code>Shape(32, 784)</code>, and also <code>Shape.unknown()</code>. It is
367+ * not compatible with, for example, <code>Shape(UNKNOWN_SIZE)</code> or <code>
368+ * Shape(UNKNOWN_SIZE, UNKNOWN_SIZE, UNKNOWN_SIZE)</code>.
369+ * <li><code>Shape(32, UNKNOWN_SIZE)</code> is compatible with all two-dimensional shapes with
370+ * size 32 in the 0th dimension, and also <code>Shape(UNKNOWN_SIZE, UNKNOWN_SIZE)</code> and
371+ * <code>Shape.unknown()</code>. It is not compatible with, for example, <code>Shape(32)
372+ * </code>, <code>Shape(32, UNKNOWN_SIZE, 1)</code> or <code>Shape(64, UNKNOWN_SIZE)</code>.
373+ * <li><code>Shape(32, 784)</code> is compatible with itself, and also <code>
374+ * Shape(32, UNKNOWN_SIZE)</code>, <code>Shape(UNKNOWN_SIZE, 784)</code>, <code>
375+ * Shape(UNKNOWN_SIZE, UNKNOWN_SIZE)</code> and <code>Shape.unknown()</code>. It is not
376+ * compatible with, for example, <code>Shape(32, 1, 784)</code> or <code>Shape(UNKNOWN_SIZE)
377+ * </code>.
378+ * </ul>
379+ *
380+ * <p>The compatibility relation is reflexive and symmetric, but not transitive. For example,
381+ * <code>Shape(32, 784)</code> is compatible with <code>Shape.unknown()</code>, and <code>
382+ * Shape.unknown()</code> is compatible with <code>Shape(4, 4)</code>, but <code>Shape(32, 784)
383+ * </code> is not compatible with <code>Shape(4, 4)</code>.
384+ *
385+ * <p>Compatibility is not the same as broadcasting. Compatible shapes must have the same number
386+ * of dimensions and for each dimension pair, one dimension has to equal the other dimensions or
387+ * at least one of the dimensions in the pair has to be UNKNOWN_SIZE.
388+ *
389+ * <p>Broadcasting allows different dimensions, but paired dimensions have to either be equal, or
390+ * one dimension must be 1. If one shape has less dimensions than another shape, the smaller shape
391+ * is "stretched" with dimensions of 1.
392+ *
393+ * @param shape The other shape
394+ * @return true, if the two shapes are compatible.
395+ */
396+ public boolean isCompatibleWith (Shape shape ) {
397+ if (!this .isUnknown () && !shape .isUnknown ()) {
398+ if (numDimensions () != shape .numDimensions ()) {
399+ return false ;
400+ }
401+ for (int i = 0 ; i < numDimensions (); i ++) {
402+ if (!isCompatible (size (i ), shape .size (i ))) {
403+ return false ;
404+ }
405+ }
406+ }
407+ return true ;
408+ }
409+
410+ /**
411+ * Test to see if two shape dimensions are compatible.
412+ *
413+ * <p>The dimensions are compatible if either dimension is <code>Shape.UNKNOWN_SIZE</code> or both
414+ * dimensions are equal
415+ *
416+ * @param dim the first dimension
417+ * @param otherDim the second dimension
418+ * @return true, if both dimensions are compatible
419+ */
420+ public static boolean isCompatible (long dim , long otherDim ) {
421+ return dim == Shape .UNKNOWN_SIZE || otherDim == Shape .UNKNOWN_SIZE || dim == otherDim ;
422+ }
354423}
0 commit comments