From 4cb924aaadb230e9cad67982aba26189bba66671 Mon Sep 17 00:00:00 2001 From: Daniel Levi-Minzi <51272568+dleviminzi@users.noreply.github.com> Date: Mon, 10 Jun 2024 14:56:42 -0400 Subject: [PATCH] Always Use Neon for L2 f32 (#20) * use neon for vectors larger than 16 elements always * remove commented out code --- sqlite-vec.c | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/sqlite-vec.c b/sqlite-vec.c index b9b590f..e318907 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -186,8 +186,16 @@ static f32 l2_sqr_float_neon(const void *pVect1v, const void *pVect2v, sum3 = vfmaq_f32(sum3, diff, diff); } - return sqrt( - vaddvq_f32(vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3)))); + f32 sum_scalar = vaddvq_f32(vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3))); + const f32 *pEnd2 = pVect1 + (qty - (qty16 << 4)); + while (pVect1 < pEnd2) { + f32 diff = *pVect1 - *pVect2; + sum_scalar += diff * diff; + pVect1++; + pVect2++; + } + + return sqrt(sum_scalar); } static f32 l2_sqr_int8_neon(const void *pVect1v, const void *pVect2v, @@ -263,7 +271,7 @@ static f32 l2_sqr_int8(const void *pA, const void *pB, const void *pD) { static f32 distance_l2_sqr_float(const void *a, const void *b, const void *d) { #ifdef SQLITE_VEC_ENABLE_NEON - if (((*(const size_t *)d) % 16 == 0)) { + if ((*(const size_t *)d) > 16) { return l2_sqr_float_neon(a, b, d); } #endif