Skip to content

Commit

Permalink
Eps double -> foat
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Mar 12, 2024
1 parent 7248092 commit 4bf0f7c
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions candle-kernels/src/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ template <typename scalar_t>
__device__ void
rms_norm_kernel(scalar_t *__restrict__ out, // [num_tokens, hidden_size]
const scalar_t *__restrict__ input, // [num_tokens, hidden_size]
const double epsilon, const uint32_t num_tokens,
const float epsilon, const uint32_t num_tokens,
const uint32_t hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
Expand All @@ -30,47 +30,47 @@ rms_norm_kernel(scalar_t *__restrict__ out, // [num_tokens, hidden_size]
extern "C" __global__ void rms_norm_u8(
uint8_t *__restrict__ out, \
const uint8_t *__restrict__ input, \
const double epsilon, const uint32_t num_tokens, \
const float epsilon, const uint32_t num_tokens, \
const uint32_t hidden_size) { \
rms_norm_kernel(out, input, epsilon, num_tokens, hidden_size); \
}

extern "C" __global__ void rms_norm_u32(
uint32_t *__restrict__ out, \
const uint32_t *__restrict__ input, \
const double epsilon, const uint32_t num_tokens, \
const float epsilon, const uint32_t num_tokens, \
const uint32_t hidden_size) { \
rms_norm_kernel(out, input, epsilon, num_tokens, hidden_size); \
}

extern "C" __global__ void rms_norm_i64(
int64_t *__restrict__ out, \
const int64_t *__restrict__ input, \
const double epsilon, const uint32_t num_tokens, \
const float epsilon, const uint32_t num_tokens, \
const uint32_t hidden_size) { \
rms_norm_kernel(out, input, epsilon, num_tokens, hidden_size); \
}

extern "C" __global__ void rms_norm_f16(
__half *__restrict__ out, \
const __half *__restrict__ input, \
const double epsilon, const uint32_t num_tokens, \
const float epsilon, const uint32_t num_tokens, \
const uint32_t hidden_size) { \
rms_norm_kernel(out, input, epsilon, num_tokens, hidden_size); \
}

extern "C" __global__ void rms_norm_f32(
float *__restrict__ out, \
const float *__restrict__ input, \
const double epsilon, const uint32_t num_tokens, \
const float epsilon, const uint32_t num_tokens, \
const uint32_t hidden_size) { \
rms_norm_kernel(out, input, epsilon, num_tokens, hidden_size); \
}

extern "C" __global__ void rms_norm_f64(
double *__restrict__ out, \
const double *__restrict__ input, \
const double epsilon, const uint32_t num_tokens, \
const float epsilon, const uint32_t num_tokens, \
const uint32_t hidden_size) { \
rms_norm_kernel(out, input, epsilon, num_tokens, hidden_size); \
}
Expand All @@ -80,7 +80,7 @@ extern "C" __global__ void rms_norm_f64(
extern "C" __global__ void rms_norm_bf16(
__nv_bfloat16 *__restrict__ out, \
const __nv_bfloat16 *__restrict__ input, \
const double epsilon, const uint32_t num_tokens, \
const float epsilon, const uint32_t num_tokens, \
const uint32_t hidden_size) { \
rms_norm_kernel(out, input, epsilon, num_tokens, hidden_size); \
}
Expand Down

0 comments on commit 4bf0f7c

Please sign in to comment.