Skip to content

Commit

Permalink
Changed ordering of type configuration to easily see unchanged values
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisDryden committed Apr 29, 2024
1 parent 50acc12 commit cf3e6ef
Showing 1 changed file with 15 additions and 21 deletions.
36 changes: 15 additions & 21 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,49 +63,43 @@ enum PrecisionMode {
PRECISION_BF16
};

// fp32
// Default Properties
typedef float floatN;
#define CUBLAS_LOWP_COMPUTE cublas_compute_type
#ifdef MULTI_GPU
const ncclDataType_t ncclFloatN = ncclFloat;
#endif

// Specific configurations based on the enabled precision
#if defined(ENABLE_FP32)
typedef float floatX;
typedef float floatN;
#define CUBLAS_LOWP CUDA_R_32F
#define CUBLAS_LOWP_COMPUTE cublas_compute_type // auto-select FP32 vs TF32
const char* load_filename = "gpt2_124M.bin"; // fp32 weights
PrecisionMode PRECISION_MODE = PRECISION_FP32;
#define PRECISION_MODE PRECISION_FP32
const char* load_filename = "gpt2_124M.bin";
const char* precision_mode_str = "fp32";

#ifdef MULTI_GPU
const ncclDataType_t ncclFloatX = ncclFloat;
const ncclDataType_t ncclFloatN = ncclFloat;
#endif

// use fp16 (note: this may require gradient scaler, currently not implemented!)
#elif defined(ENABLE_FP16)
typedef half floatX;
typedef float floatN;
#define CUBLAS_LOWP CUDA_R_16F
#define CUBLAS_LOWP_COMPUTE CUBLAS_COMPUTE_32F
const char* load_filename = "gpt2_124M.bin"; // fp32 weights
PrecisionMode PRECISION_MODE = PRECISION_FP16;
#define PRECISION_MODE PRECISION_FP16
const char* load_filename = "gpt2_124M.bin";
const char* precision_mode_str = "fp16";

#ifdef MULTI_GPU
const ncclDataType_t ncclFloatX = ncclHalf;
const ncclDataType_t ncclFloatN = ncclFloat;
#endif

// bfloat16 (default!)
#else
#else // Default to bfloat16
typedef __nv_bfloat16 floatX;
typedef float floatN;
#define CUBLAS_LOWP CUDA_R_16BF
#define CUBLAS_LOWP_COMPUTE CUBLAS_COMPUTE_32F
const char* load_filename = "gpt2_124M_bf16.bin"; // bf16 weights
PrecisionMode PRECISION_MODE = PRECISION_BF16;
#define PRECISION_MODE PRECISION_BF16
const char* load_filename = "gpt2_124M_bf16.bin"; // bf16 weights specific filename
const char* precision_mode_str = "bf16";

#ifdef MULTI_GPU
const ncclDataType_t ncclFloatX = ncclBfloat16;
const ncclDataType_t ncclFloatN = ncclFloat;
#endif
#endif

Expand Down

0 comments on commit cf3e6ef

Please sign in to comment.