2828
2929public class TransposeLinDataTest {
3030
31- @ Test
32- public void Testrightelem (){
33- int [] shape = {2 , 3 , 4 };
34- TensorBlock tensor = TensorUtils .createArangeTensor (shape );
35-
36- Assert .assertArrayEquals (new int []{2 , 3 , 4 }, tensor .getDims ());
37- Assert .assertEquals (0.0 , tensor .get (new int []{0 , 0 , 0 }));
38- Assert .assertEquals (23.0 , tensor .get (new int []{1 , 2 , 3 }));
39- Assert .assertEquals (6.0 , tensor .get (new int []{0 , 1 , 2 }));
40- Assert .assertEquals (12.0 , tensor .get (new int []{1 , 0 , 0 }));
41- printTensor (tensor );
42-
43-
44- int [] permutation = {1 , 0 , 2 };
45- TensorBlock outTensor = PermuteIt .permute (tensor , permutation );
46- printTensor (outTensor );
47-
48- Assert .assertArrayEquals (new int []{3 , 2 , 4 }, outTensor .getDims ());
49- Assert .assertEquals (0.0 , outTensor .get (new int []{0 ,0 ,0 }));
50- Assert .assertEquals (23.0 , outTensor .get (new int []{2 , 1 , 3 }));
51- Assert .assertEquals (12.0 , outTensor .get (new int []{0 , 1 , 0 }));
52- Assert .assertEquals (17.0 , outTensor .get (new int []{1 , 1 , 1 }));
53-
54-
55- int [] second_permutation = {2 , 1 , 0 };
56- TensorBlock perm2Block = PermuteIt .permute (tensor , second_permutation );
57- printTensor (perm2Block );
58-
59- Assert .assertArrayEquals (new int []{4 , 3 , 2 }, perm2Block .getDims ());
60- Assert .assertEquals (0.0 , perm2Block .get (new int []{0 , 0 , 0 }));
61- Assert .assertEquals (12.0 , perm2Block .get (new int []{0 , 0 , 1 }));
62- Assert .assertEquals (11.0 , perm2Block .get (new int []{3 , 2 , 0 }));
63- Assert .assertEquals (23.0 , perm2Block .get (new int []{3 , 2 , 1 }));
64-
65- }
66-
67-
68-
69-
70- public class TensorUtils {
71-
72- public static TensorBlock createArangeTensor (int [] shape ) {
73- TensorBlock tb = new TensorBlock (ValueType .FP64 , shape );
74- tb .allocateBlock ();
75- double [] counter = { 0.0 };
76- int [] currentIndices = new int [shape .length ];
77-
78- fillRecursively (tb , shape , 0 , currentIndices , counter );
79-
80- return tb ;
81- }
82-
83- private static void fillRecursively (TensorBlock tb , int [] shape , int dim , int [] currentIndices , double [] counter ) {
84- if (dim == shape .length ) {
85- tb .set (currentIndices , counter [0 ]);
86- counter [0 ]++;
87- return ;
88- }
89-
90- for (int i = 0 ; i < shape [dim ]; i ++) {
91- currentIndices [dim ] = i ;
92-
93- fillRecursively (tb , shape , dim + 1 , currentIndices , counter );
94- }
95- }
96- }
97-
98-
99-
100- public class PermuteIt {
101-
102-
103- public static TensorBlock permute (TensorBlock tensor , int [] permute_dims ) {
104-
105- int anz_dims = tensor .getNumDims ();
106- int [] dims = tensor .getDims ();
107- ValueType tensorType = tensor .getValueType ();
108-
109- int [] out_shape = new int [anz_dims ];
110-
111- for (int idx = 0 ; idx < anz_dims ; idx ++){
112- out_shape [idx ] = dims [permute_dims [idx ]];
113- }
114-
115- TensorBlock outTensor = new TensorBlock (tensorType , out_shape );
116- outTensor .allocateBlock ();
117-
118- int [] inIndex = new int [anz_dims ];
119- int [] outIndex = new int [anz_dims ];
120-
121- rekursion (tensor , outTensor , permute_dims , dims , 0 , inIndex , outIndex );
122- return outTensor ;
123- }
124-
125- public static void rekursion (TensorBlock inTensor ,
126- TensorBlock outTensor ,
127- int [] permutation ,
128- int [] inShape ,
129- int dim ,
130- int [] inIndex ,
131- int []outIndex
132- ){
133-
134- if (dim == inShape .length ) {
135- for (int idx = 0 ; idx < permutation .length ; idx ++){
136- outIndex [idx ] = inIndex [permutation [idx ]];
137- }
138- double val = (double ) inTensor .get (inIndex );
139- outTensor .set (outIndex , val );
140- return ;
141- }
142-
143- for (int idx = 0 ; idx < inShape [dim ]; idx ++){
144- inIndex [dim ] = idx ;
145- rekursion (inTensor , outTensor , permutation , inShape , dim +1 , inIndex , outIndex );
146- }
147-
148- }
149-
150- }
151-
152-
153- public static void printTensor (TensorBlock tb ) {
154- StringBuilder sb = new StringBuilder ();
155- int [] shape = tb .getDims ();
156- int [] currentIndices = new int [shape .length ];
157-
158- sb .append ("Tensor(" ).append (Arrays .toString (shape )).append ("):\n " );
159- printRecursive (tb , shape , 0 , currentIndices , sb , 0 );
160-
161- System .out .println (sb .toString ());
162- }
163-
164- private static void printRecursive (TensorBlock tb , int [] shape , int dim , int [] indices , StringBuilder sb , int indent ) {
165- for (int k = 0 ; k < indent ; k ++) sb .append (" " );
166-
167- sb .append ("[" );
168-
169- if (dim == shape .length - 1 ) {
170- for (int i = 0 ; i < shape [dim ]; i ++) {
171- indices [dim ] = i ;
172- double val = (double ) tb .get (indices );
173- sb .append (String .format ("%.1f" , val ));
174- if (i < shape [dim ] - 1 ) sb .append (", " );
175- }
176- sb .append ("]" );
177- }
178-
179- else {
180- sb .append ("\n " );
181- for (int i = 0 ; i < shape [dim ]; i ++) {
182- indices [dim ] = i ;
183- printRecursive (tb , shape , dim + 1 , indices , sb , indent + 2 );
184-
185- if (i < shape [dim ] - 1 ) {
186- sb .append ("," );
187- sb .append ("\n " );
188- if (shape .length - dim > 2 ) sb .append ("\n " );
189- }
190- }
191- sb .append ("\n " );
192- for (int k = 0 ; k < indent ; k ++) sb .append (" " );
193- sb .append ("]" );
194- }
195- }
196-
197- }
31+ @ Test
32+ public void testRightElem (){
33+ int [] shape = {2 , 3 , 4 };
34+ TensorBlock tensor = TensorUtils .createArangeTensor (shape );
35+
36+ Assert .assertArrayEquals (new int []{2 , 3 , 4 }, tensor .getDims ());
37+ Assert .assertEquals (0.0 , tensor .get (new int []{0 , 0 , 0 }));
38+ Assert .assertEquals (23.0 , tensor .get (new int []{1 , 2 , 3 }));
39+ Assert .assertEquals (6.0 , tensor .get (new int []{0 , 1 , 2 }));
40+ Assert .assertEquals (12.0 , tensor .get (new int []{1 , 0 , 0 }));
41+ printTensor (tensor );
42+
43+
44+ int [] permutation = {1 , 0 , 2 };
45+ TensorBlock outTensor = PermuteIt .permute (tensor , permutation );
46+ printTensor (outTensor );
47+
48+ Assert .assertArrayEquals (new int []{3 , 2 , 4 }, outTensor .getDims ());
49+ Assert .assertEquals (0.0 , outTensor .get (new int []{0 ,0 ,0 }));
50+ Assert .assertEquals (23.0 , outTensor .get (new int []{2 , 1 , 3 }));
51+ Assert .assertEquals (12.0 , outTensor .get (new int []{0 , 1 , 0 }));
52+ Assert .assertEquals (17.0 , outTensor .get (new int []{1 , 1 , 1 }));
53+
54+ int [] second_permutation = {2 , 1 , 0 };
55+ TensorBlock perm2Block = PermuteIt .permute (tensor , second_permutation );
56+ printTensor (perm2Block );
57+
58+ Assert .assertArrayEquals (new int []{4 , 3 , 2 }, perm2Block .getDims ());
59+ Assert .assertEquals (0.0 , perm2Block .get (new int []{0 , 0 , 0 }));
60+ Assert .assertEquals (12.0 , perm2Block .get (new int []{0 , 0 , 1 }));
61+ Assert .assertEquals (11.0 , perm2Block .get (new int []{3 , 2 , 0 }));
62+ Assert .assertEquals (23.0 , perm2Block .get (new int []{3 , 2 , 1 }));
63+ }
64+
65+ public class TensorUtils {
66+
67+ public static TensorBlock createArangeTensor (int [] shape ) {
68+ TensorBlock tb = new TensorBlock (ValueType .FP64 , shape );
69+ tb .allocateBlock ();
70+ double [] counter = { 0.0 };
71+ int [] currentIndices = new int [shape .length ];
72+
73+ fillRecursively (tb , shape , 0 , currentIndices , counter );
74+
75+ return tb ;
76+ }
77+
78+ private static void fillRecursively (TensorBlock tb , int [] shape , int dim , int [] currentIndices , double [] counter ) {
79+ if (dim == shape .length ) {
80+ tb .set (currentIndices , counter [0 ]);
81+ counter [0 ]++;
82+ return ;
83+ }
84+
85+ for (int i = 0 ; i < shape [dim ]; i ++) {
86+ currentIndices [dim ] = i ;
87+
88+ fillRecursively (tb , shape , dim + 1 , currentIndices , counter );
89+ }
90+ }
91+ }
92+
93+ public class PermuteIt {
94+ public static TensorBlock permute (TensorBlock tensor , int [] permute_dims ) {
95+ int anz_dims = tensor .getNumDims ();
96+ int [] dims = tensor .getDims ();
97+ ValueType tensorType = tensor .getValueType ();
98+
99+ int [] out_shape = new int [anz_dims ];
100+
101+ for (int idx = 0 ; idx < anz_dims ; idx ++){
102+ out_shape [idx ] = dims [permute_dims [idx ]];
103+ }
104+
105+ TensorBlock outTensor = new TensorBlock (tensorType , out_shape );
106+ outTensor .allocateBlock ();
107+
108+ int [] inIndex = new int [anz_dims ];
109+ int [] outIndex = new int [anz_dims ];
110+
111+ recursion (tensor , outTensor , permute_dims , dims , 0 , inIndex , outIndex );
112+ return outTensor ;
113+ }
114+
115+ public static void recursion (TensorBlock inTensor , TensorBlock outTensor ,
116+ int [] permutation , int [] inShape , int dim , int [] inIndex , int []outIndex )
117+ {
118+ if (dim == inShape .length ) {
119+ for (int idx = 0 ; idx < permutation .length ; idx ++){
120+ outIndex [idx ] = inIndex [permutation [idx ]];
121+ }
122+ double val = (double ) inTensor .get (inIndex );
123+ outTensor .set (outIndex , val );
124+ return ;
125+ }
126+
127+ for (int idx = 0 ; idx < inShape [dim ]; idx ++){
128+ inIndex [dim ] = idx ;
129+ recursion (inTensor , outTensor , permutation , inShape , dim +1 , inIndex , outIndex );
130+ }
131+ }
132+ }
133+
134+ public static void printTensor (TensorBlock tb ) {
135+ StringBuilder sb = new StringBuilder ();
136+ int [] shape = tb .getDims ();
137+ int [] currentIndices = new int [shape .length ];
138+
139+ sb .append ("Tensor(" ).append (Arrays .toString (shape )).append ("):\n " );
140+ printRecursive (tb , shape , 0 , currentIndices , sb , 0 );
141+
142+ System .out .println (sb .toString ());
143+ }
144+
145+ private static void printRecursive (TensorBlock tb , int [] shape , int dim , int [] indices , StringBuilder sb , int indent ) {
146+ for (int k = 0 ; k < indent ; k ++) sb .append (" " );
147+
148+ sb .append ("[" );
149+
150+ if (dim == shape .length - 1 ) {
151+ for (int i = 0 ; i < shape [dim ]; i ++) {
152+ indices [dim ] = i ;
153+ double val = (double ) tb .get (indices );
154+ sb .append (String .format ("%.1f" , val ));
155+ if (i < shape [dim ] - 1 ) sb .append (", " );
156+ }
157+ sb .append ("]" );
158+ }
159+
160+ else {
161+ sb .append ("\n " );
162+ for (int i = 0 ; i < shape [dim ]; i ++) {
163+ indices [dim ] = i ;
164+ printRecursive (tb , shape , dim + 1 , indices , sb , indent + 2 );
165+
166+ if (i < shape [dim ] - 1 ) {
167+ sb .append ("," );
168+ sb .append ("\n " );
169+ if (shape .length - dim > 2 ) sb .append ("\n " );
170+ }
171+ }
172+ sb .append ("\n " );
173+ for (int k = 0 ; k < indent ; k ++) sb .append (" " );
174+ sb .append ("]" );
175+ }
176+ }
177+ }
0 commit comments