Skip to content

Commit a7f1212

Browse files
committed
Merge branch '28_special_preg' into 'fjdev'
28 special preg See merge request postk_dl/dnnl_aarch64!58
2 parents c3e44de + f56b007 commit a7f1212

File tree

4 files changed

+471
-155
lines changed

4 files changed

+471
-155
lines changed

src/common/utils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ int mkldnn_getenv(const char *name, char *buffer, int buffer_size) {
6565
result = int_value_length;
6666
#ifndef _WIN32
6767
if (value)
68-
strncpy(buffer, value, buffer_size - 1);
68+
strncpy(buffer, value, buffer_size - 1);
6969
#endif
7070
}
7171
}

src/cpu/jit_uni_batch_normalization.cpp

Lines changed: 78 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -283,12 +283,12 @@ struct jit_bnorm_t: public jit_generator {
283283

284284
void prepare_l_relu_mask_avx2() {
285285
Label l_mask_after;
286-
jmp(l_mask_after);
286+
jmp(l_mask_after);
287287
align(32);
288288
L(l_relu_mask_avx2); /* [0x80 0x40 0x20 0x10 0x08 0x04 0x02 0x01] */
289289
for (int i = 0; i < 8; ++i) dd(1<<i);
290290
#ifdef DNNL_INDIRECT_JIT_AARCH64
291-
binCommit();
291+
binCommit();
292292
#endif
293293
L(l_mask_after);
294294
}
@@ -358,7 +358,7 @@ struct jit_bnorm_t: public jit_generator {
358358
} else {
359359
vmaskmovps(Vmm(dst.getIdx()), vtail_mask, src.getAddress());
360360
}
361-
jmp(l_ret);
361+
jmp(l_ret);
362362
}
363363

364364
void uni_vmovups_tail_avx512_common(const Operand &dst,
@@ -368,7 +368,7 @@ struct jit_bnorm_t: public jit_generator {
368368
else
369369
uni_vmovups(Vmm(dst.getIdx()) | ktail_mask | T_z, src.getAddress());
370370

371-
jmp(l_ret);
371+
jmp(l_ret);
372372
}
373373

374374
void uni_vmovups_maybe_tail(const Operand &dst, const Operand &src) {
@@ -377,11 +377,11 @@ struct jit_bnorm_t: public jit_generator {
377377
if (is_c_padded()) {
378378
mov(reg_tmp, ptr[rsp + stack_off_is_cblk_tail]);
379379
cmp(reg_tmp, 0);
380-
jz(l_no_mask);
380+
jz(l_no_mask);
381381

382382
lea(reg_tmp, ptr[reg_coff + vlen]);
383383
cmp(reg_tmp, reg_coff_max);
384-
jl(l_no_mask);
384+
jl(l_no_mask);
385385
assert(isa == avx512_common || isa == avx2);
386386
if (isa == avx512_common)
387387
uni_vmovups_tail_avx512_common(dst, src, l_ret);
@@ -400,7 +400,7 @@ struct jit_bnorm_t: public jit_generator {
400400
void barrier() {
401401
mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]);
402402
mov(reg_bar, ptr[rsp + stack_off_barrier]);
403-
simple_barrier::generate(*this, reg_bar, reg_nnthr);
403+
simple_barrier::generate(*this, reg_bar, reg_nnthr);
404404
}
405405

406406
Address mean_ptr(size_t offt = 0) {
@@ -456,7 +456,7 @@ struct jit_bnorm_t: public jit_generator {
456456
#ifdef DNNL_INDIRECT_JIT_AARCH64
457457
CodeGeneratorAArch64::cmp(Xbyak_aarch64::XReg(reg_ctr.getIdx()), 0);
458458
#endif
459-
jnz(label);
459+
jnz(label);
460460
}
461461
if (is_spatial_thr_) {
462462
add(reg_soff, ptr[rsp + stack_off_s_tail]);
@@ -507,7 +507,7 @@ struct jit_bnorm_t: public jit_generator {
507507

508508
add(reg_coff, vlen);
509509
cmp(reg_coff, reg_coff_max);
510-
jl(ch_label);
510+
jl(ch_label);
511511
}
512512
}
513513

@@ -551,7 +551,7 @@ struct jit_bnorm_t: public jit_generator {
551551
uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
552552
add(reg_coff, vlen);
553553
cmp(reg_coff, reg_coff_max);
554-
jl(ch_label);
554+
jl(ch_label);
555555
}
556556
}
557557

@@ -563,7 +563,7 @@ struct jit_bnorm_t: public jit_generator {
563563
uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
564564
add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
565565
cmp(reg_coff, reg_coff_max);
566-
jne(zero_rbuf);
566+
jne(zero_rbuf);
567567
}
568568

569569
mov(reg_src, ptr[rsp + stack_off_src]);
@@ -590,14 +590,14 @@ struct jit_bnorm_t: public jit_generator {
590590

591591
add(reg_soff, reg_mb_stride_Bc);
592592
cmp(reg_soff, reg_soff_max);
593-
jne(mean_spatial);
593+
jne(mean_spatial);
594594
}
595595

596596
Label no_mean_reduction;
597597
barrier(); {
598598
mov(reg_tmp, ptr[rsp + stack_off_N_ithr]);
599599
cmp(reg_tmp, 0);
600-
jne(no_mean_reduction);
600+
jne(no_mean_reduction);
601601
mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]);
602602
xor_(reg_coff, reg_coff);
603603
Label mean_reduction_channels;
@@ -615,19 +615,19 @@ struct jit_bnorm_t: public jit_generator {
615615
#ifdef DNNL_INDIRECT_JIT_AARCH64
616616
CodeGeneratorAArch64::cmp(Xbyak_aarch64::XReg(reg_ctr.getIdx()), 0);
617617
#endif
618-
jnz(mean_reduction_thrs);
618+
jnz(mean_reduction_thrs);
619619
}
620620
uni_vdivps(Vmm(1), Vmm(1), vchan_size);
621621
uni_vmovups_maybe_tail(mean_ptr(), Vmm(1));
622622

623623
add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
624624

625625
cmp(reg_coff, reg_coff_max);
626-
jne(mean_reduction_channels);
626+
jne(mean_reduction_channels);
627627
}
628628
}
629629
L(no_mean_reduction);
630-
barrier();
630+
barrier();
631631

632632
xor_(reg_soff, reg_soff);
633633
Label var_spatial;
@@ -651,14 +651,14 @@ struct jit_bnorm_t: public jit_generator {
651651

652652
add(reg_soff, reg_mb_stride_Bc);
653653
cmp(reg_soff, reg_soff_max);
654-
jne(var_spatial);
654+
jne(var_spatial);
655655
}
656656

657657
Label no_var_reduction;
658658
barrier(); {
659659
mov(reg_tmp, ptr[rsp + stack_off_N_ithr]);
660660
cmp(reg_tmp, 0);
661-
jne(no_var_reduction);
661+
jne(no_var_reduction);
662662

663663
mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]);
664664
xor_(reg_coff, reg_coff);
@@ -675,18 +675,18 @@ struct jit_bnorm_t: public jit_generator {
675675
#ifdef DNNL_INDIRECT_JIT_AARCH64
676676
CodeGeneratorAArch64::cmp(Xbyak_aarch64::XReg(reg_ctr.getIdx()), 0);
677677
#endif
678-
jnz(var_reduction_thrs);
678+
jnz(var_reduction_thrs);
679679
}
680680
uni_vdivps(Vmm(1), Vmm(1), vchan_size);
681681
uni_vmovups_maybe_tail(var_ptr(), Vmm(1));
682682
add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
683683

684684
cmp(reg_coff, reg_coff_max);
685-
jne(var_reduction_channels);
685+
jne(var_reduction_channels);
686686
}
687687
}
688688
L(no_var_reduction);
689-
barrier();
689+
barrier();
690690
}
691691

692692
void forward_channels() {
@@ -755,9 +755,9 @@ struct jit_bnorm_t: public jit_generator {
755755
} else {
756756
Label unaligned_store, end_store;
757757
test(reg_dst, vlen - 1);
758-
jnz(unaligned_store, T_NEAR);
758+
jnz(unaligned_store, T_NEAR);
759759
compute(true);
760-
jmp(end_store, T_NEAR);
760+
jmp(end_store, T_NEAR);
761761
L(unaligned_store); {
762762
compute(false);
763763
}
@@ -766,7 +766,7 @@ struct jit_bnorm_t: public jit_generator {
766766

767767
add(reg_coff, vlen);
768768
cmp(reg_coff, reg_coff_max);
769-
jl(ch_label);
769+
jl(ch_label);
770770
}
771771
}
772772

@@ -798,7 +798,7 @@ struct jit_bnorm_t: public jit_generator {
798798

799799
add(reg_soff, reg_mb_stride_Bc);
800800
cmp(reg_soff, reg_soff_max);
801-
jnz(dst_spatial);
801+
jnz(dst_spatial);
802802
}
803803
}
804804

@@ -865,7 +865,7 @@ struct jit_bnorm_t: public jit_generator {
865865
uni_vmovups(vmmword[reg_rbuf2 + reg_coff], Vmm(1));
866866
add(reg_coff, vlen);
867867
cmp(reg_coff, reg_coff_max);
868-
jl(sh_channels);
868+
jl(sh_channels);
869869
}
870870
}
871871

@@ -941,9 +941,9 @@ struct jit_bnorm_t: public jit_generator {
941941
} else {
942942
Label unaligned_store, end_store;
943943
test(reg_diff_src, vlen - 1);
944-
jnz(unaligned_store, T_NEAR);
944+
jnz(unaligned_store, T_NEAR);
945945
compute(true);
946-
jmp(end_store, T_NEAR);
946+
jmp(end_store, T_NEAR);
947947
L(unaligned_store); {
948948
compute(false);
949949
}
@@ -952,7 +952,7 @@ struct jit_bnorm_t: public jit_generator {
952952

953953
add(reg_coff, vlen);
954954
cmp(reg_coff, reg_coff_max);
955-
jl(diff_channels);
955+
jl(diff_channels);
956956
}
957957
}
958958

@@ -966,7 +966,7 @@ struct jit_bnorm_t: public jit_generator {
966966
uni_vmovups(vmmword[reg_rbuf2 + reg_coff], Vmm(0));
967967
add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
968968
cmp(reg_coff, reg_coff_max);
969-
jne(zero_rbuf);
969+
jne(zero_rbuf);
970970
}
971971

972972
mov(reg_src, ptr[rsp + stack_off_src]);
@@ -994,7 +994,7 @@ struct jit_bnorm_t: public jit_generator {
994994
}
995995
add(reg_soff, reg_mb_stride_Bc);
996996
cmp(reg_soff, reg_soff_max);
997-
jne(sh_spatial);
997+
jne(sh_spatial);
998998
}
999999

10001000
mov(reg_diff_scale_shift, ptr[rsp + stack_off_diff_scale_shift]);
@@ -1004,7 +1004,7 @@ struct jit_bnorm_t: public jit_generator {
10041004
mov(reg_tmp, ptr[rsp + stack_off_N_ithr]);
10051005
cmp(reg_tmp, 0);
10061006
Label sh_reduction_channels;
1007-
jne(no_sh_reduction, T_NEAR);
1007+
jne(no_sh_reduction, T_NEAR);
10081008

10091009
mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]);
10101010
xor_(reg_coff, reg_coff);
@@ -1026,7 +1026,7 @@ struct jit_bnorm_t: public jit_generator {
10261026
#ifdef DNNL_INDIRECT_JIT_AARCH64
10271027
CodeGeneratorAArch64::cmp(Xbyak_aarch64::XReg(reg_ctr.getIdx()), 0);
10281028
#endif
1029-
jnz(sh_reduction_thrs);
1029+
jnz(sh_reduction_thrs);
10301030
}
10311031
uni_vmulps(Vmm(0), Vmm(0), vsqrtvar);
10321032
uni_vmovups_maybe_tail(diff_gamma_ptr(), Vmm(0));
@@ -1037,7 +1037,7 @@ struct jit_bnorm_t: public jit_generator {
10371037
}
10381038
}
10391039
L(no_sh_reduction);
1040-
barrier();
1040+
barrier();
10411041

10421042
mov(reg_diff_src, ptr[rsp + stack_off_diff_src]);
10431043
if (with_relu) {
@@ -1087,11 +1087,51 @@ struct jit_bnorm_t: public jit_generator {
10871087
unroll_blocks = isa == avx512_common && !is_spatial_thr_ ? 4 : 1;
10881088
unroll_regs = isa == avx512_common && !is_spatial_thr_ ? 4 : 1;
10891089

1090-
preamble();
10911090
#ifdef DNNL_INDIRECT_JIT_AARCH64
1092-
setAll1Preg0_7(7);
1093-
#endif
1091+
for(int phase = 0; phase < 2; phase++){
1092+
if ( phase == 0){
1093+
initSearchPReg();
1094+
initSearchZReg();
1095+
unSetGenJitMode();
1096+
} else {
1097+
this->clearCodeArray();
1098+
setGenJitMode();
1099+
}
1100+
preamble();
1101+
setAll1Preg0_7(7);
1102+
1103+
if (is_bf16_) {
1104+
// init emulation of bfloat16 operations
1105+
if (!mayiuse(avx512_core_bf16)) {
1106+
bf16_emu_ = new bf16_emulation_t(this, vcvt_bf16_one,
1107+
vcvt_bf16_eve, vcvt_bf16_sel, reg_bf16_tmp,
1108+
vcvt_bf16_tmp, vcvt_bf16_tmp);
1109+
bf16_emu_->init_vcvtneps2bf16();
1110+
}
1111+
}
1112+
1113+
if (isa == avx512_common)
1114+
prepare_tail_mask_avx512_common();
1115+
else if (isa == avx2)
1116+
prepare_tail_mask_avx2_common();
10941117

1118+
compute_static_strides();
1119+
sub(rsp, stack_size_required);
1120+
load_common_params();
1121+
prepare_relu();
1122+
1123+
if (bdesc_->is_fwd()) {
1124+
if (!bdesc_->stats_is_src()) {
1125+
compute_mean_variance();
1126+
}
1127+
forward();
1128+
} else {
1129+
backward();
1130+
}
1131+
add(rsp, stack_size_required);
1132+
}
1133+
#else
1134+
preamble();
10951135
if (is_bf16_) {
10961136
// init emulation of bfloat16 operations
10971137
if (!mayiuse(avx512_core_bf16)) {
@@ -1121,6 +1161,7 @@ struct jit_bnorm_t: public jit_generator {
11211161
backward();
11221162
}
11231163
add(rsp, stack_size_required);
1164+
#endif
11241165

11251166
#ifdef DNNL_INDIRECT_JIT_AARCH64
11261167
clearAll1Preg0_7();

0 commit comments

Comments
 (0)