From 1bdf61a3271a560eb8dd2acd68df6fabd03edde4 Mon Sep 17 00:00:00 2001 From: Rick Hennigan Date: Thu, 9 Jan 2025 12:20:53 -0500 Subject: [PATCH 1/2] Added `RegisterVectorDatabase` --- Source/Chatbook/Main.wl | 2 + Source/Chatbook/PromptGenerators/Common.wl | 2 + .../PromptGenerators/RelatedDocumentation.wl | 179 ++++++++++--- .../PromptGenerators/VectorDatabases.wl | 253 +++++++++++++++++- Source/Chatbook/Utils.wl | 11 +- 5 files changed, 394 insertions(+), 53 deletions(-) diff --git a/Source/Chatbook/Main.wl b/Source/Chatbook/Main.wl index 6669012c..317088fc 100644 --- a/Source/Chatbook/Main.wl +++ b/Source/Chatbook/Main.wl @@ -76,6 +76,7 @@ BeginPackage[ "Wolfram`Chatbook`" ]; `LogChatTiming; `MakeExpressionURI; `RebuildChatSearchIndex; +`RegisterVectorDatabase; `RelatedDocumentation; `RelatedWolframAlphaQueries; `RemoveChatFromSearchIndex; @@ -245,6 +246,7 @@ $ChatbookProtectedNames = "Wolfram`Chatbook`" <> # & /@ { "LogChatTiming", "MakeExpressionURI", "RebuildChatSearchIndex", + "RegisterVectorDatabase", "RelatedDocumentation", "RelatedWolframAlphaQueries", "RemoveChatFromSearchIndex", diff --git a/Source/Chatbook/PromptGenerators/Common.wl b/Source/Chatbook/PromptGenerators/Common.wl index f1ba4973..8730c280 100644 --- a/Source/Chatbook/PromptGenerators/Common.wl +++ b/Source/Chatbook/PromptGenerators/Common.wl @@ -4,7 +4,9 @@ BeginPackage[ "Wolfram`Chatbook`PromptGenerators`Common`" ]; HoldComplete[ `$$prompt, + `$defaultSources, `getSmallContextString, + `getSnippets, `insertContextPrompt, `vectorDBSearch ]; diff --git a/Source/Chatbook/PromptGenerators/RelatedDocumentation.wl b/Source/Chatbook/PromptGenerators/RelatedDocumentation.wl index ccb4acb9..970d1067 100644 --- a/Source/Chatbook/PromptGenerators/RelatedDocumentation.wl +++ b/Source/Chatbook/PromptGenerators/RelatedDocumentation.wl @@ -39,13 +39,15 @@ $unfilteredItemsPerSource = 10; (* ::**************************************************************************************************************:: *) (* ::Section::Closed:: *) (*Messages*) -Chatbook::CloudDownloadError = "Unable to download required data from the cloud. Please try again later."; -Chatbook::InvalidSources = "Invalid value for the \"Sources\" option: `1`."; +Chatbook::CloudDownloadError = "Unable to download required data from the cloud. Please try again later."; +Chatbook::InvalidSources = "Invalid value for the \"Sources\" option: `1`."; +Chatbook::SnippetFunctionOutputFailure = "The snippet function `1` returned a list of length `2` for `3` values."; +Chatbook::SnippetFunctionLengthFailure = "The snippet function `1` returned a list of length `2` for `3` values."; (* ::**************************************************************************************************************:: *) (* ::Section::Closed:: *) (*$RelatedDocumentationSources*) -$RelatedDocumentationSources = $defaultSources; +$RelatedDocumentationSources := $defaultSources; (* ::**************************************************************************************************************:: *) (* ::Section::Closed:: *) @@ -94,20 +96,28 @@ RelatedDocumentation[ prompt_, Automatic, count_, opts: OptionsPattern[ ] ] := RelatedDocumentation[ prompt: $$prompt, "URIs", Automatic, opts: OptionsPattern[ ] ] := catchMine @ Enclose[ (* TODO: filter results *) - ConfirmMatch[ vectorDBSearch[ getSources @ OptionValue[ "Sources" ], prompt, "Values" ], { ___String }, "Queries" ], + URL /@ ConfirmMatch[ + vectorDBSearch[ getSources @ OptionValue[ "Sources" ], prompt, "Values" ], + { ___String }, + "Values" + ], throwInternalFailure ]; RelatedDocumentation[ All, "URIs", Automatic, opts: OptionsPattern[ ] ] := catchMine @ Enclose[ (* TODO: filter results *) - Union @ ConfirmMatch[ vectorDBSearch[ getSources @ OptionValue[ "Sources" ], All ], { __String }, "QueryList" ], + URL /@ Union @ ConfirmMatch[ + vectorDBSearch[ getSources @ OptionValue[ "Sources" ], All ], + { ___String }, + "Values" + ], throwInternalFailure ]; RelatedDocumentation[ prompt: $$prompt, "Snippets", Automatic, opts: OptionsPattern[ ] ] := catchMine @ Enclose[ ConfirmMatch[ (* TODO: filter results *) - DeleteMissing[ makeDocSnippets @ vectorDBSearch[ getSources @ OptionValue[ "Sources" ], prompt, "Values" ] ], + DeleteMissing[ makeDocSnippets @ vectorDBSearch[ getSources @ OptionValue[ "Sources" ], prompt, "Results" ] ], { ___String }, "Snippets" ], @@ -118,7 +128,7 @@ RelatedDocumentation[ prompt_, property_, UpTo[ n_Integer ], opts: OptionsPatter catchMine @ RelatedDocumentation[ prompt, property, n, opts ]; RelatedDocumentation[ prompt_, property_, n_Integer, opts: OptionsPattern[ ] ] := catchMine @ Enclose[ - Take[ ConfirmMatch[ RelatedDocumentation[ prompt, property, Automatic, opts ], { ___String } ], UpTo @ n ], + Take[ ConfirmBy[ RelatedDocumentation[ prompt, property, Automatic, opts ], ListQ ], UpTo @ n ], throwInternalFailure ]; @@ -163,7 +173,8 @@ RelatedDocumentation[ prompt_, "Prompt", n_Integer, opts: OptionsPattern[ ] ] := $rerankMethod = Replace[ OptionValue[ "RerankMethod" ], $$unspecified :> CurrentChatSettings[ "DocumentationRerankMethod" ] - ] + ], + $RelatedDocumentationSources = getSources @ OptionValue[ "Sources" ] }, relatedDocumentationPrompt[ ensureChatMessages @ prompt, @@ -223,23 +234,25 @@ ensureChatMessages // endDefinition; relatedDocumentationPrompt // beginDefinition; relatedDocumentationPrompt[ messages: $$chatMessages, count_, filter_, filterCount_ ] := Enclose[ - Catch @ Module[ { uris, filtered, string }, + Catch @ Module[ { results, filtered, string }, - uris = ConfirmMatch[ - RelatedDocumentation[ messages, "URIs", count ], - { ___String }, - "URIs" - ] // LogChatTiming[ "RelatedDocumentationURIs" ] // withApproximateProgress[ "CheckingDocumentation", 0.2 ]; + results = ConfirmMatch[ + RelatedDocumentation[ messages, "Results", count ], + { ___Association }, + "Results" + ] // LogChatTiming[ "RelatedDocumentationResults" ] // withApproximateProgress[ "CheckingDocumentation", 0.2 ]; + + If[ results === { }, Throw[ "" ] ]; - If[ uris === { }, Throw[ "" ] ]; + results = DeleteDuplicatesBy[ results, Lookup[ "Value" ] ]; filtered = ConfirmMatch[ - filterSnippets[ messages, uris, filter, filterCount ] // LogChatTiming[ "FilterSnippets" ], + filterSnippets[ messages, results, filter, filterCount ] // LogChatTiming[ "FilterSnippets" ], { ___String }, "Filtered" ]; - string = StringTrim @ StringRiffle[ "# "<># & /@ DeleteCases[ filtered, "" ], "\n\n======\n\n" ]; + string = StringTrim @ StringRiffle[ DeleteCases[ filtered, "" ], "\n\n======\n\n" ]; If[ string === "", "", @@ -272,20 +285,20 @@ $relatedDocsStringUnfilteredHeader = filterSnippets // beginDefinition; -filterSnippets[ messages_, uris: { __String }, Except[ True ], filterCount_ ] := Enclose[ - ConfirmMatch[ makeDocSnippets @ uris, { ___String }, "Snippets" ], +filterSnippets[ messages_, results_List, Except[ True ], filterCount_ ] := Enclose[ + ConfirmMatch[ makeDocSnippets @ results, { ___String }, "Snippets" ], throwInternalFailure ]; filterSnippets[ messages_, - uris: { __String }, + results_List, True, filterCount_Integer? Positive ] /; $rerankMethod === None := Enclose[ Catch @ Module[ { snippets }, - snippets = ConfirmMatch[ makeDocSnippets @ uris, { ___String }, "Snippets" ]; + snippets = ConfirmMatch[ makeDocSnippets @ results, { ___String }, "Snippets" ]; Take[ snippets, UpTo[ filterCount ] ] ], throwInternalFailure @@ -294,13 +307,13 @@ filterSnippets[ filterSnippets[ messages_, - uris: { __String }, + results_List, True, filterCount_Integer? Positive ] /; $rerankMethod === "rerank-english-v3.0" (* EXPERIMENTAL *) := Enclose[ - Catch @ Module[ { snippets, inserted, transcript, instructions, resp, results, idx, ranked }, + Catch @ Module[ { snippets, inserted, transcript, instructions, resp, respResults, idx, ranked }, - snippets = ConfirmMatch[ makeDocSnippets @ uris, { ___String }, "Snippets" ]; + snippets = ConfirmMatch[ makeDocSnippets @ results, { ___String }, "Snippets" ]; setProgressDisplay[ "ProgressTextChoosingDocumentation" ]; inserted = insertContextPrompt @ messages; transcript = ConfirmBy[ getSmallContextString @ inserted, StringQ, "Transcript" ]; @@ -323,26 +336,33 @@ filterSnippets[ If[ FailureQ @ resp, throwTop @ resp ]; - results = ConfirmMatch[ resp[ "results" ], { __Association }, "Results" ]; + respResults = ConfirmMatch[ resp[ "results" ], { __Association }, "Results" ]; idx = ConfirmMatch[ - Select[ results, #[ "relevance_score" ] > 0.01 & ][[ All, "index" ]] + 1, + Select[ respResults, #[ "relevance_score" ] > 0.01 & ][[ All, "index" ]] + 1, { ___Integer }, "Indices" ]; ranked = ConfirmMatch[ snippets[[ idx ]], { ___String }, "Ranked" ]; + (* FIXME: need to add handler data here *) + Take[ ranked, UpTo[ filterCount ] ] ], throwInternalFailure ]; -filterSnippets[ messages_, uris: { __String }, True, filterCount_Integer? Positive ] := Enclose[ - Catch @ Module[ { snippets, inserted, transcript, xml, instructions, response, pages }, +filterSnippets[ messages_, results0_List, True, filterCount_Integer? Positive ] := Enclose[ + Catch @ Module[ + { + results, snippets, inserted, transcript, xml, + instructions, response, uriToSnippet, uris, selected, pages + }, - snippets = ConfirmMatch[ makeDocSnippets @ uris, { ___String }, "Snippets" ]; + results = ConfirmMatch[ addDocSnippets @ results0, { ___Association }, "Results" ]; + snippets = ConfirmMatch[ Lookup[ results, "Snippet" ], { ___String }, "Snippets" ]; setProgressDisplay[ "ChoosingDocumentation" ]; inserted = insertContextPrompt @ messages; transcript = ConfirmBy[ getSmallContextString @ inserted, StringQ, "Transcript" ]; @@ -370,7 +390,15 @@ filterSnippets[ messages_, uris: { __String }, True, filterCount_Integer? Positi "Response" ]; - pages = ConfirmMatch[ makeDocSnippets @ selectSnippetsFromJSON[ response, uris ], { ___String }, "Pages" ]; + $lastFilterInstructions = instructions; + $lastFilterResponse = response; + + uriToSnippet = <| #Value -> #Snippet & /@ results |>; + uris = ConfirmMatch[ Keys @ uriToSnippet, { ___String }, "URIs" ]; + selected = ConfirmMatch[ selectSnippetsFromJSON[ response, uris ], { ___String }, "Pages" ]; + pages = ConfirmMatch[ Lookup[ uriToSnippet, selected ], { ___String }, "Pages" ]; + + addHandlerArguments[ "RelatedDocumentation" -> <| "Results" -> uris, "Filtered" -> selected |> ]; pages ], @@ -387,6 +415,7 @@ Your task is to read a chat transcript between a user and assistant, and then se documentation snippets that could help the assistant answer the user's latest message. Each snippet is uniquely identified by a URI (always starts with 'paclet:' or 'https://*.wolframcloud.com'). +You must also include the fragment appearing after the '#' in the URI. Choose up to %%FilteredCount%% documentation snippets that would help answer the user's MOST RECENT message. @@ -486,26 +515,100 @@ snippetXML // endDefinition; (*Documentation Snippets*) $documentationSnippets = <| |>; +(* ::**************************************************************************************************************:: *) +(* ::Subsection::Closed:: *) +(*addDocSnippets*) +addDocSnippets // beginDefinition; + +addDocSnippets[ results: { ___Association } ] := Enclose[ + Module[ { withOrdering, grouped, withSnippets, sorted }, + + withOrdering = MapIndexed[ <| "Position" -> First[ #2 ], #1 |> &, results ]; + grouped = GroupBy[ withOrdering, Lookup[ "SnippetFunction" ] ]; + + withSnippets = ConfirmMatch[ + Flatten @ KeyValueMap[ applySnippetFunction, grouped ], + { ___Association }, + "WithSnippets" + ]; + + sorted = ConfirmMatch[ SortBy[ withSnippets, Lookup[ "Position" ] ], { ___Association }, "Sorted" ]; + + ConfirmAssert[ Length @ sorted === Length @ results, "LengthCheck" ]; + + sorted + ], + throwInternalFailure +]; + +addDocSnippets // endDefinition; + (* ::**************************************************************************************************************:: *) (* ::Subsection::Closed:: *) (*makeDocSnippets*) makeDocSnippets // beginDefinition; -makeDocSnippets[ uris0: { ___String } ] := Enclose[ - Module[ { uris, data, snippets, strings }, - uris = DeleteDuplicates @ uris0; +makeDocSnippets[ results: { ___Association } ] := Enclose[ + Module[ { sorted, snippets }, + sorted = ConfirmMatch[ addDocSnippets @ results, { ___Association }, "Sorted" ]; + snippets = ConfirmMatch[ Lookup[ sorted, "Snippet" ], { ___String }, "Snippets" ]; + ConfirmAssert[ Length @ snippets === Length @ results, "LengthCheck" ]; + DeleteDuplicates @ snippets + ], + throwInternalFailure +]; + +makeDocSnippets // endDefinition; + +(* ::**************************************************************************************************************:: *) +(* ::Subsubsection::Closed:: *) +(*applySnippetFunction*) +applySnippetFunction // beginDefinition; + +applySnippetFunction[ f_, { } ] := { }; + +applySnippetFunction[ f_, data: { ___Association } ] := Enclose[ + Module[ { values, snippets, snippetLen, valuesLen }, + + values = ConfirmMatch[ Lookup[ data, "Value" ], { ___String }, "Values" ]; + snippets = f @ values; + snippetLen = Length @ snippets; + valuesLen = Length @ values; + + If[ ! MatchQ[ snippets, { ___String } ], throwFailure[ "SnippetFunctionOutputFailure", f, snippets ] ]; + If[ snippetLen =!= valuesLen, throwFailure[ "SnippetFunctionLengthFailure", f, snippetLen, valuesLen ] ]; + + ConfirmBy[ + Association /@ Transpose @ { data, Thread[ "Snippet" -> snippets ] }, + AllTrue @ AssociationQ, + "Result" + ] + ] // LogChatTiming @ { "ApplySnippetFunction", f }, + throwInternalFailure +]; + +applySnippetFunction // endDefinition; + +(* ::**************************************************************************************************************:: *) +(* ::Subsection::Closed:: *) +(*getSnippets*) +getSnippets // beginDefinition; + +getSnippets[ uris: { ___String } ] := Enclose[ + Module[ { data, snippets, strings }, data = ConfirmBy[ getDocumentationSnippetData @ uris, AssociationQ, "Data" ]; - snippets = ConfirmMatch[ Values @ data, { ___Association }, "Snippets" ]; + snippets = ConfirmMatch[ Lookup[ data, uris ], { ___Association }, "Snippets" ]; strings = ConfirmMatch[ Lookup[ "String" ] /@ snippets, { ___String }, "Strings" ]; - strings + ConfirmAssert[ Length @ strings === Length @ uris, "LengthCheck" ]; + "# " <> # & /@ strings ], throwInternalFailure ]; -makeDocSnippets[ uri_String ] := - First @ makeDocSnippets @ { uri }; +getSnippets[ uri_String ] := + First @ getSnippets @ { uri }; -makeDocSnippets // endDefinition; +getSnippets // endDefinition; (* ::**************************************************************************************************************:: *) (* ::Subsection::Closed:: *) diff --git a/Source/Chatbook/PromptGenerators/VectorDatabases.wl b/Source/Chatbook/PromptGenerators/VectorDatabases.wl index 83b6b2c5..23ad958d 100644 --- a/Source/Chatbook/PromptGenerators/VectorDatabases.wl +++ b/Source/Chatbook/PromptGenerators/VectorDatabases.wl @@ -16,10 +16,10 @@ HoldComplete[ (* ::Section::Closed:: *) (*Configuration*) $vectorDatabases = <| - "DataRepositoryURIs" -> <| "Version" -> "1.0.0", "Bias" -> 1.0 |>, - "DocumentationURIs" -> <| "Version" -> "1.3.0", "Bias" -> 0.0 |>, - "FunctionRepositoryURIs" -> <| "Version" -> "1.0.0", "Bias" -> 1.0 |>, - "WolframAlphaQueries" -> <| "Version" -> "1.3.0", "Bias" -> 0.0 |> + "DataRepositoryURIs" -> <| "Version" -> "1.0.0", "Bias" -> 1.0, "SnippetFunction" -> getSnippets |>, + "DocumentationURIs" -> <| "Version" -> "1.3.0", "Bias" -> 0.0, "SnippetFunction" -> getSnippets |>, + "FunctionRepositoryURIs" -> <| "Version" -> "1.0.0", "Bias" -> 1.0, "SnippetFunction" -> getSnippets |>, + "WolframAlphaQueries" -> <| "Version" -> "1.3.0", "Bias" -> 0.0, "SnippetFunction" -> Identity |> |>; $vectorDBNames = Keys @ $vectorDatabases; @@ -65,16 +65,225 @@ $cloudVectorDBDirectory := PacletObject[ "Wolfram/NotebookAssistantCloudResourc (*Argument Patterns*) $$vectorDatabase = HoldPattern[ _VectorDatabaseObject? System`Private`ValidQ ]; -$$dbName = Alternatives @@ $vectorDBNames; +$$dbName = _String? vectorDBNameQ; $$dbNames = { $$dbName... }; $$dbNameOrNames = $$dbName | $$dbNames; +$$vectorDatabaseSource = $$vectorDatabase | _? DirectoryQ | _? FileExistsQ; + +(* ::**************************************************************************************************************:: *) +(* ::Subsubsection::Closed:: *) +(*vectorDBNameQ*) +vectorDBNameQ // beginDefinition; +vectorDBNameQ[ name_String ] := MemberQ[ $vectorDBNames, name ]; +vectorDBNameQ[ ___ ] := False; +vectorDBNameQ // endDefinition; + (* ::**************************************************************************************************************:: *) (* ::Subsection::Closed:: *) (*Cache*) $vectorDBSearchCache = <| |>; $embeddingCache = <| |>; +(* ::**************************************************************************************************************:: *) +(* ::Section::Closed:: *) +(*Messages*) +Chatbook::InvalidVectorDatabaseName = "Expected a string for vector database name instead of `1`"; +Chatbook::InvalidVectorDatabaseValues = "Expected a list of strings for vector database values instead of `1`"; +Chatbook::VectorDatabaseValuesFileNotFound = "Values file not found: `1`"; +Chatbook::InvalidVectorDatabaseDimensions = "Dimensions of vectors (`1`) do not match expected dimensions (`2`)"; +Chatbook::InvalidVectorDatabaseValuesLength = "The number of values (`1`) does not match the number of vectors (`2`)"; + +(* ::**************************************************************************************************************:: *) +(* ::Section::Closed:: *) +(*RegisterVectorDatabase*) +RegisterVectorDatabase // beginDefinition; + +RegisterVectorDatabase[ source: $$vectorDatabaseSource ] := + catchMine @ RegisterVectorDatabase[ source, <| |> ]; + +RegisterVectorDatabase[ source: $$vectorDatabaseSource, as_Association ] := + catchMine @ registerVectorDatabase @ toVectorDatabaseInfo[ source, as ]; + +RegisterVectorDatabase // endExportedDefinition; + +(* ::**************************************************************************************************************:: *) +(* ::Subsection::Closed:: *) +(*toVectorDatabaseInfo*) +toVectorDatabaseInfo // beginDefinition; + +toVectorDatabaseInfo[ source: $$vectorDatabaseSource, info0_Association ] := Enclose[ + Module[ { info, name, values, length, dim, bias, version, db, valueFunction }, + + info = info0; + + name = getVectorDatabaseName[ source, info ]; + If[ ! StringQ @ name, throwFailure[ "InvalidVectorDatabaseName", name ] ]; + info[ "Name" ] = name; + + values = getVectorDatabaseValues[ source, info ]; + If[ ! MatchQ[ values, { ___String } ], throwFailure[ "InvalidVectorDatabaseValues", values ] ]; + info[ "Values" ] = values; + + { length, dim } = ConfirmMatch[ + getVectorDatabaseDimensions[ source, info ], + { _Integer, _Integer }, + "Dimensions" + ]; + If[ dim =!= $embeddingDimension, throwFailure[ "InvalidVectorDatabaseDimensions", dim, $embeddingDimension ] ]; + If[ Length @ values =!= length, throwFailure[ "InvalidVectorDatabaseValuesLength", Length @ values, length ] ]; + info[ "Dimensions" ] = { length, dim }; + + bias = Lookup[ info, "Bias", 0.0 ]; + If[ ! NumberQ @ bias, throwFailure[ "InvalidVectorDatabaseBias", bias ] ]; + info[ "Bias" ] = bias; + + version = Lookup[ info, "Version", "1.0.0" ]; + If[ ! StringQ @ version, version = "1.0.0" ]; + info[ "Version" ] = version; + + db = ConfirmMatch[ getVectorDatabaseObject[ source, info ], $$vectorDatabase, "VectorDatabaseObject" ]; + info[ "VectorDatabaseObject" ] = db; + + valueFunction = Replace[ Lookup[ info, "SnippetFunction" ], $$unspecified -> Identity ]; + info[ "SnippetFunction" ] = valueFunction; + + info + ], + throwInternalFailure +]; + +toVectorDatabaseInfo // endDefinition; + +(* ::**************************************************************************************************************:: *) +(* ::Subsubsection::Closed:: *) +(*getVectorDatabaseName*) +getVectorDatabaseName // beginDefinition; + +getVectorDatabaseName[ db: $$vectorDatabase, info_Association ] := Lookup[ info, "Name", db[ "ID" ] ]; + +getVectorDatabaseName // endDefinition; + +(* ::**************************************************************************************************************:: *) +(* ::Subsubsection::Closed:: *) +(*getVectorDatabaseValues*) +getVectorDatabaseValues // beginDefinition; + +getVectorDatabaseValues[ source_, as: KeyValuePattern[ "Values" -> values_? FileExistsQ ] ] := + getVectorDatabaseValues[ values, as ]; + +getVectorDatabaseValues[ source_, KeyValuePattern[ "Values" -> values: Except[ $$unspecified ] ] ] := + values; + +getVectorDatabaseValues[ dir_? DirectoryQ, info_ ] := Enclose[ + Module[ { name, infoFile, valuesFile }, + + name = ConfirmBy[ info[ "Name" ], StringQ, "Name" ]; + infoFile = FileNameJoin @ { dir, name <> ".wxf" }; + valuesFile = FileNameJoin @ { dir, "Values.wxf" }; + + If[ FileExistsQ @ valuesFile, + getVectorDatabaseValues[ valuesFile, info ], + getVectorDatabaseValues[ infoFile, info ] + ] + ], + throwInternalFailure +]; + +getVectorDatabaseValues[ file_? FileExistsQ, info_ ] := + toVectorDatabaseValues @ Developer`ReadWXFFile @ ExpandFileName @ file; + +getVectorDatabaseValues[ db: $$vectorDatabase, info_ ] := + getVectorDatabaseValues[ DirectoryName @ db[ "Location" ], info ]; + +getVectorDatabaseValues // endDefinition; + +(* ::**************************************************************************************************************:: *) +(* ::Subsubsection::Closed:: *) +(*toVectorDatabaseValues*) +toVectorDatabaseValues // beginDefinition; +toVectorDatabaseValues[ v: { ___String } ] := v; +toVectorDatabaseValues[ KeyValuePattern[ "Metadata" -> KeyValuePattern[ Automatic -> v: { ___String } ] ] ] := v; +toVectorDatabaseValues // endDefinition; + +(* ::**************************************************************************************************************:: *) +(* ::Subsubsection::Closed:: *) +(*getVectorDatabaseDimensions*) +getVectorDatabaseDimensions // beginDefinition; + +getVectorDatabaseDimensions[ db: $$vectorDatabase, info_ ] := + db[ "Dimensions" ]; + +getVectorDatabaseDimensions[ dir_? DirectoryQ, info_Association ] := Enclose[ + Module[ { name, file }, + name = ConfirmBy[ info[ "Name" ], StringQ, "Name" ]; + file = ConfirmBy[ FileNameJoin @ { dir, name <> ".wxf" }, FileExistsQ, "File" ]; + getVectorDatabaseDimensions[ file, info ] + ], + throwInternalFailure +]; + +getVectorDatabaseDimensions // endDefinition; + +(* ::**************************************************************************************************************:: *) +(* ::Subsubsection::Closed:: *) +(*getVectorDatabaseObject*) +getVectorDatabaseObject // beginDefinition; + +getVectorDatabaseObject[ source_, KeyValuePattern[ "VectorDatabaseObject" -> db: $$vectorDatabase ] ] := + db; + +getVectorDatabaseObject[ db: $$vectorDatabase, info_ ] := + db; + +getVectorDatabaseObject[ dir_? DirectoryQ, info_ ] := Enclose[ + Module[ { name, file }, + name = ConfirmBy[ info[ "Name" ], StringQ, "Name" ]; + file = ConfirmBy[ FileNameJoin @ { dir, name <> ".wxf" }, FileExistsQ, "File" ]; + getVectorDatabaseObject[ file, info ] + ], + throwInternalFailure +]; + +getVectorDatabaseObject[ file_? FileExistsQ, info_ ] := + VectorDatabaseObject @ Flatten @ File @ file; + +getVectorDatabaseObject // endDefinition; + +(* ::**************************************************************************************************************:: *) +(* ::Subsection::Closed:: *) +(*registerVectorDatabase*) +registerVectorDatabase // beginDefinition; + +registerVectorDatabase[ info_Association ] := Enclose[ + Module[ { name, version, bias, values, db, valueFunction }, + + ConfirmAssert[ DirectoryQ @ $vectorDBDirectory, "DownloadCheck" ]; + + name = ConfirmBy[ info[ "Name" ], StringQ, "Name" ]; + version = ConfirmBy[ info[ "Version" ], StringQ, "Version" ]; + bias = ConfirmBy[ info[ "Bias" ], NumberQ, "Bias" ]; + values = ConfirmMatch[ info[ "Values" ], { ___String }, "Values" ]; + db = ConfirmMatch[ info[ "VectorDatabaseObject" ], $$vectorDatabase, "VectorDatabaseObject" ]; + valueFunction = ConfirmMatch[ info[ "SnippetFunction" ], Except[ $$unspecified ], "SnippetFunction" ]; + + getVectorDB[ name ] = <| "Values" -> values, "VectorDatabaseObject" -> db |>; + $vectorDatabases[ name ] = <| "Version" -> version, "Bias" -> bias, "SnippetFunction" -> valueFunction |>; + + $vectorDBNames = DeleteDuplicates @ Append[ $vectorDBNames, name ]; + + $defaultSources = DeleteDuplicates @ Append[ + ConfirmMatch[ $defaultSources, { ___String }, "DefaultSources" ], + name + ]; + + Success[ "VectorDatabaseRegistered", KeyTake[ info, { "Name", "Version", "Bias", "VectorDatabaseObject" } ] ] + ], + throwInternalFailure +]; + +registerVectorDatabase // endDefinition; + (* ::**************************************************************************************************************:: *) (* ::Section::Closed:: *) (*InstallVectorDatabases*) @@ -437,7 +646,11 @@ vectorDBSearch[ dbName: $$dbName, prompt_String, All ] := (* Main definition for string prompt: *) vectorDBSearch[ dbName: $$dbName, prompt_String, All ] := Enclose[ - Module[ { vectorDBInfo, vectorDB, allValues, embeddingVector, close, indices, distances, values, data, result }, + Module[ + { + vectorDBInfo, vectorDB, allValues, embeddingVector, close, + indices, distances, values, data, result, snippetFunction + }, vectorDBInfo = ConfirmBy[ getVectorDB @ dbName, AssociationQ, "VectorDBInfo" ]; vectorDB = ConfirmMatch[ vectorDBInfo[ "VectorDatabaseObject" ], $$vectorDatabase, "VectorDatabase" ]; @@ -457,13 +670,20 @@ vectorDBSearch[ dbName: $$dbName, prompt_String, All ] := Enclose[ indices = ConfirmMatch[ close[[ All, "Index" ]], { ___Integer }, "Indices" ]; distances = ConfirmMatch[ close[[ All, "Distance" ]], { ___Real }, "Distances" ]; - values = ConfirmBy[ allValues[[ indices ]], ListQ, "Values" ]; ConfirmAssert[ Length @ indices === Length @ distances === Length @ values, "LengthCheck" ]; + snippetFunction = Confirm[ getSnippetFunction @ dbName, "SnippetFunction" ]; + data = MapApply[ - <| "Value" -> #1, "Index" -> #2, "Distance" -> #3, "Source" -> dbName |> &, + <| + "Value" -> #1, + "Index" -> #2, + "Distance" -> #3, + "Source" -> dbName, + "SnippetFunction" -> snippetFunction + |> &, Transpose @ { values, indices, distances } ]; @@ -611,11 +831,7 @@ vectorDBSearch[ names: $$dbNames, prompt_, prop: "Values"|"Results" ] := Enclose If[ prop === "Results", sorted, - ConfirmMatch[ - DeleteDuplicates @ Lookup[ sorted, "Value" ], - { __String }, - "Values" - ] + ConfirmBy[ DeleteDuplicates @ Lookup[ sorted, "Value" ], ListQ, "Values" ] ] ], throwInternalFailure @@ -633,6 +849,15 @@ vectorDBSearch[ All, prompt_, prop_ ] := vectorDBSearch // endDefinition; +(* ::**************************************************************************************************************:: *) +(* ::Subsection::Closed:: *) +(*getSnippetFunction*) +getSnippetFunction // beginDefinition; +getSnippetFunction[ name_String ] := getSnippetFunction[ name, $vectorDatabases[ name, "SnippetFunction" ] ]; +getSnippetFunction[ name_String, $$unspecified ] := Identity; +getSnippetFunction[ name_String, function_ ] := function; +getSnippetFunction // endDefinition; + (* ::**************************************************************************************************************:: *) (* ::Subsubsection::Closed:: *) (*applyBias*) @@ -640,7 +865,7 @@ applyBias // beginDefinition; applyBias[ name_String, results_ ] := applyBias[ $vectorDatabases[ name, "Bias" ], results ]; applyBias[ None | _Missing | 0 | 0.0, results_ ] := results; applyBias[ bias_, results_List ] := (applyBias[ bias, #1 ] &) /@ results; -applyBias[ bias: $$size, as: KeyValuePattern[ "Distance" -> d: $$size ] ] := <| as, "Distance" -> d + bias |>; +applyBias[ bias_? NumberQ, as: KeyValuePattern[ "Distance" -> d: $$size ] ] := <| as, "Distance" -> d + bias |>; applyBias // endDefinition; (* ::**************************************************************************************************************:: *) diff --git a/Source/Chatbook/Utils.wl b/Source/Chatbook/Utils.wl index 78466bcc..ec00a992 100644 --- a/Source/Chatbook/Utils.wl +++ b/Source/Chatbook/Utils.wl @@ -1028,7 +1028,7 @@ LogChatTiming // Attributes = { HoldFirst, SequenceHold }; LogChatTiming[ tag_String ] := Function[ eval, LogChatTiming[ eval, tag ], HoldAllComplete ]; LogChatTiming[ sym_Symbol ] := LogChatTiming @ Evaluate @ Capitalize @ SymbolName @ sym; -LogChatTiming[ tags_List ] := LogChatTiming @ Evaluate @ StringRiffle[ tags, ":" ]; +LogChatTiming[ tags_List ] := LogChatTiming @ Evaluate @ StringRiffle[ timingTag /@ tags, ":" ]; LogChatTiming[ eval: (h_Symbol)[ ___ ] ] := LogChatTiming[ eval, Capitalize @ SymbolName @ h ]; LogChatTiming[ eval_ ] := LogChatTiming[ eval, "None" ]; @@ -1048,6 +1048,15 @@ LogChatTiming // endExportedDefinition; $timings = Internal`Bag[ ]; $timingLog = Internal`Bag[ ]; +(* ::**************************************************************************************************************:: *) +(* ::Subsection::Closed:: *) +(*timingTag*) +timingTag // beginDefinition; +timingTag[ tag_String ] := tag; +timingTag[ symbol_Symbol ] := Capitalize @ SymbolName @ symbol; +timingTag[ other_ ] := StringDelete[ stringTrimMiddle[ ToString @ other, 32 ], Except[ LetterCharacter ] ]; +timingTag // endDefinition; + (* ::**************************************************************************************************************:: *) (* ::Subsection::Closed:: *) (*logChatTiming*) From 86a86a2cf5509ad30dc9361c25f6f34fa90e4b12 Mon Sep 17 00:00:00 2001 From: Rick Hennigan Date: Thu, 9 Jan 2025 12:43:41 -0500 Subject: [PATCH 2/2] Update tests for `RelatedDocumentation` --- Tests/RelatedDocumentation.wlt | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/Tests/RelatedDocumentation.wlt b/Tests/RelatedDocumentation.wlt index 2ade3ca0..dfb98913 100644 --- a/Tests/RelatedDocumentation.wlt +++ b/Tests/RelatedDocumentation.wlt @@ -28,15 +28,15 @@ VerificationTest[ (* ::Section::Closed:: *) (*RelatedDocumentation*) VerificationTest[ - uris = RelatedDocumentation[ "What's the biggest pokemon?" ], - { __String }, + urls = RelatedDocumentation[ "What's the biggest pokemon?" ], + { URL[ _String ].. }, SameTest -> MatchQ, TestID -> "RelatedDocumentation-URIs@@Tests/RelatedDocumentation.wlt:30,1-35,2" ] VerificationTest[ Length @ Select[ - uris, + First /@ urls, StringStartsQ @ StringExpression[ "paclet:ref/", "interpreter"|"entity"|"textcontent", @@ -65,14 +65,14 @@ VerificationTest[ ] VerificationTest[ - uris = RelatedDocumentation[ "What's the biggest pokemon?", Automatic, 3 ], - { _String, _String, _String }, + urls = RelatedDocumentation[ "What's the biggest pokemon?", Automatic, 3, "Sources" -> { "Documentation" } ], + { URL[ _String ], URL[ _String ], URL[ _String ] }, SameTest -> MatchQ, TestID -> "RelatedDocumentation-URIs-Count@@Tests/RelatedDocumentation.wlt:67,1-72,2" ] VerificationTest[ - AllTrue[ uris, StringStartsQ[ "paclet:ref/" ] ], + AllTrue[ First /@ urls, StringStartsQ[ "paclet:ref/" ] ], True, SameTest -> MatchQ, TestID -> "RelatedDocumentation-URIs-Match@@Tests/RelatedDocumentation.wlt:74,1-79,2"