-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsiamese_dataset.py
49 lines (40 loc) · 1.66 KB
/
siamese_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from torch.utils.data import Dataset
import random
import torch
import numpy as np
import PIL
from PIL import Image
class SiameseNetworkDataset(Dataset):
def __init__(self, imageFolderDataset, transform=None, should_invert=True):
self.imageFolderDataset = imageFolderDataset
self.transform = transform
self.should_invert = should_invert
def __getitem__(self, index):
img0_tuple = random.choice(self.imageFolderDataset.imgs)
# we need to make sure approx 50% of images are in the same class
should_get_same_class = random.randint(0, 1)
if should_get_same_class:
while True:
# keep looping till the same class image is found
img1_tuple = random.choice(self.imageFolderDataset.imgs)
if img0_tuple[1] == img1_tuple[1]:
break
else:
while True:
# keep looping till a different class image is found
img1_tuple = random.choice(self.imageFolderDataset.imgs)
if img0_tuple[1] != img1_tuple[1]:
break
img0 = Image.open(img0_tuple[0])
img1 = Image.open(img1_tuple[0])
img0 = img0.convert("L")
img1 = img1.convert("L")
if self.should_invert:
img0 = PIL.ImageOps.invert(img0)
img1 = PIL.ImageOps.invert(img1)
if self.transform is not None:
img0 = self.transform(img0)
img1 = self.transform(img1)
return img0, img1, torch.from_numpy(np.array([int(img1_tuple[1] != img0_tuple[1])], dtype=np.float32))
def __len__(self):
return len(self.imageFolderDataset.imgs)