-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathindex.py
65 lines (55 loc) · 2.13 KB
/
index.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from datautils.datareader import read_data
from datautils.dataset import COD10KDataset
from torch.utils.data import DataLoader
import random
import numpy as np
import torch
import argparse
from tqdm import tqdm
import warnings
from experiments.style_transfer import style_transfer
try:
from experiments.synthetic import run_synthetic_pipeline
except ModuleNotFoundError:
warnings.warn("Unable to import synthetic data generation! Check package diffusers if using this")
try:
from experiments.sam_baseline import run_sam_pipeline
except ModuleNotFoundError:
warnings.warn("Unable to import sam baseline! Check package groundingdino and segment_anything if using this")
def run_style_transfer_pipeline(args):
pos_data_paths = read_data('Train')
dataset = COD10KDataset(pos_data_paths)
dataloader = DataLoader(dataset, batch_size = args.batch_size)
for i_batch, batch in enumerate(dataloader):
style_transfer(batch['img'],
(1, 4, 6, 7),
3,
6e-2,
(2000, 512, 12, 1),
6e-2,
args)
if i_batch == 0:
break
if __name__ == "__main__":
seed = 123
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
parser = argparse.ArgumentParser()
parser.add_argument('--mode', type=str, default="style_transfer")
parser.add_argument('--device', type=str, default="available")
parser.add_argument('--batch_size', type = int, default = 1)
parser.add_argument('--model_name', type=str, default='squeezenet')
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--max_iter', type=int, default=200)
parser.add_argument('--resize_size', type=int, default=256)
parser.add_argument('--synthetic_path', type=str, default="generated")
args = parser.parse_args()
if args.mode == "style_transfer":
run_style_transfer_pipeline(args)
elif args.mode == "synthetic":
run_synthetic_pipeline(args)
elif args.mode == "sam":
run_sam_pipeline(args)