Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add _mm_[ceil|floor]_[pd|ps|sd|ss] #66

Merged
merged 1 commit into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 89 additions & 8 deletions sse2rvv.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#define _sse2rvv_const const
#endif

#include <math.h>
#include <riscv_vector.h>
#include <stdint.h>
#include <stdlib.h>
Expand Down Expand Up @@ -560,13 +561,53 @@ FORCE_INLINE __m128 _mm_castsi128_ps(__m128i a) {
return __riscv_vreinterpret_v_i32m1_f32m1(a);
}

// FORCE_INLINE __m128d _mm_ceil_pd (__m128d a) {}
FORCE_INLINE __m128d _mm_ceil_pd(__m128d a) {
// FIXME riscv round doesn't work
vfloat64m1_t _a = vreinterpretq_m128d_f64(a);
double arr[2];
const int len = 2;
__riscv_vse64_v_f64m1(arr, _a, len);
for (int i = 0; i < len; i++) {
arr[i] = ceil(arr[i]);
}
return vreinterpretq_f64_m128d(__riscv_vle64_v_f64m1(arr, len));
}

// FORCE_INLINE __m128 _mm_ceil_ps (__m128 a) {}
FORCE_INLINE __m128 _mm_ceil_ps(__m128 a) {
// FIXME riscv round doesn't work
vfloat32m1_t _a = vreinterpretq_m128_f32(a);
float arr[4];
const int len = 4;
__riscv_vse32_v_f32m1(arr, _a, len);
for (int i = 0; i < len; i++) {
arr[i] = ceil(arr[i]);
}
return vreinterpretq_f32_m128(__riscv_vle32_v_f32m1(arr, len));
}

// FORCE_INLINE __m128d _mm_ceil_sd (__m128d a, __m128d b) {}
FORCE_INLINE __m128d _mm_ceil_sd(__m128d a, __m128d b) {
// FIXME riscv round doesn't work
vfloat64m1_t _a = vreinterpretq_m128d_f64(a);
vfloat64m1_t _b = vreinterpretq_m128d_f64(b);
double arr[2];
const int len = 2;
__riscv_vse64_v_f64m1(arr, _b, len);
arr[0] = ceil(arr[0]);
vfloat64m1_t _arr = __riscv_vle64_v_f64m1(arr, 1);
return vreinterpretq_f64_m128d(__riscv_vslideup_vx_f64m1(_a, _arr, 0, 1));
}

// FORCE_INLINE __m128 _mm_ceil_ss (__m128 a, __m128 b) {}
FORCE_INLINE __m128 _mm_ceil_ss(__m128 a, __m128 b) {
// FIXME riscv round doesn't work
vfloat32m1_t _a = vreinterpretq_m128_f32(a);
vfloat32m1_t _b = vreinterpretq_m128_f32(b);
float arr[4];
const int len = 4;
__riscv_vse32_v_f32m1(arr, _b, len);
arr[0] = ceil(arr[0]);
vfloat32m1_t _arr = __riscv_vle32_v_f32m1(arr, 1);
return vreinterpretq_f32_m128(__riscv_vslideup_vx_f32m1(_a, _arr, 0, 1));
}

// FORCE_INLINE void _mm_clflush (void const* p) {}

Expand Down Expand Up @@ -1460,13 +1501,53 @@ FORCE_INLINE int _mm_extract_ps(__m128 a, const int imm8) {
return (int)__riscv_vmv_x_s_i32m1_i32(a_s);
}

// FORCE_INLINE __m128d _mm_floor_pd (__m128d a) {}
FORCE_INLINE __m128d _mm_floor_pd(__m128d a) {
// FIXME riscv round doesn't work
vfloat64m1_t _a = vreinterpretq_m128d_f64(a);
double arr[2];
const int len = 2;
__riscv_vse64_v_f64m1(arr, _a, len);
for (int i = 0; i < len; i++) {
arr[i] = floor(arr[i]);
}
return vreinterpretq_f64_m128d(__riscv_vle64_v_f64m1(arr, len));
}

// FORCE_INLINE __m128 _mm_floor_ps (__m128 a) {}
FORCE_INLINE __m128 _mm_floor_ps(__m128 a) {
// FIXME riscv round doesn't work
vfloat32m1_t _a = vreinterpretq_m128_f32(a);
float arr[4];
const int len = 4;
__riscv_vse32_v_f32m1(arr, _a, len);
for (int i = 0; i < len; i++) {
arr[i] = floor(arr[i]);
}
return vreinterpretq_f32_m128(__riscv_vle32_v_f32m1(arr, len));
}

// FORCE_INLINE __m128d _mm_floor_sd (__m128d a, __m128d b) {}
FORCE_INLINE __m128d _mm_floor_sd(__m128d a, __m128d b) {
// FIXME riscv round doesn't work
vfloat64m1_t _a = vreinterpretq_m128d_f64(a);
vfloat64m1_t _b = vreinterpretq_m128d_f64(b);
double arr[2];
const int len = 2;
__riscv_vse64_v_f64m1(arr, _b, len);
arr[0] = floor(arr[0]);
vfloat64m1_t _arr = __riscv_vle64_v_f64m1(arr, 1);
return vreinterpretq_f64_m128d(__riscv_vslideup_vx_f64m1(_a, _arr, 0, 1));
}

// FORCE_INLINE __m128 _mm_floor_ss (__m128 a, __m128 b) {}
FORCE_INLINE __m128 _mm_floor_ss(__m128 a, __m128 b) {
// FIXME riscv round doesn't work
vfloat32m1_t _a = vreinterpretq_m128_f32(a);
vfloat32m1_t _b = vreinterpretq_m128_f32(b);
float arr[4];
const int len = 4;
__riscv_vse32_v_f32m1(arr, _b, len);
arr[0] = floor(arr[0]);
vfloat32m1_t _arr = __riscv_vle32_v_f32m1(arr, 1);
return vreinterpretq_f32_m128(__riscv_vslideup_vx_f32m1(_a, _arr, 0, 1));
}

FORCE_INLINE void _mm_free(void *mem_addr) { free(mem_addr); }

Expand Down
178 changes: 89 additions & 89 deletions tests/impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9438,53 +9438,53 @@ result_t test_mm_blendv_ps(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
}

result_t test_mm_ceil_pd(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const double *_a = (const double *)impl.test_cases_float_pointer1;
//
// double dx = ceil(_a[0]);
// double dy = ceil(_a[1]);
//
// __m128d a = load_m128d(_a);
// __m128d ret = _mm_ceil_pd(a);
//
// return validate_double(ret, dx, dy);
// #else
#ifdef ENABLE_TEST_ALL
const double *_a = (const double *)impl.test_cases_float_pointer1;

double dx = ceil(_a[0]);
double dy = ceil(_a[1]);

__m128d a = load_m128d(_a);
__m128d ret = _mm_ceil_pd(a);

return validate_double(ret, dx, dy);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_ceil_ps(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const float *_a = impl.test_cases_float_pointer1;
// float dx = ceilf(_a[0]);
// float dy = ceilf(_a[1]);
// float dz = ceilf(_a[2]);
// float dw = ceilf(_a[3]);
//
// __m128 a = _mm_load_ps(_a);
// __m128 c = _mm_ceil_ps(a);
// return validate_float(c, dx, dy, dz, dw);
// #else
#ifdef ENABLE_TEST_ALL
const float *_a = impl.test_cases_float_pointer1;
float dx = ceilf(_a[0]);
float dy = ceilf(_a[1]);
float dz = ceilf(_a[2]);
float dw = ceilf(_a[3]);

__m128 a = _mm_load_ps(_a);
__m128 c = _mm_ceil_ps(a);
return validate_float(c, dx, dy, dz, dw);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_ceil_sd(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const double *_a = (const double *)impl.test_cases_float_pointer1;
// const double *_b = (const double *)impl.test_cases_float_pointer2;
//
// double dx = ceil(_b[0]);
// double dy = _a[1];
//
// __m128d a = load_m128d(_a);
// __m128d b = load_m128d(_b);
// __m128d ret = _mm_ceil_sd(a, b);
//
// return validate_double(ret, dx, dy);
// #else
#ifdef ENABLE_TEST_ALL
const double *_a = (const double *)impl.test_cases_float_pointer1;
const double *_b = (const double *)impl.test_cases_float_pointer2;

double dx = ceil(_b[0]);
double dy = _a[1];

__m128d a = load_m128d(_a);
__m128d b = load_m128d(_b);
__m128d ret = _mm_ceil_sd(a, b);

return validate_double(ret, dx, dy);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_ceil_ss(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
Expand Down Expand Up @@ -9896,70 +9896,70 @@ result_t test_mm_extract_ps(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
}

result_t test_mm_floor_pd(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const double *_a = (const double *)impl.test_cases_float_pointer1;
//
// double dx = floor(_a[0]);
// double dy = floor(_a[1]);
//
// __m128d a = load_m128d(_a);
// __m128d ret = _mm_floor_pd(a);
//
// return validate_double(ret, dx, dy);
// #else
#ifdef ENABLE_TEST_ALL
const double *_a = (const double *)impl.test_cases_float_pointer1;

double dx = floor(_a[0]);
double dy = floor(_a[1]);

__m128d a = load_m128d(_a);
__m128d ret = _mm_floor_pd(a);

return validate_double(ret, dx, dy);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_floor_ps(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const float *_a = impl.test_cases_float_pointer1;
// float dx = floorf(_a[0]);
// float dy = floorf(_a[1]);
// float dz = floorf(_a[2]);
// float dw = floorf(_a[3]);
//
// __m128 a = load_m128(_a);
// __m128 c = _mm_floor_ps(a);
// return validate_float(c, dx, dy, dz, dw);
// #else
#ifdef ENABLE_TEST_ALL
const float *_a = impl.test_cases_float_pointer1;
float dx = floorf(_a[0]);
float dy = floorf(_a[1]);
float dz = floorf(_a[2]);
float dw = floorf(_a[3]);

__m128 a = load_m128(_a);
__m128 c = _mm_floor_ps(a);
return validate_float(c, dx, dy, dz, dw);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_floor_sd(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const double *_a = (const double *)impl.test_cases_float_pointer1;
// const double *_b = (const double *)impl.test_cases_float_pointer2;
//
// double dx = floor(_b[0]);
// double dy = _a[1];
//
// __m128d a = load_m128d(_a);
// __m128d b = load_m128d(_b);
// __m128d ret = _mm_floor_sd(a, b);
//
// return validate_double(ret, dx, dy);
// #else
#ifdef ENABLE_TEST_ALL
const double *_a = (const double *)impl.test_cases_float_pointer1;
const double *_b = (const double *)impl.test_cases_float_pointer2;

double dx = floor(_b[0]);
double dy = _a[1];

__m128d a = load_m128d(_a);
__m128d b = load_m128d(_b);
__m128d ret = _mm_floor_sd(a, b);

return validate_double(ret, dx, dy);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_floor_ss(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const float *_a = impl.test_cases_float_pointer1;
// const float *_b = impl.test_cases_float_pointer1;
//
// float f0 = floorf(_b[0]);
//
// __m128 a = load_m128(_a);
// __m128 b = load_m128(_b);
// __m128 c = _mm_floor_ss(a, b);
//
// return validate_float(c, f0, _a[1], _a[2], _a[3]);
// #else
#ifdef ENABLE_TEST_ALL
const float *_a = impl.test_cases_float_pointer1;
const float *_b = impl.test_cases_float_pointer1;

float f0 = floorf(_b[0]);

__m128 a = load_m128(_a);
__m128 b = load_m128(_b);
__m128 c = _mm_floor_ss(a, b);

return validate_float(c, f0, _a[1], _a[2], _a[3]);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_insert_epi32(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
Expand Down
Loading