Skip to content

Commit

Permalink
Improve Q2A handling (#20)
Browse files Browse the repository at this point in the history
* Handle Q2A better in API

* Fix typing

* Add FaqDict for simple queries

* Initialize Q2Q dictionary if in Q2A mode

* Simplify FaqDict usage
  • Loading branch information
ButterscotchV authored Oct 18, 2024
1 parent 72f525a commit c98a1cb
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 29 deletions.
57 changes: 45 additions & 12 deletions BingusApi/Controllers/FaqController.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
using BingusLib.Config;
using BingusLib.FaqHandling;
using BingusLib.HNSW;
using HNSW.Net;
using Microsoft.AspNetCore.Mvc;
using static BingusLib.FaqHandling.FaqHandler;

namespace BingusApi.Controllers;

Expand All @@ -16,12 +17,16 @@ public class FaqController : ControllerBase
private static readonly int MaxLength = 4000;

private readonly FaqHandler _faqHandler;
private readonly BingusConfig _bingusConfig;
private readonly FaqConfig _faqConfig;
private readonly FaqDict? _faqDict;

public FaqController(FaqHandler faqHandler, FaqConfig faqConfig)
public FaqController(FaqHandler faqHandler, BingusConfig bingusConfig, FaqConfig faqConfig)
{
_faqHandler = faqHandler;
_bingusConfig = bingusConfig;
_faqConfig = faqConfig;
_faqDict = bingusConfig.UseQ2A ? new FaqDict(faqConfig) : null;
}

private static FaqEntry GetEntry(ILazyItem<float[]> item)
Expand All @@ -43,18 +48,29 @@ public IEnumerable<FaqEntryResponse> Search(string question, int responseCount =

// Actually query a larger set amount to reduce duplicates in the response,
// but one result will never have duplicates
var results = _faqHandler.Search(question, responseCount > 1 ? SearchAmount : 1);
var searchAmount = _bingusConfig.UseQ2A
? responseCount
: (responseCount > 1 ? SearchAmount : 1);
var results = _faqHandler.Search(question, searchAmount);

IEnumerable<SmallWorld<ILazyItem<float[]>, float>.KNNSearchResult> filteredResults =
results;

// Only consider duplicates if Q2Q, there will only be one for Q2A
if (!_bingusConfig.UseQ2A)
{
// Group the duplicates
// Select the highest relevance entry for each duplicate group
filteredResults = filteredResults
.GroupBy(result => GetEntry(result.Item).Answer)
.Select(groupedResults =>
groupedResults.MinBy(result => result.Distance) ?? groupedResults.First()
);
}

// Format the entry JSON
// Group the duplicates
// Select the highest relevance entry for each duplicate group
// Sort the entries by relevance
// Take only the requested number of results
var responses = results
.GroupBy(result => GetEntry(result.Item).Answer)
.Select(groupedResults =>
groupedResults.MinBy(result => result.Distance) ?? groupedResults.First()
)
var response = filteredResults
.OrderByDescending(result => -result.Distance)
.Take(responseCount)
.Select(result =>
Expand All @@ -69,7 +85,24 @@ public IEnumerable<FaqEntryResponse> Search(string question, int responseCount =
};
});

return responses;
var dictAnswer = _faqDict?.Search(question);
if (dictAnswer != null)
{
response = response
.Where(result => result.Text != dictAnswer.Answer)
.Prepend(
new FaqEntryResponse()
{
Relevance = 100f,
MatchedQuestion = dictAnswer.Question,
Title = dictAnswer.Title,
Text = dictAnswer.Answer,
}
)
.Take(responseCount);
}

return response;
}

[HttpGet(template: "Config", Name = "Config")]
Expand Down
7 changes: 3 additions & 4 deletions BingusApi/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ string GetConfig(string fileName)
);

// Initialize Bingus library dependencies
builder.Services.AddSingleton<HttpClient>();
builder.Services.AddHttpClient();
builder.Services.AddSingleton(sp => sp.GetRequiredService<BingusConfig>().GetSentenceEncoder(sp));
builder.Services.AddSingleton(CosineDistance.SIMDForUnits);
builder.Services.AddSingleton<IProvideRandomValues>(sp => new SeededRandom(
Expand Down Expand Up @@ -104,9 +104,8 @@ string GetConfig(string fileName)
app.MapControllers();

// Load FAQ
var useQ2A = app.Services.GetService<BingusConfig>()?.UseQ2A ?? false;
var faqConf = app.Services.GetRequiredService<FaqConfig>();
var useQ2A = app.Services.GetRequiredService<BingusConfig>().UseQ2A;
app.Services.GetRequiredService<FaqHandler>()
.AddItems(useQ2A ? faqConf.AnswerEntryEnumerator() : faqConf.QaEntryEnumerator(), useQ2A);
.AddItems(app.Services.GetRequiredService<FaqConfig>(), useQ2A);

app.Run();
30 changes: 30 additions & 0 deletions BingusLib/FaqHandling/FaqDict.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
namespace BingusLib.FaqHandling
{
public class FaqDict
{
private readonly Dictionary<string, FaqEntry> _faqDict = [];

public FaqDict(FaqConfig faqConfig)
: this(faqConfig.QaEntryEnumerator()) { }

public FaqDict(IEnumerable<(string title, string question, string answer)> tqaMapping)
{
foreach (var (title, question, answer) in tqaMapping)
{
_faqDict[CleanQuery(question)] = new FaqEntry()
{
Title = title,
Question = question,
Answer = answer,
};
}
}

private static string CleanQuery(string query) => query.Trim().ToLowerInvariant();

public FaqEntry? Search(string query)
{
return _faqDict.TryGetValue(CleanQuery(query), out var entry) ? entry : null;
}
}
}
12 changes: 12 additions & 0 deletions BingusLib/FaqHandling/FaqEntry.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
using MathNet.Numerics.LinearAlgebra;

namespace BingusLib.FaqHandling
{
public record class FaqEntry
{
public string Title { get; set; } = "";
public string Question { get; set; } = "";
public string Answer { get; set; } = "";
public Vector<float>? Vector { get; set; }
}
}
17 changes: 7 additions & 10 deletions BingusLib/FaqHandling/FaqHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,8 @@

namespace BingusLib.FaqHandling
{
public record FaqHandler
public class FaqHandler
{
public record FaqEntry
{
public string Title { get; set; } = "";
public string Question { get; set; } = "";
public string Answer { get; set; } = "";
public Vector<float>? Vector { get; set; }
}

private readonly ILogger<FaqHandler>? _logger;

private readonly IEmbeddingStore? _embeddingStore;
Expand All @@ -41,9 +33,14 @@ public FaqHandler(
_hnswHandler = new(distanceFunction, randomProvider, parameters);
}

public void AddItems(FaqConfig faq, bool useQ2A = true)
{
AddItems(useQ2A ? faq.AnswerEntryEnumerator() : faq.QaEntryEnumerator(), useQ2A);
}

public void AddItems(
IEnumerable<(string title, string question, string answer)> tqaMapping,
bool useQ2A = false
bool useQ2A = true
)
{
var hnswItems = new List<LazyKeyItem<FaqEntry, float[]>>();
Expand Down
2 changes: 1 addition & 1 deletion BingusLib/HNSW/LazyItem.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
namespace BingusLib.HNSW
{
public class LazyItem<TItem> : ILazyItem<TItem>
public record class LazyItem<TItem> : ILazyItem<TItem>
{
private readonly Func<TItem> _getValue;
public TItem Value => _getValue();
Expand Down
2 changes: 1 addition & 1 deletion BingusLib/HNSW/LazyItemValue.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
namespace BingusLib.HNSW
{
public readonly struct LazyItemValue<TItem> : ILazyItem<TItem>
public readonly record struct LazyItemValue<TItem> : ILazyItem<TItem>
{
private readonly Func<TItem> _getValue;
public TItem Value => _getValue();
Expand Down
2 changes: 1 addition & 1 deletion BingusLib/HNSW/LazyKeyItem.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

namespace BingusLib.HNSW
{
public class LazyKeyItem<TKey, TItem> : LazyItem<TItem>
public record class LazyKeyItem<TKey, TItem> : LazyItem<TItem>
{
private readonly TKey _key;
public TKey Key => _key;
Expand Down

0 comments on commit c98a1cb

Please sign in to comment.