-
Notifications
You must be signed in to change notification settings - Fork 798
/
Copy pathconvert-pretrain-to-detectron2.py
40 lines (30 loc) · 1.1 KB
/
convert-pretrain-to-detectron2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#!/usr/bin/env python
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import pickle as pkl
import sys
import torch
if __name__ == "__main__":
input = sys.argv[1]
obj = torch.load(input, map_location="cpu")
obj = obj["state_dict"]
newmodel = {}
for k, v in obj.items():
if not k.startswith("module.encoder_q."):
continue
old_k = k
k = k.replace("module.encoder_q.", "")
if "layer" not in k:
k = "stem." + k
for t in [1, 2, 3, 4]:
k = k.replace("layer{}".format(t), "res{}".format(t + 1))
for t in [1, 2, 3]:
k = k.replace("bn{}".format(t), "conv{}.norm".format(t))
k = k.replace("downsample.0", "shortcut")
k = k.replace("downsample.1", "shortcut.norm")
print(old_k, "->", k)
newmodel[k] = v.numpy()
res = {"model": newmodel, "__author__": "MOCO", "matching_heuristics": True}
with open(sys.argv[2], "wb") as f:
pkl.dump(res, f)