Skip to content

Commit 5556626

Browse files
committed
[MINOR] Fix test formatting (method names, tab indentation)
1 parent de7e9a0 commit 5556626

1 file changed

Lines changed: 147 additions & 167 deletions

File tree

src/test/java/org/apache/sysds/test/component/tensor/TransposeLinDataTest.java

Lines changed: 147 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -28,170 +28,150 @@
2828

2929
public class TransposeLinDataTest {
3030

31-
@Test
32-
public void Testrightelem(){
33-
int[] shape = {2, 3, 4};
34-
TensorBlock tensor = TensorUtils.createArangeTensor(shape);
35-
36-
Assert.assertArrayEquals(new int[]{2, 3, 4}, tensor.getDims());
37-
Assert.assertEquals(0.0, tensor.get(new int[]{0, 0, 0}));
38-
Assert.assertEquals(23.0, tensor.get(new int[]{1, 2, 3}));
39-
Assert.assertEquals(6.0, tensor.get(new int[]{0, 1, 2}));
40-
Assert.assertEquals(12.0, tensor.get(new int[]{1, 0, 0}));
41-
printTensor(tensor);
42-
43-
44-
int[] permutation = {1, 0, 2};
45-
TensorBlock outTensor = PermuteIt.permute(tensor, permutation);
46-
printTensor(outTensor);
47-
48-
Assert.assertArrayEquals(new int[]{3, 2, 4}, outTensor.getDims());
49-
Assert.assertEquals(0.0, outTensor.get(new int[]{0,0,0}));
50-
Assert.assertEquals(23.0, outTensor.get(new int[]{2, 1, 3}));
51-
Assert.assertEquals(12.0, outTensor.get(new int[]{0, 1, 0}));
52-
Assert.assertEquals(17.0, outTensor.get(new int[]{1, 1, 1}));
53-
54-
55-
int[] second_permutation = {2, 1, 0};
56-
TensorBlock perm2Block = PermuteIt.permute(tensor, second_permutation);
57-
printTensor(perm2Block);
58-
59-
Assert.assertArrayEquals(new int[]{4, 3, 2}, perm2Block.getDims());
60-
Assert.assertEquals(0.0, perm2Block.get(new int[]{0, 0, 0}));
61-
Assert.assertEquals(12.0, perm2Block.get(new int[]{0, 0, 1}));
62-
Assert.assertEquals(11.0, perm2Block.get(new int[]{3, 2, 0}));
63-
Assert.assertEquals(23.0, perm2Block.get(new int[]{3, 2, 1}));
64-
65-
}
66-
67-
68-
69-
70-
public class TensorUtils {
71-
72-
public static TensorBlock createArangeTensor(int[] shape) {
73-
TensorBlock tb = new TensorBlock(ValueType.FP64, shape);
74-
tb.allocateBlock();
75-
double[] counter = { 0.0 };
76-
int[] currentIndices = new int[shape.length];
77-
78-
fillRecursively(tb, shape, 0, currentIndices, counter);
79-
80-
return tb;
81-
}
82-
83-
private static void fillRecursively(TensorBlock tb, int[] shape, int dim, int[] currentIndices, double[] counter) {
84-
if (dim == shape.length) {
85-
tb.set(currentIndices, counter[0]);
86-
counter[0]++;
87-
return;
88-
}
89-
90-
for (int i = 0; i < shape[dim]; i++) {
91-
currentIndices[dim] = i;
92-
93-
fillRecursively(tb, shape, dim + 1, currentIndices, counter);
94-
}
95-
}
96-
}
97-
98-
99-
100-
public class PermuteIt {
101-
102-
103-
public static TensorBlock permute(TensorBlock tensor, int[] permute_dims) {
104-
105-
int anz_dims = tensor.getNumDims();
106-
int[] dims = tensor.getDims();
107-
ValueType tensorType = tensor.getValueType();
108-
109-
int[] out_shape = new int[anz_dims];
110-
111-
for (int idx = 0; idx < anz_dims; idx++){
112-
out_shape[idx] = dims[permute_dims[idx]];
113-
}
114-
115-
TensorBlock outTensor = new TensorBlock(tensorType, out_shape);
116-
outTensor.allocateBlock();
117-
118-
int[] inIndex = new int[anz_dims];
119-
int[] outIndex = new int[anz_dims];
120-
121-
rekursion(tensor, outTensor, permute_dims, dims, 0, inIndex, outIndex);
122-
return outTensor;
123-
}
124-
125-
public static void rekursion(TensorBlock inTensor,
126-
TensorBlock outTensor,
127-
int[] permutation,
128-
int[] inShape,
129-
int dim,
130-
int[] inIndex,
131-
int[]outIndex
132-
){
133-
134-
if (dim == inShape.length) {
135-
for(int idx = 0; idx < permutation.length; idx++){
136-
outIndex[idx] = inIndex[permutation[idx]];
137-
}
138-
double val = (double) inTensor.get(inIndex);
139-
outTensor.set(outIndex, val);
140-
return;
141-
}
142-
143-
for(int idx = 0; idx < inShape[dim]; idx++){
144-
inIndex[dim] = idx;
145-
rekursion(inTensor, outTensor, permutation, inShape, dim+1, inIndex, outIndex);
146-
}
147-
148-
}
149-
150-
}
151-
152-
153-
public static void printTensor(TensorBlock tb) {
154-
StringBuilder sb = new StringBuilder();
155-
int[] shape = tb.getDims();
156-
int[] currentIndices = new int[shape.length];
157-
158-
sb.append("Tensor(").append(Arrays.toString(shape)).append("):\n");
159-
printRecursive(tb, shape, 0, currentIndices, sb, 0);
160-
161-
System.out.println(sb.toString());
162-
}
163-
164-
private static void printRecursive(TensorBlock tb, int[] shape, int dim, int[] indices, StringBuilder sb, int indent) {
165-
for (int k = 0; k < indent; k++) sb.append(" ");
166-
167-
sb.append("[");
168-
169-
if (dim == shape.length - 1) {
170-
for (int i = 0; i < shape[dim]; i++) {
171-
indices[dim] = i;
172-
double val = (double) tb.get(indices);
173-
sb.append(String.format("%.1f", val));
174-
if (i < shape[dim] - 1) sb.append(", ");
175-
}
176-
sb.append("]");
177-
}
178-
179-
else {
180-
sb.append("\n");
181-
for (int i = 0; i < shape[dim]; i++) {
182-
indices[dim] = i;
183-
printRecursive(tb, shape, dim + 1, indices, sb, indent + 2);
184-
185-
if (i < shape[dim] - 1) {
186-
sb.append(",");
187-
sb.append("\n");
188-
if (shape.length - dim > 2) sb.append("\n");
189-
}
190-
}
191-
sb.append("\n");
192-
for (int k = 0; k < indent; k++) sb.append(" ");
193-
sb.append("]");
194-
}
195-
}
196-
197-
}
31+
@Test
32+
public void testRightElem(){
33+
int[] shape = {2, 3, 4};
34+
TensorBlock tensor = TensorUtils.createArangeTensor(shape);
35+
36+
Assert.assertArrayEquals(new int[]{2, 3, 4}, tensor.getDims());
37+
Assert.assertEquals(0.0, tensor.get(new int[]{0, 0, 0}));
38+
Assert.assertEquals(23.0, tensor.get(new int[]{1, 2, 3}));
39+
Assert.assertEquals(6.0, tensor.get(new int[]{0, 1, 2}));
40+
Assert.assertEquals(12.0, tensor.get(new int[]{1, 0, 0}));
41+
printTensor(tensor);
42+
43+
44+
int[] permutation = {1, 0, 2};
45+
TensorBlock outTensor = PermuteIt.permute(tensor, permutation);
46+
printTensor(outTensor);
47+
48+
Assert.assertArrayEquals(new int[]{3, 2, 4}, outTensor.getDims());
49+
Assert.assertEquals(0.0, outTensor.get(new int[]{0,0,0}));
50+
Assert.assertEquals(23.0, outTensor.get(new int[]{2, 1, 3}));
51+
Assert.assertEquals(12.0, outTensor.get(new int[]{0, 1, 0}));
52+
Assert.assertEquals(17.0, outTensor.get(new int[]{1, 1, 1}));
53+
54+
int[] second_permutation = {2, 1, 0};
55+
TensorBlock perm2Block = PermuteIt.permute(tensor, second_permutation);
56+
printTensor(perm2Block);
57+
58+
Assert.assertArrayEquals(new int[]{4, 3, 2}, perm2Block.getDims());
59+
Assert.assertEquals(0.0, perm2Block.get(new int[]{0, 0, 0}));
60+
Assert.assertEquals(12.0, perm2Block.get(new int[]{0, 0, 1}));
61+
Assert.assertEquals(11.0, perm2Block.get(new int[]{3, 2, 0}));
62+
Assert.assertEquals(23.0, perm2Block.get(new int[]{3, 2, 1}));
63+
}
64+
65+
public class TensorUtils {
66+
67+
public static TensorBlock createArangeTensor(int[] shape) {
68+
TensorBlock tb = new TensorBlock(ValueType.FP64, shape);
69+
tb.allocateBlock();
70+
double[] counter = { 0.0 };
71+
int[] currentIndices = new int[shape.length];
72+
73+
fillRecursively(tb, shape, 0, currentIndices, counter);
74+
75+
return tb;
76+
}
77+
78+
private static void fillRecursively(TensorBlock tb, int[] shape, int dim, int[] currentIndices, double[] counter) {
79+
if (dim == shape.length) {
80+
tb.set(currentIndices, counter[0]);
81+
counter[0]++;
82+
return;
83+
}
84+
85+
for (int i = 0; i < shape[dim]; i++) {
86+
currentIndices[dim] = i;
87+
88+
fillRecursively(tb, shape, dim + 1, currentIndices, counter);
89+
}
90+
}
91+
}
92+
93+
public class PermuteIt {
94+
public static TensorBlock permute(TensorBlock tensor, int[] permute_dims) {
95+
int anz_dims = tensor.getNumDims();
96+
int[] dims = tensor.getDims();
97+
ValueType tensorType = tensor.getValueType();
98+
99+
int[] out_shape = new int[anz_dims];
100+
101+
for (int idx = 0; idx < anz_dims; idx++){
102+
out_shape[idx] = dims[permute_dims[idx]];
103+
}
104+
105+
TensorBlock outTensor = new TensorBlock(tensorType, out_shape);
106+
outTensor.allocateBlock();
107+
108+
int[] inIndex = new int[anz_dims];
109+
int[] outIndex = new int[anz_dims];
110+
111+
recursion(tensor, outTensor, permute_dims, dims, 0, inIndex, outIndex);
112+
return outTensor;
113+
}
114+
115+
public static void recursion(TensorBlock inTensor, TensorBlock outTensor,
116+
int[] permutation, int[] inShape, int dim, int[] inIndex, int[]outIndex)
117+
{
118+
if (dim == inShape.length) {
119+
for(int idx = 0; idx < permutation.length; idx++){
120+
outIndex[idx] = inIndex[permutation[idx]];
121+
}
122+
double val = (double) inTensor.get(inIndex);
123+
outTensor.set(outIndex, val);
124+
return;
125+
}
126+
127+
for(int idx = 0; idx < inShape[dim]; idx++){
128+
inIndex[dim] = idx;
129+
recursion(inTensor, outTensor, permutation, inShape, dim+1, inIndex, outIndex);
130+
}
131+
}
132+
}
133+
134+
public static void printTensor(TensorBlock tb) {
135+
StringBuilder sb = new StringBuilder();
136+
int[] shape = tb.getDims();
137+
int[] currentIndices = new int[shape.length];
138+
139+
sb.append("Tensor(").append(Arrays.toString(shape)).append("):\n");
140+
printRecursive(tb, shape, 0, currentIndices, sb, 0);
141+
142+
System.out.println(sb.toString());
143+
}
144+
145+
private static void printRecursive(TensorBlock tb, int[] shape, int dim, int[] indices, StringBuilder sb, int indent) {
146+
for (int k = 0; k < indent; k++) sb.append(" ");
147+
148+
sb.append("[");
149+
150+
if (dim == shape.length - 1) {
151+
for (int i = 0; i < shape[dim]; i++) {
152+
indices[dim] = i;
153+
double val = (double) tb.get(indices);
154+
sb.append(String.format("%.1f", val));
155+
if (i < shape[dim] - 1) sb.append(", ");
156+
}
157+
sb.append("]");
158+
}
159+
160+
else {
161+
sb.append("\n");
162+
for (int i = 0; i < shape[dim]; i++) {
163+
indices[dim] = i;
164+
printRecursive(tb, shape, dim + 1, indices, sb, indent + 2);
165+
166+
if (i < shape[dim] - 1) {
167+
sb.append(",");
168+
sb.append("\n");
169+
if (shape.length - dim > 2) sb.append("\n");
170+
}
171+
}
172+
sb.append("\n");
173+
for (int k = 0; k < indent; k++) sb.append(" ");
174+
sb.append("]");
175+
}
176+
}
177+
}

0 commit comments

Comments
 (0)