@@ -16,38 +16,66 @@ struct FloatMatrix {
1616 float data [FOUR ][FOUR ][FOUR ];
1717};
1818
19+ #ifdef MAT2_IS_TRANSPOSED
20+ vec4 matmul_naive_W_packed_W_packed (
21+ #else
1922vec4 matmul_naive_W_packed_H_packed (
20- sampler3D im_mat1 ,
21- sampler3D im_mat2 ,
22- ivec3 mat1_pos ,
23- ivec3 mat2_pos ,
23+ #endif
24+ const sampler3D im_mat1 ,
25+ const sampler3D im_mat2 ,
26+ const ivec3 out_pos ,
2427 const int width ) {
28+ ivec3 mat1_pos = ivec3 (0 , out_pos .y , out_pos .z );
29+ #ifdef MAT2_IS_TRANSPOSED
30+ ivec3 mat2_pos = ivec3 (0 , out_pos .x * 4 , 0 );
31+ #else
32+ ivec3 mat2_pos = ivec3 (out_pos .x * 4 , 0 , out_pos .z );
33+ #endif
34+
2535 vec4 texel = vec4 (0 );
26- int K = (width + 3 ) / 4 ;
36+ const int K = (width + 3 ) / 4 ;
2737
2838 for (int i = 0 ; i < K ; ++ i ) {
29- vec4 mat1_tex = texelFetch (im_mat1 , mat1_pos , 0 );
30- vec4 sums = vec4 (
39+ const vec4 mat1_tex = texelFetch (im_mat1 , mat1_pos , 0 );
40+ #ifdef MAT2_IS_TRANSPOSED
41+ const vec4 sums = vec4 (
42+ dot (mat1_tex , texelFetch (im_mat2 , mat2_pos , 0 )),
43+ dot (mat1_tex , texelFetch (im_mat2 , mat2_pos + ivec3 (0 , 1 , 0 ), 0 )),
44+ dot (mat1_tex , texelFetch (im_mat2 , mat2_pos + ivec3 (0 , 2 , 0 ), 0 )),
45+ dot (mat1_tex , texelFetch (im_mat2 , mat2_pos + ivec3 (0 , 3 , 0 ), 0 )));
46+ #else
47+ const vec4 sums = vec4 (
3148 dot (mat1_tex , texelFetch (im_mat2 , mat2_pos , 0 )),
3249 dot (mat1_tex , texelFetch (im_mat2 , mat2_pos + ivec3 (1 , 0 , 0 ), 0 )),
3350 dot (mat1_tex , texelFetch (im_mat2 , mat2_pos + ivec3 (2 , 0 , 0 ), 0 )),
3451 dot (mat1_tex , texelFetch (im_mat2 , mat2_pos + ivec3 (3 , 0 , 0 ), 0 )));
52+ #endif
3553
3654 texel += sums ;
3755
3856 mat1_pos .x ++ ;
57+ #ifdef MAT2_IS_TRANSPOSED
58+ mat2_pos .x ++ ;
59+ #else
3960 mat2_pos .y ++ ;
61+ #endif
4062 }
4163
4264 return texel ;
4365}
4466
67+ #ifdef MAT2_IS_TRANSPOSED
68+ vec4 matmul_naive_W_packed_H_packed (
69+ #else
4570vec4 matmul_naive_W_packed_W_packed (
46- sampler3D im_mat1 ,
47- sampler3D im_mat2 ,
48- ivec3 mat1_pos ,
49- ivec3 mat2_pos ,
71+ #endif
72+ const sampler3D im_mat1 ,
73+ const sampler3D im_mat2 ,
74+ const ivec3 out_pos ,
5075 const int width ) {
76+ ivec3 mat1_pos = ivec3 (0 , out_pos .y , out_pos .z );
77+ ivec3 mat2_pos = ivec3 (out_pos .x , 0 , out_pos .z );
78+
5179 vec4 texel = vec4 (0 );
5280 int K = divup4 (width );
5381
@@ -87,7 +115,7 @@ vec4 get_texel_W_packed(
87115 else if (broadcast_at_height ) {
88116 self_texel = texelFetch (im_self , ivec3 (pos .x , 0 , 0 ), 0 );
89117 } else {
90- self_texel = texelFetch (im_self , pos , 0 );
118+ self_texel = texelFetch (im_self , ivec3 ( pos . x , pos . y , 0 ) , 0 );
91119 }
92120
93121 return self_texel ;
@@ -112,7 +140,7 @@ vec4 get_texel_C_packed(
112140 else if (broadcast_at_height ) {
113141 self_texel = texelFetch (im_self , ivec3 (pos .x , 0 , 0 ), 0 );
114142 } else {
115- self_texel = texelFetch (im_self , pos , 0 );
143+ self_texel = texelFetch (im_self , ivec3 ( pos . x , pos . y , 0 ) , 0 );
116144 }
117145
118146 return self_texel ;
@@ -123,8 +151,7 @@ FloatMatrix matmul_partial_4x4(
123151 sampler3D im_mat2 ,
124152 const ivec3 pos ,
125153 const int batch_size ,
126- const int K_texel_len ,
127- const int packed_dim_padding ) {
154+ const int K_texel_len ) {
128155 FloatMatrix results ;
129156 for (int i = 0 ; i < FOUR ; i ++ ) {
130157 for (int j = 0 ; j < FOUR ; j ++ ) {
@@ -133,43 +160,36 @@ FloatMatrix matmul_partial_4x4(
133160 }
134161 }
135162 }
136- vec4 im_mat1_partial_rows [FOUR ];
137- vec4 im_mat2_partial_cols [FOUR ];
163+ vec4 im_mat1_partial_load [FOUR ];
164+ vec4 im_mat2_partial_load [FOUR ];
138165
139166 for (int batch_idx = 0 ; batch_idx < FOUR ; batch_idx ++ ) {
140167 if (FOUR * pos .z + batch_idx >= batch_size ) {
141168 break ;
142169 }
143- // read and cache 4x4 tile of im_mat1 (4 adjacent rows)
170+ int mat_z = FOUR * pos . z + batch_idx ;
144171 for (int mat1_x = 0 ; mat1_x < K_texel_len ; mat1_x ++ ) {
145- for (int mat1_row = 0 ; mat1_row < FOUR ; mat1_row ++ ) {
146- const int mat1_y = (FOUR * pos .y ) + mat1_row ;
147- const ivec3 mat1_pos = ivec3 (mat1_x , mat1_y , FOUR * pos .z + batch_idx );
148- im_mat1_partial_rows [mat1_row ] = texelFetch (im_mat1 , mat1_pos , 0 );
149- // set the value out of the boundary to be 0
150- if (mat1_x == K_texel_len - 1 && packed_dim_padding > 0 ) {
151- for (int kk = 0 ; kk < packed_dim_padding ; kk ++ ) {
152- im_mat1_partial_rows [mat1_row ][3 - kk ] = 0 ;
153- }
154- }
155- }
156- // read and cache 4x4 tile of im_mat2 (4 adjacent columns)
157- for (int mat2_col = 0 ; mat2_col < FOUR ; mat2_col ++ ) {
158- const int mat2_x = (FOUR * pos .x ) + mat2_col ;
159- const ivec3 pos_rd = ivec3 (mat2_x , mat1_x , FOUR * pos .z + batch_idx );
160- im_mat2_partial_cols [mat2_col ] = texelFetch (im_mat2 , pos_rd , 0 );
161- // set the value out of the boundary to be 0
162- if (mat1_x == K_texel_len - 1 && packed_dim_padding > 0 ) {
163- for (int kk = 0 ; kk < packed_dim_padding ; kk ++ ) {
164- im_mat2_partial_cols [mat2_col ][3 - kk ] = 0 ;
165- }
166- }
172+ for (int offset = 0 ; offset < FOUR ; offset ++ ) {
173+ // read and cache 4x4 tile of im_mat1
174+ const int mat1_y = (FOUR * pos .y ) + offset ;
175+ const ivec3 mat1_pos = ivec3 (mat1_x , mat1_y , mat_z );
176+ im_mat1_partial_load [offset ] = texelFetch (im_mat1 , mat1_pos , 0 );
177+ // read and cache 4x4 tile of im_mat2
178+ #ifdef MAT2_IS_TRANSPOSED
179+ const int mat2_y = (FOUR * pos .x ) + offset ;
180+ const ivec3 mat2_pos = ivec3 (mat1_x , mat2_y , 0 );
181+ im_mat2_partial_load [offset ] = texelFetch (im_mat2 , mat2_pos , 0 );
182+ #else
183+ const int mat2_x = (FOUR * pos .x ) + offset ;
184+ const ivec3 mat2_pos = ivec3 (mat2_x , mat1_x , mat_z );
185+ im_mat2_partial_load [offset ] = texelFetch (im_mat2 , mat2_pos , 0 );
186+ #endif
167187 }
168188 // perform partial dot products and add partial result to results
169189 for (int out_row = 0 ; out_row < FOUR ; out_row ++ ) {
170190 for (int out_col = 0 ; out_col < FOUR ; out_col ++ ) {
171191 results .data [out_row ][out_col ][batch_idx ] +=
172- dot (im_mat1_partial_rows [out_row ], im_mat2_partial_cols [out_col ]);
192+ dot (im_mat1_partial_load [out_row ], im_mat2_partial_load [out_col ]);
173193 }
174194 }
175195 }
0 commit comments