@@ -177,13 +177,26 @@ KERNEL(eltwise)(
177
177
uint output_offset = GET_INDEX (OUTPUT ,, OUTPUT_IDX_ORDER );
178
178
179
179
// zero-padding the blocked format padded memory area since it might be used as input of onednn concatenation
180
- if (d4 + d3 + d2 + d1 == 0 ) {
180
+ const size_t g0 = get_group_id (0 );
181
+ const size_t g1 = get_group_id (1 );
182
+ const size_t g2 = get_group_id (2 );
183
+ const uint l_z = get_local_id (2 );
184
+
185
+ if (g0 + g1 + g2 + l_z == 0 ) {
186
+ const uint l_x = get_local_id (0 );
187
+ const uint l_y = get_local_id (1 );
188
+ const uint ls_x = get_local_size (0 );
189
+ const uint ls_y = get_local_size (1 );
190
+
191
+ const uint BLOCK_Y = y_size / ls_y ;
192
+ const uint BLOCK_X = x_size / ls_x ;
181
193
const uint b_size = OUTPUT_SIZES [3 ], f_size = OUTPUT_SIZES [2 ], y_size = OUTPUT_SIZES [1 ], x_size = OUTPUT_SIZES [0 ];
182
194
183
195
#if BATCH_BLOCK_SIZE && FEATURE_BLOCK_SIZE
196
+ const uint z_size = 1 ;
197
+
184
198
const uint padded_fs = (f_size + FEATURE_BLOCK_SIZE - 1 ) / FEATURE_BLOCK_SIZE ;
185
199
const uint padded_bs = (b_size + BATCH_BLOCK_SIZE - 1 ) / BATCH_BLOCK_SIZE ;
186
- const uint z_size = 1 ;
187
200
188
201
const uint bsv_pitch = FEATURE_BLOCK_SIZE ;
189
202
const uint x_pitch = bsv_pitch * BATCH_BLOCK_SIZE ;
@@ -198,8 +211,8 @@ KERNEL(eltwise)(
198
211
for (uint bs = 0 ; bs < padded_bs ; ++ bs ) {
199
212
for (uint fs = 0 ; fs < padded_fs ; ++ fs ) {
200
213
for (uint z = 0 ; z < z_size ; ++ z ) {
201
- for (uint y = 0 ; y < y_size ; ++ y ) {
202
- for (uint x = 0 ; x < x_size ; ++ x ) {
214
+ for (uint y = l_y * BLOCK_Y ; y < ( l_y + 1 ) * BLOCK_Y ; ++ y ) {
215
+ for (uint x = l_x * BLOCK_X ; x < ( l_x + 1 ) * BLOCK_X ; ++ x ) {
203
216
for (uint bsv = 0 ; bsv < BATCH_BLOCK_SIZE ; ++ bsv ) {
204
217
for (uint fsv = 0 ; fsv < FEATURE_BLOCK_SIZE ; ++ fsv ) {
205
218
b = bs * BATCH_BLOCK_SIZE + bsv ;
@@ -228,14 +241,12 @@ KERNEL(eltwise)(
228
241
uint offset = 0 ;
229
242
for (uint b = 0 ; b < b_size ; ++ b ) {
230
243
for (uint fs = padded_fs - 1 ; fs < padded_fs ; ++ fs ) {
231
- for (uint y = 0 ; y < y_size ; ++ y ) {
232
- for (uint x = 0 ; x < x_size ; ++ x ) {
233
- for (uint fsv = 0 ; fsv < FEATURE_BLOCK_SIZE ; ++ fsv ) {
244
+ for (uint y = l_y * BLOCK_Y ; y < ( l_y + 1 ) * BLOCK_Y ; ++ y ) {
245
+ for (uint x = l_x * BLOCK_X ; x < ( l_x + 1 ) * BLOCK_X ; ++ x ) {
246
+ for (uint fsv = f_size % FEATURE_BLOCK_SIZE ; fsv < FEATURE_BLOCK_SIZE ; ++ fsv ) {
234
247
f = fs * FEATURE_BLOCK_SIZE + fsv ;
235
- if (f >= f_size ) {
236
- offset = b * b_pitch + fs * fs_pitch + y * y_pitch + x * x_pitch + fsv ;
237
- output [offset ] = 0 ;
238
- }
248
+ offset = b * b_pitch + fs * fs_pitch + y * y_pitch + x * x_pitch + fsv ;
249
+ output [offset ] = 0 ;
239
250
}
240
251
}
241
252
}
0 commit comments