From af56fa6b38ecc9cf059f427b859ccffac900c178 Mon Sep 17 00:00:00 2001 From: Chinthaka Gamanayakege Date: Mon, 10 Jun 2024 10:13:37 +0000 Subject: [PATCH] replacing mpi functionality in train_gpt2.cu using distributed iface --- train_gpt2.cu | 98 ++++++++++++++------------------------------------- 1 file changed, 26 insertions(+), 72 deletions(-) diff --git a/train_gpt2.cu b/train_gpt2.cu index 286afbb6d..01c4de6de 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -63,8 +63,7 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage. #include "llmc/global_norm.cuh" // ----------- Multi-GPU support ----------- #ifdef MULTI_GPU -#include -#include +#include "llmc/distributed.h" #endif // ---------------------------------------------------------------------------- @@ -92,17 +91,6 @@ void nccl_check(ncclResult_t status, const char *file, int line) { } #define ncclCheck(err) (nccl_check(err, __FILE__, __LINE__)) -void mpi_check(int status, const char *file, int line) { - if (status != MPI_SUCCESS) { - char mpi_error[4096]; - int mpi_error_len = 0; - assert(MPI_Error_string(status, &mpi_error[0], &mpi_error_len) == MPI_SUCCESS); - printf("[MPI ERROR] at file %s:%d:\n%.*s\n", file, line, mpi_error_len, mpi_error); - exit(EXIT_FAILURE); - } -} -#define mpiCheck(err) (mpi_check(err, __FILE__, __LINE__)) - #endif // MULTI_GPU // ---------------------------------------------------------------------------- @@ -130,61 +118,21 @@ typedef struct { // one global variable to hold the multi-GPU configuration for this process MultiGpuConfig multi_gpu_config; +MultiGpuConfig multi_gpu_config_init(int process_rank, int num_processes, int gpus_per_node,const char *server_ip, int server_port) { #ifdef MULTI_GPU -// Determine which GPU this process should use. -// Processes on the same machines use different GPU indicies. Processes on other machines don't. -// Copied from NCCL examples: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/examples.html#example-2-one-device-per-process-or-thread -int multi_gpu_get_local_device_idx(int process_rank, int num_processes) { - char hostname[1024]; - hostname[1023] = '\0'; - // All processes on the same machine will share the same hostname. - gethostname(hostname, 1023); - for (int i=0; i < 1024; i++) { - if (hostname[i] == '.') { - hostname[i] = '\0'; - break; - } - } - uint64_t hostname_hash = 5381; - for (int c = 0; hostname[c] != '\0'; c++){ hostname_hash = ((hostname_hash << 5) + hostname_hash) ^ hostname[c]; } - - // Distribute all hostname hashes to all processes. - uint64_t* all_hostsname_hashes = (uint64_t*)malloc(num_processes * sizeof(uint64_t)); - all_hostsname_hashes[process_rank] = hostname_hash; - mpiCheck(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, all_hostsname_hashes, sizeof(uint64_t), MPI_BYTE, MPI_COMM_WORLD)); - - // Identify which GPU we need to use. - int local_device_idx = 0; - for (int current_process = 0; current_process < num_processes; ++current_process) { - if (current_process == process_rank) { - // Found my gpu, local_device_idx now has my target GPU index. - break; - } - if (all_hostsname_hashes[current_process] == all_hostsname_hashes[process_rank]) { - // This process ID runs on the same machine, but it's not me, skip this GPU - local_device_idx++; - } - } - - free(all_hostsname_hashes); - return local_device_idx; -} -#endif + distributed_init(num_processes, process_rank, server_ip, server_port); -MultiGpuConfig multi_gpu_config_init(int *argc, char ***argv) { -#ifdef MULTI_GPU - // Initialize MPI. MultiGpuConfig result; - mpiCheck(MPI_Init(argc, argv)); - mpiCheck(MPI_Comm_rank(MPI_COMM_WORLD, &result.process_rank)); - mpiCheck(MPI_Comm_size(MPI_COMM_WORLD, &result.num_processes)); - result.local_device_idx = multi_gpu_get_local_device_idx(result.process_rank, result.num_processes); + result.process_rank = process_rank; + result.num_processes = num_processes; + result.local_device_idx = process_rank % gpus_per_node; + cudaCheck(cudaSetDevice(result.local_device_idx)); ncclUniqueId nccl_id; if (result.process_rank == 0) { ncclCheck(ncclGetUniqueId(&nccl_id)); } - mpiCheck(MPI_Bcast((void *)&nccl_id, sizeof(nccl_id), MPI_BYTE, 0, MPI_COMM_WORLD)); + distributed_broadcast(&nccl_id); ncclCheck(ncclCommInitRank(&result.nccl_comm, result.num_processes, nccl_id, result.process_rank)); return result; #else @@ -201,15 +149,12 @@ MultiGpuConfig multi_gpu_config_init(int *argc, char ***argv) { void multi_gpu_config_free(const MultiGpuConfig* multi_gpu_config) { #ifdef MULTI_GPU ncclCheck(ncclCommDestroy(multi_gpu_config->nccl_comm)); - mpiCheck(MPI_Finalize()); #endif } void multi_gpu_barrier(const MultiGpuConfig* multi_gpu_config) { #ifdef MULTI_GPU - if (multi_gpu_config->num_processes > 1) { - mpiCheck(MPI_Barrier(MPI_COMM_WORLD)); - } + distributed_barrier(); #endif } @@ -1062,13 +1007,9 @@ void gpt2_backward(GPT2 *model, int* inputs) { // Compute sum of a single CPU value across all GPU processes. No-op when multi-GPU is disabled. float multi_gpu_cpu_float_sum(float value) { #ifdef MULTI_GPU - // note MPI doesn't support all reduce with mean, only sum - float result; - mpiCheck(MPI_Allreduce(&value, &result, 1, MPI_FLOAT, MPI_SUM, MPI_COMM_WORLD)); - return result; -#else - return value; + distributed_reduce(&value); #endif + return value; } // Averages out the loss and gradients across all GPUs. No-op when multi-GPU is disabled. @@ -1438,8 +1379,6 @@ void error_usage() { // ---------------------------------------------------------------------------- // main training loop int main(int argc, char *argv[]) { - multi_gpu_config = multi_gpu_config_init(&argc, &argv); - // read in the (optional) command line arguments const char* train_data_pattern = "dev/data/tinyshakespeare/tiny_shakespeare_train.bin"; const char* val_data_pattern = "dev/data/tinyshakespeare/tiny_shakespeare_val.bin"; @@ -1465,6 +1404,13 @@ int main(int argc, char *argv[]) { int recompute = 1; // recompute during backward setting, 0 = none, 1 = recompute gelu int zero_stage = 0; // Zero Optimization Stage for Multi-GPU training int hellaswag_eval = 0; + + int num_processes = 1; + int process_rank = 0; + int gpus_per_node = 8; + char *server_ip = "127.0.0.1"; + int server_port = 8090; + for (int i = 1; i < argc; i+=2) { if (i + 1 >= argc) { error_usage(); } // must have arg after flag if (argv[i][0] != '-') { error_usage(); } // must start with dash @@ -1494,8 +1440,16 @@ int main(int argc, char *argv[]) { else if (argv[i][1] == 'z') { zero_stage = atoi(argv[i+1]); } else if (argv[i][1] == 'r') { recompute = atoi(argv[i+1]); } else if (argv[i][1] == 'h') { hellaswag_eval = atoi(argv[i+1]); } + else if (argv[i][1] == 'p') { num_processes = atoi(argv[i+1]); } + else if (argv[i][1] == 'k') { process_rank = atoi(argv[i+1]); } + else if (argv[i][1] == 'i') { server_ip = argv[i+1]; } + else if (argv[i][1] == 'j') { server_port = atoi(argv[i+1]); } + else { error_usage(); } } + + multi_gpu_config = multi_gpu_config_init(process_rank, num_processes, gpus_per_node, server_ip, server_port); + // should do a bit more error checking here assert(warmup_iterations >= 0); if (output_log_dir != NULL) {