Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Q2A handling #20

Merged
merged 5 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 45 additions & 12 deletions BingusApi/Controllers/FaqController.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
using BingusLib.Config;
using BingusLib.FaqHandling;
using BingusLib.HNSW;
using HNSW.Net;
using Microsoft.AspNetCore.Mvc;
using static BingusLib.FaqHandling.FaqHandler;

namespace BingusApi.Controllers;

Expand All @@ -16,12 +17,16 @@ public class FaqController : ControllerBase
private static readonly int MaxLength = 4000;

private readonly FaqHandler _faqHandler;
private readonly BingusConfig _bingusConfig;
private readonly FaqConfig _faqConfig;
private readonly FaqDict? _faqDict;

public FaqController(FaqHandler faqHandler, FaqConfig faqConfig)
public FaqController(FaqHandler faqHandler, BingusConfig bingusConfig, FaqConfig faqConfig)
{
_faqHandler = faqHandler;
_bingusConfig = bingusConfig;
_faqConfig = faqConfig;
_faqDict = bingusConfig.UseQ2A ? new FaqDict(faqConfig) : null;
}

private static FaqEntry GetEntry(ILazyItem<float[]> item)
Expand All @@ -43,18 +48,29 @@ public IEnumerable<FaqEntryResponse> Search(string question, int responseCount =

// Actually query a larger set amount to reduce duplicates in the response,
// but one result will never have duplicates
var results = _faqHandler.Search(question, responseCount > 1 ? SearchAmount : 1);
var searchAmount = _bingusConfig.UseQ2A
? responseCount
: (responseCount > 1 ? SearchAmount : 1);
var results = _faqHandler.Search(question, searchAmount);

IEnumerable<SmallWorld<ILazyItem<float[]>, float>.KNNSearchResult> filteredResults =
results;

// Only consider duplicates if Q2Q, there will only be one for Q2A
if (!_bingusConfig.UseQ2A)
{
// Group the duplicates
// Select the highest relevance entry for each duplicate group
filteredResults = filteredResults
.GroupBy(result => GetEntry(result.Item).Answer)
.Select(groupedResults =>
groupedResults.MinBy(result => result.Distance) ?? groupedResults.First()
);
}

// Format the entry JSON
// Group the duplicates
// Select the highest relevance entry for each duplicate group
// Sort the entries by relevance
// Take only the requested number of results
var responses = results
.GroupBy(result => GetEntry(result.Item).Answer)
.Select(groupedResults =>
groupedResults.MinBy(result => result.Distance) ?? groupedResults.First()
)
var response = filteredResults
.OrderByDescending(result => -result.Distance)
.Take(responseCount)
.Select(result =>
Expand All @@ -69,7 +85,24 @@ public IEnumerable<FaqEntryResponse> Search(string question, int responseCount =
};
});

return responses;
var dictAnswer = _faqDict?.Search(question);
if (dictAnswer != null)
{
response = response
.Where(result => result.Text != dictAnswer.Answer)
.Prepend(
new FaqEntryResponse()
{
Relevance = 100f,
MatchedQuestion = dictAnswer.Question,
Title = dictAnswer.Title,
Text = dictAnswer.Answer,
}
)
.Take(responseCount);
}

return response;
}

[HttpGet(template: "Config", Name = "Config")]
Expand Down
7 changes: 3 additions & 4 deletions BingusApi/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ string GetConfig(string fileName)
);

// Initialize Bingus library dependencies
builder.Services.AddSingleton<HttpClient>();
builder.Services.AddHttpClient();
builder.Services.AddSingleton(sp => sp.GetRequiredService<BingusConfig>().GetSentenceEncoder(sp));
builder.Services.AddSingleton(CosineDistance.SIMDForUnits);
builder.Services.AddSingleton<IProvideRandomValues>(sp => new SeededRandom(
Expand Down Expand Up @@ -104,9 +104,8 @@ string GetConfig(string fileName)
app.MapControllers();

// Load FAQ
var useQ2A = app.Services.GetService<BingusConfig>()?.UseQ2A ?? false;
var faqConf = app.Services.GetRequiredService<FaqConfig>();
var useQ2A = app.Services.GetRequiredService<BingusConfig>().UseQ2A;
app.Services.GetRequiredService<FaqHandler>()
.AddItems(useQ2A ? faqConf.AnswerEntryEnumerator() : faqConf.QaEntryEnumerator(), useQ2A);
.AddItems(app.Services.GetRequiredService<FaqConfig>(), useQ2A);

app.Run();
30 changes: 30 additions & 0 deletions BingusLib/FaqHandling/FaqDict.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
namespace BingusLib.FaqHandling
{
public class FaqDict
{
private readonly Dictionary<string, FaqEntry> _faqDict = [];

public FaqDict(FaqConfig faqConfig)
: this(faqConfig.QaEntryEnumerator()) { }

public FaqDict(IEnumerable<(string title, string question, string answer)> tqaMapping)
{
foreach (var (title, question, answer) in tqaMapping)
{
_faqDict[CleanQuery(question)] = new FaqEntry()
{
Title = title,
Question = question,
Answer = answer,
};
}
}

private static string CleanQuery(string query) => query.Trim().ToLowerInvariant();

public FaqEntry? Search(string query)
{
return _faqDict.TryGetValue(CleanQuery(query), out var entry) ? entry : null;
}
}
}
12 changes: 12 additions & 0 deletions BingusLib/FaqHandling/FaqEntry.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
using MathNet.Numerics.LinearAlgebra;

namespace BingusLib.FaqHandling
{
public record class FaqEntry
{
public string Title { get; set; } = "";
public string Question { get; set; } = "";
public string Answer { get; set; } = "";
public Vector<float>? Vector { get; set; }
}
}
17 changes: 7 additions & 10 deletions BingusLib/FaqHandling/FaqHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,8 @@

namespace BingusLib.FaqHandling
{
public record FaqHandler
public class FaqHandler
{
public record FaqEntry
{
public string Title { get; set; } = "";
public string Question { get; set; } = "";
public string Answer { get; set; } = "";
public Vector<float>? Vector { get; set; }
}

private readonly ILogger<FaqHandler>? _logger;

private readonly IEmbeddingStore? _embeddingStore;
Expand All @@ -41,9 +33,14 @@ public FaqHandler(
_hnswHandler = new(distanceFunction, randomProvider, parameters);
}

public void AddItems(FaqConfig faq, bool useQ2A = true)
{
AddItems(useQ2A ? faq.AnswerEntryEnumerator() : faq.QaEntryEnumerator(), useQ2A);
}

public void AddItems(
IEnumerable<(string title, string question, string answer)> tqaMapping,
bool useQ2A = false
bool useQ2A = true
)
{
var hnswItems = new List<LazyKeyItem<FaqEntry, float[]>>();
Expand Down
2 changes: 1 addition & 1 deletion BingusLib/HNSW/LazyItem.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
namespace BingusLib.HNSW
{
public class LazyItem<TItem> : ILazyItem<TItem>
public record class LazyItem<TItem> : ILazyItem<TItem>
{
private readonly Func<TItem> _getValue;
public TItem Value => _getValue();
Expand Down
2 changes: 1 addition & 1 deletion BingusLib/HNSW/LazyItemValue.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
namespace BingusLib.HNSW
{
public readonly struct LazyItemValue<TItem> : ILazyItem<TItem>
public readonly record struct LazyItemValue<TItem> : ILazyItem<TItem>
{
private readonly Func<TItem> _getValue;
public TItem Value => _getValue();
Expand Down
2 changes: 1 addition & 1 deletion BingusLib/HNSW/LazyKeyItem.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

namespace BingusLib.HNSW
{
public class LazyKeyItem<TKey, TItem> : LazyItem<TItem>
public record class LazyKeyItem<TKey, TItem> : LazyItem<TItem>
{
private readonly TKey _key;
public TKey Key => _key;
Expand Down Expand Up @@ -42,7 +42,7 @@
)
{
var formatter = options.Resolver.GetFormatter<TKey>();
formatter.Serialize(ref writer, value.Key, options);

Check warning on line 45 in BingusLib/HNSW/LazyKeyItem.cs

View workflow job for this annotation

GitHub Actions / build

Dereference of a possibly null reference.

Check warning on line 45 in BingusLib/HNSW/LazyKeyItem.cs

View workflow job for this annotation

GitHub Actions / build

Dereference of a possibly null reference.

Check warning on line 45 in BingusLib/HNSW/LazyKeyItem.cs

View workflow job for this annotation

GitHub Actions / build

Dereference of a possibly null reference.

Check warning on line 45 in BingusLib/HNSW/LazyKeyItem.cs

View workflow job for this annotation

GitHub Actions / build

Dereference of a possibly null reference.
}

public LazyKeyItem<TKey, TItem> Deserialize(
Expand All @@ -51,7 +51,7 @@
)
{
var formatter = options.Resolver.GetFormatter<TKey>();
var key = formatter.Deserialize(ref reader, options);

Check warning on line 54 in BingusLib/HNSW/LazyKeyItem.cs

View workflow job for this annotation

GitHub Actions / build

Dereference of a possibly null reference.

Check warning on line 54 in BingusLib/HNSW/LazyKeyItem.cs

View workflow job for this annotation

GitHub Actions / build

Dereference of a possibly null reference.

Check warning on line 54 in BingusLib/HNSW/LazyKeyItem.cs

View workflow job for this annotation

GitHub Actions / build

Dereference of a possibly null reference.

Check warning on line 54 in BingusLib/HNSW/LazyKeyItem.cs

View workflow job for this annotation

GitHub Actions / build

Dereference of a possibly null reference.
return new LazyKeyItem<TKey, TItem>(key, () => ResolveItem(key));
}
}
Expand Down