diff --git a/candle-kernels/src/cast.cu b/candle-kernels/src/cast.cu index ee20fe5fe8..024642c6b1 100644 --- a/candle-kernels/src/cast.cu +++ b/candle-kernels/src/cast.cu @@ -1,6 +1,53 @@ #include "cuda_utils.cuh" #include +template +__device__ void cast_( + const size_t numel, + const size_t num_dims, + const size_t *info, + const S *inp, + T *out +) { + const size_t *dims = info; + const size_t *strides = info + num_dims; + if (is_contiguous(num_dims, dims, strides)) { + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + out[i] = inp[i]; + } + } + else { + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + unsigned strided_i = get_strided_index(i, num_dims, dims, strides); + out[i] = inp[strided_i]; + } + } +} + +template +__device__ void cast_through( + const size_t numel, + const size_t num_dims, + const size_t *info, + const S *inp, + T *out +) { + const size_t *dims = info; + const size_t *strides = info + num_dims; + if (is_contiguous(num_dims, dims, strides)) { + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + out[i] = static_cast(static_cast(inp[i])); + } + } + else { + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + unsigned strided_i = get_strided_index(i, num_dims, dims, strides); + out[i] = static_cast(static_cast(inp[strided_i])); + } + } +} + + #define CAST_OP(SRC_TYPENAME, DST_TYPENAME, FN_NAME) \ extern "C" __global__ void FN_NAME( \ const size_t numel, \ @@ -9,22 +56,10 @@ extern "C" __global__ void FN_NAME( \ const SRC_TYPENAME *inp, \ DST_TYPENAME *out \ ) { \ - const size_t *dims = info; \ - const size_t *strides = info + num_dims; \ - if (is_contiguous(num_dims, dims, strides)) { \ - for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ - out[i] = inp[i]; \ - } \ - } \ - else { \ - for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ - unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \ - out[i] = inp[strided_i]; \ - } \ - } \ + cast_(numel, num_dims, info, inp, out); \ } \ -#define CAST_BF_OP(SRC_TYPENAME, DST_TYPENAME, FN_NAME) \ +#define CAST_THROUGH_OP(SRC_TYPENAME, DST_TYPENAME, INT_TYPENAME, FN_NAME) \ extern "C" __global__ void FN_NAME( \ const size_t numel, \ const size_t num_dims, \ @@ -32,25 +67,12 @@ extern "C" __global__ void FN_NAME( \ const SRC_TYPENAME *inp, \ DST_TYPENAME *out \ ) { \ - const size_t *dims = info; \ - const size_t *strides = info + num_dims; \ - if (is_contiguous(num_dims, dims, strides)) { \ - for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ - out[i] = (DST_TYPENAME) (float) inp[i]; \ - } \ - } \ - else { \ - for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ - unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \ - out[i] = (DST_TYPENAME) (float) inp[strided_i]; \ - } \ - } \ + cast_through(numel, num_dims, info, inp, out); \ } \ #if __CUDA_ARCH__ >= 800 CAST_OP(__nv_bfloat16, __nv_bfloat16, cast_bf16_bf16) -// CAST_OP(__nv_bfloat16, uint8_t, cast_bf16_u8) CAST_OP(__nv_bfloat16, uint32_t, cast_bf16_u32) CAST_OP(__nv_bfloat16, float, cast_bf16_f32) CAST_OP(__nv_bfloat16, double, cast_bf16_f64) @@ -58,14 +80,15 @@ CAST_OP(uint8_t, __nv_bfloat16, cast_u8_bf16) CAST_OP(uint32_t, __nv_bfloat16, cast_u32_bf16) CAST_OP(float, __nv_bfloat16, cast_f32_bf16) CAST_OP(double, __nv_bfloat16, cast_f64_bf16) -CAST_BF_OP(__nv_bfloat16, __half, cast_bf16_f16) -CAST_BF_OP(__half, __nv_bfloat16, cast_f16_bf16) +CAST_THROUGH_OP(__nv_bfloat16, uint8_t, float, cast_bf16_u8) +CAST_THROUGH_OP(__nv_bfloat16, __half, float, cast_bf16_f16) +CAST_THROUGH_OP(__half, __nv_bfloat16, float, cast_f16_bf16) #endif #if __CUDA_ARCH__ >= 530 CAST_OP(__half, __half, cast_f16_f16) -// CAST_OP(__half, uint8_t, cast_f16_u8 ) +CAST_THROUGH_OP(__half, uint8_t, float, cast_f16_u8) CAST_OP(__half, uint32_t, cast_f16_u32) CAST_OP(__half, float, cast_f16_f32) CAST_OP(__half, double, cast_f16_f64)