-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDemo_getmodesize.py
82 lines (58 loc) · 2.15 KB
/
Demo_getmodesize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch, os, cv2
import torchvision
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import math, shutil, copy
from model import STRNN_final
def getModelSize(model):
param_size = 0
param_sum = 0
param_trainable = 0
for param in model.parameters():
param_size += param.nelement() * param.element_size()
param_sum += param.nelement()
if param.requires_grad:
param_trainable += param.nelement() * param.element_size()
buffer_size = 0
buffer_sum = 0
for buffer in model.buffers():
buffer_size += buffer.nelement() * buffer.element_size()
buffer_sum += buffer.nelement()
param_mb = param_size / 1024 / 1024
buffer_mb = buffer_size / 1024 / 1024
all_mb = (param_size + buffer_size) / 1024 / 1024
# train_p_mb = param_trainable / 1024 / 1024
print('param size:{:.2f} MB'.format(param_mb))
print('buffer size:{:.2f} MB'.format(buffer_mb))
print('total size:{:.2f} MB'.format(all_mb))
# print('trainable params size:{:.2f} MB'.format(train_p_mb))
return param_mb#,buffer_mb,all_mb
if __name__ == '__main__':
other_size = 0
model = STRNN_final()
print('\nmodel')
model_size = getModelSize(model)
print('\nSt-Net')
st_size = getModelSize(model.feat_sm)
print('\nOF-Net')
of_size = getModelSize(model.feat_of)
print('\nSR-Fu')
srfu_size = getModelSize(model.att_channel)
print('\nfeat_fu')
other_size += getModelSize(model.feat_fu)
print('\nAA-LSTM')
aalstm_size = getModelSize(model.att_lstm)
print('\nconv_out')
other_size +=getModelSize(model.conv_out)
print('\n\n-----Params Size-----')
print('Total STRNN: %.2f MB' % model_size)
print('St-Net: %.2f MB' % st_size)
print('OF-Net: %.2f MB' % of_size)
print('SR-Fu: %.2f MB' % srfu_size)
print('AA-LSTM: %.2f MB' % aalstm_size)
print('Other: %.2f MB' % other_size)
print('diff: %.2f MB' % (model_size-st_size-of_size-srfu_size-aalstm_size-other_size))
print('\nTotal STRNN (param+buffer)): 361.07 MB')
print('Saved STRNN model file size: 361.26 MB')
print('done')