From f52800f1a06b9ba4f8d6f4695106c73b7a409629 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristian=20I=C3=B1iguez=20Rodriguez?= <60439856+criniguez@users.noreply.github.com> Date: Thu, 16 Jan 2025 14:26:46 +0100 Subject: [PATCH] Vectorized extend kernels integration in WFA2-lib. (#99) Adding support for AVX2 and AVX512 accelerated extend @criniguez --- wavefront/wavefront_extend.c | 4 - wavefront/wavefront_extend_kernels.c | 75 ++- wavefront/wavefront_extend_kernels_avx.c | 596 +++++++++++++++++++++-- wavefront/wavefront_extend_kernels_avx.h | 33 ++ 4 files changed, 644 insertions(+), 64 deletions(-) diff --git a/wavefront/wavefront_extend.c b/wavefront/wavefront_extend.c index 1ba7886..77ede3f 100644 --- a/wavefront/wavefront_extend.c +++ b/wavefront/wavefront_extend.c @@ -54,11 +54,7 @@ void wavefront_extend_end2end_dispatcher_seq( wavefront_sequences_t* const seqs = &wf_aligner->sequences; // Check the sequence mode if (seqs->mode == wf_sequences_ascii) { -//#if __AVX2__ // TODO -// wavefront_extend_matches_packed_end2end_avx2(wf_aligner,mwavefront,lo,hi); -//#else wavefront_extend_matches_packed_end2end(wf_aligner,mwavefront,lo,hi); -//#endif } else { wf_offset_t dummy; wavefront_extend_matches_custom(wf_aligner,mwavefront,score,lo,hi,false,&dummy); diff --git a/wavefront/wavefront_extend_kernels.c b/wavefront/wavefront_extend_kernels.c index f962ff2..031f951 100644 --- a/wavefront/wavefront_extend_kernels.c +++ b/wavefront/wavefront_extend_kernels.c @@ -35,6 +35,7 @@ #include "wavefront_extend_kernels.h" #include "wavefront_termination.h" +#include "wavefront_extend_kernels_avx.h" #if __BYTE_ORDER == __LITTLE_ENDIAN #define wavefront_extend_matches_kernel wavefront_extend_matches_kernel_blockwise @@ -63,6 +64,7 @@ FORCE_INLINE wf_offset_t wavefront_extend_matches_kernel_charwise( // Return extended offset return offset; } + FORCE_INLINE wf_offset_t wavefront_extend_matches_kernel_blockwise( wavefront_aligner_t* const wf_aligner, const int k, @@ -88,6 +90,7 @@ FORCE_INLINE wf_offset_t wavefront_extend_matches_kernel_blockwise( // Return extended offset return offset; } + /* * Wavefront-Extend Inner Kernels * Wavefront offset extension comparing characters @@ -100,42 +103,67 @@ FORCE_NO_INLINE void wavefront_extend_matches_packed_end2end( wavefront_t* const mwavefront, const int lo, const int hi) { - wf_offset_t* const offsets = mwavefront->offsets; - int k; - for (k=lo;k<=hi;++k) { - // Fetch offset - const wf_offset_t offset = offsets[k]; - if (offset == WAVEFRONT_OFFSET_NULL) continue; - // Extend offset - offsets[k] = wavefront_extend_matches_kernel(wf_aligner,k,offset); - } + #if __AVX2__ && __BYTE_ORDER == __LITTLE_ENDIAN + #if __AVX512CD__ && __AVX512VL__ + wavefront_extend_matches_packed_end2end_avx512(wf_aligner, mwavefront, lo, hi); + #else + wavefront_extend_matches_packed_end2end_avx2(wf_aligner, mwavefront, lo, hi); + #endif + #else + wf_offset_t* const offsets = mwavefront->offsets; + int k; + for (k=lo;k<=hi;++k) { + // Fetch offset + const wf_offset_t offset = offsets[k]; + if (offset == WAVEFRONT_OFFSET_NULL) continue; + // Extend offset + offsets[k] = wavefront_extend_matches_kernel(wf_aligner,k,offset); + } + #endif } + FORCE_NO_INLINE wf_offset_t wavefront_extend_matches_packed_end2end_max( wavefront_aligner_t* const wf_aligner, wavefront_t* const mwavefront, const int lo, const int hi) { - wf_offset_t* const offsets = mwavefront->offsets; - wf_offset_t max_antidiag = 0; - int k; - for (k=lo;k<=hi;++k) { - // Fetch offset - const wf_offset_t offset = offsets[k]; - if (offset == WAVEFRONT_OFFSET_NULL) continue; - // Extend offset - offsets[k] = wavefront_extend_matches_kernel(wf_aligner,k,offset); - // Compute max - const wf_offset_t antidiag = WAVEFRONT_ANTIDIAGONAL(k,offsets[k]); - if (max_antidiag < antidiag) max_antidiag = antidiag; - } - return max_antidiag; + #if __AVX2__ && __BYTE_ORDER == __LITTLE_ENDIAN + #if __AVX512CD__ && __AVX512VL__ + return wavefront_extend_matches_packed_end2end_max_avx512(wf_aligner, mwavefront, lo, hi); + #else + return wavefront_extend_matches_packed_end2end_max_avx2(wf_aligner, mwavefront, lo, hi); + #endif + #else + wf_offset_t* const offsets = mwavefront->offsets; + wf_offset_t max_antidiag = 0; + int k; + for (k=lo;k<=hi;++k) { + // Fetch offset + const wf_offset_t offset = offsets[k]; + if (offset == WAVEFRONT_OFFSET_NULL) continue; + // Extend offset + offsets[k] = wavefront_extend_matches_kernel(wf_aligner,k,offset); + // Compute max + const wf_offset_t antidiag = WAVEFRONT_ANTIDIAGONAL(k,offsets[k]); + if (max_antidiag < antidiag) max_antidiag = antidiag; + } + return max_antidiag; + #endif } + FORCE_NO_INLINE bool wavefront_extend_matches_packed_endsfree( wavefront_aligner_t* const wf_aligner, wavefront_t* const mwavefront, const int score, const int lo, const int hi) { + #if __AVX2__ && __BYTE_ORDER == __LITTLE_ENDIAN + #if __AVX512CD__ && __AVX512VL__ + return wavefront_extend_matches_packed_endsfree_avx512(wf_aligner, mwavefront, score, lo, hi); + #else + return wavefront_extend_matches_packed_endsfree_avx2(wf_aligner, mwavefront, score, lo, hi); + #endif + #else // Parameters wf_offset_t* const offsets = mwavefront->offsets; int k; @@ -162,6 +190,7 @@ FORCE_NO_INLINE bool wavefront_extend_matches_packed_endsfree( } // Alignment not finished return false; + #endif } /* * Wavefront-Extend Inner Kernel (Custom match function) diff --git a/wavefront/wavefront_extend_kernels_avx.c b/wavefront/wavefront_extend_kernels_avx.c index b31d8d5..244c9d4 100644 --- a/wavefront/wavefront_extend_kernels_avx.c +++ b/wavefront/wavefront_extend_kernels_avx.c @@ -35,6 +35,7 @@ #include "wavefront_heuristic.h" #include "wavefront_extend_kernels.h" #include "wavefront_extend_kernels_avx.h" +#include "wavefront_termination.h" #if __AVX2__ #include @@ -66,25 +67,29 @@ FORCE_INLINE wf_offset_t wavefront_extend_matches_packed_kernel( // Return extended offset return offset; } + /* * SIMD clz, use a native instruction when available (AVX512 CD or VL * extensions), or emulate the clz behavior. */ FORCE_INLINE __m256i avx2_lzcnt_epi32(__m256i v) { -#if __AVX512CD__ && __AVX512VL__ - return _mm256_lzcnt_epi32(v); -#else - // Emulate clz for AVX2: https://stackoverflow.com/a/58827596 - v = _mm256_andnot_si256(_mm256_srli_epi32(v,8),v); // keep 8 MSB - v = _mm256_castps_si256(_mm256_cvtepi32_ps(v)); // convert an integer to float - v = _mm256_srli_epi32(v,23); // shift down the exponent - v = _mm256_subs_epu16(_mm256_set1_epi32(158),v); // undo bias - v = _mm256_min_epi16(v,_mm256_set1_epi32(32)); // clamp at 32 - return v; -#endif + #if __AVX512CD__ && __AVX512VL__ + return _mm256_lzcnt_epi32(v); + #else + // Emulate clz for AVX2: https://stackoverflow.com/a/58827596 + v = _mm256_andnot_si256(_mm256_srli_epi32(v,8),v); // keep 8 MSB + v = _mm256_castps_si256(_mm256_cvtepi32_ps(v)); // convert an integer to float + v = _mm256_srli_epi32(v,23); // shift down the exponent + v = _mm256_subs_epu16(_mm256_set1_epi32(158),v); // undo bias + v = _mm256_min_epi16(v,_mm256_set1_epi32(32)); // clamp at 32 + return v; + #endif } + + + /* - * Wavefront-Extend Inner Kernel (SIMD AVX2/AVX512) + * Wavefront-Extend Inner Kernel (SIMD AVX2) */ FORCE_NO_INLINE void wavefront_extend_matches_packed_end2end_avx2( wavefront_aligner_t* const wf_aligner, @@ -98,7 +103,6 @@ FORCE_NO_INLINE void wavefront_extend_matches_packed_end2end_avx2( const char* pattern = wf_aligner->sequences.pattern; const char* text = wf_aligner->sequences.text; const __m256i vector_null = _mm256_set1_epi32(-1); - const __m256i fours = _mm256_set1_epi32(4); const __m256i eights = _mm256_set1_epi32(8); const __m256i vecShuffle = _mm256_set_epi8(28,29,30,31,24,25,26,27, 20,21,22,23,16,17,18,19, @@ -115,40 +119,41 @@ FORCE_NO_INLINE void wavefront_extend_matches_packed_end2end_avx2( offsets[k] = wavefront_extend_matches_packed_kernel(wf_aligner,k,offset); } if (num_of_diagonals < elems_per_register) return; + k_min += loop_peeling_iters; __m256i ks = _mm256_set_epi32 ( k_min+7,k_min+6,k_min+5,k_min+4, k_min+3,k_min+2,k_min+1,k_min); - // Main SIMD extension loop - for (k=k_min;k<=k_max;k+=elems_per_register) { + + + for (k=k_min; k<=k_max; k+=elems_per_register) { __m256i offsets_vector = _mm256_lddqu_si256 ((__m256i*)&offsets[k]); - __m256i h_vector = offsets_vector; - __m256i v_vector = _mm256_sub_epi32(offsets_vector,ks); - ks =_mm256_add_epi32 (ks, eights); + __m256i h_vector = offsets_vector; + __m256i v_vector = _mm256_sub_epi32(offsets_vector,ks); + // NULL offsets will read at index 0 (avoid segfaults) - __m256i null_mask = _mm256_cmpgt_epi32(offsets_vector,vector_null); - v_vector = _mm256_and_si256(null_mask,v_vector); - h_vector = _mm256_and_si256(null_mask,h_vector); + __m256i null_mask = _mm256_cmpgt_epi32(offsets_vector, vector_null); + v_vector = _mm256_and_si256(null_mask, v_vector); + h_vector = _mm256_and_si256(null_mask, h_vector); + __m256i pattern_vector = _mm256_i32gather_epi32((int const*)&pattern[0],v_vector,1); - __m256i text_vector = _mm256_i32gather_epi32((int const*)&text[0],h_vector,1); - // Change endianess to make the xor + clz character comparison - pattern_vector = _mm256_shuffle_epi8(pattern_vector,vecShuffle); - text_vector = _mm256_shuffle_epi8(text_vector,vecShuffle); + __m256i text_vector = _mm256_i32gather_epi32((int const*)&text[0],h_vector,1); + __m256i vector_mask = _mm256_cmpeq_epi32(text_vector, pattern_vector); + int mask = _mm256_movemask_epi8(vector_mask); + __m256i xor_result_vector = _mm256_xor_si256(pattern_vector,text_vector); - __m256i clz_vector = avx2_lzcnt_epi32(xor_result_vector); - // Divide clz by 8 to get the number of equal characters - // Assume there are sentinels on sequences so we won't count characters - // outside the sequences - __m256i equal_chars = _mm256_srli_epi32(clz_vector,3); - offsets_vector = _mm256_add_epi32 (offsets_vector,equal_chars); - v_vector = _mm256_add_epi32 (v_vector,fours); - h_vector = _mm256_add_epi32 (h_vector,fours); - // Lanes to continue == 0xffffffff, other lanes = 0 - __m256i vector_mask = _mm256_cmpeq_epi32(equal_chars,fours); + xor_result_vector = _mm256_shuffle_epi8(xor_result_vector, vecShuffle); + __m256i clz_vector = avx2_lzcnt_epi32(xor_result_vector); + + __m256i equal_chars = _mm256_srli_epi32(clz_vector,3); + //equal_chars = _mm256_and_si256(null_mask, equal_chars); + offsets_vector = _mm256_add_epi32 (offsets_vector,equal_chars); + ks = _mm256_add_epi32 (ks, eights); + _mm256_storeu_si256((__m256i*)&offsets[k],offsets_vector); - int mask = _mm256_movemask_epi8(vector_mask); + if(mask == 0) continue; - // ctz(0) is undefined + while (mask != 0) { int tz = __builtin_ctz(mask); int curr_k = k + (tz/4); @@ -164,4 +169,521 @@ FORCE_NO_INLINE void wavefront_extend_matches_packed_end2end_avx2( } } + +FORCE_NO_INLINE wf_offset_t wavefront_extend_matches_packed_end2end_max_avx2( + wavefront_aligner_t* const wf_aligner, + wavefront_t* const mwavefront, + const int lo, + const int hi) { + // Parameters + + const int elems_per_register = 8; + const char* pattern = wf_aligner->sequences.pattern; + const char* text = wf_aligner->sequences.text; + + wf_offset_t* const offsets = mwavefront->offsets; + wf_offset_t max_antidiag = 0; + + int k_min = lo; + int k_max = hi; + + int num_of_diagonals = k_max - k_min + 1; + int loop_peeling_iters = num_of_diagonals % elems_per_register; + int k; + + const __m256i vector_null = _mm256_set1_epi32(-1); + const __m256i eights = _mm256_set1_epi32(8); + const __m256i vecShuffle = _mm256_set_epi8(28,29,30,31,24,25,26,27, + 20,21,22,23,16,17,18,19, + 12,13,14,15, 8, 9,10,11, + 4 , 5, 6, 7, 0, 1, 2 ,3); + + + for (k=k_min;k= 0) { + offset = wavefront_extend_matches_packed_kernel(wf_aligner,curr_k,offset); + offsets[curr_k] = offset; + const wf_offset_t antidiag = WAVEFRONT_ANTIDIAGONAL(curr_k, offset); + if (max_antidiag < antidiag) max_antidiag = antidiag; + } else { + offsets[curr_k] = WAVEFRONT_OFFSET_NULL; + } + mask &= (0xfffffff0 << tz); + } + max_antidiag_v = _mm256_set1_epi32(max_antidiag); + } + } + + const wf_offset_t max_antidiagonal_buffer[8]; + _mm256_storeu_si256((__m256i*)&max_antidiagonal_buffer[0], max_antidiag_v); + for (int i = 0; i < 8; i++) + { + const wf_offset_t antidiag = max_antidiagonal_buffer[i]; + if (max_antidiag < antidiag) max_antidiag = antidiag; + } + return max_antidiag; +} + + +FORCE_NO_INLINE bool wavefront_extend_matches_packed_endsfree_avx2( + wavefront_aligner_t* const wf_aligner, + wavefront_t* const mwavefront, + const int score, + const int lo, + const int hi) { + // Parameters + + const int elems_per_register = 8; + const char* pattern = wf_aligner->sequences.pattern; + const char* text = wf_aligner->sequences.text; + + wf_offset_t* const offsets = mwavefront->offsets; + + int k_min = lo; + int k_max = hi; + + int num_of_diagonals = k_max - k_min + 1; + int loop_peeling_iters = num_of_diagonals % elems_per_register; + int k; + + const __m256i vector_null = _mm256_set1_epi32(-1); + const __m256i eights = _mm256_set1_epi32(8); + const __m256i vecShuffle = _mm256_set_epi8(28,29,30,31,24,25,26,27, + 20,21,22,23,16,17,18,19, + 12,13,14,15, 8, 9,10,11, + 4 , 5, 6, 7, 0, 1, 2 ,3); + + for (k=k_min;k= 0) { + offsets[curr_k] = wavefront_extend_matches_packed_kernel(wf_aligner,curr_k,offset); + if (wavefront_termination_endsfree(wf_aligner,mwavefront,score,curr_k,offsets[curr_k])) { + return true; // Quit (we are done) + } + } else { + offsets[curr_k] = WAVEFRONT_OFFSET_NULL; + } + mask &= (0xfffffff0 << tz); + } + } + for (k=k_min; k <= k_max; k++) { + const wf_offset_t offset = offsets[k]; + if (offset < 0) continue; + // Check ends-free reaching boundaries + if (wavefront_termination_endsfree(wf_aligner,mwavefront,score,k,offset)) { + return true; // Quit (we are done) + } + } + return false; +} + + +#if __AVX512CD__ && __AVX512VL__ +/* + * Wavefront-Extend Inner Kernel (SIMD AVX512) + */ +FORCE_NO_INLINE void wavefront_extend_matches_packed_end2end_avx512( + wavefront_aligner_t* const wf_aligner, + wavefront_t* const mwavefront, + const int lo, + const int hi) { + // Parameters + wf_offset_t* const offsets = mwavefront->offsets; + int k_min = lo; + int k_max = hi; + const char* pattern = wf_aligner->sequences.pattern; + const char* text = wf_aligner->sequences.text; + const __m512i zero_vector = _mm512_setzero_si512(); + const __m512i vector_null = _mm512_set1_epi32(-1); + const __m512i sixteens = _mm512_set1_epi32(16); + const __m512i vecShuffle = _mm512_set_epi8(60,61,62,63,56,57,58,59, + 52,53,54,55,48,49,50,51, + 44,45,46,47,40,41,42,43, + 36,37,38,39,32,33,34,35, + 28,29,30,31,24,25,26,27, + 20,21,22,23,16,17,18,19, + 12,13,14,15,8,9,10,11, + 4,5,6,7,0,1,2,3); + const int elems_per_register = 16; + int num_of_diagonals = k_max - k_min + 1; + int loop_peeling_iters = num_of_diagonals % elems_per_register; + int k; + for (k=k_min;k> i) & 1) == 0) continue; + const int curr_k = k + i; + const wf_offset_t offset = offsets[curr_k]; + if (offset >= 0) { + offsets[curr_k] = wavefront_extend_matches_packed_kernel(wf_aligner,curr_k,offset); + } else { + offsets[curr_k] = WAVEFRONT_OFFSET_NULL; + } + } + } +} + + +FORCE_NO_INLINE wf_offset_t wavefront_extend_matches_packed_end2end_max_avx512( + wavefront_aligner_t* const wf_aligner, + wavefront_t* const mwavefront, + const int lo, + const int hi) { + // Parameters + + const int elems_per_register = 16; + const char* pattern = wf_aligner->sequences.pattern; + const char* text = wf_aligner->sequences.text; + + wf_offset_t* const offsets = mwavefront->offsets; + wf_offset_t max_antidiag = 0; + + int k_min = lo; + int k_max = hi; + + int num_of_diagonals = k_max - k_min + 1; + int loop_peeling_iters = num_of_diagonals % elems_per_register; + int k; + + const __m512i zero_vector = _mm512_setzero_si512(); + const __m512i vector_null = _mm512_set1_epi32(-1); + const __m512i sixteens = _mm512_set1_epi32(16); + const __m512i vecShuffle = _mm512_set_epi8(60,61,62,63,56,57,58,59, + 52,53,54,55,48,49,50,51, + 44,45,46,47,40,41,42,43, + 36,37,38,39,32,33,34,35, + 28,29,30,31,24,25,26,27, + 20,21,22,23,16,17,18,19, + 12,13,14,15,8,9,10,11, + 4,5,6,7,0,1,2,3); + + for (k=k_min;k> i) & 1) == 0) continue; + const int curr_k = k + i; + wf_offset_t offset = offsets[curr_k]; + if (offset >= 0) + { + offset = wavefront_extend_matches_packed_kernel(wf_aligner,curr_k,offset); + offsets[curr_k] = offset; + const wf_offset_t antidiag = WAVEFRONT_ANTIDIAGONAL(curr_k, offset); + if (max_antidiag < antidiag) max_antidiag = antidiag; + } + else + { + offsets[curr_k] = WAVEFRONT_OFFSET_NULL; + } + } + max_antidiag_v = _mm512_set1_epi32(max_antidiag); + } + + return _mm512_reduce_max_epi32(max_antidiag_v); +} + + +FORCE_NO_INLINE bool wavefront_extend_matches_packed_endsfree_avx512( + wavefront_aligner_t* const wf_aligner, + wavefront_t* const mwavefront, + const int score, + const int lo, + const int hi) { + + // Parameters + wf_offset_t* const offsets = mwavefront->offsets; + int k_min = lo; + int k_max = hi; + const char* pattern = wf_aligner->sequences.pattern; + const char* text = wf_aligner->sequences.text; + const __m512i zero_vector = _mm512_setzero_si512(); + const __m512i vector_null = _mm512_set1_epi32(-1); + const __m512i sixteens = _mm512_set1_epi32(16); + const __m512i vecShuffle = _mm512_set_epi8(60,61,62,63,56,57,58,59, + 52,53,54,55,48,49,50,51, + 44,45,46,47,40,41,42,43, + 36,37,38,39,32,33,34,35, + 28,29,30,31,24,25,26,27, + 20,21,22,23,16,17,18,19, + 12,13,14,15,8,9,10,11, + 4,5,6,7,0,1,2,3); + const int elems_per_register = 16; + int num_of_diagonals = k_max - k_min + 1; + int loop_peeling_iters = num_of_diagonals % elems_per_register; + int k; + + for (k=k_min;k> i) & 1) == 0) continue; + const int curr_k = k + i; + const wf_offset_t offset = offsets[curr_k]; + if (offset >= 0) + { + offsets[curr_k] = wavefront_extend_matches_packed_kernel(wf_aligner,curr_k,offset); + if (wavefront_termination_endsfree(wf_aligner,mwavefront,score,curr_k,offsets[curr_k])) { + return true; // Quit (we are done) + } + } + else + { + offsets[curr_k] = WAVEFRONT_OFFSET_NULL; + } + } + } + + for (k=k_min; k <= k_max; k++) { + const wf_offset_t offset = offsets[k]; + if (offset < 0) continue; + // Check ends-free reaching boundaries + if (wavefront_termination_endsfree(wf_aligner,mwavefront,score,k,offset)) { + return true; // Quit (we are done) + } + } + return false; +} +#endif + #endif // AVX2 diff --git a/wavefront/wavefront_extend_kernels_avx.h b/wavefront/wavefront_extend_kernels_avx.h index 0e932d4..e25f1a5 100644 --- a/wavefront/wavefront_extend_kernels_avx.h +++ b/wavefront/wavefront_extend_kernels_avx.h @@ -42,6 +42,39 @@ void wavefront_extend_matches_packed_end2end_avx2( const int lo, const int hi); +wf_offset_t wavefront_extend_matches_packed_end2end_max_avx2( + wavefront_aligner_t* const wf_aligner, + wavefront_t* const mwavefront, + const int lo, + const int hi); + +bool wavefront_extend_matches_packed_endsfree_avx2( + wavefront_aligner_t* const wf_aligner, + wavefront_t* const mwavefront, + const int score, + const int lo, + const int hi); + +#if __AVX512CD__ && __AVX512VL__ +void wavefront_extend_matches_packed_end2end_avx512( + wavefront_aligner_t* const wf_aligner, + wavefront_t* const mwavefront, + const int lo, + const int hi); + +wf_offset_t wavefront_extend_matches_packed_end2end_max_avx512( + wavefront_aligner_t* const wf_aligner, + wavefront_t* const mwavefront, + const int lo, + const int hi); + +bool wavefront_extend_matches_packed_endsfree_avx512( + wavefront_aligner_t* const wf_aligner, + wavefront_t* const mwavefront, + const int score, + const int lo, + const int hi); +#endif #endif // AVX2 #endif /* WAVEFRONT_EXTEND_AVX_H_ */