Skip to content

Commit

Permalink
feat: Add _mm_[ceil|floor]*
Browse files Browse the repository at this point in the history
  • Loading branch information
howjmay committed Jan 26, 2024
1 parent 1adda2c commit 29942c3
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 97 deletions.
97 changes: 89 additions & 8 deletions sse2rvv.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#define _sse2rvv_const const
#endif

#include <math.h>
#include <riscv_vector.h>
#include <stdint.h>
#include <stdlib.h>
Expand Down Expand Up @@ -560,13 +561,53 @@ FORCE_INLINE __m128 _mm_castsi128_ps(__m128i a) {
return __riscv_vreinterpret_v_i32m1_f32m1(a);
}

// FORCE_INLINE __m128d _mm_ceil_pd (__m128d a) {}
FORCE_INLINE __m128d _mm_ceil_pd(__m128d a) {
// FIXME riscv round doesn't work
vfloat64m1_t _a = vreinterpretq_m128d_f64(a);
double arr[2];
const int len = 2;
__riscv_vse64_v_f64m1(arr, _a, len);
for (int i = 0; i < len; i++) {
arr[i] = ceil(arr[i]);
}
return vreinterpretq_f64_m128d(__riscv_vle64_v_f64m1(arr, len));
}

// FORCE_INLINE __m128 _mm_ceil_ps (__m128 a) {}
FORCE_INLINE __m128 _mm_ceil_ps(__m128 a) {
// FIXME riscv round doesn't work
vfloat32m1_t _a = vreinterpretq_m128_f32(a);
float arr[4];
const int len = 4;
__riscv_vse32_v_f32m1(arr, _a, len);
for (int i = 0; i < len; i++) {
arr[i] = ceil(arr[i]);
}
return vreinterpretq_f32_m128(__riscv_vle32_v_f32m1(arr, len));
}

// FORCE_INLINE __m128d _mm_ceil_sd (__m128d a, __m128d b) {}
FORCE_INLINE __m128d _mm_ceil_sd(__m128d a, __m128d b) {
// FIXME riscv round doesn't work
vfloat64m1_t _a = vreinterpretq_m128d_f64(a);
vfloat64m1_t _b = vreinterpretq_m128d_f64(b);
double arr[2];
const int len = 2;
__riscv_vse64_v_f64m1(arr, _b, len);
arr[0] = ceil(arr[0]);
vfloat64m1_t _arr = __riscv_vle64_v_f64m1(arr, 1);
return vreinterpretq_f64_m128d(__riscv_vslideup_vx_f64m1(_a, _arr, 0, 1));
}

// FORCE_INLINE __m128 _mm_ceil_ss (__m128 a, __m128 b) {}
FORCE_INLINE __m128 _mm_ceil_ss(__m128 a, __m128 b) {
// FIXME riscv round doesn't work
vfloat32m1_t _a = vreinterpretq_m128_f32(a);
vfloat32m1_t _b = vreinterpretq_m128_f32(b);
float arr[4];
const int len = 4;
__riscv_vse32_v_f32m1(arr, _b, len);
arr[0] = ceil(arr[0]);
vfloat32m1_t _arr = __riscv_vle32_v_f32m1(arr, 1);
return vreinterpretq_f32_m128(__riscv_vslideup_vx_f32m1(_a, _arr, 0, 1));
}

// FORCE_INLINE void _mm_clflush (void const* p) {}

Expand Down Expand Up @@ -1460,13 +1501,53 @@ FORCE_INLINE int _mm_extract_ps(__m128 a, const int imm8) {
return (int)__riscv_vmv_x_s_i32m1_i32(a_s);
}

// FORCE_INLINE __m128d _mm_floor_pd (__m128d a) {}
FORCE_INLINE __m128d _mm_floor_pd(__m128d a) {
// FIXME riscv round doesn't work
vfloat64m1_t _a = vreinterpretq_m128d_f64(a);
double arr[2];
const int len = 2;
__riscv_vse64_v_f64m1(arr, _a, len);
for (int i = 0; i < len; i++) {
arr[i] = floor(arr[i]);
}
return vreinterpretq_f64_m128d(__riscv_vle64_v_f64m1(arr, len));
}

// FORCE_INLINE __m128 _mm_floor_ps (__m128 a) {}
FORCE_INLINE __m128 _mm_floor_ps(__m128 a) {
// FIXME riscv round doesn't work
vfloat32m1_t _a = vreinterpretq_m128_f32(a);
float arr[4];
const int len = 4;
__riscv_vse32_v_f32m1(arr, _a, len);
for (int i = 0; i < len; i++) {
arr[i] = floor(arr[i]);
}
return vreinterpretq_f32_m128(__riscv_vle32_v_f32m1(arr, len));
}

// FORCE_INLINE __m128d _mm_floor_sd (__m128d a, __m128d b) {}
FORCE_INLINE __m128d _mm_floor_sd(__m128d a, __m128d b) {
// FIXME riscv round doesn't work
vfloat64m1_t _a = vreinterpretq_m128d_f64(a);
vfloat64m1_t _b = vreinterpretq_m128d_f64(b);
double arr[2];
const int len = 2;
__riscv_vse64_v_f64m1(arr, _b, len);
arr[0] = floor(arr[0]);
vfloat64m1_t _arr = __riscv_vle64_v_f64m1(arr, 1);
return vreinterpretq_f64_m128d(__riscv_vslideup_vx_f64m1(_a, _arr, 0, 1));
}

// FORCE_INLINE __m128 _mm_floor_ss (__m128 a, __m128 b) {}
FORCE_INLINE __m128 _mm_floor_ss(__m128 a, __m128 b) {
// FIXME riscv round doesn't work
vfloat32m1_t _a = vreinterpretq_m128_f32(a);
vfloat32m1_t _b = vreinterpretq_m128_f32(b);
float arr[4];
const int len = 4;
__riscv_vse32_v_f32m1(arr, _b, len);
arr[0] = floor(arr[0]);
vfloat32m1_t _arr = __riscv_vle32_v_f32m1(arr, 1);
return vreinterpretq_f32_m128(__riscv_vslideup_vx_f32m1(_a, _arr, 0, 1));
}

FORCE_INLINE void _mm_free(void *mem_addr) { free(mem_addr); }

Expand Down
178 changes: 89 additions & 89 deletions tests/impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9438,53 +9438,53 @@ result_t test_mm_blendv_ps(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
}

result_t test_mm_ceil_pd(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const double *_a = (const double *)impl.test_cases_float_pointer1;
//
// double dx = ceil(_a[0]);
// double dy = ceil(_a[1]);
//
// __m128d a = load_m128d(_a);
// __m128d ret = _mm_ceil_pd(a);
//
// return validate_double(ret, dx, dy);
// #else
#ifdef ENABLE_TEST_ALL
const double *_a = (const double *)impl.test_cases_float_pointer1;

double dx = ceil(_a[0]);
double dy = ceil(_a[1]);

__m128d a = load_m128d(_a);
__m128d ret = _mm_ceil_pd(a);

return validate_double(ret, dx, dy);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_ceil_ps(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const float *_a = impl.test_cases_float_pointer1;
// float dx = ceilf(_a[0]);
// float dy = ceilf(_a[1]);
// float dz = ceilf(_a[2]);
// float dw = ceilf(_a[3]);
//
// __m128 a = _mm_load_ps(_a);
// __m128 c = _mm_ceil_ps(a);
// return validate_float(c, dx, dy, dz, dw);
// #else
#ifdef ENABLE_TEST_ALL
const float *_a = impl.test_cases_float_pointer1;
float dx = ceilf(_a[0]);
float dy = ceilf(_a[1]);
float dz = ceilf(_a[2]);
float dw = ceilf(_a[3]);

__m128 a = _mm_load_ps(_a);
__m128 c = _mm_ceil_ps(a);
return validate_float(c, dx, dy, dz, dw);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_ceil_sd(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const double *_a = (const double *)impl.test_cases_float_pointer1;
// const double *_b = (const double *)impl.test_cases_float_pointer2;
//
// double dx = ceil(_b[0]);
// double dy = _a[1];
//
// __m128d a = load_m128d(_a);
// __m128d b = load_m128d(_b);
// __m128d ret = _mm_ceil_sd(a, b);
//
// return validate_double(ret, dx, dy);
// #else
#ifdef ENABLE_TEST_ALL
const double *_a = (const double *)impl.test_cases_float_pointer1;
const double *_b = (const double *)impl.test_cases_float_pointer2;

double dx = ceil(_b[0]);
double dy = _a[1];

__m128d a = load_m128d(_a);
__m128d b = load_m128d(_b);
__m128d ret = _mm_ceil_sd(a, b);

return validate_double(ret, dx, dy);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_ceil_ss(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
Expand Down Expand Up @@ -9896,70 +9896,70 @@ result_t test_mm_extract_ps(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
}

result_t test_mm_floor_pd(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const double *_a = (const double *)impl.test_cases_float_pointer1;
//
// double dx = floor(_a[0]);
// double dy = floor(_a[1]);
//
// __m128d a = load_m128d(_a);
// __m128d ret = _mm_floor_pd(a);
//
// return validate_double(ret, dx, dy);
// #else
#ifdef ENABLE_TEST_ALL
const double *_a = (const double *)impl.test_cases_float_pointer1;

double dx = floor(_a[0]);
double dy = floor(_a[1]);

__m128d a = load_m128d(_a);
__m128d ret = _mm_floor_pd(a);

return validate_double(ret, dx, dy);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_floor_ps(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const float *_a = impl.test_cases_float_pointer1;
// float dx = floorf(_a[0]);
// float dy = floorf(_a[1]);
// float dz = floorf(_a[2]);
// float dw = floorf(_a[3]);
//
// __m128 a = load_m128(_a);
// __m128 c = _mm_floor_ps(a);
// return validate_float(c, dx, dy, dz, dw);
// #else
#ifdef ENABLE_TEST_ALL
const float *_a = impl.test_cases_float_pointer1;
float dx = floorf(_a[0]);
float dy = floorf(_a[1]);
float dz = floorf(_a[2]);
float dw = floorf(_a[3]);

__m128 a = load_m128(_a);
__m128 c = _mm_floor_ps(a);
return validate_float(c, dx, dy, dz, dw);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_floor_sd(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const double *_a = (const double *)impl.test_cases_float_pointer1;
// const double *_b = (const double *)impl.test_cases_float_pointer2;
//
// double dx = floor(_b[0]);
// double dy = _a[1];
//
// __m128d a = load_m128d(_a);
// __m128d b = load_m128d(_b);
// __m128d ret = _mm_floor_sd(a, b);
//
// return validate_double(ret, dx, dy);
// #else
#ifdef ENABLE_TEST_ALL
const double *_a = (const double *)impl.test_cases_float_pointer1;
const double *_b = (const double *)impl.test_cases_float_pointer2;

double dx = floor(_b[0]);
double dy = _a[1];

__m128d a = load_m128d(_a);
__m128d b = load_m128d(_b);
__m128d ret = _mm_floor_sd(a, b);

return validate_double(ret, dx, dy);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_floor_ss(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const float *_a = impl.test_cases_float_pointer1;
// const float *_b = impl.test_cases_float_pointer1;
//
// float f0 = floorf(_b[0]);
//
// __m128 a = load_m128(_a);
// __m128 b = load_m128(_b);
// __m128 c = _mm_floor_ss(a, b);
//
// return validate_float(c, f0, _a[1], _a[2], _a[3]);
// #else
#ifdef ENABLE_TEST_ALL
const float *_a = impl.test_cases_float_pointer1;
const float *_b = impl.test_cases_float_pointer1;

float f0 = floorf(_b[0]);

__m128 a = load_m128(_a);
__m128 b = load_m128(_b);
__m128 c = _mm_floor_ss(a, b);

return validate_float(c, f0, _a[1], _a[2], _a[3]);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_insert_epi32(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
Expand Down

0 comments on commit 29942c3

Please sign in to comment.