From cf3e6efc792a5803a7a3104263dd21c35a6373de Mon Sep 17 00:00:00 2001 From: Christopher Dryden Date: Mon, 29 Apr 2024 02:28:23 +0000 Subject: [PATCH] Changed ordering of type configuration to easily see unchanged values --- train_gpt2.cu | 36 +++++++++++++++--------------------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/train_gpt2.cu b/train_gpt2.cu index 0d536fdad..9d59f8608 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -63,49 +63,43 @@ enum PrecisionMode { PRECISION_BF16 }; -// fp32 +// Default Properties +typedef float floatN; +#define CUBLAS_LOWP_COMPUTE cublas_compute_type +#ifdef MULTI_GPU +const ncclDataType_t ncclFloatN = ncclFloat; +#endif + +// Specific configurations based on the enabled precision #if defined(ENABLE_FP32) typedef float floatX; -typedef float floatN; #define CUBLAS_LOWP CUDA_R_32F -#define CUBLAS_LOWP_COMPUTE cublas_compute_type // auto-select FP32 vs TF32 -const char* load_filename = "gpt2_124M.bin"; // fp32 weights -PrecisionMode PRECISION_MODE = PRECISION_FP32; +#define PRECISION_MODE PRECISION_FP32 +const char* load_filename = "gpt2_124M.bin"; const char* precision_mode_str = "fp32"; - #ifdef MULTI_GPU const ncclDataType_t ncclFloatX = ncclFloat; -const ncclDataType_t ncclFloatN = ncclFloat; #endif // use fp16 (note: this may require gradient scaler, currently not implemented!) #elif defined(ENABLE_FP16) typedef half floatX; -typedef float floatN; #define CUBLAS_LOWP CUDA_R_16F -#define CUBLAS_LOWP_COMPUTE CUBLAS_COMPUTE_32F -const char* load_filename = "gpt2_124M.bin"; // fp32 weights -PrecisionMode PRECISION_MODE = PRECISION_FP16; +#define PRECISION_MODE PRECISION_FP16 +const char* load_filename = "gpt2_124M.bin"; const char* precision_mode_str = "fp16"; - #ifdef MULTI_GPU const ncclDataType_t ncclFloatX = ncclHalf; -const ncclDataType_t ncclFloatN = ncclFloat; #endif -// bfloat16 (default!) -#else +#else // Default to bfloat16 typedef __nv_bfloat16 floatX; -typedef float floatN; #define CUBLAS_LOWP CUDA_R_16BF -#define CUBLAS_LOWP_COMPUTE CUBLAS_COMPUTE_32F -const char* load_filename = "gpt2_124M_bf16.bin"; // bf16 weights -PrecisionMode PRECISION_MODE = PRECISION_BF16; +#define PRECISION_MODE PRECISION_BF16 +const char* load_filename = "gpt2_124M_bf16.bin"; // bf16 weights specific filename const char* precision_mode_str = "bf16"; - #ifdef MULTI_GPU const ncclDataType_t ncclFloatX = ncclBfloat16; -const ncclDataType_t ncclFloatN = ncclFloat; #endif #endif