-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathloc2vec.py
86 lines (69 loc) · 3.13 KB
/
loc2vec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""
Experiment to see if we can create a loc2vec as detailed in the blogpost.
bloglink: https://www.sentiance.com/2018/05/03/venue-mapping/
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import lr_scheduler
import torch.optim as optim
from torch.utils.data.sampler import WeightedRandomSampler
from torch.utils.data import DataLoader
from torchvision import transforms
# For Mixed precision training
from apex import amp
# Set up the network and training parameters
from trainer import fit
# Strategies for selecting triplets within a minibatch
from utils import HardestNegativeTripletSelector
from utils import RandomNegativeTripletSelector, SemihardNegativeTripletSelector
from datasets import GeoTileDataset
from networks import Loc2Vec
from losses import OnlineTripletLoss
from config import IMG_SIZE, LOG_INTERVAL, N_EPOCHS, BATCH_SIZE, MARGIN, TILE_FILE
def main():
cuda = torch.cuda.is_available()
anchor_transform = transforms.Compose([
transforms.RandomAffine(degrees=90, translate=(0.25, 0.25)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.CenterCrop(128),
transforms.Resize(IMG_SIZE),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
train_transforms = transforms.Compose([
transforms.Resize(IMG_SIZE),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
# Let's use 12 while developing as it reduces the start time.
dset_train = GeoTileDataset(TILE_FILE,
transform=train_transforms,
center_transform=anchor_transform)
pd_files = dset_train.get_file_df()
weights = pd_files.frequency
train_sampler = WeightedRandomSampler(weights , len(dset_train))
# Should numworkers be 1?
kwargs = {'num_workers': 8, 'pin_memory': True} if cuda else {}
online_train_loader = DataLoader(dset_train, batch_size=BATCH_SIZE,
sampler=train_sampler,
**kwargs)
model = Loc2Vec()
if cuda:
model.cuda()
loss_fn = OnlineTripletLoss(MARGIN,
HardestNegativeTripletSelector(MARGIN),
SemihardNegativeTripletSelector(MARGIN),
RandomNegativeTripletSelector(MARGIN))
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = lr_scheduler.StepLR(optimizer, 16, gamma=0.1, last_epoch=-1)
# Mixed precision training
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
# if torch.cuda.device_count() > 1:
# print("Let's use", torch.cuda.device_count(), "GPUs!")
# model = nn.DataParallel(model)
fit(online_train_loader, online_train_loader, model, loss_fn, optimizer, scheduler,
N_EPOCHS, cuda, LOG_INTERVAL)
if __name__ == "__main__":
main()