@@ -283,12 +283,12 @@ struct jit_bnorm_t: public jit_generator {
283
283
284
284
void prepare_l_relu_mask_avx2 () {
285
285
Label l_mask_after;
286
- jmp (l_mask_after);
286
+ jmp (l_mask_after);
287
287
align (32 );
288
288
L (l_relu_mask_avx2); /* [0x80 0x40 0x20 0x10 0x08 0x04 0x02 0x01] */
289
289
for (int i = 0 ; i < 8 ; ++i) dd (1 <<i);
290
290
#ifdef DNNL_INDIRECT_JIT_AARCH64
291
- binCommit ();
291
+ binCommit ();
292
292
#endif
293
293
L (l_mask_after);
294
294
}
@@ -358,7 +358,7 @@ struct jit_bnorm_t: public jit_generator {
358
358
} else {
359
359
vmaskmovps (Vmm (dst.getIdx ()), vtail_mask, src.getAddress ());
360
360
}
361
- jmp (l_ret);
361
+ jmp (l_ret);
362
362
}
363
363
364
364
void uni_vmovups_tail_avx512_common (const Operand &dst,
@@ -368,7 +368,7 @@ struct jit_bnorm_t: public jit_generator {
368
368
else
369
369
uni_vmovups (Vmm (dst.getIdx ()) | ktail_mask | T_z, src.getAddress ());
370
370
371
- jmp (l_ret);
371
+ jmp (l_ret);
372
372
}
373
373
374
374
void uni_vmovups_maybe_tail (const Operand &dst, const Operand &src) {
@@ -377,11 +377,11 @@ struct jit_bnorm_t: public jit_generator {
377
377
if (is_c_padded ()) {
378
378
mov (reg_tmp, ptr[rsp + stack_off_is_cblk_tail]);
379
379
cmp (reg_tmp, 0 );
380
- jz (l_no_mask);
380
+ jz (l_no_mask);
381
381
382
382
lea (reg_tmp, ptr[reg_coff + vlen]);
383
383
cmp (reg_tmp, reg_coff_max);
384
- jl (l_no_mask);
384
+ jl (l_no_mask);
385
385
assert (isa == avx512_common || isa == avx2);
386
386
if (isa == avx512_common)
387
387
uni_vmovups_tail_avx512_common (dst, src, l_ret);
@@ -400,7 +400,7 @@ struct jit_bnorm_t: public jit_generator {
400
400
void barrier () {
401
401
mov (reg_nnthr, ptr[rsp + stack_off_N_nthr]);
402
402
mov (reg_bar, ptr[rsp + stack_off_barrier]);
403
- simple_barrier::generate (*this , reg_bar, reg_nnthr);
403
+ simple_barrier::generate (*this , reg_bar, reg_nnthr);
404
404
}
405
405
406
406
Address mean_ptr (size_t offt = 0 ) {
@@ -456,7 +456,7 @@ struct jit_bnorm_t: public jit_generator {
456
456
#ifdef DNNL_INDIRECT_JIT_AARCH64
457
457
CodeGeneratorAArch64::cmp (Xbyak_aarch64::XReg (reg_ctr.getIdx ()), 0 );
458
458
#endif
459
- jnz (label);
459
+ jnz (label);
460
460
}
461
461
if (is_spatial_thr_) {
462
462
add (reg_soff, ptr[rsp + stack_off_s_tail]);
@@ -507,7 +507,7 @@ struct jit_bnorm_t: public jit_generator {
507
507
508
508
add (reg_coff, vlen);
509
509
cmp (reg_coff, reg_coff_max);
510
- jl (ch_label);
510
+ jl (ch_label);
511
511
}
512
512
}
513
513
@@ -551,7 +551,7 @@ struct jit_bnorm_t: public jit_generator {
551
551
uni_vmovups (vmmword[reg_rbuf1 + reg_coff], Vmm (0 ));
552
552
add (reg_coff, vlen);
553
553
cmp (reg_coff, reg_coff_max);
554
- jl (ch_label);
554
+ jl (ch_label);
555
555
}
556
556
}
557
557
@@ -563,7 +563,7 @@ struct jit_bnorm_t: public jit_generator {
563
563
uni_vmovups (vmmword[reg_rbuf1 + reg_coff], Vmm (0 ));
564
564
add (reg_coff, isa == sse42 ? vlen / 2 : vlen);
565
565
cmp (reg_coff, reg_coff_max);
566
- jne (zero_rbuf);
566
+ jne (zero_rbuf);
567
567
}
568
568
569
569
mov (reg_src, ptr[rsp + stack_off_src]);
@@ -590,14 +590,14 @@ struct jit_bnorm_t: public jit_generator {
590
590
591
591
add (reg_soff, reg_mb_stride_Bc);
592
592
cmp (reg_soff, reg_soff_max);
593
- jne (mean_spatial);
593
+ jne (mean_spatial);
594
594
}
595
595
596
596
Label no_mean_reduction;
597
597
barrier (); {
598
598
mov (reg_tmp, ptr[rsp + stack_off_N_ithr]);
599
599
cmp (reg_tmp, 0 );
600
- jne (no_mean_reduction);
600
+ jne (no_mean_reduction);
601
601
mov (reg_nnthr, ptr[rsp + stack_off_N_nthr]);
602
602
xor_ (reg_coff, reg_coff);
603
603
Label mean_reduction_channels;
@@ -615,19 +615,19 @@ struct jit_bnorm_t: public jit_generator {
615
615
#ifdef DNNL_INDIRECT_JIT_AARCH64
616
616
CodeGeneratorAArch64::cmp (Xbyak_aarch64::XReg (reg_ctr.getIdx ()), 0 );
617
617
#endif
618
- jnz (mean_reduction_thrs);
618
+ jnz (mean_reduction_thrs);
619
619
}
620
620
uni_vdivps (Vmm (1 ), Vmm (1 ), vchan_size);
621
621
uni_vmovups_maybe_tail (mean_ptr (), Vmm (1 ));
622
622
623
623
add (reg_coff, isa == sse42 ? vlen / 2 : vlen);
624
624
625
625
cmp (reg_coff, reg_coff_max);
626
- jne (mean_reduction_channels);
626
+ jne (mean_reduction_channels);
627
627
}
628
628
}
629
629
L (no_mean_reduction);
630
- barrier ();
630
+ barrier ();
631
631
632
632
xor_ (reg_soff, reg_soff);
633
633
Label var_spatial;
@@ -651,14 +651,14 @@ struct jit_bnorm_t: public jit_generator {
651
651
652
652
add (reg_soff, reg_mb_stride_Bc);
653
653
cmp (reg_soff, reg_soff_max);
654
- jne (var_spatial);
654
+ jne (var_spatial);
655
655
}
656
656
657
657
Label no_var_reduction;
658
658
barrier (); {
659
659
mov (reg_tmp, ptr[rsp + stack_off_N_ithr]);
660
660
cmp (reg_tmp, 0 );
661
- jne (no_var_reduction);
661
+ jne (no_var_reduction);
662
662
663
663
mov (reg_nnthr, ptr[rsp + stack_off_N_nthr]);
664
664
xor_ (reg_coff, reg_coff);
@@ -675,18 +675,18 @@ struct jit_bnorm_t: public jit_generator {
675
675
#ifdef DNNL_INDIRECT_JIT_AARCH64
676
676
CodeGeneratorAArch64::cmp (Xbyak_aarch64::XReg (reg_ctr.getIdx ()), 0 );
677
677
#endif
678
- jnz (var_reduction_thrs);
678
+ jnz (var_reduction_thrs);
679
679
}
680
680
uni_vdivps (Vmm (1 ), Vmm (1 ), vchan_size);
681
681
uni_vmovups_maybe_tail (var_ptr (), Vmm (1 ));
682
682
add (reg_coff, isa == sse42 ? vlen / 2 : vlen);
683
683
684
684
cmp (reg_coff, reg_coff_max);
685
- jne (var_reduction_channels);
685
+ jne (var_reduction_channels);
686
686
}
687
687
}
688
688
L (no_var_reduction);
689
- barrier ();
689
+ barrier ();
690
690
}
691
691
692
692
void forward_channels () {
@@ -755,9 +755,9 @@ struct jit_bnorm_t: public jit_generator {
755
755
} else {
756
756
Label unaligned_store, end_store;
757
757
test (reg_dst, vlen - 1 );
758
- jnz (unaligned_store, T_NEAR);
758
+ jnz (unaligned_store, T_NEAR);
759
759
compute (true );
760
- jmp (end_store, T_NEAR);
760
+ jmp (end_store, T_NEAR);
761
761
L (unaligned_store); {
762
762
compute (false );
763
763
}
@@ -766,7 +766,7 @@ struct jit_bnorm_t: public jit_generator {
766
766
767
767
add (reg_coff, vlen);
768
768
cmp (reg_coff, reg_coff_max);
769
- jl (ch_label);
769
+ jl (ch_label);
770
770
}
771
771
}
772
772
@@ -798,7 +798,7 @@ struct jit_bnorm_t: public jit_generator {
798
798
799
799
add (reg_soff, reg_mb_stride_Bc);
800
800
cmp (reg_soff, reg_soff_max);
801
- jnz (dst_spatial);
801
+ jnz (dst_spatial);
802
802
}
803
803
}
804
804
@@ -865,7 +865,7 @@ struct jit_bnorm_t: public jit_generator {
865
865
uni_vmovups (vmmword[reg_rbuf2 + reg_coff], Vmm (1 ));
866
866
add (reg_coff, vlen);
867
867
cmp (reg_coff, reg_coff_max);
868
- jl (sh_channels);
868
+ jl (sh_channels);
869
869
}
870
870
}
871
871
@@ -941,9 +941,9 @@ struct jit_bnorm_t: public jit_generator {
941
941
} else {
942
942
Label unaligned_store, end_store;
943
943
test (reg_diff_src, vlen - 1 );
944
- jnz (unaligned_store, T_NEAR);
944
+ jnz (unaligned_store, T_NEAR);
945
945
compute (true );
946
- jmp (end_store, T_NEAR);
946
+ jmp (end_store, T_NEAR);
947
947
L (unaligned_store); {
948
948
compute (false );
949
949
}
@@ -952,7 +952,7 @@ struct jit_bnorm_t: public jit_generator {
952
952
953
953
add (reg_coff, vlen);
954
954
cmp (reg_coff, reg_coff_max);
955
- jl (diff_channels);
955
+ jl (diff_channels);
956
956
}
957
957
}
958
958
@@ -966,7 +966,7 @@ struct jit_bnorm_t: public jit_generator {
966
966
uni_vmovups (vmmword[reg_rbuf2 + reg_coff], Vmm (0 ));
967
967
add (reg_coff, isa == sse42 ? vlen / 2 : vlen);
968
968
cmp (reg_coff, reg_coff_max);
969
- jne (zero_rbuf);
969
+ jne (zero_rbuf);
970
970
}
971
971
972
972
mov (reg_src, ptr[rsp + stack_off_src]);
@@ -994,7 +994,7 @@ struct jit_bnorm_t: public jit_generator {
994
994
}
995
995
add (reg_soff, reg_mb_stride_Bc);
996
996
cmp (reg_soff, reg_soff_max);
997
- jne (sh_spatial);
997
+ jne (sh_spatial);
998
998
}
999
999
1000
1000
mov (reg_diff_scale_shift, ptr[rsp + stack_off_diff_scale_shift]);
@@ -1004,7 +1004,7 @@ struct jit_bnorm_t: public jit_generator {
1004
1004
mov (reg_tmp, ptr[rsp + stack_off_N_ithr]);
1005
1005
cmp (reg_tmp, 0 );
1006
1006
Label sh_reduction_channels;
1007
- jne (no_sh_reduction, T_NEAR);
1007
+ jne (no_sh_reduction, T_NEAR);
1008
1008
1009
1009
mov (reg_nnthr, ptr[rsp + stack_off_N_nthr]);
1010
1010
xor_ (reg_coff, reg_coff);
@@ -1026,7 +1026,7 @@ struct jit_bnorm_t: public jit_generator {
1026
1026
#ifdef DNNL_INDIRECT_JIT_AARCH64
1027
1027
CodeGeneratorAArch64::cmp (Xbyak_aarch64::XReg (reg_ctr.getIdx ()), 0 );
1028
1028
#endif
1029
- jnz (sh_reduction_thrs);
1029
+ jnz (sh_reduction_thrs);
1030
1030
}
1031
1031
uni_vmulps (Vmm (0 ), Vmm (0 ), vsqrtvar);
1032
1032
uni_vmovups_maybe_tail (diff_gamma_ptr (), Vmm (0 ));
@@ -1037,7 +1037,7 @@ struct jit_bnorm_t: public jit_generator {
1037
1037
}
1038
1038
}
1039
1039
L (no_sh_reduction);
1040
- barrier ();
1040
+ barrier ();
1041
1041
1042
1042
mov (reg_diff_src, ptr[rsp + stack_off_diff_src]);
1043
1043
if (with_relu) {
@@ -1087,11 +1087,51 @@ struct jit_bnorm_t: public jit_generator {
1087
1087
unroll_blocks = isa == avx512_common && !is_spatial_thr_ ? 4 : 1 ;
1088
1088
unroll_regs = isa == avx512_common && !is_spatial_thr_ ? 4 : 1 ;
1089
1089
1090
- preamble ();
1091
1090
#ifdef DNNL_INDIRECT_JIT_AARCH64
1092
- setAll1Preg0_7 (7 );
1093
- #endif
1091
+ for (int phase = 0 ; phase < 2 ; phase++){
1092
+ if ( phase == 0 ){
1093
+ initSearchPReg ();
1094
+ initSearchZReg ();
1095
+ unSetGenJitMode ();
1096
+ } else {
1097
+ this ->clearCodeArray ();
1098
+ setGenJitMode ();
1099
+ }
1100
+ preamble ();
1101
+ setAll1Preg0_7 (7 );
1102
+
1103
+ if (is_bf16_) {
1104
+ // init emulation of bfloat16 operations
1105
+ if (!mayiuse (avx512_core_bf16)) {
1106
+ bf16_emu_ = new bf16_emulation_t (this , vcvt_bf16_one,
1107
+ vcvt_bf16_eve, vcvt_bf16_sel, reg_bf16_tmp,
1108
+ vcvt_bf16_tmp, vcvt_bf16_tmp);
1109
+ bf16_emu_->init_vcvtneps2bf16 ();
1110
+ }
1111
+ }
1112
+
1113
+ if (isa == avx512_common)
1114
+ prepare_tail_mask_avx512_common ();
1115
+ else if (isa == avx2)
1116
+ prepare_tail_mask_avx2_common ();
1094
1117
1118
+ compute_static_strides ();
1119
+ sub (rsp, stack_size_required);
1120
+ load_common_params ();
1121
+ prepare_relu ();
1122
+
1123
+ if (bdesc_->is_fwd ()) {
1124
+ if (!bdesc_->stats_is_src ()) {
1125
+ compute_mean_variance ();
1126
+ }
1127
+ forward ();
1128
+ } else {
1129
+ backward ();
1130
+ }
1131
+ add (rsp, stack_size_required);
1132
+ }
1133
+ #else
1134
+ preamble ();
1095
1135
if (is_bf16_) {
1096
1136
// init emulation of bfloat16 operations
1097
1137
if (!mayiuse (avx512_core_bf16)) {
@@ -1121,6 +1161,7 @@ struct jit_bnorm_t: public jit_generator {
1121
1161
backward ();
1122
1162
}
1123
1163
add (rsp, stack_size_required);
1164
+ #endif
1124
1165
1125
1166
#ifdef DNNL_INDIRECT_JIT_AARCH64
1126
1167
clearAll1Preg0_7 ();
0 commit comments