Skip to content

Commit

Permalink
rewrite generate_missing_mappings tool in the newly added Application…
Browse files Browse the repository at this point in the history
…Command style
  • Loading branch information
chesterbot01 committed Aug 1, 2024
1 parent 48a6799 commit ab60513
Show file tree
Hide file tree
Showing 2 changed files with 319 additions and 284 deletions.
318 changes: 318 additions & 0 deletions app/commands/generate_missing_mappings_command.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
# frozen_string_literal: true

class GenerateMissingMappingsCommand < ApplicationCommand
MAX_RETRIES = 3
QDRANT_PORT = 6333
EMBEDDING_MODEL = "text-embedding-3-small"
MAPPING_GRADER_GPT_MODEL = "gpt-4"
LATEST_SHOPIFY_VERSION = "shopify/#{File.read("VERSION").strip}"

usage do
no_command
end

def execute
frame("Generating missing mappings") do
shopify_categories_missing_mapping_groups = search_unmapped_shopify_categories
return if shopify_categories_missing_mapping_groups.empty?

initialize_clients
generate_missing_mappings_for_groups(shopify_categories_missing_mapping_groups)
end
end

private

def search_unmapped_shopify_categories
shopify_categories_lack_mappings = nil
spinner("Searching shopify categories that lack mappings") do
shopify_categories_lack_mappings = []
all_shopify_category_ids = Set.new(Category.all.pluck(:id))
MappingRule.where(
input_version: LATEST_SHOPIFY_VERSION,
).group_by(&:output_version).each do |output_version, mappings|
shopify_category_ids_from_mappings_input = Set.new(
mappings.map do |mapping|
mapping.input.product_category_id.split("/").last
end,
)
unmapped_category_ids = all_shopify_category_ids - shopify_category_ids_from_mappings_input
category_ids_full_names = unmapped_category_ids.sort.map do |id|
category_full_name = Category.find(id)&.full_name
[id, category_full_name] if category_full_name
end.compact.to_h
next if category_ids_full_names.empty?

shopify_categories_lack_mappings << {
input_taxonomy: mappings.first.input_version,
output_taxonomy: output_version,
category_ids_full_names: category_ids_full_names,
}
end
end
shopify_categories_lack_mappings
end

def initialize_clients
frame("Initializing clients for performing semantic search") do
@openai_client = create_openai_client
@qdrant_client = create_qdrant_client
end
end

def create_openai_client
OpenAI::Client.new(
access_token: "shopify-eyJtb2RlIjoidGVhbSIsInRlYW0iOjE1MjU3LCJwcm9qZWN0Ijo0MTQ5NywicmVwbyI6InByb2R1Y3QtdGF4b25vbXkiLCJlbnZpcm9ubWVudCI6ImRldiIsImVtYWlsIjoiaG9uZ21pbmcud2FuZ0BzaG9waWZ5LmNvbSJ9-v3kkQN43B9f1YVX1H+viPtRjGXg+rD1tfRNZLgUQeNo=",
uri_base: "https://openai-proxy.shopify.ai/v1",
request_timeout: 10,
)
end

def create_qdrant_client
ensure_qdrant_server_running
Qdrant::Client.new(url: "http://localhost:#{QDRANT_PORT}")
end

def ensure_qdrant_server_running
return if system("lsof -i:#{QDRANT_PORT}", out: "/dev/null")

command = "podman run -p #{QDRANT_PORT}:#{QDRANT_PORT} qdrant/qdrant"
pid = Process.spawn(command, out: "/dev/null", err: "/dev/null")
Process.detach(pid)
logger.info("Started Qdrant server in the background with PID #{pid}.")
end

def generate_missing_mappings_for_groups(missing_mapping_groups)
missing_mapping_groups.each do |missing_mapping_group|
input_taxonomy = missing_mapping_group[:input_taxonomy]
output_taxonomy = missing_mapping_group[:output_taxonomy]
index_name = output_taxonomy.gsub(%r{[/\-]}, "_")
frame("Generating mappings for #{input_taxonomy} -> #{output_taxonomy}") do
embedding_data = load_embedding_data(output_taxonomy)
index_embedding_data(embedding_data:, index_name:)
generate_and_evaluate_mappings_for_group(
missing_mapping_group:,
index_name:,
)
end
end
end

def load_embedding_data(output_taxonomy)
embeddings = nil
spinner("Loading embeddings for #{output_taxonomy}") do
files = Dir.glob(File.join("data/integrations/#{output_taxonomy}/embeddings", "_*.txt"))
embeddings = files.each_with_object({}) do |partition, embedding_data|
File.foreach(partition) do |line|
word, vector_str = line.chomp.split(":", 2)
vector = vector_str.split(", ").map { |num| BigDecimal(num).to_f }
embedding_data[word] = vector
end
end
end
embeddings
end

def load_destination_taxonomy_ids(output_taxonomy)
logger.debug("Loading destination taxonomy IDs for #{output_taxonomy}")
YAML.load_file("data/integrations/#{output_taxonomy}/full_names.yml").each_with_object({}) do |category, hash|
hash[category["full_name"]] = category["id"]
end
end

def index_embedding_data(embedding_data:, index_name:)
spinner("Indexing embeddings for #{index_name}") do
@qdrant_client.collections.delete(collection_name: index_name)
@qdrant_client.collections.create(
collection_name: index_name,
vectors: { size: 1536, distance: "Cosine" },
)

points = embedding_data.map.with_index do |(key, value), index|
{
id: index + 1,
vector: value,
payload: { index_name => key },
}
end

points.each_slice(100) do |batch|
@qdrant_client.points.upsert(
collection_name: index_name,
points: batch,
)
end
end
end

def generate_and_evaluate_mappings_for_group(
missing_mapping_group:,
index_name:
)
spinner("Generating and evaluating mappings for each category in the group") do
destination_taxonomy_ids_by_full_name = load_destination_taxonomy_ids(missing_mapping_group[:output_taxonomy])
mapping_file_path = "data/integrations/#{missing_mapping_group[:output_taxonomy]}/mappings/from_shopify.yml"
mapping_data = YAML.load_file(mapping_file_path)

disagree_messages = []
missing_mapping_group[:category_ids_full_names].each do |source_category_id, source_category_name|
generated_mapping = generate_mapping(
source_category_id:,
source_category_name:,
index_name:,
destination_taxonomy_ids_by_full_name:,
)
mapping_data["rules"] << generated_mapping[:new_entry]
disagree_messages << generated_mapping[:mapping_to_be_graded] if generated_mapping[:grading_result] == "No"
end

mapping_data["rules"].sort_by! { |rule| rule["input"]["product_category_id"] }
File.write(mapping_file_path, mapping_data.to_yaml)

write_disagree_messages(disagree_messages) if disagree_messages.any?
end
end

def generate_mapping(source_category_id:, source_category_name:, index_name:,
destination_taxonomy_ids_by_full_name:)
logger.debug("Generating mapping for #{source_category_name}")
category_embedding = get_embeddings(source_category_name)
top_candidate = search_top_candidate(query_embedding: category_embedding, index_name:)
destination_category_id = destination_taxonomy_ids_by_full_name[top_candidate]

new_entry = {
"input" => { "product_category_id" => source_category_id },
"output" => { "product_category_id" => [destination_category_id.to_s] },
}

mapping_to_be_graded = {
from_category_id: source_category_id,
from_category: source_category_name,
to_category_id: destination_category_id.to_s,
to_category: top_candidate,
}

logger.debug("Grading mapping for #{source_category_name} -> #{top_candidate}")
grading_result = grade_taxonomy_mapping(mapping_to_be_graded)

{ new_entry: new_entry, mapping_to_be_graded: mapping_to_be_graded, grading_result: grading_result }
end

def get_embeddings(text)
with_retries do
response = @openai_client.embeddings(
parameters: {
model: EMBEDDING_MODEL,
input: text,
},
)
response.dig("data", 0, "embedding")
end
end

def search_top_candidate(query_embedding:, index_name:)
result = @qdrant_client.points.search(
collection_name: index_name,
vector: query_embedding,
with_payload: true,
limit: 1,
)
result["result"].first["payload"][index_name]
end

def grade_taxonomy_mapping(mapping)
with_retries do
response = @openai_client.chat(
parameters: {
model: MAPPING_GRADER_GPT_MODEL,
messages: [
{ role: "system", content: system_prompts_of_taxonomy_mapping_grader },
{ role: "user", content: [mapping].to_json },
],
temperature: 0,
},
)
JSON.parse(response.dig("choices", 0, "message", "content")).first["agree_with_mapping"]
end
end

def with_retries
retries = 0
begin
yield
rescue StandardError => e
retries += 1
if retries <= MAX_RETRIES
logger.debug("Received error: #{e.message}. Retrying (#{retries}/#{MAX_RETRIES})...")
sleep(1)
retry
else
logger.fatal("Failed after #{MAX_RETRIES} retries.")
raise
end
end
end

def write_disagree_messages(disagree_messages)
File.open("tmp/mapping_update_message.txt", "a") do |file|
file.puts "❗AI Grader disagrees with the following mappings:"
disagree_messages.each do |mapping|
mapping.each { |key, value| file.puts "#{key}:#{value}" }
file.puts
end
end
end

def system_prompts_of_taxonomy_mapping_grader
<<~CONTEXT
You are a taxonomy mapping expert who evaluate the accuracy of product category mappings between two taxonomies.
Your task is to review and grade the accuracy of the mappings, Yes or No, based on the following criteria:
1. Mark a mapping as Yes, i.e. correct, if two categories of a mapping are highly relevant to each other and similar
in terms of product type, function, or purpose.
For example:
- "Apparel & Accessories" and "Clothing, Shoes & Jewelry"
- "Apparel & Accessories > Clothing > One-Pieces" and "Clothing, Shoes & Accessories > Women > Women's Clothing > Jumpsuits & Rompers"
2. Mark a mapping as No, i.e. incorrect, if two categories of a mapping are irrevant to each other
in terms of product type, function, or purpose.
For example:
- "Apparel & Accessories > Clothing > Dresses" and "Clothing, Shoes & Jewelry>Shoe, Jewelry & Watch Accessories"
- "Apparel & Accessories" and "Clothing, Shoes & Jewelry>Luggage & Travel Gear"
Note, the character ">" in a category name indicates the start of a new category level. For example:
"sporting goods > exercise & fitness > cardio equipment"'s ancestor categories are "sporting goods > exercise & fitness" and "sporting goods".
You will receive a list of mappings. Each mapping contains a from_category name and a to_category name.
e.g. user's prompt in json format:
[
{
"from_category_id": "111",
"from_category": "Apparel & Accessories > Jewelry > Smart Watches",
"to_category_id": "222",
"to_category": "Clothing, Shoes & Jewelry>Men's Fashion>Men's Watches>Men's Smartwatches",
},
{
"from_category_id": "333",
"from_category": "Apparel & Accessories > Clothing > One-Pieces",
"to_category_id": "444",
"to_category": "Clothing, Shoes & Accessories > Women > Women's Clothing > Outfits & Sets",
},
]
You evaluate accuracy of every mapping and reply in the following format. Do not change the order of mappings in your reply.
e.g. your response in json format:
[
{
"from_category_id": "111",
"from_category": "Apparel & Accessories > Jewelry > Smart Watches",
"to_category_id": "222",
"to_category": "Clothing, Shoes & Jewelry>Men's Fashion>Men's Watches>Men's Smartwatches",
"agree_with_mapping": "Yes",
},
{
"from_category_id": "333",
"from_category": "Apparel & Accessories > Clothing > One-Pieces",
"to_category_id": "444",
"to_category": "Clothing, Shoes & Accessories > Women > Women's Clothing > Outfits & Sets",
"agree_with_mapping": "No",
},
]
CONTEXT
end
end
Loading

0 comments on commit ab60513

Please sign in to comment.