From 1a6c6d93c19995dbd538c87beddea297c64b834c Mon Sep 17 00:00:00 2001 From: Ben Sussman Date: Fri, 31 Jul 2020 18:00:38 -0400 Subject: [PATCH 1/2] Add preload mode, add comments --- models/4_pytorch_distributed_horovod.py | 120 ++++++++++++++++-------- 1 file changed, 81 insertions(+), 39 deletions(-) diff --git a/models/4_pytorch_distributed_horovod.py b/models/4_pytorch_distributed_horovod.py index 5206a05..236dc01 100644 --- a/models/4_pytorch_distributed_horovod.py +++ b/models/4_pytorch_distributed_horovod.py @@ -1,3 +1,4 @@ +import argparse import torch import torchvision import numpy as np @@ -15,61 +16,77 @@ class PascalVOCSegmentationDataset(Dataset): def __init__(self, raw): super().__init__() self._dataset = raw - self.resize_img = torchvision.transforms.Resize((256, 256), interpolation=PIL.Image.BILINEAR) - self.resize_segmap = torchvision.transforms.Resize((256, 256), interpolation=PIL.Image.NEAREST) - + self.resize_img = torchvision.transforms.Resize( + (256, 256), interpolation=PIL.Image.BILINEAR + ) + self.resize_segmap = torchvision.transforms.Resize( + (256, 256), interpolation=PIL.Image.NEAREST + ) + def __len__(self): return len(self._dataset) - + def __getitem__(self, idx): img, segmap = self._dataset[idx] img, segmap = self.resize_img(img), self.resize_segmap(segmap) img, segmap = np.array(img), np.array(segmap) - img, segmap = (img / 255).astype('float32'), segmap.astype('int32') + img, segmap = (img / 255).astype("float32"), segmap.astype("int32") img = np.transpose(img, (-1, 0, 1)) # The PASCAL VOC dataset PyTorch provides labels the edges surrounding classes in 255-valued # pixels in the segmentation map. However, PyTorch requires class values to be contiguous # in range 0 through n_classes, so we must relabel these pixels to 21. segmap[segmap == 255] = 21 - + return img, segmap -def get_dataloader(): - _PascalVOCSegmentationDataset = torchvision.datasets.VOCSegmentation( - '/mnt/pascal_voc_segmentation/', year='2012', image_set='train', download=True, - transform=None, target_transform=None, transforms=None + +def download_dataloader(shouldDownload=False): + return torchvision.datasets.VOCSegmentation( + "/mnt/pascal_voc_segmentation/", + year="2012", + image_set="train", + download=shouldDownload, + transform=None, + target_transform=None, + transforms=None, ) - dataset = PascalVOCSegmentationDataset(_PascalVOCSegmentationDataset) + + +def download_model(): + return torchvision.models.segmentation.deeplabv3_resnet101( + pretrained=False, progress=True, num_classes=22, aux_loss=None + ) + + +def get_dataloader(): + dataset = PascalVOCSegmentationDataset(download_dataloader()) # NEW # Distributed sampler. sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=hvd.size(), rank=hvd.rank() ) - dataloader = DataLoader( - dataset, batch_size=8, shuffle=False, sampler=sampler - ) - + dataloader = DataLoader(dataset, batch_size=8, shuffle=False, sampler=sampler) + return dataloader, sampler + def get_model(): # num_classes is 22. PASCAL VOC includes 20 classes of interest, 1 background class, and the 1 # special border class mentioned in the previous comment. 20 + 1 + 1 = 22. - DeepLabV3 = torchvision.models.segmentation.deeplabv3_resnet101( - pretrained=False, progress=True, num_classes=22, aux_loss=None - ) - model = DeepLabV3 + model = download_model() model.cuda() model.train() - + return model + def train(NUM_EPOCHS): for epoch in range(1, NUM_EPOCHS + 1): # NEW: # set epoch to sampler for shuffling. sampler.set_epoch(epoch) - + losses = [] for i, (batch, segmap) in enumerate(dataloader): @@ -78,7 +95,7 @@ def train(NUM_EPOCHS): batch = batch.cuda() segmap = segmap.cuda() - output = model(batch)['out'] + output = model(batch)["out"] loss = criterion(output, segmap.type(torch.int64)) loss.backward() optimizer.step() @@ -90,7 +107,7 @@ def train(NUM_EPOCHS): # ) if hvd.rank() == 0: - writer.add_scalar('training loss', curr_loss) + writer.add_scalar("training loss", curr_loss) losses.append(curr_loss) # print( @@ -98,55 +115,80 @@ def train(NUM_EPOCHS): # f'avg loss: {np.mean(losses)}; median loss: {np.min(losses)}' # ) if hvd.rank() == 0 and epoch % 5 == 0: - if not os.path.exists('/spell/checkpoints/'): - os.mkdir('/spell/checkpoints/') - torch.save(model.state_dict(), f'/spell/checkpoints/model_{epoch}.pth') + if not os.path.exists("/spell/checkpoints/"): + os.mkdir("/spell/checkpoints/") + torch.save(model.state_dict(), f"/spell/checkpoints/model_{epoch}.pth") if hvd.rank() == 0: - torch.save(model.state_dict(), f'/spell/checkpoints/model_final.pth') + torch.save(model.state_dict(), f"/spell/checkpoints/model_final.pth") + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="Pytorch Distributed Horovod") + parser.add_argument("--preload-mode", action="store_true") + parser.set_defaults(feature=True) + args = parser.parse_args() + + if args.preload_mode: + print("PRELOAD MODE!") + download_dataloader(shouldDownload=True) + download_model() + exit() - -if __name__ == '__main__': # NEW: # Init horovod + print("TRAINING MODE! INIT HOROVOD") hvd.init() + print("SET DEVICE + THREADS") torch.cuda.set_device(hvd.local_rank()) torch.set_num_threads(1) - - writer = SummaryWriter(f'/spell/tensorboards/model_4') + + writer = SummaryWriter(f"/spell/tensorboards/model_4") # since the background class doesn't matter nearly as much as the classes of interest to the # overall task a more selective loss would be more appropriate, however this training script # is merely a benchmark so we'll just use simple cross-entropy loss + print("criterion = nn.CrossEntropyLoss()") criterion = nn.CrossEntropyLoss() - + # NEW: # Download the data on only one thread. Have the rest wait until the download finishes. if hvd.local_rank() == 0: + print("hvd.local_rank() == 0; calling get_model()") get_model() + print("hvd.local_rank() == 0; calling get_dataloader()") get_dataloader() + print("hvd.join()") hvd.join() print(f"Rank {hvd.rank() + 1}/{hvd.size()} process cleared download barrier.") - + + print("MAIN: model = get_model()") model = get_model() + print("MAIN: dataloader, sampler = get_dataloader()") dataloader, sampler = get_dataloader() - + # NEW: # Scale learning learning rate by size. + print("optimizer = Adam") optimizer = Adam(model.parameters(), lr=1e-3 * hvd.size()) # New: # Broadcast parameters & optimizer state. + print("hvd.broadcast.*") hvd.broadcast_parameters(model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer, root_rank=0) # NEW: # (optional) Free-ish compression (reduces over-the-wire size -> increases speed). + print("compression = hvd.Compression.fp16") compression = hvd.Compression.fp16 - + # NEW: # Wrap optimizer with DistributedOptimizer. - optimizer = hvd.DistributedOptimizer(optimizer, - named_parameters=model.named_parameters(), - compression=compression, - op=hvd.Average) + optimizer = hvd.DistributedOptimizer( + optimizer, + named_parameters=model.named_parameters(), + compression=compression, + op=hvd.Average, + ) train(20) From 1234f5c8a5c8810b80855acbac3d38f6eacf2398 Mon Sep 17 00:00:00 2001 From: Ben Sussman Date: Fri, 31 Jul 2020 18:19:36 -0400 Subject: [PATCH 2/2] try mnt instead of /mnt --- models/4_pytorch_distributed_horovod.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/4_pytorch_distributed_horovod.py b/models/4_pytorch_distributed_horovod.py index 236dc01..31115f2 100644 --- a/models/4_pytorch_distributed_horovod.py +++ b/models/4_pytorch_distributed_horovod.py @@ -43,7 +43,7 @@ def __getitem__(self, idx): def download_dataloader(shouldDownload=False): return torchvision.datasets.VOCSegmentation( - "/mnt/pascal_voc_segmentation/", + "mnt/pascal_voc_segmentation/", year="2012", image_set="train", download=shouldDownload,