@@ -151,10 +151,10 @@ void jit_sve_x8s8s32x_1x1_conv_kernel::reduce_loop(int load_loop_blk,
151
151
jcp.typesize_bia * jcp.oc_block * i_load);
152
152
};
153
153
154
- // auto comp_ptr = [=](int i_load) {
155
- // return SVE_compress_addr(reg_comp_data,
156
- // sizeof(int32_t) * jcp.oc_block * i_load);
157
- // };
154
+ auto comp_ptr = [=](int i_load) {
155
+ return SVE_compress_addr (reg_comp_data,
156
+ sizeof (int32_t ) * jcp.oc_block * i_load);
157
+ };
158
158
159
159
auto scale_ptr = [=](int i_load) {
160
160
return SVE_compress_addr (reg_ptr_scales,
@@ -234,6 +234,12 @@ void jit_sve_x8s8s32x_1x1_conv_kernel::reduce_loop(int load_loop_blk,
234
234
auto r = vreg_accum (i_load, i_ur);
235
235
vpxord (r, r, r);
236
236
}
237
+ if (!jcp.signed_input ) {
238
+ xor_ (reg_scratch, reg_scratch);
239
+ Reg8 _t8 = reg_scratch.cvt8 ();
240
+ mov (_t8, (int8_t )-128 );
241
+ vpbroadcastb (zmm_shift, _t8);
242
+ }
237
243
};
238
244
239
245
auto store = [=](const bool mask_flag_in) {
@@ -251,15 +257,25 @@ void jit_sve_x8s8s32x_1x1_conv_kernel::reduce_loop(int load_loop_blk,
251
257
for (int i_load = 0 ; i_load < load_loop_blk; ++i_load) {
252
258
const bool mask_flag = mask_flag_in && i_load == load_loop_blk - 1 ;
253
259
auto zmm_bias = zmm_tmp;
260
+ auto zmm_comp = zmm_bcast;
254
261
if (jcp.with_bias ) {
262
+ if (!jcp.signed_input )
263
+ mov (reg_bias_data,
264
+ SVE_compress_addr (rsp,reg_bias_data_off));
255
265
cvt2ps (jcp.bia_dt , zmm_bias, bias_ptr (i_load), mask_flag);
256
266
}
267
+ if (!jcp.signed_input ) {
268
+ mov (reg_comp_data, SVE_compress_addr (rsp, reg_comp_data_off));
269
+ cvt2ps (data_type::s32, zmm_comp, comp_ptr (i_load), mask_flag);
270
+ }
257
271
258
272
auto zmm_scale = zmm_one;
259
273
vmovups (zmm_scale, scale_ptr (i_load));
260
274
for (int i_ur = 0 ; i_ur < ur; ++i_ur) {
261
275
auto r = vreg_accum (i_load, i_ur);
262
276
CGA64::scvtf (xa::ZRegS (r.getIdx ()), xa::PReg (vmask.getIdx ()), xa::ZRegS (r.getIdx ())); // < vcvtdq2ps(r, r);
277
+ if (!jcp.signed_input )
278
+ vsubps (r, r, zmm_comp);
263
279
if (jcp.with_bias )
264
280
vaddps (r, r, zmm_bias);
265
281
@@ -393,29 +409,47 @@ void jit_sve_x8s8s32x_1x1_conv_kernel::reduce_loop(int load_loop_blk,
393
409
for (int i_load = 0 ; i_load < load_loop_blk; ++i_load)
394
410
vmovups (vreg_load (i_load), load_ptr (i_reduce, i_load));
395
411
for (int i_ur = 0 ; i_ur < ur; ++i_ur) {
396
- if (last_block && tail_size != 0
397
- && i_reduce == loop_unroll - reduce_step) {
398
- Xmm xmm_bcast = Xmm (zmm_bcast.getIdx ());
399
- for (int r = 0 ; r < tail_size; ++r)
400
- vpinsrb (xmm_bcast, xmm_bcast, ptr[aux_reg_bcast_data
401
- + jcp.ic_without_padding * i_ur + i_reduce + r], r);
402
- Zmm _bcast = ((i_ur % 2 ) == 0 )? zmm_bcast : zmm_bcast2;
403
- vpbroadcastd (_bcast, xmm_bcast);
412
+ if (jcp.signed_input ) {
413
+ if (last_block && tail_size != 0
414
+ && i_reduce == loop_unroll - reduce_step) {
415
+ Xmm xmm_bcast = Xmm (zmm_bcast.getIdx ());
416
+ for (int r = 0 ; r < tail_size; ++r)
417
+ vpinsrb (xmm_bcast, xmm_bcast, ptr[aux_reg_bcast_data
418
+ + jcp.ic_without_padding * i_ur + i_reduce + r], r);
419
+ Zmm _bcast = ((i_ur % 2 ) == 0 )? zmm_bcast : zmm_bcast2;
420
+ vpbroadcastd (_bcast, xmm_bcast);
421
+ } else {
422
+ if (i_ur == 0 ) {
423
+ // vpbroadcastd(zmm_bcast, bcast_ptr(i_reduce, i_ur, false));
424
+ bcast_ptr (zmm_bcast, i_reduce, i_ur, false );
425
+ }
426
+ if ((i_ur+1 ) < ur) {
427
+ Zmm _bcast = ((i_ur % 2 ) == 0 )? zmm_bcast2 : zmm_bcast;
428
+ // vpbroadcastd(zmm_bcast, bcast_ptr(i_reduce, (i_ur+1), false));
429
+ bcast_ptr (_bcast, i_reduce, (i_ur+1 ), false );
430
+ }
431
+ }
432
+ for (int i_load = 0 ; i_load < load_loop_blk; ++i_load) {
433
+ Zmm _bcast = ((i_ur % 2 ) == 0 )? zmm_bcast : zmm_bcast2;
434
+ compute (vreg_accum (i_load, i_ur),
435
+ vreg_load (i_load), _bcast);
436
+ }
404
437
} else {
405
- if (i_ur == 0 ) {
406
- // vpbroadcastd(zmm_bcast, bcast_ptr(i_reduce, i_ur, false));
407
- bcast_ptr (zmm_bcast, i_reduce, i_ur, false );
408
- }
409
- if ((i_ur+1 ) < ur) {
410
- Zmm _bcast = ((i_ur % 2 ) == 0 )? zmm_bcast2 : zmm_bcast;
411
- // vpbroadcastd(zmm_bcast, bcast_ptr(i_reduce, (i_ur+1), false));
412
- bcast_ptr (_bcast, i_reduce, (i_ur+1 ), false );
413
- }
414
- }
415
- for (int i_load = 0 ; i_load < load_loop_blk; ++i_load) {
416
- Zmm _bcast = ((i_ur % 2 ) == 0 )? zmm_bcast : zmm_bcast2;
417
- compute (vreg_accum (i_load, i_ur),
418
- vreg_load (i_load), _bcast);
438
+ if (last_block && tail_size != 0
439
+ && i_reduce == loop_unroll - reduce_step) {
440
+ Xmm xmm_bcast = Xmm (zmm_bcast.getIdx ());
441
+ for (int r = 0 ; r < tail_size; ++r)
442
+ vpinsrb (xmm_bcast, xmm_bcast, ptr[aux_reg_bcast_data
443
+ + jcp.ic_without_padding * i_ur + i_reduce + r], r);
444
+ vpbroadcastd (zmm_bcast, xmm_bcast);
445
+ } else {
446
+ bcast_ptr (zmm_bcast, i_reduce, i_ur, false );
447
+ }
448
+ vpaddb (zmm_bcast, zmm_bcast, zmm_shift);
449
+ for (int i_load = 0 ; i_load < load_loop_blk; ++i_load) {
450
+ compute (vreg_accum (i_load, i_ur),
451
+ vreg_load (i_load), zmm_bcast);
452
+ }
419
453
}
420
454
}
421
455
}
@@ -498,6 +532,12 @@ void jit_sve_x8s8s32x_1x1_conv_kernel::generate()
498
532
499
533
if (jcp.with_bias )
500
534
mov (reg_bias_data, ptr[param1 + GET_OFF (bias_data)]);
535
+ if (!jcp.signed_input ) {
536
+ mov (SVE_compress_addr (rsp, reg_bias_data_off), reg_bias_data);
537
+ mov (reg_comp_data, ptr[param1 + GET_OFF (compensation)]);
538
+ mov (SVE_compress_addr (rsp, reg_comp_data_off), reg_comp_data);
539
+ }
540
+
501
541
mov (reg_ptr_scales, ptr[param1 + GET_OFF (scales)]);
502
542
mov (SVE_compress_addr (rsp, reg_ptr_sum_scale_off), reg_ptr_scales);
503
543
mov (reg_bcast_data, ptr[param1 + GET_OFF (bcast_data)]);
@@ -515,8 +555,18 @@ void jit_sve_x8s8s32x_1x1_conv_kernel::generate()
515
555
bcast_loop (load_loop_blk);
516
556
add (reg_load_data, load_loop_blk * jcp.load_loop_load_step );
517
557
if (jcp.with_bias ) {
558
+ if (!jcp.signed_input )
559
+ mov (reg_bias_data, SVE_compress_addr (rsp, reg_bias_data_off));
518
560
add (reg_bias_data,
519
561
load_loop_blk * jcp.load_block * jcp.typesize_bia );
562
+ if (!jcp.signed_input )
563
+ mov (SVE_compress_addr (rsp, reg_bias_data_off), reg_bias_data);
564
+ }
565
+ if (!jcp.signed_input ) {
566
+ mov (reg_comp_data, SVE_compress_addr (rsp, reg_comp_data_off));
567
+ add (reg_comp_data,
568
+ load_loop_blk * jcp.load_block * sizeof (int32_t ));
569
+ mov (SVE_compress_addr (rsp, reg_comp_data_off), reg_comp_data);
520
570
}
521
571
mov (SVE_compress_addr (rsp, reg_bcast_data_off), reg_bcast_data);
522
572
mov (reg_ptr_scales, SVE_compress_addr (rsp, reg_ptr_sum_scale_off));
0 commit comments