-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathchangepth.py
44 lines (30 loc) · 1.44 KB
/
changepth.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import numpy as np
import torch
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
device = torch.device("cuda")
print("CUDA visible devices: " + str(torch.cuda.device_count()))
print("CUDA Device Name: " + str(torch.cuda.get_device_name(device)))
# # student model
# model_type = 'vit_b'
# checkpoint = None
# student_model = sam_model_registry[model_type](checkpoint=checkpoint)
# print(student_model.image_encoder)
state_dict = torch.load('checkpoints/sam_vit_b_0b3195.pth')
for i in range(0,24):
originkey = "image_encoder.blocks."+str(i)+".attn.qkv.weight"
state_dict["image_encoder.blocks."+str(i)+".attn.q.weight"] = state_dict[originkey][0:1024]
state_dict["image_encoder.blocks."+str(i)+".attn.k.weight"] = state_dict[originkey][1024:2048]
state_dict["image_encoder.blocks."+str(i)+".attn.v.weight"] = state_dict[originkey][2048:3072]
del state_dict[originkey]
originkey = "image_encoder.blocks."+str(i)+".attn.qkv.bias"
state_dict["image_encoder.blocks."+str(i)+".attn.q.bias"] = state_dict[originkey][0:1024]
state_dict["image_encoder.blocks."+str(i)+".attn.k.bias"] = state_dict[originkey][1024:2048]
state_dict["image_encoder.blocks."+str(i)+".attn.v.bias"] = state_dict[originkey][2048:3072]
del state_dict[originkey]
for key, value in state_dict.items():
print(key,value.shape)
torch.save(state_dict,'checkpoints/sam_vit_b_qkv.pth')