Skip to content

Commit

Permalink
replacing mpi functionality in train_gpt2.cu using distributed iface
Browse files Browse the repository at this point in the history
  • Loading branch information
Chinthaka Gamanayakege committed Jun 10, 2024
1 parent 195b9e1 commit af56fa6
Showing 1 changed file with 26 additions and 72 deletions.
98 changes: 26 additions & 72 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage.
#include "llmc/global_norm.cuh"
// ----------- Multi-GPU support -----------
#ifdef MULTI_GPU
#include <mpi.h>
#include <nccl.h>
#include "llmc/distributed.h"
#endif

// ----------------------------------------------------------------------------
Expand Down Expand Up @@ -92,17 +91,6 @@ void nccl_check(ncclResult_t status, const char *file, int line) {
}
#define ncclCheck(err) (nccl_check(err, __FILE__, __LINE__))

void mpi_check(int status, const char *file, int line) {
if (status != MPI_SUCCESS) {
char mpi_error[4096];
int mpi_error_len = 0;
assert(MPI_Error_string(status, &mpi_error[0], &mpi_error_len) == MPI_SUCCESS);
printf("[MPI ERROR] at file %s:%d:\n%.*s\n", file, line, mpi_error_len, mpi_error);
exit(EXIT_FAILURE);
}
}
#define mpiCheck(err) (mpi_check(err, __FILE__, __LINE__))

#endif // MULTI_GPU

// ----------------------------------------------------------------------------
Expand Down Expand Up @@ -130,61 +118,21 @@ typedef struct {
// one global variable to hold the multi-GPU configuration for this process
MultiGpuConfig multi_gpu_config;

MultiGpuConfig multi_gpu_config_init(int process_rank, int num_processes, int gpus_per_node,const char *server_ip, int server_port) {
#ifdef MULTI_GPU
// Determine which GPU this process should use.
// Processes on the same machines use different GPU indicies. Processes on other machines don't.
// Copied from NCCL examples: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/examples.html#example-2-one-device-per-process-or-thread
int multi_gpu_get_local_device_idx(int process_rank, int num_processes) {
char hostname[1024];
hostname[1023] = '\0';
// All processes on the same machine will share the same hostname.
gethostname(hostname, 1023);
for (int i=0; i < 1024; i++) {
if (hostname[i] == '.') {
hostname[i] = '\0';
break;
}
}
uint64_t hostname_hash = 5381;
for (int c = 0; hostname[c] != '\0'; c++){ hostname_hash = ((hostname_hash << 5) + hostname_hash) ^ hostname[c]; }

// Distribute all hostname hashes to all processes.
uint64_t* all_hostsname_hashes = (uint64_t*)malloc(num_processes * sizeof(uint64_t));
all_hostsname_hashes[process_rank] = hostname_hash;
mpiCheck(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, all_hostsname_hashes, sizeof(uint64_t), MPI_BYTE, MPI_COMM_WORLD));

// Identify which GPU we need to use.
int local_device_idx = 0;
for (int current_process = 0; current_process < num_processes; ++current_process) {
if (current_process == process_rank) {
// Found my gpu, local_device_idx now has my target GPU index.
break;
}
if (all_hostsname_hashes[current_process] == all_hostsname_hashes[process_rank]) {
// This process ID runs on the same machine, but it's not me, skip this GPU
local_device_idx++;
}
}

free(all_hostsname_hashes);
return local_device_idx;
}
#endif
distributed_init(num_processes, process_rank, server_ip, server_port);

MultiGpuConfig multi_gpu_config_init(int *argc, char ***argv) {
#ifdef MULTI_GPU
// Initialize MPI.
MultiGpuConfig result;
mpiCheck(MPI_Init(argc, argv));
mpiCheck(MPI_Comm_rank(MPI_COMM_WORLD, &result.process_rank));
mpiCheck(MPI_Comm_size(MPI_COMM_WORLD, &result.num_processes));
result.local_device_idx = multi_gpu_get_local_device_idx(result.process_rank, result.num_processes);
result.process_rank = process_rank;
result.num_processes = num_processes;
result.local_device_idx = process_rank % gpus_per_node;

cudaCheck(cudaSetDevice(result.local_device_idx));
ncclUniqueId nccl_id;
if (result.process_rank == 0) {
ncclCheck(ncclGetUniqueId(&nccl_id));
}
mpiCheck(MPI_Bcast((void *)&nccl_id, sizeof(nccl_id), MPI_BYTE, 0, MPI_COMM_WORLD));
distributed_broadcast(&nccl_id);
ncclCheck(ncclCommInitRank(&result.nccl_comm, result.num_processes, nccl_id, result.process_rank));
return result;
#else
Expand All @@ -201,15 +149,12 @@ MultiGpuConfig multi_gpu_config_init(int *argc, char ***argv) {
void multi_gpu_config_free(const MultiGpuConfig* multi_gpu_config) {
#ifdef MULTI_GPU
ncclCheck(ncclCommDestroy(multi_gpu_config->nccl_comm));
mpiCheck(MPI_Finalize());
#endif
}

void multi_gpu_barrier(const MultiGpuConfig* multi_gpu_config) {
#ifdef MULTI_GPU
if (multi_gpu_config->num_processes > 1) {
mpiCheck(MPI_Barrier(MPI_COMM_WORLD));
}
distributed_barrier();
#endif
}

Expand Down Expand Up @@ -1062,13 +1007,9 @@ void gpt2_backward(GPT2 *model, int* inputs) {
// Compute sum of a single CPU value across all GPU processes. No-op when multi-GPU is disabled.
float multi_gpu_cpu_float_sum(float value) {
#ifdef MULTI_GPU
// note MPI doesn't support all reduce with mean, only sum
float result;
mpiCheck(MPI_Allreduce(&value, &result, 1, MPI_FLOAT, MPI_SUM, MPI_COMM_WORLD));
return result;
#else
return value;
distributed_reduce(&value);
#endif
return value;
}

// Averages out the loss and gradients across all GPUs. No-op when multi-GPU is disabled.
Expand Down Expand Up @@ -1438,8 +1379,6 @@ void error_usage() {
// ----------------------------------------------------------------------------
// main training loop
int main(int argc, char *argv[]) {
multi_gpu_config = multi_gpu_config_init(&argc, &argv);

// read in the (optional) command line arguments
const char* train_data_pattern = "dev/data/tinyshakespeare/tiny_shakespeare_train.bin";
const char* val_data_pattern = "dev/data/tinyshakespeare/tiny_shakespeare_val.bin";
Expand All @@ -1465,6 +1404,13 @@ int main(int argc, char *argv[]) {
int recompute = 1; // recompute during backward setting, 0 = none, 1 = recompute gelu
int zero_stage = 0; // Zero Optimization Stage for Multi-GPU training
int hellaswag_eval = 0;

int num_processes = 1;
int process_rank = 0;
int gpus_per_node = 8;
char *server_ip = "127.0.0.1";
int server_port = 8090;

for (int i = 1; i < argc; i+=2) {
if (i + 1 >= argc) { error_usage(); } // must have arg after flag
if (argv[i][0] != '-') { error_usage(); } // must start with dash
Expand Down Expand Up @@ -1494,8 +1440,16 @@ int main(int argc, char *argv[]) {
else if (argv[i][1] == 'z') { zero_stage = atoi(argv[i+1]); }
else if (argv[i][1] == 'r') { recompute = atoi(argv[i+1]); }
else if (argv[i][1] == 'h') { hellaswag_eval = atoi(argv[i+1]); }
else if (argv[i][1] == 'p') { num_processes = atoi(argv[i+1]); }
else if (argv[i][1] == 'k') { process_rank = atoi(argv[i+1]); }
else if (argv[i][1] == 'i') { server_ip = argv[i+1]; }
else if (argv[i][1] == 'j') { server_port = atoi(argv[i+1]); }

else { error_usage(); }
}

multi_gpu_config = multi_gpu_config_init(process_rank, num_processes, gpus_per_node, server_ip, server_port);

// should do a bit more error checking here
assert(warmup_iterations >= 0);
if (output_log_dir != NULL) {
Expand Down

0 comments on commit af56fa6

Please sign in to comment.