Skip to content

Commit

Permalink
Merge pull request #435 from WolframResearch/feature/processing-funct…
Browse files Browse the repository at this point in the history
…ions

Custom Processing Functions
  • Loading branch information
rhennigan authored Oct 31, 2023
2 parents 573713e + 3b8abe7 commit 7bb3fc3
Show file tree
Hide file tree
Showing 12 changed files with 989 additions and 534 deletions.
50 changes: 42 additions & 8 deletions Source/Chatbook/Actions.wl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ Needs[ "Wolfram`Chatbook`Prompting`" ];
Needs[ "Wolfram`Chatbook`SendChat`" ];
Needs[ "Wolfram`Chatbook`Serialization`" ];
Needs[ "Wolfram`Chatbook`Services`" ];
Needs[ "Wolfram`Chatbook`Settings`" ];
Needs[ "Wolfram`Chatbook`ToolManager`" ];
Needs[ "Wolfram`Chatbook`Tools`" ];

Expand All @@ -59,6 +60,31 @@ HoldComplete[
System`LLMToolResponse
];

(* ::**************************************************************************************************************:: *)
(* ::Section::Closed:: *)
(*ChatCellEvaluate*)
ChatCellEvaluate // ClearAll;

ChatCellEvaluate[ ] :=
catchMine @ ChatCellEvaluate @ topParentCell @ EvaluationCell[ ];

ChatCellEvaluate[ cell_CellObject ] :=
catchMine @ ChatCellEvaluate[ cell, parentNotebook @ cell ];

ChatCellEvaluate[ cell_CellObject, nbo_NotebookObject ] :=
catchMine @ Block[ { cellPrint = cellPrintAfter @ cell },
Replace[
Reap[ EvaluateChatInput[ cell, nbo ], $chatObjectTag ],
{
{ _, { { chat: HoldPattern[ _ChatObject ] } } } :> chat,
___ :> Null
}
]
];

ChatCellEvaluate[ args___ ] :=
catchMine @ throwFailure[ "InvalidArguments", ChatCellEvaluate, HoldForm @ ChatCellEvaluate @ args ];

(* ::**************************************************************************************************************:: *)
(* ::Section::Closed:: *)
(*ChatbookAction*)
Expand Down Expand Up @@ -324,7 +350,7 @@ rotateTabPage[ cell_CellObject, n_Integer ] := Enclose[
currentPage = ConfirmBy[ pageData[ "CurrentPage" ], IntegerQ, "CurrentPage" ];
newPage = Mod[ currentPage + n, pageCount, 1 ];
encoded = ConfirmMatch[ pageData[ "Pages", newPage ], _String, "EncodedContent" ];
content = ConfirmMatch[ BinaryDeserialize @ BaseDecode @ encoded, TextData[ _String|_List ], "Content" ];
content = ConfirmMatch[ BinaryDeserialize @ BaseDecode @ encoded, TextData[ $$textData ], "Content" ];

writePageContent[ cell, newPage, content ]
],
Expand All @@ -338,13 +364,13 @@ rotateTabPage // endDefinition;
(*writePageContent*)
writePageContent // beginDefinition;

writePageContent[ cell_CellObject, newPage_Integer, content: TextData[ _String | _List ] ] /; $cloudNotebooks := (
writePageContent[ cell_CellObject, newPage_Integer, content: TextData[ $$textData ] ] /; $cloudNotebooks := (
CurrentValue[ cell, { TaggingRules, "PageData", "CurrentPage" } ] = newPage;
CurrentValue[ cell, TaggingRules ] = GeneralUtilities`ToAssociations @ CurrentValue[ cell, TaggingRules ];
NotebookWrite[ cell, ReplacePart[ NotebookRead @ cell, 1 -> content ] ];
)

writePageContent[ cell_CellObject, newPage_Integer, content: TextData[ _String | _List ] ] := (
writePageContent[ cell_CellObject, newPage_Integer, content: TextData[ $$textData ] ] := (
SelectionMove[ cell, All, CellContents, AutoScroll -> False ];
NotebookWrite[ parentNotebook @ cell, content, None, AutoScroll -> False ];
SelectionMove[ cell, After, Cell, AutoScroll -> False ];
Expand Down Expand Up @@ -390,10 +416,11 @@ EvaluateChatInput[ evalCell_CellObject, nbo_NotebookObject, settings_Association
<| "Role" -> "Assistant", "Content" -> $lastChatString |>
]
},
applyHandlerFunction[ settings, "ChatPost", <| "ChatObject" -> chat |> ];
chat
applyHandlerFunction[ settings, "ChatPost", <| "ChatObject" -> chat, "NotebookObject" -> nbo |> ];
Sow[ chat, $chatObjectTag ]
],
applyHandlerFunction[ settings, "ChatPost", <| "ChatObject" -> None |> ];
applyHandlerFunction[ settings, "ChatPost", <| "ChatObject" -> None, "NotebookObject" -> nbo |> ];
Sow[ None, $chatObjectTag ];
Null
];
]
Expand Down Expand Up @@ -478,6 +505,8 @@ waitForLastTask // beginDefinition;

waitForLastTask[ ] := waitForLastTask @ $lastTask;

waitForLastTask[ $Canceled ] := $Canceled;

waitForLastTask[ task_TaskObject ] := (
TaskWait @ task;
runNextTask[ ];
Expand Down Expand Up @@ -546,9 +575,14 @@ autoAssistQ // endDefinition;
(*StopChat*)
StopChat // beginDefinition;

StopChat[ cell_CellObject ] :=
With[ { parent = parentCell @ cell },
StopChat @ parent /; MatchQ[ parent, Except[ cell, _CellObject ] ]
];

StopChat[ cell0_CellObject ] := Enclose[
Module[ { cell, settings, container, content },
cell = ConfirmMatch[ ensureChatOutputCell @ parentCell @ cell0, _CellObject, "ParentCell" ];
cell = ConfirmMatch[ ensureChatOutputCell @ cell0, _CellObject, "ParentCell" ];
settings = ConfirmBy[ currentChatSettings @ cell, AssociationQ, "ChatNotebookSettings" ];
removeTask @ Lookup[ settings, "Task" ];
container = ConfirmBy[ Lookup[ settings, "Container" ], AssociationQ, "Container" ];
Expand Down Expand Up @@ -1222,7 +1256,7 @@ withChatState // Attributes = { HoldFirst };

withChatState[ eval_ ] :=
Block[ { $enableLLMServices },
$handlerArguments = <| |>;
$ChatHandlerData = <| |>;
withToolBox @ withBasePromptBuilder @ eval
];

Expand Down
102 changes: 90 additions & 12 deletions Source/Chatbook/ChatMessages.wl
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ Needs[ "Wolfram`Chatbook`" ];
Needs[ "Wolfram`Chatbook`Actions`" ];
Needs[ "Wolfram`Chatbook`Common`" ];
Needs[ "Wolfram`Chatbook`FrontEnd`" ];
Needs[ "Wolfram`Chatbook`Handlers`" ];
Needs[ "Wolfram`Chatbook`InlineReferences`" ];
Needs[ "Wolfram`Chatbook`Models`" ];
Needs[ "Wolfram`Chatbook`Personas`" ];
Needs[ "Wolfram`Chatbook`Prompting`" ];
Needs[ "Wolfram`Chatbook`Serialization`" ];
Needs[ "Wolfram`Chatbook`Settings`" ];
Needs[ "Wolfram`Chatbook`Tools`" ];

(* ::**************************************************************************************************************:: *)
Expand Down Expand Up @@ -56,16 +58,22 @@ $styleRoles = <|
CellToChatMessage // Options = { "Role" -> Automatic };

CellToChatMessage[ cell_Cell, opts: OptionsPattern[ ] ] :=
CellToChatMessage[ cell, <| "Cells" -> { cell }, "HistoryPosition" -> 0 |>, opts ];
catchMine @ CellToChatMessage[ cell, <| "Cells" -> { cell }, "HistoryPosition" -> 0 |>, opts ];

(* TODO: this should eventually utilize "HistoryPosition" for dynamic compression rates *)
CellToChatMessage[ cell_Cell, settings_Association? AssociationQ, opts: OptionsPattern[ ] ] :=
Block[ { $cellRole = OptionValue[ "Role" ] },
catchMine @ Block[ { $cellRole = OptionValue[ "Role" ] },
Replace[
Flatten @ {
If[ TrueQ @ Positive @ Lookup[ settings, "HistoryPosition", 0 ],
makeCellMessage @ cell,
makeCurrentCellMessage[ settings, Lookup[ settings, "Cells", { cell } ] ]
makeCurrentCellMessage[
settings,
Replace[
Lookup[ settings, "Cells", { cell } ],
{ c___, _Cell } :> { c, cell }
]
]
]
},
{ message_? AssociationQ } :> message
Expand Down Expand Up @@ -93,18 +101,87 @@ constructMessages[ settings_Association? AssociationQ, cells: { __Cell } ] :=
constructMessages[ settings, makeChatMessages[ settings, cells ] ];

constructMessages[ settings_Association? AssociationQ, messages0: { __Association } ] :=
Enclose @ Module[ { messages },
Enclose @ Module[ { prompted, messages, processed },
If[ settings[ "AutoFormat" ], needsBasePrompt[ "Formatting" ] ];
needsBasePrompt @ settings;
messages = messages0 /. s_String :> RuleCondition @ StringReplace[ s, "%%BASE_PROMPT%%" -> $basePrompt ];
prompted = addPrompts[ settings, messages0 ];
messages = prompted /. s_String :> RuleCondition @ StringReplace[ s, "%%BASE_PROMPT%%" -> $basePrompt ];
processed = applyProcessingFunction[ settings, "ChatMessages", HoldComplete[ messages, $ChatHandlerData ] ];

If[ ! MatchQ[ processed, $$validMessageResults ],
messagePrint[ "InvalidMessages", getProcessingFunction[ settings, "ChatMessages" ], processed ];
processed = messages
];
Sow[ <| "Messages" -> processed |>, $chatDataTag ];

$lastSettings = settings;
$lastMessages = messages;
Sow[ <| "Messages" -> messages |>, $chatDataTag ];
messages
$lastMessages = processed;

processed
];

constructMessages // endDefinition;

(* ::**************************************************************************************************************:: *)
(* ::Subsubsection::Closed:: *)
(*addPrompts*)
addPrompts // beginDefinition;

addPrompts[ settings_Association, messages_List ] :=
addPrompts[ assembleCustomPrompt @ settings, messages ];

addPrompts[ None, messages_List ] :=
messages;

addPrompts[ prompt_String, { sysMessage: KeyValuePattern[ "Role" -> "System" ], messages___ } ] := Enclose[
Module[ { systemPrompt, newPrompt, newMessage },
systemPrompt = ConfirmBy[ Lookup[ sysMessage, "Content" ] , StringQ , "SystemPrompt" ];
newPrompt = ConfirmBy[ StringJoin[ systemPrompt, "\n\n", prompt ] , StringQ , "NewPrompt" ];
newMessage = ConfirmBy[ Append[ sysMessage, "Content" -> newPrompt ], AssociationQ, "NewMessage" ];
{ newMessage, messages }
]
];

addPrompts[ prompt_String, { messages___ } ] := {
<| "Role" -> "System", "Content" -> prompt |>,
messages
};

addPrompts // endDefinition;

(* ::**************************************************************************************************************:: *)
(* ::Subsubsection::Closed:: *)
(*assembleCustomPrompt*)
assembleCustomPrompt // beginDefinition;
assembleCustomPrompt[ settings_Association ] := assembleCustomPrompt[ settings, Lookup[ settings, "Prompts" ] ];
assembleCustomPrompt[ settings_, $$unspecified ] := None;
assembleCustomPrompt[ settings_, prompt_String ] := prompt;
assembleCustomPrompt[ settings_, prompts: { ___String } ] := StringRiffle[ prompts, "\n\n" ];

assembleCustomPrompt[ settings_? AssociationQ, templated: { ___, _TemplateObject, ___ } ] := Enclose[
Module[ { params, prompts },
params = ConfirmBy[ Association[ settings, $ChatHandlerData ], AssociationQ, "Params" ];
prompts = Replace[ templated, t_TemplateObject :> applyPromptTemplate[ t, params ], { 1 } ];
assembleCustomPrompt[ settings, prompts ] /; MatchQ[ prompts, { ___String } ]
],
throwInternalFailure[ assembleCustomPrompt[ settings, templated ], ## ] &
];

assembleCustomPrompt // endDefinition;

(* ::**************************************************************************************************************:: *)
(* ::Subsubsection::Closed:: *)
(*applyPromptTemplate*)
applyPromptTemplate // beginDefinition;

applyPromptTemplate[ template_TemplateObject, params_Association ] :=
If[ FreeQ[ template, TemplateSlot[ _String, ___ ] ],
TemplateApply @ template,
TemplateApply[ template, params ]
];

applyPromptTemplate // endDefinition;

(* ::**************************************************************************************************************:: *)
(* ::Subsection::Closed:: *)
(*makeChatMessages*)
Expand Down Expand Up @@ -145,16 +222,17 @@ makeChatMessages // endDefinition;
(* ::Subsection::Closed:: *)
(*getCellMessageFunction*)
getCellMessageFunction // beginDefinition;
getCellMessageFunction[ as_? AssociationQ ] := getCellMessageFunction[ as, as[ "CellToMessageFunction" ] ];
getCellMessageFunction[ as_, _Missing|Automatic|Inherited ] := CellToChatMessage;
getCellMessageFunction[ as_, toMessage_ ] := checkedMessageFunction @ replaceCellContext @ toMessage;
getCellMessageFunction[ as_ ] := checkedMessageFunction @ getProcessingFunction[ as, "CellToChatMessage" ];
getCellMessageFunction // endDefinition;

(* ::**************************************************************************************************************:: *)
(* ::Subsubsection::Closed:: *)
(*checkedMessageFunction*)
checkedMessageFunction // beginDefinition;

checkedMessageFunction[ CellToChatMessage ] :=
CellToChatMessage;

checkedMessageFunction[ func_ ] :=
checkedMessageFunction[ func, { ## } ] &;

Expand All @@ -163,7 +241,7 @@ checkedMessageFunction[ func_, { cell_, settings_ } ] :=
func[ cell, settings ],
{
message_String? StringQ :> <| "Role" -> cellRole @ cell, "Content" -> message |>,
Except[ $$validMessageResults ] :> CellToChatMessage[ cell, settings ]
Except[ $$validMessageResults ] :> CellToChatMessage[ cell, settings ] (* TODO: issue message here? *)
}
];

Expand Down
10 changes: 8 additions & 2 deletions Source/Chatbook/Common.wl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ BeginPackage[ "Wolfram`Chatbook`Common`" ];
`$$excludeHistoryStyle;
`$$nestedCellStyle;

`$$textData;
`$$textDataList;
`$$unspecified;

Expand Down Expand Up @@ -100,7 +101,9 @@ $$chatOutputStyle = cellStylePattern @ $chatOutputStyles;
$$excludeHistoryStyle = cellStylePattern @ $excludeHistoryStyles;
$$nestedCellStyle = cellStylePattern @ $nestedCellStyles;

$$textDataList = { (_String|_Cell|_StyleBox|_ButtonBox)... };
$$textDataItem = (_String|_Cell|_StyleBox|_ButtonBox);
$$textDataList = { $$textDataItem... };
$$textData = $$textDataItem | $$textDataList;
$$unspecified = _Missing | Automatic | Inherited;

(* ::**************************************************************************************************************:: *)
Expand All @@ -116,9 +119,12 @@ KeyValueMap[ Function[ MessageName[ Chatbook, #1 ] = #2 ], <|
"Internal" -> "An unexpected error occurred. `1`",
"InvalidAPIKey" -> "Invalid value for API key: `1`",
"InvalidArguments" -> "Invalid arguments given for `1` in `2`.",
"InvalidFunctions" -> "Invalid setting for ProcessingFunctions: `1`; using defaults instead.",
"InvalidHandlerKeys" -> "Invalid setting for HandlerFunctionsKeys: `1`; using defaults instead.",
"InvalidHandlers" -> "Invalid setting for HandlerFunctions: `1`; using defaults instead.",
"InvalidHandlerArguments" -> "Invalid value for $ChatHandlerData: `1`; resetting to default value.",
"InvalidResourceSpecification" -> "The argument `1` is not a valid resource specification.",
"InvalidMessages" -> "The value `2` returned by `1` is not a valid list of messages.",
"InvalidResourceURL" -> "The specified URL does not represent a valid resource object.",
"InvalidStreamingOutputMethod" -> "Invalid streaming output method: `1`.",
"InvalidWriteMethod" -> "Invalid setting for NotebookWriteMethod: `1`; using default instead.",
Expand Down Expand Up @@ -790,7 +796,7 @@ sufficientVersionQ // endDefinition;
insufficientVersionQ // beginDefinition;
insufficientVersionQ[ version_? NumberQ ] := TrueQ[ $VersionNumber < version ];
insufficientVersionQ[ id_String ] := insufficientVersionQ[ id ] = insufficientVersionQ[ id, $versionRequirements[ id ] ];
insufficientVersionQ[ id_, version_? NumberQ ] := sufficientVersionQ @ version;
insufficientVersionQ[ id_, version_? NumberQ ] := insufficientVersionQ @ version;
insufficientVersionQ // endDefinition;

(* ::**************************************************************************************************************:: *)
Expand Down
Loading

0 comments on commit 7bb3fc3

Please sign in to comment.