-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathingest_embeddings.py
172 lines (141 loc) · 4.91 KB
/
ingest_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import redis
from redis_client import RedisClient
from redis.commands.search.field import TagField, TextField, VectorField
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
import os
import json
import numpy as np
from ignore_filter import IgnoreFilter
import argparse
import yaml
# Define index name and document prefix
INDEX_NAME = "index"
DOC_PREFIX = "doc:"
VECTOR_DIMENSIONS = 1536
def delete_index():
try:
r.ft(INDEX_NAME).dropindex(delete_documents=True)
print("Index deleted!")
except:
print("Index does not exist!")
def create_index(vector_dimensions: int):
try:
# check to see if index exists
r.ft(INDEX_NAME).info()
print("Index already exists!")
except:
# schema
schema = (
TagField("tag"),
TextField("content"),
VectorField(
"vector",
"FLAT", # FLAT OR HSNW
{
"TYPE": "FLOAT32", # FLOAT32 or FLOAT64
"DIM": vector_dimensions, # Number of Vector Dimensions
"DISTANCE_METRIC": "COSINE", # Vector Search Distance Metric
},
),
TextField("link"),
TextField("filename"),
)
# index Definition
definition = IndexDefinition(prefix=[DOC_PREFIX], index_type=IndexType.HASH)
# create Index
r.ft(INDEX_NAME).create_index(fields=schema, definition=definition)
# Function to insert embeddings into Redis index
def insert_embedding(pipe: redis.client.Pipeline, file_name: str):
# Load the JSON data from file
with open(file_name, "r") as f:
json_data = json.load(f)
# Extract the embedding from the JSON data
embedding = json_data["data"][0]["embedding"]
# Open the file with the same name but ending in .md or .txt
base_name = os.path.splitext(os.path.basename(file_name))[0]
file_dir = os.path.dirname(file_name)
content = None
for extension in [".md", ".txt"]:
try:
with open(
os.path.join(file_dir, base_name + extension), "r", encoding="utf8"
) as f:
content = f.read()
break
except FileNotFoundError:
continue
if content is None:
print(f"No text file (.md or .txt) found for {file_name}")
return
yaml_file_path = os.path.join(file_dir, base_name + ".yaml")
pipe.hset(
f"{DOC_PREFIX}{file_dir}/{base_name}",
mapping={
"content": content,
"vector": np.array(embedding).astype(np.float32).tobytes(),
"tag": "openai",
"link": get_link_from_yaml(yaml_file_path),
"filename": base_name + extension
},
)
print(f"Inserted {file_name}")
def get_link_from_yaml(file_path):
try:
# Open the file
with open(file_path, 'r') as file:
# Load the YAML data
data = yaml.safe_load(file)
# Get the value of the key 'link'
value = data.get('link')
except FileNotFoundError:
value = ""
return value
def process_file(pipe, file_path, ignore_filter):
if file_path.endswith(".json") and not ignore_filter.is_ignored(file_path):
insert_embedding(pipe, file_path)
def process_folder(pipe, folder_abs_path, ignore_filter):
for root, dirs, files in os.walk(folder_abs_path):
# Remove ignored directories
dirs[:] = [
d for d in dirs if not ignore_filter.is_ignored(os.path.join(root, d))
]
for file in files:
file_path = os.path.join(root, file)
process_file(pipe, file_path, ignore_filter)
# Argument Parser
parser = argparse.ArgumentParser(
description="Insert embeddings for specific file or directory into Redis index"
)
parser.add_argument(
"path",
type=str,
nargs="?",
default=".",
help="The path to the specific file or directory",
)
parser.add_argument(
"--reset", action="store_true", help="Reset the index before inserting embeddings"
)
args = parser.parse_args()
# Determine the directory of this script
script_directory = os.path.dirname(os.path.realpath(__file__))
# Load ignore filter
embedding_ignore_path = os.path.join(script_directory, ".embedding_ignore")
ignore_filter = IgnoreFilter(embedding_ignore_path)
# Create the index
with RedisClient() as r:
# Delete the index if the reset flag was provided
if args.reset:
delete_index()
create_index(vector_dimensions=VECTOR_DIMENSIONS)
# Check if path is a file or a directory
if os.path.isfile(args.path):
pipe = r.pipeline()
process_file(pipe, args.path, ignore_filter)
pipe.execute()
elif os.path.isdir(args.path):
pipe = r.pipeline()
process_folder(pipe, args.path, ignore_filter)
pipe.execute()
else:
print(f"{args.path} is not a valid file or directory.")