From 49c63fc329323d702633588fd031bb100f6b6bcd Mon Sep 17 00:00:00 2001 From: nyoki-mtl Date: Sun, 18 Aug 2019 03:49:45 +0900 Subject: [PATCH] Fix dataloader --- src/dataset/apolloscape.py | 6 +- src/dataset/cityscapes.py | 2 +- src/dataset/pascal_voc.py | 4 +- src/tpu/common_utils.py | 0 src/tpu/test_utils.py | 0 src/tpu/tpu_check.py | 0 src/tpu/tpu_train_mnist.py | 153 +++++++++++++++++++++++++++++++++++++ src/train.py | 6 +- 8 files changed, 162 insertions(+), 9 deletions(-) create mode 100644 src/tpu/common_utils.py create mode 100644 src/tpu/test_utils.py create mode 100644 src/tpu/tpu_check.py create mode 100644 src/tpu/tpu_train_mnist.py diff --git a/src/dataset/apolloscape.py b/src/dataset/apolloscape.py index f9f31a6..8074a58 100644 --- a/src/dataset/apolloscape.py +++ b/src/dataset/apolloscape.py @@ -87,7 +87,7 @@ def __getitem__(self, index): resized = self.resizer(image=img, mask=lbl) img, lbl = resized['image'], resized['mask'] - if self.split != 'valid': + if self.split == 'train': augmented = self.augmenter(image=img, mask=lbl) img, lbl = augmented['image'], augmented['mask'] @@ -97,7 +97,7 @@ def __getitem__(self, index): img = self.img_transformer(img) lbl = self.lbl_transformer(lbl) - return img, lbl + return img, lbl, img_path.stem if __name__ == '__main__': @@ -110,7 +110,7 @@ def __getitem__(self, index): print(len(dataset)) for i, batched in enumerate(dataloader): - images, labels = batched + images, labels, _ = batched if i == 0: fig, axes = plt.subplots(8, 2, figsize=(10, 30)) plt.tight_layout() diff --git a/src/dataset/cityscapes.py b/src/dataset/cityscapes.py index 289afb3..a5ca32d 100644 --- a/src/dataset/cityscapes.py +++ b/src/dataset/cityscapes.py @@ -129,7 +129,7 @@ def encode_mask(self, lbl): print(len(dataset)) for i, batched in enumerate(dataloader): - images, labels = batched + images, labels, _ = batched if i == 0: fig, axes = plt.subplots(8, 2, figsize=(20, 48)) plt.tight_layout() diff --git a/src/dataset/pascal_voc.py b/src/dataset/pascal_voc.py index bff0957..200ef67 100644 --- a/src/dataset/pascal_voc.py +++ b/src/dataset/pascal_voc.py @@ -112,7 +112,7 @@ def __getitem__(self, index): img = img.transpose(2, 0, 1) img = torch.FloatTensor(img) lbl = torch.LongTensor(lbl) - return img, lbl + return img, lbl, img_path.stem if __name__ == '__main__': @@ -133,7 +133,7 @@ def __getitem__(self, index): print(len(dataset)) for i, batched in enumerate(dataloader): - images, labels = batched + images, labels, _ = batched if i == 0: fig, axes = plt.subplots(8, 2, figsize=(20, 48)) plt.tight_layout() diff --git a/src/tpu/common_utils.py b/src/tpu/common_utils.py new file mode 100644 index 0000000..e69de29 diff --git a/src/tpu/test_utils.py b/src/tpu/test_utils.py new file mode 100644 index 0000000..e69de29 diff --git a/src/tpu/tpu_check.py b/src/tpu/tpu_check.py new file mode 100644 index 0000000..e69de29 diff --git a/src/tpu/tpu_train_mnist.py b/src/tpu/tpu_train_mnist.py new file mode 100644 index 0000000..8f34ecd --- /dev/null +++ b/src/tpu/tpu_train_mnist.py @@ -0,0 +1,153 @@ +import test_utils + +FLAGS = test_utils.parse_common_options( + datadir='/tmp/mnist-data', + batch_size=128, + momentum=0.5, + lr=0.01, + target_accuracy=98.0, + num_epochs=18) + +from common_utils import TestCase, run_tests +import os +import shutil +import time +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +import torch_xla +import torch_xla_py.data_parallel as dp +import torch_xla_py.utils as xu +import torch_xla_py.xla_model as xm +import unittest + + +class MNIST(nn.Module): + + def __init__(self): + super(MNIST, self).__init__() + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.bn1 = nn.BatchNorm2d(10) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.bn2 = nn.BatchNorm2d(20) + self.fc1 = nn.Linear(320, 50) + self.fc2 = nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = self.bn1(x) + x = F.relu(F.max_pool2d(self.conv2(x), 2)) + x = self.bn2(x) + x = torch.flatten(x, 1) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + + +def train_mnist(): + torch.manual_seed(1) + + if FLAGS.fake_data: + train_loader = xu.SampleGenerator( + data=(torch.zeros(FLAGS.batch_size, 1, 28, + 28), torch.zeros(FLAGS.batch_size, + dtype=torch.int64)), + sample_count=60000 // FLAGS.batch_size) + test_loader = xu.SampleGenerator( + data=(torch.zeros(FLAGS.batch_size, 1, 28, + 28), torch.zeros(FLAGS.batch_size, + dtype=torch.int64)), + sample_count=10000 // FLAGS.batch_size) + else: + train_loader = torch.utils.data.DataLoader( + datasets.MNIST( + FLAGS.datadir, + train=True, + download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ])), + batch_size=FLAGS.batch_size, + shuffle=True, + num_workers=FLAGS.num_workers) + test_loader = torch.utils.data.DataLoader( + datasets.MNIST( + FLAGS.datadir, + train=False, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ])), + batch_size=FLAGS.batch_size, + shuffle=True, + num_workers=FLAGS.num_workers) + + devices = ( + xm.get_xla_supported_devices( + max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else []) + # Scale learning rate to num cores + lr = FLAGS.lr * max(len(devices), 1) + # Pass [] as device_ids to run using the PyTorch/CPU engine. + model_parallel = dp.DataParallel(MNIST, device_ids=devices) + + def train_loop_fn(model, loader, device, context): + loss_fn = nn.NLLLoss() + optimizer = context.getattr_or( + 'optimizer', + lambda: optim.SGD(model.parameters(), lr=lr, momentum=FLAGS.momentum)) + tracker = xm.RateTracker() + + model.train() + for x, (data, target) in loader: + optimizer.zero_grad() + output = model(data) + loss = loss_fn(output, target) + loss.backward() + xm.optimizer_step(optimizer) + tracker.add(FLAGS.batch_size) + if x % FLAGS.log_steps == 0: + print('[{}]({}) Loss={:.5f} Rate={:.2f}'.format(device, x, loss.item(), + tracker.rate())) + + def test_loop_fn(model, loader, device, context): + total_samples = 0 + correct = 0 + model.eval() + for x, (data, target) in loader: + output = model(data) + pred = output.max(1, keepdim=True)[1] + correct += pred.eq(target.view_as(pred)).sum().item() + total_samples += data.size()[0] + + print('[{}] Accuracy={:.2f}%'.format(device, + 100.0 * correct / total_samples)) + return correct / total_samples + + accuracy = 0.0 + for epoch in range(1, FLAGS.num_epochs + 1): + model_parallel(train_loop_fn, train_loader) + accuracies = model_parallel(test_loop_fn, test_loader) + accuracy = sum(accuracies) / len(accuracies) + if FLAGS.metrics_debug: + print(torch_xla._XLAC._xla_metrics_report()) + + return accuracy * 100.0 + + +class TrainMnist(TestCase): + + def tearDown(self): + super(TrainMnist, self).tearDown() + if FLAGS.tidy and os.path.isdir(FLAGS.datadir): + shutil.rmtree(FLAGS.datadir) + + def test_accurracy(self): + self.assertGreaterEqual(train_mnist(), FLAGS.target_accuracy) + + +# Run the tests. +torch.set_default_tensor_type('torch.FloatTensor') +run_tests() diff --git a/src/train.py b/src/train.py index bf38067..a484417 100644 --- a/src/train.py +++ b/src/train.py @@ -145,7 +145,7 @@ model.train() with tqdm(train_loader) as _tqdm: for batched in _tqdm: - images, labels = batched + images, labels, _ = batched if fp16: images = images.half() images, labels = images.to(device), labels.to(device) @@ -182,14 +182,14 @@ torch.save(model.state_dict(), output_dir.joinpath('model_tmp.pth')) torch.save(optimizer.state_dict(), output_dir.joinpath('opt_tmp.pth')) - if (i_epoch + 1) % 10 == 0: + if (i_epoch + 1) % 1 == 0: valid_losses = [] valid_ious = [] model.eval() with torch.no_grad(): with tqdm(valid_loader) as _tqdm: for batched in _tqdm: - images, labels = batched + images, labels, _ = batched if fp16: images = images.half() images, labels = images.to(device), labels.to(device)