Skip to content

Commit c3e44de

Browse files
committed
Merge branch '32_bug_int8_conv1x1' into 'fjdev'
32 bug int8 conv1x1 See merge request postk_dl/dnnl_aarch64!57
2 parents 237ff92 + 15e8a08 commit c3e44de

9 files changed

+179
-78
lines changed

src/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ else()
6868
${CMAKE_CURRENT_SOURCE_DIR}/cpu/jit_sve_x8s8s32x_1x1_conv_kernel.cpp
6969
${CMAKE_CURRENT_SOURCE_DIR}/cpu/jit_sve_x8s8s32x_1x1_convolution.cpp
7070
${CMAKE_CURRENT_SOURCE_DIR}/cpu/jit_sve_x8s8s32x_conv_kernel.cpp
71-
${CMAKE_CURRENT_SOURCE_DIR}/cpu/jit_sve_x8s8s32x_convolution.hpp
71+
${CMAKE_CURRENT_SOURCE_DIR}/cpu/jit_sve_x8s8s32x_convolution.cpp
7272
)
7373
endif()
7474

src/cpu/jit_sve_x8s8s32x_1x1_conv_kernel.cpp

Lines changed: 76 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,10 @@ void jit_sve_x8s8s32x_1x1_conv_kernel::reduce_loop(int load_loop_blk,
151151
jcp.typesize_bia * jcp.oc_block * i_load);
152152
};
153153

154-
// auto comp_ptr = [=](int i_load) {
155-
// return SVE_compress_addr(reg_comp_data,
156-
// sizeof(int32_t) * jcp.oc_block * i_load);
157-
// };
154+
auto comp_ptr = [=](int i_load) {
155+
return SVE_compress_addr(reg_comp_data,
156+
sizeof(int32_t) * jcp.oc_block * i_load);
157+
};
158158

159159
auto scale_ptr = [=](int i_load) {
160160
return SVE_compress_addr(reg_ptr_scales,
@@ -234,6 +234,12 @@ void jit_sve_x8s8s32x_1x1_conv_kernel::reduce_loop(int load_loop_blk,
234234
auto r = vreg_accum(i_load, i_ur);
235235
vpxord(r, r, r);
236236
}
237+
if (!jcp.signed_input) {
238+
xor_(reg_scratch, reg_scratch);
239+
Reg8 _t8 = reg_scratch.cvt8();
240+
mov(_t8, (int8_t)-128);
241+
vpbroadcastb(zmm_shift, _t8);
242+
}
237243
};
238244

239245
auto store = [=](const bool mask_flag_in) {
@@ -251,15 +257,25 @@ void jit_sve_x8s8s32x_1x1_conv_kernel::reduce_loop(int load_loop_blk,
251257
for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
252258
const bool mask_flag = mask_flag_in && i_load == load_loop_blk - 1;
253259
auto zmm_bias = zmm_tmp;
260+
auto zmm_comp = zmm_bcast;
254261
if (jcp.with_bias) {
262+
if (!jcp.signed_input)
263+
mov(reg_bias_data,
264+
SVE_compress_addr(rsp,reg_bias_data_off));
255265
cvt2ps(jcp.bia_dt, zmm_bias, bias_ptr(i_load), mask_flag);
256266
}
267+
if (!jcp.signed_input) {
268+
mov(reg_comp_data, SVE_compress_addr(rsp, reg_comp_data_off));
269+
cvt2ps(data_type::s32, zmm_comp, comp_ptr(i_load), mask_flag);
270+
}
257271

258272
auto zmm_scale = zmm_one;
259273
vmovups(zmm_scale, scale_ptr(i_load));
260274
for (int i_ur = 0; i_ur < ur; ++i_ur) {
261275
auto r = vreg_accum(i_load, i_ur);
262276
CGA64::scvtf(xa::ZRegS(r.getIdx()), xa::PReg(vmask.getIdx()), xa::ZRegS(r.getIdx())); //< vcvtdq2ps(r, r);
277+
if (!jcp.signed_input)
278+
vsubps(r, r, zmm_comp);
263279
if (jcp.with_bias)
264280
vaddps(r, r, zmm_bias);
265281

@@ -393,29 +409,47 @@ void jit_sve_x8s8s32x_1x1_conv_kernel::reduce_loop(int load_loop_blk,
393409
for (int i_load = 0; i_load < load_loop_blk; ++i_load)
394410
vmovups(vreg_load(i_load), load_ptr(i_reduce, i_load));
395411
for (int i_ur = 0; i_ur < ur; ++i_ur) {
396-
if (last_block && tail_size != 0
397-
&& i_reduce == loop_unroll - reduce_step) {
398-
Xmm xmm_bcast = Xmm(zmm_bcast.getIdx());
399-
for (int r = 0; r < tail_size; ++r)
400-
vpinsrb(xmm_bcast, xmm_bcast, ptr[aux_reg_bcast_data
401-
+ jcp.ic_without_padding * i_ur + i_reduce + r], r);
402-
Zmm _bcast = ((i_ur % 2) == 0)? zmm_bcast : zmm_bcast2;
403-
vpbroadcastd(_bcast, xmm_bcast);
412+
if (jcp.signed_input) {
413+
if (last_block && tail_size != 0
414+
&& i_reduce == loop_unroll - reduce_step) {
415+
Xmm xmm_bcast = Xmm(zmm_bcast.getIdx());
416+
for (int r = 0; r < tail_size; ++r)
417+
vpinsrb(xmm_bcast, xmm_bcast, ptr[aux_reg_bcast_data
418+
+ jcp.ic_without_padding * i_ur + i_reduce + r], r);
419+
Zmm _bcast = ((i_ur % 2) == 0)? zmm_bcast : zmm_bcast2;
420+
vpbroadcastd(_bcast, xmm_bcast);
421+
} else {
422+
if(i_ur == 0) {
423+
// vpbroadcastd(zmm_bcast, bcast_ptr(i_reduce, i_ur, false));
424+
bcast_ptr(zmm_bcast, i_reduce, i_ur, false);
425+
}
426+
if ((i_ur+1) < ur) {
427+
Zmm _bcast = ((i_ur % 2) == 0)? zmm_bcast2 : zmm_bcast;
428+
// vpbroadcastd(zmm_bcast, bcast_ptr(i_reduce, (i_ur+1), false));
429+
bcast_ptr(_bcast, i_reduce, (i_ur+1), false);
430+
}
431+
}
432+
for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
433+
Zmm _bcast = ((i_ur % 2) == 0)? zmm_bcast : zmm_bcast2;
434+
compute(vreg_accum(i_load, i_ur),
435+
vreg_load(i_load), _bcast);
436+
}
404437
} else {
405-
if(i_ur == 0) {
406-
// vpbroadcastd(zmm_bcast, bcast_ptr(i_reduce, i_ur, false));
407-
bcast_ptr(zmm_bcast, i_reduce, i_ur, false);
408-
}
409-
if ((i_ur+1) < ur) {
410-
Zmm _bcast = ((i_ur % 2) == 0)? zmm_bcast2 : zmm_bcast;
411-
// vpbroadcastd(zmm_bcast, bcast_ptr(i_reduce, (i_ur+1), false));
412-
bcast_ptr(_bcast, i_reduce, (i_ur+1), false);
413-
}
414-
}
415-
for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
416-
Zmm _bcast = ((i_ur % 2) == 0)? zmm_bcast : zmm_bcast2;
417-
compute(vreg_accum(i_load, i_ur),
418-
vreg_load(i_load), _bcast);
438+
if (last_block && tail_size != 0
439+
&& i_reduce == loop_unroll - reduce_step) {
440+
Xmm xmm_bcast = Xmm(zmm_bcast.getIdx());
441+
for (int r = 0; r < tail_size; ++r)
442+
vpinsrb(xmm_bcast, xmm_bcast, ptr[aux_reg_bcast_data
443+
+ jcp.ic_without_padding * i_ur + i_reduce + r], r);
444+
vpbroadcastd(zmm_bcast, xmm_bcast);
445+
} else {
446+
bcast_ptr(zmm_bcast, i_reduce, i_ur, false);
447+
}
448+
vpaddb(zmm_bcast, zmm_bcast, zmm_shift);
449+
for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
450+
compute(vreg_accum(i_load, i_ur),
451+
vreg_load(i_load), zmm_bcast);
452+
}
419453
}
420454
}
421455
}
@@ -498,6 +532,12 @@ void jit_sve_x8s8s32x_1x1_conv_kernel::generate()
498532

499533
if (jcp.with_bias)
500534
mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]);
535+
if (!jcp.signed_input) {
536+
mov(SVE_compress_addr(rsp, reg_bias_data_off), reg_bias_data);
537+
mov(reg_comp_data, ptr[param1 + GET_OFF(compensation)]);
538+
mov(SVE_compress_addr(rsp, reg_comp_data_off), reg_comp_data);
539+
}
540+
501541
mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
502542
mov(SVE_compress_addr(rsp, reg_ptr_sum_scale_off), reg_ptr_scales);
503543
mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]);
@@ -515,8 +555,18 @@ void jit_sve_x8s8s32x_1x1_conv_kernel::generate()
515555
bcast_loop(load_loop_blk);
516556
add(reg_load_data, load_loop_blk * jcp.load_loop_load_step);
517557
if (jcp.with_bias) {
558+
if (!jcp.signed_input)
559+
mov(reg_bias_data, SVE_compress_addr(rsp, reg_bias_data_off));
518560
add(reg_bias_data,
519561
load_loop_blk * jcp.load_block * jcp.typesize_bia);
562+
if (!jcp.signed_input)
563+
mov(SVE_compress_addr(rsp, reg_bias_data_off), reg_bias_data);
564+
}
565+
if (!jcp.signed_input) {
566+
mov(reg_comp_data, SVE_compress_addr(rsp, reg_comp_data_off));
567+
add(reg_comp_data,
568+
load_loop_blk * jcp.load_block * sizeof(int32_t));
569+
mov(SVE_compress_addr(rsp, reg_comp_data_off), reg_comp_data);
520570
}
521571
mov(SVE_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data);
522572
mov(reg_ptr_scales, SVE_compress_addr(rsp, reg_ptr_sum_scale_off));

src/cpu/jit_sve_x8s8s32x_1x1_conv_kernel.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ struct jit_sve_x8s8s32x_1x1_conv_kernel: public jit_generator {
136136
Xbyak::Zmm zmm_zero = Xbyak::Zmm(30);
137137
Xbyak::Zmm zmm_bcast = Xbyak::Zmm(31);
138138
Xbyak::Zmm zmm_bcast2 = Xbyak::Zmm(30);
139+
Xbyak::Zmm zmm_shift = Xbyak::Zmm(30);
139140

140141
int bcast_loop_work_off = 0;
141142
int reg_bias_data_off = 8;

src/cpu/jit_sve_x8s8s32x_1x1_convolution.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ ::execute_forward_thr(const int ithr, const int nthr, const src_data_t *src,
9393
int offset = jcp.ngroups * (jcp.oc / jcp.oc_block) * (jcp.ic / jcp.ic_block)
9494
* jcp.oc_block * jcp.ic_block;
9595
wei_data_t *w = const_cast<wei_data_t *>(weights);
96-
int32_t* compensation = (jcp.signed_input)
96+
int32_t* compensation = (!jcp.signed_input)
9797
? reinterpret_cast<int32_t *>(w + offset) : 0;
9898

9999
auto step = [](int default_step, int remaining, int tail_step) {
@@ -173,7 +173,7 @@ ::execute_forward_thr(const int ithr, const int nthr, const src_data_t *src,
173173
? weights_d.blk_off(g, ocb, icb)
174174
: weights_d.blk_off(ocb, icb)];
175175
p.bias_data = &bias[_ocb * jcp.oc_block * bia_dt_size];
176-
p.compensation = (jcp.signed_input)
176+
p.compensation = (!jcp.signed_input)
177177
? &compensation[_ocb * jcp.oc_block] : 0;
178178
p.scales = &oscales.scales_[jcp.is_oc_scale * _ocb * jcp.oc_block];
179179
if (pd()->rtus_.reduce_src_) {

src/cpu/jit_sve_x8s8s32x_1x1_convolution.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ struct jit_sve_x8s8s32x_1x1_convolution_fwd_t : public cpu_primitive_t {
120120
CHECK(this->dst_pd_.set_format(nhwc));
121121
if (this->weights_pd_.desc()->format == any)
122122
CHECK(this->weights_pd_.set_format(this->with_groups()
123-
? (is_sign_input ? gOIhw4i16o4i_s8s8 : gOIhw4i16o4i)
124-
: (is_sign_input ? OIhw4i16o4i_s8s8 : OIhw4i16o4i)));
123+
? (!is_sign_input ? gOIhw4i16o4i_s8s8 : gOIhw4i16o4i)
124+
: (!is_sign_input ? OIhw4i16o4i_s8s8 : OIhw4i16o4i)));
125125
if (this->bias_pd_.desc()->format == any)
126126
CHECK(this->bias_pd_.set_format(x));
127127
if (this->desc()->alg_kind == alg_kind::convolution_auto)

0 commit comments

Comments
 (0)