@@ -179,16 +179,14 @@ KERNEL(eltwise)(
179
179
// zero-padding the blocked format padded memory area since it might be used as input of onednn concatenation
180
180
if (d4 + d3 + d2 + d1 == 0 ) {
181
181
const uint b_size = OUTPUT_SIZES [3 ], f_size = OUTPUT_SIZES [2 ], y_size = OUTPUT_SIZES [1 ], x_size = OUTPUT_SIZES [0 ];
182
- const uint BLOCK_F = FEATURE_BLOCK_SIZE ;
183
182
184
- #if BATCH_BLOCK_SIZE
185
- const uint BLOCK_B = BATCH_BLOCK_SIZE ;
183
+ #if BATCH_BLOCK_SIZE && FEATURE_BLOCK_SIZE
186
184
const uint padded_fs = (f_size + FEATURE_BLOCK_SIZE - 1 ) / FEATURE_BLOCK_SIZE ;
187
185
const uint padded_bs = (b_size + BATCH_BLOCK_SIZE - 1 ) / BATCH_BLOCK_SIZE ;
188
186
const uint z_size = 1 ;
189
187
190
- const uint bsv_pitch = BLOCK_F ;
191
- const uint x_pitch = bsv_pitch * BLOCK_B ;
188
+ const uint bsv_pitch = FEATURE_BLOCK_SIZE ;
189
+ const uint x_pitch = bsv_pitch * BATCH_BLOCK_SIZE ;
192
190
const uint y_pitch = x_pitch * x_size ;
193
191
const uint z_pitch = y_pitch * y_size ;
194
192
const uint fs_pitch = z_pitch * z_size ;
@@ -202,10 +200,10 @@ KERNEL(eltwise)(
202
200
for (uint z = 0 ; z < z_size ; ++ z ) {
203
201
for (uint y = 0 ; y < y_size ; ++ y ) {
204
202
for (uint x = 0 ; x < x_size ; ++ x ) {
205
- for (uint bsv = 0 ; bsv < BLOCK_B ; ++ bsv ) {
206
- for (uint fsv = 0 ; fsv < BLOCK_F ; ++ fsv ) {
207
- b = bs * BLOCK_B + bsv ;
208
- f = fs * BLOCK_F + fsv ;
203
+ for (uint bsv = 0 ; bsv < BATCH_BLOCK_SIZE ; ++ bsv ) {
204
+ for (uint fsv = 0 ; fsv < FEATURE_BLOCK_SIZE ; ++ fsv ) {
205
+ b = bs * BATCH_BLOCK_SIZE + bsv ;
206
+ f = fs * FEATURE_BLOCK_SIZE + fsv ;
209
207
if (b >= b_size || f >= f_size ) {
210
208
offset = bs * bs_pitch + fs * fs_pitch + z * z_pitch +
211
209
y * y_pitch + x * x_pitch + bsv * bsv_pitch + fsv ;
@@ -221,7 +219,7 @@ KERNEL(eltwise)(
221
219
#elif FEATURE_BLOCK_SIZE
222
220
const uint padded_fs = (f_size + FEATURE_BLOCK_SIZE - 1 ) / FEATURE_BLOCK_SIZE ;
223
221
224
- const uint x_pitch = BLOCK_F ;
222
+ const uint x_pitch = FEATURE_BLOCK_SIZE ;
225
223
const uint y_pitch = x_pitch * x_size ;
226
224
const uint fs_pitch = y_pitch * y_size ;
227
225
const uint b_pitch = fs_pitch * padded_fs ;
@@ -232,8 +230,8 @@ KERNEL(eltwise)(
232
230
for (uint fs = padded_fs - 1 ; fs < padded_fs ; ++ fs ) {
233
231
for (uint y = 0 ; y < y_size ; ++ y ) {
234
232
for (uint x = 0 ; x < x_size ; ++ x ) {
235
- for (uint fsv = 0 ; fsv < BLOCK_F ; ++ fsv ) {
236
- f = fs * BLOCK_F + fsv ;
233
+ for (uint fsv = 0 ; fsv < FEATURE_BLOCK_SIZE ; ++ fsv ) {
234
+ f = fs * FEATURE_BLOCK_SIZE + fsv ;
237
235
if (f >= f_size ) {
238
236
offset = b * b_pitch + fs * fs_pitch + y * y_pitch + x * x_pitch + fsv ;
239
237
output [offset ] = 0 ;
0 commit comments