forked from XinJCheng/CSPN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathupdate_model.py
40 lines (32 loc) · 1.19 KB
/
update_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Mon Feb 5 16:19:25 2018
@author: Xinjing Cheng
@email : [email protected]
"""
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import string
# update pretrained model params according to my model params
def update_model(my_model, pretrained_dict):
my_model_dict = my_model.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in my_model_dict}
# 2. overwrite entries in the existing state dict
my_model_dict.update(pretrained_dict)
return my_model_dict
# dont know why my offline saved model has 'module.' in front of all key name
def remove_moudle(remove_dict):
for k, v in remove_dict.items():
if string.find(k, 'module')!=-1:
print("==> model dict with addtional module, remove it...")
removed_dict = { k[7:]: v for k, v in remove_dict.items()}
else:
removed_dict = remove_dict
break
return removed_dict
def update_conv_spn_model(out_dict, in_dict):
in_dict = {k: v for k, v in in_dict.items() if k in out_dict}
return in_dict