-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathconvert-pth-to-ggml.py
162 lines (132 loc) · 5.14 KB
/
convert-pth-to-ggml.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
"""
This script converts the PyTorch weights of a Vision Transformer to the ggml file format.
It accepts a timm model name and returns the converted weights in the same directory as the script.
You can also specify the float type : 0 for float32, 1 for float16. Use float 16 (for now patch_embed.proj.weight only supports float16 in ggml)
It can also be used to list some of the available pre-trained models.
For now only the original ViT model family is supported.
usage: convert-pth-to-ggml.py [-h] [--model_name MODEL_NAME] [--ftype {0,1}] [--list [LIST]]
Convert PyTorch weights of a Vision Transformer to the ggml file format.
optional arguments:
-h, --help show this help message and exit
--model_name MODEL_NAME
timm model name
--ftype {0,1} float type: 0 for float32, 1 for float16
--list [LIST] List some examples of the supported model names.
"""
import argparse
import struct
import sys
import numpy as np
import timm
from timm.data import ImageNetInfo, infer_imagenet_subset
GGML_MAGIC = 0x67676d6c
def main():
# Set up argument parser
parser = argparse.ArgumentParser(
description="Convert PyTorch weights of a Vision Transformer to the ggml file format."
)
parser.add_argument(
"--model_name",
type=str,
default="vit_base_patch8_224.augreg2_in21k_ft_in1k",
help="timm model name",
)
parser.add_argument(
"--ftype",
type=int,
choices=[0, 1],
default=1,
help="float type: 0 for float32, 1 for float16",
)
parser.add_argument(
"--list",
type=bool,
nargs="?",
const=True,
default=False,
help="List some examples of the supported model names.",
)
args = parser.parse_args()
# List some available model names
if args.list:
print("Here are some model names (not all are supported!) : ")
model_sizes = ["tiny", "small", "base", "large"]
for size in model_sizes:
print(f"---- {size.upper()} ----")
print(", ".join(timm.list_pretrained(f"vit_{size}*")))
sys.exit(1)
# Output file name
fname_out = f"./ggml-model-{['f32', 'f16'][args.ftype]}.gguf"
# Load the pretrained model
timm_model = timm.create_model(args.model_name, pretrained=True)
# Create id2label dictionary
# if no labels added to config, use imagenet labeller in timm
imagenet_subset = infer_imagenet_subset(timm_model)
if imagenet_subset:
dataset_info = ImageNetInfo(imagenet_subset)
id2label = {
i: dataset_info.index_to_description(i)
for i in range(dataset_info.num_classes())
}
else:
print(
f"Unable to infer class labels for {args.model_name}. Will use fallaback label names(i.e ints)"
)
# fallback label names
id2label = {i: f"LABEL_{i}" for i in range(timm_model.num_classes)}
# Hyperparameters
hparams = {
"hidden_size": timm_model.embed_dim,
"num_hidden_layers": len(timm_model.blocks),
"num_attention_heads": timm_model.blocks[0].attn.num_heads,
"num_classes": timm_model.num_classes,
"patch_size": timm_model.patch_embed.patch_size[0],
"img_size": timm_model.patch_embed.img_size[0],
}
# Write to file
with open(fname_out, "wb") as fout:
fout.write(struct.pack("i", GGML_MAGIC)) # Magic: ggml in hex
for param in hparams.values():
fout.write(struct.pack("i", param))
fout.write(struct.pack("i", args.ftype))
# Write id2label dictionary to the file
write_id2label(fout, id2label)
# Process and write model weights
for k, v in timm_model.state_dict().items():
if k.startswith("norm_pre"):
print(f"the model {args.model_name} contains a pre_norm")
print(k)
continue
print(
"Processing variable: " + k + " with shape: ",
v.shape,
" and type: ",
v.dtype,
)
process_and_write_variable(fout, k, v, args.ftype)
print("Done. Output file: " + fname_out)
def write_id2label(file, id2label):
file.write(struct.pack("i", len(id2label)))
for key, value in id2label.items():
file.write(struct.pack("i", key))
encoded_value = value.encode("utf-8")
file.write(struct.pack("i", len(encoded_value)))
file.write(encoded_value)
def process_and_write_variable(file, name, tensor, ftype):
data = tensor.numpy()
ftype_cur = (
1
if ftype == 1 and tensor.ndim != 1 and name not in ["pos_embed", "cls_token"]
else 0
)
data = data.astype(np.float32) if ftype_cur == 0 else data.astype(np.float16)
if name == "patch_embed.proj.bias":
data = data.reshape(1, data.shape[0], 1, 1)
str_name = name.encode("utf-8")
file.write(struct.pack("iii", len(data.shape), len(str_name), ftype_cur))
for dim_size in reversed(data.shape):
file.write(struct.pack("i", dim_size))
file.write(str_name)
data.tofile(file)
if __name__ == "__main__":
main()