-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathsqlite_vec_demo.py
58 lines (45 loc) · 1.36 KB
/
sqlite_vec_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import sqlite3
import sqlite_vec
from typing import List
import struct
def serialize_f32(vector: List[float]) -> bytes:
"""serializes a list of floats into a compact "raw bytes" format"""
return struct.pack("%sf" % len(vector), *vector)
def main() -> None:
db = sqlite3.connect(":memory:")
db.enable_load_extension(True)
sqlite_vec.load(db)
db.enable_load_extension(False)
sqlite_version, vec_version = db.execute(
"select sqlite_version(), vec_version()"
).fetchone()
print(f"sqlite_version={sqlite_version}, vec_version={vec_version}")
items = [
(1, [0.1, 0.1, 0.1, 0.1]),
(2, [0.2, 0.2, 0.2, 0.2]),
(3, [0.3, 0.3, 0.3, 0.3]),
(4, [0.4, 0.4, 0.4, 0.4]),
(5, [0.5, 0.5, 0.5, 0.5]),
]
query = [0.3, 0.3, 0.3, 0.3]
db.execute("CREATE VIRTUAL TABLE vec_items USING vec0(embedding float[4])")
with db:
for item in items:
db.execute(
"INSERT INTO vec_items(rowid, embedding) VALUES (?, ?)",
[item[0], serialize_f32(item[1])],
)
rows = db.execute(
"""
SELECT
rowid,
distance
FROM vec_items
WHERE embedding MATCH ?
and k = 3
""",
[serialize_f32(query)],
).fetchall()
print(rows)
if __name__ == "__main__":
main()