-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
43 lines (38 loc) · 1.36 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import copy
import os
import shutil
import numpy as np
import torch
# To store 50 generated image in a pool and sample from it when it is full
# Shrivastava et al’s strategy
class Sample_from_Pool(object):
def __init__(self, max_elements=50):
self.max_elements = max_elements
self.cur_elements = 0
self.items = []
def __call__(self, in_items):
return_items = []
for in_item in in_items:
if self.cur_elements < self.max_elements:
self.items.append(in_item)
self.cur_elements = self.cur_elements + 1
return_items.append(in_item)
else:
if np.random.ranf() > 0.5:
idx = np.random.randint(0, self.max_elements)
tmp = copy.copy(self.items[idx])
self.items[idx] = in_item
return_items.append(tmp)
else:
return_items.append(in_item)
return return_items
def set_grad(nets, requires_grad=False):
for net in nets:
for param in net.parameters():
param.requires_grad = requires_grad
def cuda(xs, device):
if torch.cuda.is_available():
if not isinstance(xs, (list, tuple)):
return xs.to(device)
else:
return [x.to(device) for x in xs]