-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsafetensors_hack.py
102 lines (81 loc) · 4.05 KB
/
safetensors_hack.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import io
import os
import mmap
import torch
import json
import hashlib
import safetensors
import safetensors.torch
import sd_models
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
UntypedStorage = torch.storage.UntypedStorage if hasattr(torch.storage, 'UntypedStorage') else torch.storage._UntypedStorage
def read_metadata(filename):
"""Reads the JSON metadata from a .safetensors file"""
with open(filename, mode="r", encoding="utf8") as file_obj:
with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as m:
header = m.read(8)
n = int.from_bytes(header, "little")
metadata_bytes = m.read(n)
metadata = json.loads(metadata_bytes)
return metadata.get("__metadata__", {})
def load_file(filename, device):
""""Loads a .safetensors file without memory mapping that locks the model file.
Works around safetensors issue: https://github.com/huggingface/safetensors/issues/164"""
with open(filename, mode="r", encoding="utf8") as file_obj:
with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as m:
header = m.read(8)
n = int.from_bytes(header, "little")
metadata_bytes = m.read(n)
metadata = json.loads(metadata_bytes)
size = os.stat(filename).st_size
storage = UntypedStorage.from_file(filename, False, size)
offset = n + 8
md = metadata.get("__metadata__", {})
return {name: create_tensor(storage, info, offset) for name, info in metadata.items() if name != "__metadata__"}, md
def hash_file(filename):
"""Hashes a .safetensors file using the new hashing method.
Only hashes the weights of the model."""
hash_sha256 = hashlib.sha256()
blksize = 1024 * 1024
with open(filename, mode="r", encoding="utf8") as file_obj:
with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as m:
header = m.read(8)
n = int.from_bytes(header, "little")
with open(filename, mode="rb") as file_obj:
offset = n + 8
file_obj.seek(offset)
for chunk in iter(lambda: file_obj.read(blksize), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest()
def legacy_hash_file(filename):
"""Hashes a model file using the legacy `sd_models.model_hash()` method."""
hash_sha256 = hashlib.sha256()
metadata = read_metadata(filename)
# For compatibility with legacy models: This replicates the behavior of
# sd_models.model_hash as if there were no user-specified metadata in the
# .safetensors file. That leaves the training parameters, which are
# immutable. It is important the hash does not include the embedded user
# metadata as that would mean the hash could change every time the user
# updates the name/description/etc. The new hashing method fixes this
# problem by only hashing the region of the file containing the tensors.
if any(not k.startswith("ss_") for k in metadata):
# Strip the user metadata, re-serialize the file as if it were freshly
# created from sd-scripts, and hash that with model_hash's behavior.
tensors, metadata = load_file(filename, "cpu")
metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
model_bytes = safetensors.torch.save(tensors, metadata)
hash_sha256.update(model_bytes[0x100000:0x110000])
return hash_sha256.hexdigest()[0:8]
else:
# This should work fine with model_hash since when the legacy hashing
# method was being used the user metadata system hadn't been implemented
# yet.
return sd_models.model_hash(filename)
DTYPES = {"F32": torch.float32, "F16": torch.float16, "BF16": torch.bfloat16}
def create_tensor(storage, info, offset):
"""Creates a tensor without holding on to an open handle to the parent model
file."""
dtype = DTYPES[info["dtype"]]
shape = info["shape"]
start, stop = info["data_offsets"]
return torch.asarray(storage[start + offset : stop + offset], dtype=torch.uint8).view(dtype=dtype).reshape(shape).clone().detach()