diff --git a/llmc/cuda_utils.cuh b/llmc/cuda_utils.cuh index 392c4711f..14ee2b2b3 100644 --- a/llmc/cuda_utils.cuh +++ b/llmc/cuda_utils.cuh @@ -79,6 +79,37 @@ __device__ void store128cg(ElementType* target, Packed128 value) { typedef Packed128 f128; typedef Packed128 x128; +// ---------------------------------------------------------------------------- +// Copy, cast functions + +// device functions and the kernel to cast data between types +template +__device__ Td cast_value(Ts val); + +template<> +__device__ float cast_value(float val) { + return val; +} + +template<> +__device__ float cast_value(half val) { + return __half2float(val); +} + +template<> +__device__ float cast_value(__nv_bfloat16 val) { + return __bfloat162float(val); +} + +template +__global__ void copy_and_cast_kernel(Td* dst, const Ts* src, size_t n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + // need to try grid stride looping for more perf later + if (idx < n) { + dst[idx] = cast_value(src[idx]); + } +} + // ---------------------------------------------------------------------------- // Warp/Block communication primitives diff --git a/train_gpt2.cu b/train_gpt2.cu index 489f1be45..f9eeac37a 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -34,7 +34,7 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage. #include "llmc/cuda_common.h" // defines: // Packed128, f128, x128 -// warpReduceSum, warpReduceMax, blockReduce +// warpReduceSum, warpReduceMax, blockReduce, copy_and_cast_kernel #include "llmc/cuda_utils.cuh" // defines: CUBLAS_LOWP, cublasCheck, cublaslt_workspace_size, cublaslt_workspace // defines: cublas_compute, cublaslt_handle, cublas_handle @@ -250,37 +250,6 @@ void set_zero_configs(MultiGpuConfig* multi_gpu_config, int zero_stage, size_t t } } -// ---------------------------------------------------------------------------- -// Kernels - -// device functions and the kernel to cast data between types -template -__device__ Td cast_value(Ts val); - -template<> -__device__ float cast_value(float val) { - return val; -} - -template<> -__device__ float cast_value(half val) { - return __half2float(val); -} - -template<> -__device__ float cast_value(__nv_bfloat16 val) { - return __bfloat162float(val); -} - -template -__global__ void copy_and_cast_kernel(Td* dst, const Ts* src, size_t n) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - // need to try grid stride looping for more perf later - if (idx < n) { - dst[idx] = cast_value(src[idx]); - } -} - // ---------------------------------------------------------------------------- // GPT-2 model definition