Skip to content

Commit

Permalink
Fix dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
nyoki-mtl committed Aug 17, 2019
1 parent 7a30cab commit 49c63fc
Show file tree
Hide file tree
Showing 8 changed files with 162 additions and 9 deletions.
6 changes: 3 additions & 3 deletions src/dataset/apolloscape.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __getitem__(self, index):
resized = self.resizer(image=img, mask=lbl)
img, lbl = resized['image'], resized['mask']

if self.split != 'valid':
if self.split == 'train':
augmented = self.augmenter(image=img, mask=lbl)
img, lbl = augmented['image'], augmented['mask']

Expand All @@ -97,7 +97,7 @@ def __getitem__(self, index):
img = self.img_transformer(img)
lbl = self.lbl_transformer(lbl)

return img, lbl
return img, lbl, img_path.stem


if __name__ == '__main__':
Expand All @@ -110,7 +110,7 @@ def __getitem__(self, index):
print(len(dataset))

for i, batched in enumerate(dataloader):
images, labels = batched
images, labels, _ = batched
if i == 0:
fig, axes = plt.subplots(8, 2, figsize=(10, 30))
plt.tight_layout()
Expand Down
2 changes: 1 addition & 1 deletion src/dataset/cityscapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def encode_mask(self, lbl):
print(len(dataset))

for i, batched in enumerate(dataloader):
images, labels = batched
images, labels, _ = batched
if i == 0:
fig, axes = plt.subplots(8, 2, figsize=(20, 48))
plt.tight_layout()
Expand Down
4 changes: 2 additions & 2 deletions src/dataset/pascal_voc.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __getitem__(self, index):
img = img.transpose(2, 0, 1)
img = torch.FloatTensor(img)
lbl = torch.LongTensor(lbl)
return img, lbl
return img, lbl, img_path.stem


if __name__ == '__main__':
Expand All @@ -133,7 +133,7 @@ def __getitem__(self, index):
print(len(dataset))

for i, batched in enumerate(dataloader):
images, labels = batched
images, labels, _ = batched
if i == 0:
fig, axes = plt.subplots(8, 2, figsize=(20, 48))
plt.tight_layout()
Expand Down
Empty file added src/tpu/common_utils.py
Empty file.
Empty file added src/tpu/test_utils.py
Empty file.
Empty file added src/tpu/tpu_check.py
Empty file.
153 changes: 153 additions & 0 deletions src/tpu/tpu_train_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import test_utils

FLAGS = test_utils.parse_common_options(
datadir='/tmp/mnist-data',
batch_size=128,
momentum=0.5,
lr=0.01,
target_accuracy=98.0,
num_epochs=18)

from common_utils import TestCase, run_tests
import os
import shutil
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torch_xla
import torch_xla_py.data_parallel as dp
import torch_xla_py.utils as xu
import torch_xla_py.xla_model as xm
import unittest


class MNIST(nn.Module):

def __init__(self):
super(MNIST, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.bn1 = nn.BatchNorm2d(10)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.bn2 = nn.BatchNorm2d(20)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)

def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = self.bn1(x)
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = self.bn2(x)
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)


def train_mnist():
torch.manual_seed(1)

if FLAGS.fake_data:
train_loader = xu.SampleGenerator(
data=(torch.zeros(FLAGS.batch_size, 1, 28,
28), torch.zeros(FLAGS.batch_size,
dtype=torch.int64)),
sample_count=60000 // FLAGS.batch_size)
test_loader = xu.SampleGenerator(
data=(torch.zeros(FLAGS.batch_size, 1, 28,
28), torch.zeros(FLAGS.batch_size,
dtype=torch.int64)),
sample_count=10000 // FLAGS.batch_size)
else:
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
FLAGS.datadir,
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=FLAGS.batch_size,
shuffle=True,
num_workers=FLAGS.num_workers)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
FLAGS.datadir,
train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=FLAGS.batch_size,
shuffle=True,
num_workers=FLAGS.num_workers)

devices = (
xm.get_xla_supported_devices(
max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else [])
# Scale learning rate to num cores
lr = FLAGS.lr * max(len(devices), 1)
# Pass [] as device_ids to run using the PyTorch/CPU engine.
model_parallel = dp.DataParallel(MNIST, device_ids=devices)

def train_loop_fn(model, loader, device, context):
loss_fn = nn.NLLLoss()
optimizer = context.getattr_or(
'optimizer',
lambda: optim.SGD(model.parameters(), lr=lr, momentum=FLAGS.momentum))
tracker = xm.RateTracker()

model.train()
for x, (data, target) in loader:
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
xm.optimizer_step(optimizer)
tracker.add(FLAGS.batch_size)
if x % FLAGS.log_steps == 0:
print('[{}]({}) Loss={:.5f} Rate={:.2f}'.format(device, x, loss.item(),
tracker.rate()))

def test_loop_fn(model, loader, device, context):
total_samples = 0
correct = 0
model.eval()
for x, (data, target) in loader:
output = model(data)
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum().item()
total_samples += data.size()[0]

print('[{}] Accuracy={:.2f}%'.format(device,
100.0 * correct / total_samples))
return correct / total_samples

accuracy = 0.0
for epoch in range(1, FLAGS.num_epochs + 1):
model_parallel(train_loop_fn, train_loader)
accuracies = model_parallel(test_loop_fn, test_loader)
accuracy = sum(accuracies) / len(accuracies)
if FLAGS.metrics_debug:
print(torch_xla._XLAC._xla_metrics_report())

return accuracy * 100.0


class TrainMnist(TestCase):

def tearDown(self):
super(TrainMnist, self).tearDown()
if FLAGS.tidy and os.path.isdir(FLAGS.datadir):
shutil.rmtree(FLAGS.datadir)

def test_accurracy(self):
self.assertGreaterEqual(train_mnist(), FLAGS.target_accuracy)


# Run the tests.
torch.set_default_tensor_type('torch.FloatTensor')
run_tests()
6 changes: 3 additions & 3 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@
model.train()
with tqdm(train_loader) as _tqdm:
for batched in _tqdm:
images, labels = batched
images, labels, _ = batched
if fp16:
images = images.half()
images, labels = images.to(device), labels.to(device)
Expand Down Expand Up @@ -182,14 +182,14 @@
torch.save(model.state_dict(), output_dir.joinpath('model_tmp.pth'))
torch.save(optimizer.state_dict(), output_dir.joinpath('opt_tmp.pth'))

if (i_epoch + 1) % 10 == 0:
if (i_epoch + 1) % 1 == 0:
valid_losses = []
valid_ious = []
model.eval()
with torch.no_grad():
with tqdm(valid_loader) as _tqdm:
for batched in _tqdm:
images, labels = batched
images, labels, _ = batched
if fp16:
images = images.half()
images, labels = images.to(device), labels.to(device)
Expand Down

0 comments on commit 49c63fc

Please sign in to comment.