Skip to content

Commit

Permalink
feat(cuda): new fft impl
Browse files Browse the repository at this point in the history
  • Loading branch information
sarah el kazdadi committed Aug 29, 2024
1 parent 6e2908a commit f49213e
Show file tree
Hide file tree
Showing 7 changed files with 438 additions and 641 deletions.
64 changes: 58 additions & 6 deletions backends/tfhe-cuda-backend/cuda/src/crypto/torus.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,64 @@ __device__ inline void typecast_double_to_torus<uint32_t>(double x,
template <>
__device__ inline void typecast_double_to_torus<uint64_t>(double x,
uint64_t &r) {
// The ull intrinsic does not behave in the same way on all architectures and
// on some platforms this causes the cmux tree test to fail
// Hence the intrinsic is not used here
uint128 nnnn = make_uint128_from_float(x);
uint64_t lll = nnnn.lo_;
r = lll;
uint64_t x_bits = *((uint64_t const *)(&x));

uint64_t biased_exp = (x_bits >> 52) & 0x7FF;

int exp = int(biased_exp) - 1023;
int shift = exp - 52;

uint64_t mantissa = x_bits & ((uint64_t(1) << 52) - 1);
mantissa |= (uint64_t(1) << 52);

shift = shift < -63 ? -63 : shift;
shift = shift > 63 ? 63 : shift;
bool positive = shift >= 0;
shift = positive ? shift : (-1 - shift);

uint64_t left_shift = mantissa << shift;
uint64_t right_shift = mantissa >> shift;
right_shift += right_shift & 1;
right_shift >>= 1;

mantissa = positive ? left_shift : right_shift;
r = mantissa;
}

template <typename T>
__device__ inline void typecast_double_round_to_torus(double x, T &r) {
double mx = (sizeof(T) == 4) ? 4294967296.0 : 18446744073709551616.0;
double frac = x - floor(x);
frac *= mx;
typecast_double_to_torus(frac, r);
}

template <>
__device__ inline void typecast_double_round_to_torus<uint64_t>(double x,
uint64_t &r) {
uint64_t x_bits = *((uint64_t const *)(&x));

uint64_t biased_exp = (x_bits >> 52) & 0x7FF;
bool sign = x_bits >> 63 != 0;

int exp = int(biased_exp) + (64 - 1023);
int shift = exp - 52;

uint64_t mantissa = x_bits & ((uint64_t(1) << 52) - 1);
mantissa |= (uint64_t(1) << 52);

shift = shift < -63 ? -63 : shift;
shift = shift > 63 ? 63 : shift;
bool positive = shift >= 0;
shift = positive ? shift : (-1 - shift);

uint64_t left_shift = mantissa << shift;
uint64_t right_shift = mantissa >> shift;
right_shift += right_shift & 1;
right_shift >>= 1;

mantissa = positive ? left_shift : right_shift;
r = sign ? -mantissa : mantissa;
}

template <typename T>
Expand Down
Loading

0 comments on commit f49213e

Please sign in to comment.