-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathdata_loader.py
94 lines (75 loc) · 3.09 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import torch
from torchvision import datasets
class ImageNetCategory():
"""
For ImageNet-like directory structures without sessions/conditions:
.../{category}/{img_name}
"""
def __init__(self):
pass
def __call__(self, full_path):
img_name = full_path.split("/")[-1]
category = full_path.split("/")[-2]
return category
class ImageNetDataset(datasets.ImageFolder):
"""Custom dataset that includes image file paths. Extends
torchvision.datasets.ImageFolder
"""
def __init__(self, *args, **kwargs):
super(ImageNetDataset, self).__init__(*args, **kwargs)
# override the __getitem__ method. this is the method that dataloader calls
def __getitem__(self, index):
# this is what ImageFolder normally returns
sample, target = super(ImageNetDataset, self).__getitem__(index)
# the image file path
path = self.imgs[index][0]
new_target = ImageNetCategory()(path)
original_tuple = (sample, new_target)
# make a new tuple that includes original and the path
tuple_with_path = (original_tuple + (path,))
return tuple_with_path
class ImageNetClipDataset(datasets.ImageFolder):
"""Custom dataset that includes image file paths. Extends
torchvision.datasets.ImageFolder
Adapted from:
https://gist.github.com/andrewjong/6b02ff237533b3b2c554701fb53d5c4d
"""
SOFT_LABELS = "soft_labels"
HARD_LABELS = "hard_labels"
def __init__(self, label_type, mappings, *args, **kwargs):
self.label_type = label_type
self.clip_class_mapping = mappings
super(ImageNetClipDataset, self).__init__(*args, **kwargs)
def _get_new_template_hard_labels(self, image_path):
file_name = os.path.basename(image_path)
target_class = self.clip_class_mapping[file_name]
target_index = self.class_to_idx[target_class]
return target_index
def _get_new_template_soft_labels(self, image_path):
file_name = os.path.basename(image_path)
target_class = self.clip_class_mapping[file_name]
return target_class
def __getitem__(self, index):
"""override the __getitem__ method. This is the method that dataloader calls."""
# this is what ImageFolder normally returns
(sample, target) = super(ImageNetClipDataset, self).__getitem__(index)
# the image file path
path = self.imgs[index][0]
if self.label_type == ImageNetClipDataset.HARD_LABELS:
new_target = self._get_new_template_hard_labels(path)
elif self.label_type == ImageNetClipDataset.SOFT_LABELS:
new_target = self._get_new_template_soft_labels(path)
else:
new_target = target
original_tuple = (sample, new_target,)
return original_tuple
def data_loader(transform, args):
imagenet_data = ImageNetDataset(args.data_dir, transform)
data_loader = torch.utils.data.DataLoader(
imagenet_data,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers
)
return data_loader