Skip to content

Commit

Permalink
Merge pull request lh3#359 from jmarshall/neon
Browse files Browse the repository at this point in the history
Add ARM Neon and scalar implementations of SIMD functions
  • Loading branch information
lh3 authored Aug 31, 2022
2 parents b56db22 + c77ace7 commit 2e603e4
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ fastmap.o: bwa.h bntseq.h bwt.h bwamem.h kvec.h malloc_wrap.h utils.h kseq.h
is.o: malloc_wrap.h
kopen.o: malloc_wrap.h
kstring.o: kstring.h malloc_wrap.h
ksw.o: ksw.h malloc_wrap.h
ksw.o: ksw.h neon_sse.h scalar_sse.h malloc_wrap.h
main.o: kstring.h malloc_wrap.h utils.h
malloc_wrap.o: malloc_wrap.h
maxk.o: bwa.h bntseq.h bwt.h bwamem.h kseq.h malloc_wrap.h
Expand Down
46 changes: 41 additions & 5 deletions ksw.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@
#include <stdlib.h>
#include <stdint.h>
#include <assert.h>
#if defined __SSE2__
#include <emmintrin.h>
#elif defined __ARM_NEON
#include "neon_sse.h"
#else
#include "scalar_sse.h"
#endif
#include "ksw.h"

#ifdef USE_MALLOC_WRAPPERS
Expand Down Expand Up @@ -108,13 +114,19 @@ kswq_t *ksw_qinit(int size, int qlen, const uint8_t *query, int m, const int8_t
return q;
}

#if defined __ARM_NEON
// This macro implicitly uses each function's `zero` local variable
#define _mm_slli_si128(a, n) (vextq_u8(zero, (a), 16 - (n)))
#endif

kswr_t ksw_u8(kswq_t *q, int tlen, const uint8_t *target, int _o_del, int _e_del, int _o_ins, int _e_ins, int xtra) // the first gap costs -(_o+_e)
{
int slen, i, m_b, n_b, te = -1, gmax = 0, minsc, endsc;
uint64_t *b;
__m128i zero, oe_del, e_del, oe_ins, e_ins, shift, *H0, *H1, *E, *Hmax;
kswr_t r;

#if defined __SSE2__
#define __max_16(ret, xx) do { \
(xx) = _mm_max_epu8((xx), _mm_srli_si128((xx), 8)); \
(xx) = _mm_max_epu8((xx), _mm_srli_si128((xx), 4)); \
Expand All @@ -123,6 +135,18 @@ kswr_t ksw_u8(kswq_t *q, int tlen, const uint8_t *target, int _o_del, int _e_del
(ret) = _mm_extract_epi16((xx), 0) & 0x00ff; \
} while (0)

// Given entries with arbitrary values, return whether they are all 0x00
#define allzero_16(xx) (_mm_movemask_epi8(_mm_cmpeq_epi8((xx), zero)) == 0xffff)

#elif defined __ARM_NEON
#define __max_16(ret, xx) (ret) = vmaxvq_u8((xx))
#define allzero_16(xx) (vmaxvq_u8((xx)) == 0)

#else
#define __max_16(ret, xx) (ret) = m128i_max_u8((xx))
#define allzero_16(xx) (m128i_allzero((xx)))
#endif

// initialization
r = g_defr;
minsc = (xtra&KSW_XSUBO)? xtra&0xffff : 0x10000;
Expand All @@ -143,7 +167,7 @@ kswr_t ksw_u8(kswq_t *q, int tlen, const uint8_t *target, int _o_del, int _e_del
}
// the core loop
for (i = 0; i < tlen; ++i) {
int j, k, cmp, imax;
int j, k, imax;
__m128i e, h, t, f = zero, max = zero, *S = q->qp + target[i] * slen; // s is the 1st score vector
h = _mm_load_si128(H0 + slen - 1); // h={2,5,8,11,14,17,-1,-1} in the above example
h = _mm_slli_si128(h, 1); // h=H(i-1,-1); << instead of >> because x64 is little-endian
Expand Down Expand Up @@ -182,8 +206,7 @@ kswr_t ksw_u8(kswq_t *q, int tlen, const uint8_t *target, int _o_del, int _e_del
_mm_store_si128(H1 + j, h);
h = _mm_subs_epu8(h, oe_ins);
f = _mm_subs_epu8(f, e_ins);
cmp = _mm_movemask_epi8(_mm_cmpeq_epi8(_mm_subs_epu8(f, h), zero));
if (UNLIKELY(cmp == 0xffff)) goto end_loop16;
if (UNLIKELY(allzero_16(_mm_subs_epu8(f, h)))) goto end_loop16;
}
}
end_loop16:
Expand Down Expand Up @@ -236,13 +259,26 @@ kswr_t ksw_i16(kswq_t *q, int tlen, const uint8_t *target, int _o_del, int _e_de
__m128i zero, oe_del, e_del, oe_ins, e_ins, *H0, *H1, *E, *Hmax;
kswr_t r;

#if defined __SSE2__
#define __max_8(ret, xx) do { \
(xx) = _mm_max_epi16((xx), _mm_srli_si128((xx), 8)); \
(xx) = _mm_max_epi16((xx), _mm_srli_si128((xx), 4)); \
(xx) = _mm_max_epi16((xx), _mm_srli_si128((xx), 2)); \
(ret) = _mm_extract_epi16((xx), 0); \
} while (0)

// Given entries all either 0x0000 or 0xffff, return whether they are all 0x0000
#define allzero_0f_8(xx) (!_mm_movemask_epi8((xx)))

#elif defined __ARM_NEON
#define __max_8(ret, xx) (ret) = vmaxvq_s16(vreinterpretq_s16_u8((xx)))
#define allzero_0f_8(xx) (vmaxvq_u16(vreinterpretq_u16_u8((xx))) == 0)

#else
#define __max_8(ret, xx) (ret) = m128i_max_s16((xx))
#define allzero_0f_8(xx) (m128i_allzero((xx)))
#endif

// initialization
r = g_defr;
minsc = (xtra&KSW_XSUBO)? xtra&0xffff : 0x10000;
Expand All @@ -267,7 +303,7 @@ kswr_t ksw_i16(kswq_t *q, int tlen, const uint8_t *target, int _o_del, int _e_de
h = _mm_load_si128(H0 + slen - 1); // h={2,5,8,11,14,17,-1,-1} in the above example
h = _mm_slli_si128(h, 2);
for (j = 0; LIKELY(j < slen); ++j) {
h = _mm_adds_epi16(h, *S++);
h = _mm_adds_epi16(h, _mm_load_si128(S++));
e = _mm_load_si128(E + j);
h = _mm_max_epi16(h, e);
h = _mm_max_epi16(h, f);
Expand All @@ -290,7 +326,7 @@ kswr_t ksw_i16(kswq_t *q, int tlen, const uint8_t *target, int _o_del, int _e_de
_mm_store_si128(H1 + j, h);
h = _mm_subs_epu16(h, oe_ins);
f = _mm_subs_epu16(f, e_ins);
if(UNLIKELY(!_mm_movemask_epi8(_mm_cmpgt_epi16(f, h)))) goto end_loop8;
if(UNLIKELY(allzero_0f_8(_mm_cmpgt_epi16(f, h)))) goto end_loop8;
}
}
end_loop8:
Expand Down
33 changes: 33 additions & 0 deletions neon_sse.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#ifndef NEON_SSE_H
#define NEON_SSE_H

#include <arm_neon.h>

typedef uint8x16_t __m128i;

static inline __m128i _mm_load_si128(const __m128i *ptr) { return vld1q_u8((const uint8_t *) ptr); }
static inline __m128i _mm_set1_epi32(int n) { return vreinterpretq_u8_s32(vdupq_n_s32(n)); }
static inline void _mm_store_si128(__m128i *ptr, __m128i a) { vst1q_u8((uint8_t *) ptr, a); }

static inline __m128i _mm_adds_epu8(__m128i a, __m128i b) { return vqaddq_u8(a, b); }
static inline __m128i _mm_max_epu8(__m128i a, __m128i b) { return vmaxq_u8(a, b); }
static inline __m128i _mm_set1_epi8(int8_t n) { return vreinterpretq_u8_s8(vdupq_n_s8(n)); }
static inline __m128i _mm_subs_epu8(__m128i a, __m128i b) { return vqsubq_u8(a, b); }

#define M128I(a) vreinterpretq_u8_s16((a))
#define UM128I(a) vreinterpretq_u8_u16((a))
#define S16(a) vreinterpretq_s16_u8((a))
#define U16(a) vreinterpretq_u16_u8((a))

static inline __m128i _mm_adds_epi16(__m128i a, __m128i b) { return M128I(vqaddq_s16(S16(a), S16(b))); }
static inline __m128i _mm_cmpgt_epi16(__m128i a, __m128i b) { return UM128I(vcgtq_s16(S16(a), S16(b))); }
static inline __m128i _mm_max_epi16(__m128i a, __m128i b) { return M128I(vmaxq_s16(S16(a), S16(b))); }
static inline __m128i _mm_set1_epi16(int16_t n) { return vreinterpretq_u8_s16(vdupq_n_s16(n)); }
static inline __m128i _mm_subs_epu16(__m128i a, __m128i b) { return UM128I(vqsubq_u16(U16(a), U16(b))); }

#undef M128I
#undef UM128I
#undef S16
#undef U16

#endif
119 changes: 119 additions & 0 deletions scalar_sse.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#ifndef SCALAR_SSE_H
#define SCALAR_SSE_H

#include <assert.h>
#include <stdint.h>
#include <string.h>

typedef union m128i {
uint8_t u8[16];
int16_t i16[8];
} __m128i;

static inline __m128i _mm_set1_epi32(int32_t n) {
assert(n >= 0 && n <= 255);
__m128i r; memset(&r, n, sizeof r); return r;
}

static inline __m128i _mm_load_si128(const __m128i *ptr) { __m128i r; memcpy(&r, ptr, sizeof r); return r; }
static inline void _mm_store_si128(__m128i *ptr, __m128i a) { memcpy(ptr, &a, sizeof a); }

static inline int m128i_allzero(__m128i a) {
static const char zero[] = "\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0";
return memcmp(&a, zero, sizeof a) == 0;
}

static inline __m128i _mm_slli_si128(__m128i a, int n) {
int i;
memmove(&a.u8[n], &a.u8[0], 16 - n);
for (i = 0; i < n; i++) a.u8[i] = 0;
return a;
}

static inline __m128i _mm_adds_epu8(__m128i a, __m128i b) {
int i;
for (i = 0; i < 16; i++) {
uint16_t aa = a.u8[i];
aa += b.u8[i];
a.u8[i] = (aa < 256)? aa : 255;
}
return a;
}

static inline __m128i _mm_max_epu8(__m128i a, __m128i b) {
int i;
for (i = 0; i < 16; i++)
if (a.u8[i] < b.u8[i]) a.u8[i] = b.u8[i];
return a;
}

static inline uint8_t m128i_max_u8(__m128i a) {
uint8_t max = 0;
int i;
for (i = 0; i < 16; i++)
if (max < a.u8[i]) max = a.u8[i];
return max;
}

static inline __m128i _mm_set1_epi8(int8_t n) { __m128i r; memset(&r, n, sizeof r); return r; }

static inline __m128i _mm_subs_epu8(__m128i a, __m128i b) {
int i;
for (i = 0; i < 16; i++) {
int16_t aa = a.u8[i];
aa -= b.u8[i];
a.u8[i] = (aa >= 0)? aa : 0;
}
return a;
}

static inline __m128i _mm_adds_epi16(__m128i a, __m128i b) {
int i;
for (i = 0; i < 8; i++) {
int32_t aa = a.i16[i];
aa += b.i16[i];
a.i16[i] = (aa < 32768)? aa : 32767;
}
return a;
}

static inline __m128i _mm_cmpgt_epi16(__m128i a, __m128i b) {
int i;
for (i = 0; i < 8; i++)
a.i16[i] = (a.i16[i] > b.i16[i])? 0xffff : 0x0000;
return a;
}

static inline __m128i _mm_max_epi16(__m128i a, __m128i b) {
int i;
for (i = 0; i < 8; i++)
if (a.i16[i] < b.i16[i]) a.i16[i] = b.i16[i];
return a;
}

static inline __m128i _mm_set1_epi16(int16_t n) {
__m128i r;
r.i16[0] = r.i16[1] = r.i16[2] = r.i16[3] =
r.i16[4] = r.i16[5] = r.i16[6] = r.i16[7] = n;
return r;
}

static inline int16_t m128i_max_s16(__m128i a) {
int16_t max = -32768;
int i;
for (i = 0; i < 8; i++)
if (max < a.i16[i]) max = a.i16[i];
return max;
}

static inline __m128i _mm_subs_epu16(__m128i a, __m128i b) {
int i;
for (i = 0; i < 8; i++) {
int32_t aa = a.i16[i];
aa -= b.i16[i];
a.i16[i] = (aa >= 0)? aa : 0;
}
return a;
}

#endif

0 comments on commit 2e603e4

Please sign in to comment.