-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathupload_to_hf.py
104 lines (93 loc) · 3.56 KB
/
upload_to_hf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import sys
import tempfile
from huggingface_hub import HfApi, create_repo
from huggingface_hub.utils._errors import BadRequestError, HfHubHTTPError
TODO_COMMENT = "**TODO: Add base_model description to model card section in Hugging Face Hub**"
def main():
converted_ckpt = sys.argv[1]
repo_name = sys.argv[2]
branch_name = sys.argv[3]
try:
create_repo(repo_name, repo_type="model", private=True)
except HfHubHTTPError as e:
if str(e).startswith("409 Client Error: Conflict for url: "):
print(f"repo {repo_name} already exists and will be upload target.")
else:
raise e
api = HfApi()
if branch_name != "main":
try:
api.create_branch(
repo_id=repo_name,
repo_type="model",
branch=branch_name,
)
except HfHubHTTPError as e:
if str(e).startswith("409 Client Error: Conflict for url: "):
print(f"branch {branch_name} already exists, try again...")
exit(1)
else:
raise e
print(f"to upload: {converted_ckpt}")
uploaded_count = 0
temp_path = None
base_model_line = None
for file in os.listdir(converted_ckpt):
path_or_file = os.path.join(converted_ckpt, file)
if not os.path.isfile(path_or_file):
print(f"skipping {file} ...")
continue
print(f"uploading {file} to branch {branch_name} ...")
while True:
try:
api.upload_file(
path_or_fileobj=path_or_file,
path_in_repo=file,
repo_id=repo_name,
repo_type="model",
commit_message=f"Upload {file}",
revision=branch_name,
)
print(f"successfully uploaded {file}")
uploaded_count += 1
break
except BadRequestError as e:
if not temp_path and file == "README.md":
temp_path, base_model_line = copy_with_base_model_filter(path_or_file)
if temp_path:
print(f"{base_model_line} does not exist in Hugging Face hub. Retrying upload README.md with the base_model line removed.", file=sys.stderr)
path_or_file = temp_path
continue
raise e
print(f"{uploaded_count} files were uploaded to {repo_name} {branch_name}")
if temp_path:
print(f"{base_model_line} does not exist in Hugging Face hub.")
print(TODO_COMMENT)
os.remove(temp_path)
def copy_with_base_model_filter(path):
in_model_card = None
base_model_line = None
lines = []
with open(path, "r", encoding="utf8") as fin:
for _ in fin:
line = _.rstrip("\r\n")
if in_model_card and not base_model_line and line.startswith("base_model:"):
base_model_line = line
continue
if line == "---":
if in_model_card is None:
in_model_card = True
elif in_model_card:
in_model_card = False
if base_model_line:
line += "\n" + TODO_COMMENT
lines.append(line)
if base_model_line:
with tempfile.NamedTemporaryFile(mode="w", delete=False) as fout:
print(*lines, sep="\n", file=fout)
return fout.name, base_model_line
else:
return None, None
if __name__ == "__main__":
main()