Skip to content

Commit

Permalink
πŸ”€ Add system parameter to ERNIEBotChatCompletion (#76)
Browse files Browse the repository at this point in the history
Add a new system parameter to the ERNIEBotChatCompletion class in order to enhance the chat completion functionality.
  • Loading branch information
xbotter authored Nov 28, 2023
1 parent 17070d3 commit 2342c77
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 10 deletions.
20 changes: 20 additions & 0 deletions samples/SK-ERNIE-Bot.Sample/Controllers/ApiController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,25 @@ public async Task<IActionResult> SemanticPlugin([FromBody] UserInput input)
var result = await _kernel.RunAsync(input.Text, translateFunc);
return Ok(result.GetValue<string>());
}

[HttpPost("chat_with_system")]
public async Task<IActionResult> ChatWithSystemAsync([FromBody] UserInput input, CancellationToken cancellationToken)
{
if (string.IsNullOrWhiteSpace(input.Text))
{
return NoContent();
}

var chat = _kernel.GetService<IChatCompletion>();

var history = chat.CreateNewChat($"δ½ ζ˜―δΈ€δΈͺε‹ε–„ηš„AIεŠ©ζ‰‹γ€‚δ½ ηš„εε­—ε«εšAlice,今倩是{DateTime.Today}.");

history.AddUserMessage(input.Text);

var result = await chat.GetChatCompletionsAsync(history, null, cancellationToken);

var text = await result.First().GetChatMessageAsync();
return Ok(text.Content);
}
}
}
3 changes: 3 additions & 0 deletions src/ERNIE-Bot.SDK/Models/ChatRequest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ namespace ERNIE_Bot.SDK.Models
{
public class ChatRequest
{
[JsonPropertyName("system")]
public string? System { get; set; }

[JsonPropertyName("messages")]
public List<Message> Messages { get; set; } = new List<Message>();

Expand Down
37 changes: 27 additions & 10 deletions src/ERNIE-Bot.SemanticKernel/ERNIEBotChatCompletion.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ public ChatHistory CreateNewChat(string? instructions = null)

if (instructions != null)
{
history.AddAssistantMessage(instructions);
history.AddSystemMessage(instructions);
}

return history;
}

public async Task<IReadOnlyList<IChatResult>> GetChatCompletionsAsync(ChatHistory chat, AIRequestSettings? requestSettings = null, CancellationToken cancellationToken = default)
{
var messages = ChatHistoryToMessages(chat);
var messages = ChatHistoryToMessages(chat, out var system);
requestSettings ??= new AIRequestSettings();

var settings = ERNIEBotAIRequestSettings.FromRequestSettings(requestSettings);
Expand All @@ -45,6 +45,7 @@ public async Task<IReadOnlyList<IChatResult>> GetChatCompletionsAsync(ChatHistor
settings.Temperature,
settings.TopP,
settings.PenaltyScore,
system,
cancellationToken
);
return new List<IChatResult>() { new ERNIEBotChatResult(result) };
Expand All @@ -61,6 +62,7 @@ public async Task<IReadOnlyList<ITextResult>> GetCompletionsAsync(string text, A
settings.Temperature,
settings.TopP,
settings.PenaltyScore,
null,
cancellationToken
);

Expand All @@ -69,7 +71,8 @@ public async Task<IReadOnlyList<ITextResult>> GetCompletionsAsync(string text, A

public async IAsyncEnumerable<IChatStreamingResult> GetStreamingChatCompletionsAsync(ChatHistory chat, AIRequestSettings? requestSettings = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var messages = ChatHistoryToMessages(chat);
var messages = ChatHistoryToMessages(chat, out var system);

requestSettings ??= new AIRequestSettings();

var settings = ERNIEBotAIRequestSettings.FromRequestSettings(requestSettings);
Expand All @@ -78,6 +81,7 @@ public async IAsyncEnumerable<IChatStreamingResult> GetStreamingChatCompletionsA
settings.Temperature,
settings.TopP,
settings.PenaltyScore,
system,
cancellationToken
);

Expand All @@ -96,6 +100,7 @@ public async IAsyncEnumerable<ITextStreamingResult> GetStreamingCompletionsAsync
settings.Temperature,
settings.TopP,
settings.PenaltyScore,
null,
cancellationToken
);

Expand All @@ -115,13 +120,23 @@ private List<Message> StringToMessages(string text)
};
}

private List<Message> ChatHistoryToMessages(ChatHistory chatHistory)
private List<Message> ChatHistoryToMessages(ChatHistory chatHistory, out string? system)
{
return chatHistory.Select(m => new Message()
if (chatHistory.First().Role == AuthorRole.System)
{
system = chatHistory.First().Content;
}
else
{
Role = AuthorRoleToMessageRole(m.Role),
Content = m.Content
}).ToList();
system = null;
}
return chatHistory
.Where(_ => _.Role != AuthorRole.System)
.Select(m => new Message()
{
Role = AuthorRoleToMessageRole(m.Role),
Content = m.Content
}).ToList();
}

private string AuthorRoleToMessageRole(AuthorRole role)
Expand All @@ -131,7 +146,7 @@ private string AuthorRoleToMessageRole(AuthorRole role)
return MessageRole.User;
}

protected virtual async Task<ChatResponse> InternalCompletionsAsync(List<Message> messages, float? temperature, float? topP, float? penaltyScore, CancellationToken cancellationToken)
protected virtual async Task<ChatResponse> InternalCompletionsAsync(List<Message> messages, float? temperature, float? topP, float? penaltyScore, string? system, CancellationToken cancellationToken)
{
try
{
Expand All @@ -141,6 +156,7 @@ protected virtual async Task<ChatResponse> InternalCompletionsAsync(List<Message
Temperature = temperature,
TopP = topP,
PenaltyScore = penaltyScore,
System = system,
}, _modelEndpoint, cancellationToken);
}
catch (ERNIEBotException ex)
Expand All @@ -149,7 +165,7 @@ protected virtual async Task<ChatResponse> InternalCompletionsAsync(List<Message
}
}

protected virtual IAsyncEnumerable<ChatResponse> InternalCompletionsStreamAsync(List<Message> messages, float? temperature, float? topP, float? penaltyScore, CancellationToken cancellationToken)
protected virtual IAsyncEnumerable<ChatResponse> InternalCompletionsStreamAsync(List<Message> messages, float? temperature, float? topP, float? penaltyScore, string? system, CancellationToken cancellationToken)
{
try
{
Expand All @@ -159,6 +175,7 @@ protected virtual IAsyncEnumerable<ChatResponse> InternalCompletionsStreamAsync(
Temperature = temperature,
TopP = topP,
PenaltyScore = penaltyScore,
System = system,
}, _modelEndpoint, cancellationToken);
}
catch (ERNIEBotException ex)
Expand Down

0 comments on commit 2342c77

Please sign in to comment.