-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_reception_filed.py
77 lines (50 loc) · 1.6 KB
/
test_reception_filed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#%%
from lib.arch.unet import build_unet_model
from config import load_cfg
import torch
import argparse
import sys
import numpy as np
from torchsummary import summary
activation = {}
def get_activation(name):
def hook(model, input, output):
#check for whether registered at last layer of classifier
activation[name] = output.detach()
return hook
cfg = load_cfg("config/brain_region_unet.yaml")
device = 'cuda'
model = build_unet_model(cfg)
model.to(device)
model.eval()
summary(model,(1,128,128,128))
print(model)
test_input = torch.zeros(size=(1,1,128,128,128))
test_input[:,:,64,64,64] = 1
test_input = test_input.to(device)
#register forward hook at the destination layer
extract_layer_name ='conv_in'
hook1 = model.down_layers[-1][2].act.register_forward_hook(get_activation(extract_layer_name))
out = model(test_input)
#check the shape of feats acquired by forward hook
feats = activation[extract_layer_name].cpu().detach().numpy()
feats=np.squeeze(feats)
#visulize the activation map via summary operation like std,mean
feats_std=np.std(feats,axis=0)
print(f"shape of feats is {feats.shape}")
feats_2d_midz=feats_std[int(feats_std.shape[0]//2),:,:]
feats_2d_midy=feats_std[:,int(feats_std.shape[0]//2),:]
feats_2d_midx=feats_std[:,:,int(feats_std.shape[0]//2)]
import matplotlib.pyplot as plt
plt.figure(figsize=(27,9))
# fig,axes=plt.subplots(1,3)
# axes[0].imshow(feats_2d_midz)
# axes[0].set_title('midz')
# axes[1].imshow(feats_2d_midy)
# axes[1].set_title('midy')
# axes[2].imshow(feats_2d_midx)
# axes[2].set_title('midx')
plt.imshow(feats_2d_midx)
hook1.remove()
# %%
# %%