-
Notifications
You must be signed in to change notification settings - Fork 257
/
Copy pathtrain.py
132 lines (107 loc) · 5.55 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import torch
import torch.nn.functional as F
from tqdm import tqdm
import logging
from torchmeta.datasets.helpers import omniglot
from torchmeta.utils.data import BatchMetaDataLoader
from torchmeta.utils.gradient_based import gradient_update_parameters
from model import ConvolutionalNeuralNetwork
from utils import get_accuracy
logger = logging.getLogger(__name__)
def train(args):
logger.warning('This script is an example to showcase the MetaModule and '
'data-loading features of Torchmeta, and as such has been '
'very lightly tested. For a better tested implementation of '
'Model-Agnostic Meta-Learning (MAML) using Torchmeta with '
'more features (including multi-step adaptation and '
'different datasets), please check `https://github.com/'
'tristandeleu/pytorch-maml`.')
dataset = omniglot(args.folder,
shots=args.num_shots,
ways=args.num_ways,
shuffle=True,
test_shots=15,
meta_train=True,
download=args.download)
dataloader = BatchMetaDataLoader(dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers)
model = ConvolutionalNeuralNetwork(1,
args.num_ways,
hidden_size=args.hidden_size)
model.to(device=args.device)
model.train()
meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Training loop
with tqdm(dataloader, total=args.num_batches) as pbar:
for batch_idx, batch in enumerate(pbar):
model.zero_grad()
train_inputs, train_targets = batch['train']
train_inputs = train_inputs.to(device=args.device)
train_targets = train_targets.to(device=args.device)
test_inputs, test_targets = batch['test']
test_inputs = test_inputs.to(device=args.device)
test_targets = test_targets.to(device=args.device)
outer_loss = torch.tensor(0., device=args.device)
accuracy = torch.tensor(0., device=args.device)
for task_idx, (train_input, train_target, test_input,
test_target) in enumerate(zip(train_inputs, train_targets,
test_inputs, test_targets)):
train_logit = model(train_input)
inner_loss = F.cross_entropy(train_logit, train_target)
model.zero_grad()
params = gradient_update_parameters(model,
inner_loss,
step_size=args.step_size,
first_order=args.first_order)
test_logit = model(test_input, params=params)
outer_loss += F.cross_entropy(test_logit, test_target)
with torch.no_grad():
accuracy += get_accuracy(test_logit, test_target)
outer_loss.div_(args.batch_size)
accuracy.div_(args.batch_size)
outer_loss.backward()
meta_optimizer.step()
pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))
if batch_idx >= args.num_batches:
break
# Save model
if args.output_folder is not None:
filename = os.path.join(args.output_folder, 'maml_omniglot_'
'{0}shot_{1}way.th'.format(args.num_shots, args.num_ways))
with open(filename, 'wb') as f:
state_dict = model.state_dict()
torch.save(state_dict, f)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser('Model-Agnostic Meta-Learning (MAML)')
parser.add_argument('folder', type=str,
help='Path to the folder the data is downloaded to.')
parser.add_argument('--num-shots', type=int, default=5,
help='Number of examples per class (k in "k-shot", default: 5).')
parser.add_argument('--num-ways', type=int, default=5,
help='Number of classes per task (N in "N-way", default: 5).')
parser.add_argument('--first-order', action='store_true',
help='Use the first-order approximation of MAML.')
parser.add_argument('--step-size', type=float, default=0.4,
help='Step-size for the gradient step for adaptation (default: 0.4).')
parser.add_argument('--hidden-size', type=int, default=64,
help='Number of channels for each convolutional layer (default: 64).')
parser.add_argument('--output-folder', type=str, default=None,
help='Path to the output folder for saving the model (optional).')
parser.add_argument('--batch-size', type=int, default=16,
help='Number of tasks in a mini-batch of tasks (default: 16).')
parser.add_argument('--num-batches', type=int, default=100,
help='Number of batches the model is trained over (default: 100).')
parser.add_argument('--num-workers', type=int, default=1,
help='Number of workers for data loading (default: 1).')
parser.add_argument('--download', action='store_true',
help='Download the Omniglot dataset in the data folder.')
parser.add_argument('--use-cuda', action='store_true',
help='Use CUDA if available.')
args = parser.parse_args()
args.device = torch.device('cuda' if args.use_cuda
and torch.cuda.is_available() else 'cpu')
train(args)