From 781311babab66eccc506ce9bbbc2a28d3463d726 Mon Sep 17 00:00:00 2001 From: xbotter Date: Tue, 28 Nov 2023 09:34:20 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=80=20Add=20system=20parameter=20to=20?= =?UTF-8?q?ERNIEBotChatCompletion?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a new system parameter to the ERNIEBotChatCompletion class in order to enhance the chat completion functionality. --- .../Controllers/ApiController.cs | 20 ++++++++++ src/ERNIE-Bot.SDK/Models/ChatRequest.cs | 3 ++ .../ERNIEBotChatCompletion.cs | 37 ++++++++++++++----- 3 files changed, 50 insertions(+), 10 deletions(-) diff --git a/samples/SK-ERNIE-Bot.Sample/Controllers/ApiController.cs b/samples/SK-ERNIE-Bot.Sample/Controllers/ApiController.cs index 1166e3a..6a34126 100644 --- a/samples/SK-ERNIE-Bot.Sample/Controllers/ApiController.cs +++ b/samples/SK-ERNIE-Bot.Sample/Controllers/ApiController.cs @@ -114,5 +114,25 @@ public async Task SemanticPlugin([FromBody] UserInput input) var result = await _kernel.RunAsync(input.Text, translateFunc); return Ok(result.GetValue()); } + + [HttpPost("chat_with_system")] + public async Task ChatWithSystemAsync([FromBody] UserInput input, CancellationToken cancellationToken) + { + if (string.IsNullOrWhiteSpace(input.Text)) + { + return NoContent(); + } + + var chat = _kernel.GetService(); + + var history = chat.CreateNewChat($"你是一个友善的AI助手。你的名字叫做Alice,今天是{DateTime.Today}."); + + history.AddUserMessage(input.Text); + + var result = await chat.GetChatCompletionsAsync(history, null, cancellationToken); + + var text = await result.First().GetChatMessageAsync(); + return Ok(text.Content); + } } } diff --git a/src/ERNIE-Bot.SDK/Models/ChatRequest.cs b/src/ERNIE-Bot.SDK/Models/ChatRequest.cs index ffa4229..829dc50 100644 --- a/src/ERNIE-Bot.SDK/Models/ChatRequest.cs +++ b/src/ERNIE-Bot.SDK/Models/ChatRequest.cs @@ -4,6 +4,9 @@ namespace ERNIE_Bot.SDK.Models { public class ChatRequest { + [JsonPropertyName("system")] + public string? System { get; set; } + [JsonPropertyName("messages")] public List Messages { get; set; } = new List(); diff --git a/src/ERNIE-Bot.SemanticKernel/ERNIEBotChatCompletion.cs b/src/ERNIE-Bot.SemanticKernel/ERNIEBotChatCompletion.cs index 57db27d..662e409 100644 --- a/src/ERNIE-Bot.SemanticKernel/ERNIEBotChatCompletion.cs +++ b/src/ERNIE-Bot.SemanticKernel/ERNIEBotChatCompletion.cs @@ -28,7 +28,7 @@ public ChatHistory CreateNewChat(string? instructions = null) if (instructions != null) { - history.AddAssistantMessage(instructions); + history.AddSystemMessage(instructions); } return history; @@ -36,7 +36,7 @@ public ChatHistory CreateNewChat(string? instructions = null) public async Task> GetChatCompletionsAsync(ChatHistory chat, AIRequestSettings? requestSettings = null, CancellationToken cancellationToken = default) { - var messages = ChatHistoryToMessages(chat); + var messages = ChatHistoryToMessages(chat, out var system); requestSettings ??= new AIRequestSettings(); var settings = ERNIEBotAIRequestSettings.FromRequestSettings(requestSettings); @@ -45,6 +45,7 @@ public async Task> GetChatCompletionsAsync(ChatHistor settings.Temperature, settings.TopP, settings.PenaltyScore, + system, cancellationToken ); return new List() { new ERNIEBotChatResult(result) }; @@ -61,6 +62,7 @@ public async Task> GetCompletionsAsync(string text, A settings.Temperature, settings.TopP, settings.PenaltyScore, + null, cancellationToken ); @@ -69,7 +71,8 @@ public async Task> GetCompletionsAsync(string text, A public async IAsyncEnumerable GetStreamingChatCompletionsAsync(ChatHistory chat, AIRequestSettings? requestSettings = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - var messages = ChatHistoryToMessages(chat); + var messages = ChatHistoryToMessages(chat, out var system); + requestSettings ??= new AIRequestSettings(); var settings = ERNIEBotAIRequestSettings.FromRequestSettings(requestSettings); @@ -78,6 +81,7 @@ public async IAsyncEnumerable GetStreamingChatCompletionsA settings.Temperature, settings.TopP, settings.PenaltyScore, + system, cancellationToken ); @@ -96,6 +100,7 @@ public async IAsyncEnumerable GetStreamingCompletionsAsync settings.Temperature, settings.TopP, settings.PenaltyScore, + null, cancellationToken ); @@ -115,13 +120,23 @@ private List StringToMessages(string text) }; } - private List ChatHistoryToMessages(ChatHistory chatHistory) + private List ChatHistoryToMessages(ChatHistory chatHistory, out string? system) { - return chatHistory.Select(m => new Message() + if (chatHistory.First().Role == AuthorRole.System) + { + system = chatHistory.First().Content; + } + else { - Role = AuthorRoleToMessageRole(m.Role), - Content = m.Content - }).ToList(); + system = null; + } + return chatHistory + .Where(_ => _.Role != AuthorRole.System) + .Select(m => new Message() + { + Role = AuthorRoleToMessageRole(m.Role), + Content = m.Content + }).ToList(); } private string AuthorRoleToMessageRole(AuthorRole role) @@ -131,7 +146,7 @@ private string AuthorRoleToMessageRole(AuthorRole role) return MessageRole.User; } - protected virtual async Task InternalCompletionsAsync(List messages, float? temperature, float? topP, float? penaltyScore, CancellationToken cancellationToken) + protected virtual async Task InternalCompletionsAsync(List messages, float? temperature, float? topP, float? penaltyScore, string? system, CancellationToken cancellationToken) { try { @@ -141,6 +156,7 @@ protected virtual async Task InternalCompletionsAsync(List InternalCompletionsAsync(List InternalCompletionsStreamAsync(List messages, float? temperature, float? topP, float? penaltyScore, CancellationToken cancellationToken) + protected virtual IAsyncEnumerable InternalCompletionsStreamAsync(List messages, float? temperature, float? topP, float? penaltyScore, string? system, CancellationToken cancellationToken) { try { @@ -159,6 +175,7 @@ protected virtual IAsyncEnumerable InternalCompletionsStreamAsync( Temperature = temperature, TopP = topP, PenaltyScore = penaltyScore, + System = system, }, _modelEndpoint, cancellationToken); } catch (ERNIEBotException ex)