diff --git a/torchtest/torchtest.py b/torchtest/torchtest.py index f0ee9ef..3185095 100644 --- a/torchtest/torchtest.py +++ b/torchtest/torchtest.py @@ -36,6 +36,38 @@ def setup(seed=0): """Set random seed for torch""" torch.manual_seed(seed) +def _pack_batch(x, device): + """ Packages object ``x`` into a tuple to be unpacked. + + Recursively transfers all tensor objects to device + + Parameters + ---------- + x : torch.Tensor or tuple containing torch.Tensor + device : str + + Returns + ------- + tuple + positional arguments + """ + + def _helper(x): + if isinstance(x, torch.Tensor): + x = x.to(device) + return x + + output = [_helper(item) for item in x] + return output + + + if isinstance(x, torch.Tensor): + # For backwards compatability + return (x.to(device),) + else: + return _helper(x) + + def _train_step(model, loss_fn, optim, batch, device): """Run a training step on model for a given batch of data @@ -62,14 +94,14 @@ def _train_step(model, loss_fn, optim, batch, device): # clear gradient optim.zero_grad() # inputs and targets - inputs, targets = batch[0], batch[1] + inputs, targets = batch[0], batch[1] # Need to recursively move these to device # move data to DEVICE - inputs = inputs.to(device) - targets = targets.to(device) + inputs = _pack_batch(inputs, device) + targets = _pack_batch(targets, device) # forward - likelihood = model(inputs) + likelihood = model(*inputs) # calc loss - loss = loss_fn(likelihood, targets) + loss = loss_fn(likelihood, *targets) # backward loss.backward() # optimization step