-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathevaluate_chestxray_stylegan.py
47 lines (38 loc) · 1.32 KB
/
evaluate_chestxray_stylegan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import sys; sys.path.append('../stylegan2-ada-pytorch')
import legacy
import dnnlib
import os
import torch
import argparse
import torchvision.utils as vutils
from evaluate_samples_chestxray import load_class_balanced_real_data
from fid import run_fid
device = 'cuda:0'
network = '/scratch/hdd001/home/wangkuan/stylegan/run_scripts/May12-train-gan-chestxray.sh-1/auto/00003-chestxray-aux-auto2-resumefromprev/network-snapshot-000121.pkl'
# StyleGAN
print('Loading networks from "%s"...' % network)
with dnnlib.util.open_url(network) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
noise_mode = 'const'
with torch.no_grad():
z_nuisance = torch.randn(50, G.z_dim).to(device).double()
w_nuisance = G.mapping(z_nuisance, None)
x = G.synthesis(w_nuisance, noise_mode=noise_mode)
vutils.save_image(x, 'cdb.jpeg', normalize=True)
# FID
# Load Data
target_x, _ = load_class_balanced_real_data()
target_x = target_x.repeat(1, 3, 1, 1)
# Samples
fake = []
for _ in range(30):
with torch.no_grad():
z_nuisance = torch.randn(100, G.z_dim).to(device).double()
w_nuisance = G.mapping(z_nuisance, None)
x = G.synthesis(w_nuisance, noise_mode=noise_mode)
fake.append(x.cpu())
fake = torch.cat(fake)
fake = fake.repeat(1, 3, 1, 1)
# FID
fid = run_fid(target_x, fake)
print(f'FID: {fid}')