-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathconvert_torch_to_paddle.py
30 lines (23 loc) · 1.31 KB
/
convert_torch_to_paddle.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from collections import OrderedDict
def convert_pytorch_checkpoint_to_paddle(pytorch_checkpoint_path,
paddle_dump_path):
import torch
import paddle
pytorch_state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu")
paddle_state_dict = OrderedDict()
for idx, (k, v) in enumerate(pytorch_state_dict.items()):
if k.startswith('transformer.'):
k = k.replace('transformer.', '')
if 'weight' in k and v.ndim == 2 and 'embedding' not in k:
v = v.transpose(0, 1)
paddle_state_dict[k] = v.data.numpy().astype('float32')
paddle.save(paddle_state_dict, paddle_dump_path)
print(pytorch_checkpoint_path, v.dtype)
if __name__ == "__main__":
import torch
convert_pytorch_checkpoint_to_paddle('./models/squeezebert-uncased/pytorch_model.bin',
'./models/squeezebert-uncased/model_state.pdparams')
convert_pytorch_checkpoint_to_paddle('./models/squeezebert-mnli-headless/pytorch_model.bin',
'./models/squeezebert-mnli-headless/model_state.pdparams')
convert_pytorch_checkpoint_to_paddle('./models/squeezebert-mnli/pytorch_model.bin',
'./models/squeezebert-mnli/model_state.pdparams')