Skip to content

Commit 9da3b9e

Browse files
committed
Add ShapeOps Enhancement to work with org.tensorflow.op.core.Shape directly
1 parent f056482 commit 9da3b9e

2 files changed

Lines changed: 522 additions & 0 deletions

File tree

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
/*
2+
* To change this license header, choose License Headers in Project Properties.
3+
* To change this template file, choose Tools | Templates
4+
* and open the template in the editor.
5+
*/
6+
package org.tensorflow.keras.utils;
7+
8+
import java.util.Arrays;
9+
import org.tensorflow.DataType;
10+
import org.tensorflow.Operand;
11+
import org.tensorflow.op.Scope;
12+
import org.tensorflow.op.core.Concat;
13+
import org.tensorflow.op.core.ExpandDims;
14+
import org.tensorflow.op.core.Reshape;
15+
import org.tensorflow.op.core.Shape;
16+
import org.tensorflow.op.core.Size;
17+
import org.tensorflow.op.core.Slice;
18+
import org.tensorflow.op.math.FloorMod;
19+
import org.tensorflow.op.core.Constant;
20+
import org.tensorflow.op.core.Gather;
21+
import org.tensorflow.op.core.OnesLike;
22+
import org.tensorflow.op.core.ReduceProd;
23+
import org.tensorflow.op.core.Where;
24+
import org.tensorflow.op.dtypes.Cast;
25+
import org.tensorflow.op.math.NotEqual;
26+
import org.tensorflow.op.math.Sub;
27+
import org.tensorflow.types.TBool;
28+
import org.tensorflow.types.TInt32;
29+
import org.tensorflow.types.family.TNumber;
30+
import org.tensorflow.types.family.TType;
31+
32+
/**
33+
*
34+
* @author Jim Clarke
35+
* @param <T> the type of operand
36+
*/
37+
38+
// TODO should shape be based on TInt64 ???
39+
public class ShapeOps<U extends TNumber > {
40+
private final Scope scope;
41+
private final DataType<U> dType;
42+
43+
/**
44+
* Create a ShapeUtils with a DataType of TInt32
45+
*
46+
* @param scope is a scope used to add the underlying operations
47+
* @return the ShapeUtils
48+
*/
49+
public static ShapeOps<TInt32> create(Scope scope) {
50+
return create(scope, TInt32.DTYPE);
51+
}
52+
53+
/**
54+
*
55+
* @param <U> the Shape type
56+
* @param scope is a scope used to add the underlying operations
57+
* @param dType the Shape datatype
58+
* @return the ShapeUtils
59+
*/
60+
public static <U extends TNumber> ShapeOps<U> create(Scope scope, DataType<U> dType) {
61+
return new ShapeOps<>(scope, dType);
62+
}
63+
64+
/**
65+
* The constructor for ShapeUtils
66+
* @param scope is a scope used to add the underlying operations
67+
* @param dType the Shape datatype
68+
*/
69+
private ShapeOps(Scope scope, DataType<U> dType) {
70+
this.scope = scope;
71+
this.dType = dType;
72+
}
73+
74+
public Scope scope() {
75+
return this.scope;
76+
}
77+
78+
public DataType<U> datatype() {
79+
return this.dType;
80+
}
81+
82+
/**
83+
* flatten the shape to 1 dimension
84+
*
85+
* @param <T> the type of operand
86+
* @param operand the operand to flatten
87+
* @return the reshaped operand
88+
*/
89+
public <T extends TType> Operand<T> flatten(Operand<T> operand) {
90+
Operand<U> flatShape = flatten(Shape.create(scope, operand, dType));
91+
return Reshape.create(scope, operand, flatShape);
92+
}
93+
94+
/**
95+
* flatten the shape to 1 dimension
96+
*
97+
* @param shape the TensorFlow shape
98+
* @return the flattened shape
99+
* @see reduceDims
100+
*/
101+
public Operand<U> flatten(Shape<U> shape) {
102+
return ExpandDims.create(scope,
103+
Cast.create(scope, size(shape), dType),
104+
Cast.create(scope, Constant.scalarOf(scope, -1), dType));
105+
106+
}
107+
108+
109+
/**
110+
* get the size represented by the TensorFlow shape
111+
*
112+
* @param shape the TensorFlow shape
113+
* @return the size
114+
*/
115+
public Operand<U> size(Shape<U> shape) {
116+
Slice<U> dims = Slice.create(scope, shape,
117+
Cast.create(scope, Constant.arrayOf(scope, (new int[]{0})), dType),
118+
ExpandDims.create(scope, Cast.create(scope, Constant.scalarOf(scope, -1), dType), Constant.scalarOf(scope, -1)));
119+
ReduceProd<U> total = ReduceProd.create(scope, dims, Constant.scalarOf(scope, 0));
120+
return total;
121+
}
122+
123+
/**
124+
* get the size of the specified dimension in the shape
125+
*
126+
* @param shape the TensorFlow shape
127+
* @param dim the dimension
128+
* @return the size of the specified dimension
129+
*/
130+
public Operand<U> size(Shape<U> shape, Operand<U> dim) {
131+
Slice<U> dims = Slice.create(scope, shape,
132+
ExpandDims.create(scope, dim, Cast.create(scope, Constant.scalarOf(scope, -1), dType)),
133+
ExpandDims.create(scope,
134+
Cast.create(scope, Constant.scalarOf(scope, 1), dType),
135+
Cast.create(scope, Constant.scalarOf(scope, -1), dType)));
136+
return dims;
137+
}
138+
139+
/**
140+
* get the size of the specified dimension for the shape of the tensor
141+
*
142+
* @param input the operand
143+
* @param dim the dimension
144+
* @return the size of the specified dimension
145+
*/
146+
public Operand<U> size(Operand input, Operand<U> dim) {
147+
return size(Shape.create(scope, input, dType), dim);
148+
}
149+
150+
/**
151+
* get the number of dimensions of the shape object
152+
*
153+
* @param shape the shape
154+
* @return the number of dimensions
155+
* @see tf.rank
156+
*/
157+
public Operand<U> numDimensions(Shape<U> shape) {
158+
return Size.create(scope, shape, dType);
159+
}
160+
161+
/**
162+
* reshapes the operand to the specified axis,
163+
* @param <T> the type of Operand
164+
* @param operand the operand
165+
* @param axis the axis
166+
* @return the reshaped operand
167+
*/
168+
public <T extends TType> Operand<T> reduceDims(Operand<T> operand , Operand<U> axis) {
169+
Shape<U> newShape = Shape.create(scope, operand, dType);
170+
return Reshape.create(scope, operand, reduceDims(newShape, axis));
171+
}
172+
173+
/**
174+
* reduces the shape to the specified axis,
175+
* @param shape the TensorFlow shape
176+
* @param axis the axis
177+
* @return the reduced shape
178+
*/
179+
public Operand<U> reduceDims(Shape<U> shape , Operand<U> axis) {
180+
Size<U> rank = Size.create(scope, shape, dType);
181+
axis = FloorMod.create(scope, axis, rank);
182+
Sub<U> remainder = Sub.create(scope, rank, axis);
183+
184+
Operand<U> dims1 = Slice.create(scope, shape,
185+
Cast.create(scope, Constant.arrayOf(scope, (new int[]{0})), dType),
186+
ExpandDims.create(scope, axis, Constant.scalarOf(scope, -1)));
187+
188+
Operand<U> dims2 = Slice.create(scope, shape,
189+
ExpandDims.create(scope, axis, Constant.scalarOf(scope, -1)),
190+
ExpandDims.create(scope, Cast.create(scope, Constant.scalarOf(scope, -1), dType), Constant.scalarOf(scope, -1)));
191+
192+
Operand<U> prod = ReduceProd.create(scope, dims2, Constant.scalarOf(scope, 0), ReduceProd.keepDims(Boolean.TRUE));
193+
Concat<U> concat = Concat.create( scope,
194+
Arrays.asList(dims1, prod), Constant.scalarOf(scope, 0));
195+
196+
return concat;
197+
}
198+
199+
200+
201+
/**
202+
* Removes dimensions of size 1 from the shape
203+
* @param shape
204+
* @return the squeezed shape
205+
* @see tf.squeeze
206+
*/
207+
public Operand<U> squeeze(Shape<U> shape) {
208+
Operand<TBool> mask = NotEqual.create(scope, shape,
209+
Cast.create(scope, OnesLike.create(scope, shape), dType));
210+
211+
Gather<U> gather = Gather.create(scope, shape,
212+
Where.create(scope, mask),
213+
Constant.scalarOf(scope, 0));
214+
215+
return gather;
216+
}
217+
}

0 commit comments

Comments
 (0)