Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add _mm_pack[u]s_[epi16|epi32] #54

Merged
merged 1 commit into from
Jan 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 49 additions & 5 deletions sse2rvv.h
Original file line number Diff line number Diff line change
Expand Up @@ -2196,13 +2196,57 @@ FORCE_INLINE __m128i _mm_or_si128(__m128i a, __m128i b) {
return vreinterpretq_i32_m128i(__riscv_vor_vv_i32m1(_a, _b, 4));
}

// FORCE_INLINE __m128i _mm_packs_epi16 (__m128i a, __m128i b) {}

// FORCE_INLINE __m128i _mm_packs_epi32 (__m128i a, __m128i b) {}
FORCE_INLINE __m128i _mm_packs_epi16(__m128i a, __m128i b) {
vint16m1_t _a = vreinterpretq_m128i_i16(a);
vint16m1_t _b = vreinterpretq_m128i_i16(b);
vint8m1_t a_sat = __riscv_vlmul_ext_v_i8mf2_i8m1(
__riscv_vnclip_wx_i8mf2(_a, 0, __RISCV_VXRM_RDN, 8));
vint8m1_t b_sat = __riscv_vlmul_ext_v_i8mf2_i8m1(
__riscv_vnclip_wx_i8mf2(_b, 0, __RISCV_VXRM_RDN, 8));
return vreinterpretq_i8_m128i(__riscv_vslideup_vx_i8m1(a_sat, b_sat, 8, 16));
}

// FORCE_INLINE __m128i _mm_packus_epi16 (__m128i a, __m128i b) {}
FORCE_INLINE __m128i _mm_packs_epi32(__m128i a, __m128i b) {
vint32m1_t _a = vreinterpretq_m128i_i32(a);
vint32m1_t _b = vreinterpretq_m128i_i32(b);
vint16m1_t a_sat = __riscv_vlmul_ext_v_i16mf2_i16m1(
__riscv_vnclip_wx_i16mf2(_a, 0, __RISCV_VXRM_RDN, 4));
vint16m1_t b_sat = __riscv_vlmul_ext_v_i16mf2_i16m1(
__riscv_vnclip_wx_i16mf2(_b, 0, __RISCV_VXRM_RDN, 4));
return vreinterpretq_i16_m128i(__riscv_vslideup_vx_i16m1(a_sat, b_sat, 4, 8));
}

// FORCE_INLINE __m128i _mm_packus_epi32 (__m128i a, __m128i b) {}
FORCE_INLINE __m128i _mm_packus_epi16(__m128i a, __m128i b) {
vint16m1_t _a = vreinterpretq_m128i_i16(a);
vint16m1_t _b = vreinterpretq_m128i_i16(b);
vbool16_t a_neg_mask = __riscv_vmslt_vx_i16m1_b16(_a, 0, 8);
vbool16_t b_neg_mask = __riscv_vmslt_vx_i16m1_b16(_b, 0, 8);
vuint16m1_t a_unsigned = __riscv_vreinterpret_v_i16m1_u16m1(
__riscv_vmerge_vxm_i16m1(_a, 0, a_neg_mask, 8));
vuint16m1_t b_unsigned = __riscv_vreinterpret_v_i16m1_u16m1(
__riscv_vmerge_vxm_i16m1(_b, 0, b_neg_mask, 8));
vuint8m1_t a_sat = __riscv_vlmul_ext_v_u8mf2_u8m1(
__riscv_vnclipu_wx_u8mf2(a_unsigned, 0, __RISCV_VXRM_RDN, 8));
vuint8m1_t b_sat = __riscv_vlmul_ext_v_u8mf2_u8m1(
__riscv_vnclipu_wx_u8mf2(b_unsigned, 0, __RISCV_VXRM_RDN, 8));
return vreinterpretq_u8_m128i(__riscv_vslideup_vx_u8m1(a_sat, b_sat, 8, 16));
}

FORCE_INLINE __m128i _mm_packus_epi32(__m128i a, __m128i b) {
vint32m1_t _a = vreinterpretq_m128i_i32(a);
vint32m1_t _b = vreinterpretq_m128i_i32(b);
vbool32_t a_neg_mask = __riscv_vmslt_vx_i32m1_b32(_a, 0, 4);
vbool32_t b_neg_mask = __riscv_vmslt_vx_i32m1_b32(_b, 0, 4);
vuint32m1_t a_unsigned = __riscv_vreinterpret_v_i32m1_u32m1(
__riscv_vmerge_vxm_i32m1(_a, 0, a_neg_mask, 4));
vuint32m1_t b_unsigned = __riscv_vreinterpret_v_i32m1_u32m1(
__riscv_vmerge_vxm_i32m1(_b, 0, b_neg_mask, 4));
vuint16m1_t a_sat = __riscv_vlmul_ext_v_u16mf2_u16m1(
__riscv_vnclipu_wx_u16mf2(a_unsigned, 0, __RISCV_VXRM_RDN, 4));
vuint16m1_t b_sat = __riscv_vlmul_ext_v_u16mf2_u16m1(
__riscv_vnclipu_wx_u16mf2(b_unsigned, 0, __RISCV_VXRM_RDN, 4));
return vreinterpretq_u16_m128i(__riscv_vslideup_vx_u16m1(a_sat, b_sat, 4, 8));
}

// FORCE_INLINE void _mm_pause (void) {}

Expand Down
248 changes: 124 additions & 124 deletions tests/impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6285,108 +6285,108 @@ result_t test_mm_or_si128(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
}

result_t test_mm_packs_epi16(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// int8_t max = INT8_MAX;
// int8_t min = INT8_MIN;
// const int16_t *_a = (const int16_t *)impl.test_cases_int_pointer1;
// const int16_t *_b = (const int16_t *)impl.test_cases_int_pointer2;
//
// int8_t d[16];
// for (int i = 0; i < 8; i++) {
// if (_a[i] > max)
// d[i] = max;
// else if (_a[i] < min)
// d[i] = min;
// else
// d[i] = (int8_t)_a[i];
// }
// for (int i = 0; i < 8; i++) {
// if (_b[i] > max)
// d[i + 8] = max;
// else if (_b[i] < min)
// d[i + 8] = min;
// else
// d[i + 8] = (int8_t)_b[i];
// }
//
// __m128i a = load_m128i(_a);
// __m128i b = load_m128i(_b);
// __m128i c = _mm_packs_epi16(a, b);
//
// return VALIDATE_INT8_M128(c, d);
// #else
#ifdef ENABLE_TEST_ALL
int8_t max = INT8_MAX;
int8_t min = INT8_MIN;
const int16_t *_a = (const int16_t *)impl.test_cases_int_pointer1;
const int16_t *_b = (const int16_t *)impl.test_cases_int_pointer2;

int8_t d[16];
for (int i = 0; i < 8; i++) {
if (_a[i] > max)
d[i] = max;
else if (_a[i] < min)
d[i] = min;
else
d[i] = (int8_t)_a[i];
}
for (int i = 0; i < 8; i++) {
if (_b[i] > max)
d[i + 8] = max;
else if (_b[i] < min)
d[i + 8] = min;
else
d[i + 8] = (int8_t)_b[i];
}

__m128i a = load_m128i(_a);
__m128i b = load_m128i(_b);
__m128i c = _mm_packs_epi16(a, b);

return VALIDATE_INT8_M128(c, d);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_packs_epi32(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// int16_t max = INT16_MAX;
// int16_t min = INT16_MIN;
// const int32_t *_a = (const int32_t *)impl.test_cases_int_pointer1;
// const int32_t *_b = (const int32_t *)impl.test_cases_int_pointer2;
//
// int16_t d[8];
// for (int i = 0; i < 4; i++) {
// if (_a[i] > max)
// d[i] = max;
// else if (_a[i] < min)
// d[i] = min;
// else
// d[i] = (int16_t)_a[i];
// }
// for (int i = 0; i < 4; i++) {
// if (_b[i] > max)
// d[i + 4] = max;
// else if (_b[i] < min)
// d[i + 4] = min;
// else
// d[i + 4] = (int16_t)_b[i];
// }
//
// __m128i a = load_m128i(_a);
// __m128i b = load_m128i(_b);
// __m128i c = _mm_packs_epi32(a, b);
//
// return VALIDATE_INT16_M128(c, d);
// #else
#ifdef ENABLE_TEST_ALL
int16_t max = INT16_MAX;
int16_t min = INT16_MIN;
const int32_t *_a = (const int32_t *)impl.test_cases_int_pointer1;
const int32_t *_b = (const int32_t *)impl.test_cases_int_pointer2;

int16_t d[8];
for (int i = 0; i < 4; i++) {
if (_a[i] > max)
d[i] = max;
else if (_a[i] < min)
d[i] = min;
else
d[i] = (int16_t)_a[i];
}
for (int i = 0; i < 4; i++) {
if (_b[i] > max)
d[i + 4] = max;
else if (_b[i] < min)
d[i + 4] = min;
else
d[i + 4] = (int16_t)_b[i];
}

__m128i a = load_m128i(_a);
__m128i b = load_m128i(_b);
__m128i c = _mm_packs_epi32(a, b);

return VALIDATE_INT16_M128(c, d);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_packus_epi16(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// uint8_t max = UINT8_MAX;
// uint8_t min = 0;
// const int16_t *_a = (const int16_t *)impl.test_cases_int_pointer1;
// const int16_t *_b = (const int16_t *)impl.test_cases_int_pointer2;
//
// uint8_t d[16];
// for (int i = 0; i < 8; i++) {
// if (_a[i] > (int16_t)max)
// d[i] = max;
// else if (_a[i] < (int16_t)min)
// d[i] = min;
// else
// d[i] = (uint8_t)_a[i];
// }
// for (int i = 0; i < 8; i++) {
// if (_b[i] > (int16_t)max)
// d[i + 8] = max;
// else if (_b[i] < (int16_t)min)
// d[i + 8] = min;
// else
// d[i + 8] = (uint8_t)_b[i];
// }
//
// __m128i a = load_m128i(_a);
// __m128i b = load_m128i(_b);
// __m128i c = _mm_packus_epi16(a, b);
//
// return VALIDATE_UINT8_M128(c, d);
// #else
#ifdef ENABLE_TEST_ALL
uint8_t max = UINT8_MAX;
uint8_t min = 0;
const int16_t *_a = (const int16_t *)impl.test_cases_int_pointer1;
const int16_t *_b = (const int16_t *)impl.test_cases_int_pointer2;

uint8_t d[16];
for (int i = 0; i < 8; i++) {
if (_a[i] > (int16_t)max)
d[i] = max;
else if (_a[i] < (int16_t)min)
d[i] = min;
else
d[i] = (uint8_t)_a[i];
}
for (int i = 0; i < 8; i++) {
if (_b[i] > (int16_t)max)
d[i + 8] = max;
else if (_b[i] < (int16_t)min)
d[i + 8] = min;
else
d[i + 8] = (uint8_t)_b[i];
}

__m128i a = load_m128i(_a);
__m128i b = load_m128i(_b);
__m128i c = _mm_packus_epi16(a, b);

return VALIDATE_UINT8_M128(c, d);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_pause(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
Expand Down Expand Up @@ -10275,38 +10275,38 @@ result_t test_mm_mullo_epi32(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
}

result_t test_mm_packus_epi32(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// uint16_t max = UINT16_MAX;
// uint16_t min = 0;
// const int32_t *_a = (const int32_t *)impl.test_cases_int_pointer1;
// const int32_t *_b = (const int32_t *)impl.test_cases_int_pointer2;
//
// uint16_t d[8];
// for (int i = 0; i < 4; i++) {
// if (_a[i] > (int32_t)max)
// d[i] = max;
// else if (_a[i] < (int32_t)min)
// d[i] = min;
// else
// d[i] = (uint16_t)_a[i];
// }
// for (int i = 0; i < 4; i++) {
// if (_b[i] > (int32_t)max)
// d[i + 4] = max;
// else if (_b[i] < (int32_t)min)
// d[i + 4] = min;
// else
// d[i + 4] = (uint16_t)_b[i];
// }
//
// __m128i a = load_m128i(_a);
// __m128i b = load_m128i(_b);
// __m128i c = _mm_packus_epi32(a, b);
//
// return VALIDATE_UINT16_M128(c, d);
// #else
#ifdef ENABLE_TEST_ALL
uint16_t max = UINT16_MAX;
uint16_t min = 0;
const int32_t *_a = (const int32_t *)impl.test_cases_int_pointer1;
const int32_t *_b = (const int32_t *)impl.test_cases_int_pointer2;

uint16_t d[8];
for (int i = 0; i < 4; i++) {
if (_a[i] > (int32_t)max)
d[i] = max;
else if (_a[i] < (int32_t)min)
d[i] = min;
else
d[i] = (uint16_t)_a[i];
}
for (int i = 0; i < 4; i++) {
if (_b[i] > (int32_t)max)
d[i + 4] = max;
else if (_b[i] < (int32_t)min)
d[i + 4] = min;
else
d[i + 4] = (uint16_t)_b[i];
}

__m128i a = load_m128i(_a);
__m128i b = load_m128i(_b);
__m128i c = _mm_packus_epi32(a, b);

return VALIDATE_UINT16_M128(c, d);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_round_pd(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
Expand Down