Skip to content

Commit

Permalink
feat: Add _mm_[ceil|floor]*
Browse files Browse the repository at this point in the history
  • Loading branch information
howjmay committed Jan 26, 2024
1 parent 1adda2c commit 89418a8
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 42 deletions.
49 changes: 45 additions & 4 deletions sse2rvv.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#define _sse2rvv_const const
#endif

#include <math.h>
#include <riscv_vector.h>
#include <stdint.h>
#include <stdlib.h>
Expand Down Expand Up @@ -560,13 +561,53 @@ FORCE_INLINE __m128 _mm_castsi128_ps(__m128i a) {
return __riscv_vreinterpret_v_i32m1_f32m1(a);
}

// FORCE_INLINE __m128d _mm_ceil_pd (__m128d a) {}
FORCE_INLINE __m128d _mm_ceil_pd(__m128d a) {
// FIXME riscv round doesn't work
vfloat64m1_t _a = vreinterpretq_m128d_f64(a);
double arr[2];
const int len = 2;
__riscv_vse64_v_f64m1(arr, _a, len);
for (int i = 0; i < len; i++) {
arr[i] = ceil(arr[i]);
}
return vreinterpretq_f64_m128d(__riscv_vle64_v_f64m1(arr, len));
}

// FORCE_INLINE __m128 _mm_ceil_ps (__m128 a) {}
FORCE_INLINE __m128 _mm_ceil_ps(__m128 a) {
// FIXME riscv round doesn't work
vfloat32m1_t _a = vreinterpretq_m128_f32(a);
float arr[4];
const int len = 4;
__riscv_vse32_v_f32m1(arr, _a, len);
for (int i = 0; i < len; i++) {
arr[i] = ceil(arr[i]);
}
return vreinterpretq_f32_m128(__riscv_vle32_v_f32m1(arr, len));
}

// FORCE_INLINE __m128d _mm_ceil_sd (__m128d a, __m128d b) {}
FORCE_INLINE __m128d _mm_ceil_sd(__m128d a, __m128d b) {
// FIXME riscv round doesn't work
vfloat64m1_t _a = vreinterpretq_m128d_f64(a);
vfloat64m1_t _b = vreinterpretq_m128d_f64(b);
double arr[2];
const int len = 2;
__riscv_vse64_v_f64m1(arr, _b, len);
arr[0] = ceil(arr[0]);
vfloat64m1_t _arr = __riscv_vle64_v_f64m1(arr, 1);
return vreinterpretq_f64_m128d(__riscv_vslideup_vx_f64m1(_a, _arr, 0, 1));
}

// FORCE_INLINE __m128 _mm_ceil_ss (__m128 a, __m128 b) {}
FORCE_INLINE __m128 _mm_ceil_ss(__m128 a, __m128 b) {
// FIXME riscv round doesn't work
vfloat32m1_t _a = vreinterpretq_m128_f32(a);
vfloat32m1_t _b = vreinterpretq_m128_f32(b);
float arr[4];
const int len = 4;
__riscv_vse32_v_f32m1(arr, _b, len);
arr[0] = ceil(arr[0]);
vfloat32m1_t _arr = __riscv_vle32_v_f32m1(arr, 1);
return vreinterpretq_f32_m128(__riscv_vslideup_vx_f32m1(_a, _arr, 0, 1));
}

// FORCE_INLINE void _mm_clflush (void const* p) {}

Expand Down
76 changes: 38 additions & 38 deletions tests/impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9438,53 +9438,53 @@ result_t test_mm_blendv_ps(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
}

result_t test_mm_ceil_pd(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const double *_a = (const double *)impl.test_cases_float_pointer1;
//
// double dx = ceil(_a[0]);
// double dy = ceil(_a[1]);
//
// __m128d a = load_m128d(_a);
// __m128d ret = _mm_ceil_pd(a);
//
// return validate_double(ret, dx, dy);
// #else
#ifdef ENABLE_TEST_ALL
const double *_a = (const double *)impl.test_cases_float_pointer1;

double dx = ceil(_a[0]);
double dy = ceil(_a[1]);

__m128d a = load_m128d(_a);
__m128d ret = _mm_ceil_pd(a);

return validate_double(ret, dx, dy);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_ceil_ps(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const float *_a = impl.test_cases_float_pointer1;
// float dx = ceilf(_a[0]);
// float dy = ceilf(_a[1]);
// float dz = ceilf(_a[2]);
// float dw = ceilf(_a[3]);
//
// __m128 a = _mm_load_ps(_a);
// __m128 c = _mm_ceil_ps(a);
// return validate_float(c, dx, dy, dz, dw);
// #else
#ifdef ENABLE_TEST_ALL
const float *_a = impl.test_cases_float_pointer1;
float dx = ceilf(_a[0]);
float dy = ceilf(_a[1]);
float dz = ceilf(_a[2]);
float dw = ceilf(_a[3]);

__m128 a = _mm_load_ps(_a);
__m128 c = _mm_ceil_ps(a);
return validate_float(c, dx, dy, dz, dw);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_ceil_sd(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const double *_a = (const double *)impl.test_cases_float_pointer1;
// const double *_b = (const double *)impl.test_cases_float_pointer2;
//
// double dx = ceil(_b[0]);
// double dy = _a[1];
//
// __m128d a = load_m128d(_a);
// __m128d b = load_m128d(_b);
// __m128d ret = _mm_ceil_sd(a, b);
//
// return validate_double(ret, dx, dy);
// #else
#ifdef ENABLE_TEST_ALL
const double *_a = (const double *)impl.test_cases_float_pointer1;
const double *_b = (const double *)impl.test_cases_float_pointer2;

double dx = ceil(_b[0]);
double dy = _a[1];

__m128d a = load_m128d(_a);
__m128d b = load_m128d(_b);
__m128d ret = _mm_ceil_sd(a, b);

return validate_double(ret, dx, dy);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_ceil_ss(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
Expand Down

0 comments on commit 89418a8

Please sign in to comment.