-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
26 lines (17 loc) · 814 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from src.components.backbones.ResNet import ResNet18
from src.components.necks.ChannelReduction import ChannelReducer
from src.utils.common import count_parameters
import torch
if __name__ == "__main__":
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ResNet18()
print("num params: ", count_parameters(model))
input = torch.randn((1, 3, 512, 512))
fr_4, fr_8, fr_16, fr_32 = model(input)
channelReducer = ChannelReducer(target_channels=128)
fr_4, fr_8, fr_16, fr_32 = channelReducer(fr_4, fr_8, fr_16, fr_32)
print(f"==>> fr_4.shape: {fr_4.shape}")
print(f"==>> fr_8.shape: {fr_8.shape}")
print(f"==>> fr_16.shape: {fr_16.shape}")
print(f"==>> fr_32.shape: {fr_32.shape}")