Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔀 Add system parameter to ERNIEBotChatCompletion #76

Merged
merged 1 commit into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions samples/SK-ERNIE-Bot.Sample/Controllers/ApiController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,25 @@ public async Task<IActionResult> SemanticPlugin([FromBody] UserInput input)
var result = await _kernel.RunAsync(input.Text, translateFunc);
return Ok(result.GetValue<string>());
}

[HttpPost("chat_with_system")]
public async Task<IActionResult> ChatWithSystemAsync([FromBody] UserInput input, CancellationToken cancellationToken)
{
if (string.IsNullOrWhiteSpace(input.Text))
{
return NoContent();
}

var chat = _kernel.GetService<IChatCompletion>();

var history = chat.CreateNewChat($"你是一个友善的AI助手。你的名字叫做Alice,今天是{DateTime.Today}.");

history.AddUserMessage(input.Text);

var result = await chat.GetChatCompletionsAsync(history, null, cancellationToken);

var text = await result.First().GetChatMessageAsync();
return Ok(text.Content);
}
}
}
3 changes: 3 additions & 0 deletions src/ERNIE-Bot.SDK/Models/ChatRequest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ namespace ERNIE_Bot.SDK.Models
{
public class ChatRequest
{
[JsonPropertyName("system")]
public string? System { get; set; }

[JsonPropertyName("messages")]
public List<Message> Messages { get; set; } = new List<Message>();

Expand Down
37 changes: 27 additions & 10 deletions src/ERNIE-Bot.SemanticKernel/ERNIEBotChatCompletion.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ public ChatHistory CreateNewChat(string? instructions = null)

if (instructions != null)
{
history.AddAssistantMessage(instructions);
history.AddSystemMessage(instructions);
}

return history;
}

public async Task<IReadOnlyList<IChatResult>> GetChatCompletionsAsync(ChatHistory chat, AIRequestSettings? requestSettings = null, CancellationToken cancellationToken = default)
{
var messages = ChatHistoryToMessages(chat);
var messages = ChatHistoryToMessages(chat, out var system);
requestSettings ??= new AIRequestSettings();

var settings = ERNIEBotAIRequestSettings.FromRequestSettings(requestSettings);
Expand All @@ -45,6 +45,7 @@ public async Task<IReadOnlyList<IChatResult>> GetChatCompletionsAsync(ChatHistor
settings.Temperature,
settings.TopP,
settings.PenaltyScore,
system,
cancellationToken
);
return new List<IChatResult>() { new ERNIEBotChatResult(result) };
Expand All @@ -61,6 +62,7 @@ public async Task<IReadOnlyList<ITextResult>> GetCompletionsAsync(string text, A
settings.Temperature,
settings.TopP,
settings.PenaltyScore,
null,
cancellationToken
);

Expand All @@ -69,7 +71,8 @@ public async Task<IReadOnlyList<ITextResult>> GetCompletionsAsync(string text, A

public async IAsyncEnumerable<IChatStreamingResult> GetStreamingChatCompletionsAsync(ChatHistory chat, AIRequestSettings? requestSettings = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var messages = ChatHistoryToMessages(chat);
var messages = ChatHistoryToMessages(chat, out var system);

requestSettings ??= new AIRequestSettings();

var settings = ERNIEBotAIRequestSettings.FromRequestSettings(requestSettings);
Expand All @@ -78,6 +81,7 @@ public async IAsyncEnumerable<IChatStreamingResult> GetStreamingChatCompletionsA
settings.Temperature,
settings.TopP,
settings.PenaltyScore,
system,
cancellationToken
);

Expand All @@ -96,6 +100,7 @@ public async IAsyncEnumerable<ITextStreamingResult> GetStreamingCompletionsAsync
settings.Temperature,
settings.TopP,
settings.PenaltyScore,
null,
cancellationToken
);

Expand All @@ -115,13 +120,23 @@ private List<Message> StringToMessages(string text)
};
}

private List<Message> ChatHistoryToMessages(ChatHistory chatHistory)
private List<Message> ChatHistoryToMessages(ChatHistory chatHistory, out string? system)
{
return chatHistory.Select(m => new Message()
if (chatHistory.First().Role == AuthorRole.System)
{
system = chatHistory.First().Content;
}
else
{
Role = AuthorRoleToMessageRole(m.Role),
Content = m.Content
}).ToList();
system = null;
}
return chatHistory
.Where(_ => _.Role != AuthorRole.System)
.Select(m => new Message()
{
Role = AuthorRoleToMessageRole(m.Role),
Content = m.Content
}).ToList();
}

private string AuthorRoleToMessageRole(AuthorRole role)
Expand All @@ -131,7 +146,7 @@ private string AuthorRoleToMessageRole(AuthorRole role)
return MessageRole.User;
}

protected virtual async Task<ChatResponse> InternalCompletionsAsync(List<Message> messages, float? temperature, float? topP, float? penaltyScore, CancellationToken cancellationToken)
protected virtual async Task<ChatResponse> InternalCompletionsAsync(List<Message> messages, float? temperature, float? topP, float? penaltyScore, string? system, CancellationToken cancellationToken)
{
try
{
Expand All @@ -141,6 +156,7 @@ protected virtual async Task<ChatResponse> InternalCompletionsAsync(List<Message
Temperature = temperature,
TopP = topP,
PenaltyScore = penaltyScore,
System = system,
}, _modelEndpoint, cancellationToken);
}
catch (ERNIEBotException ex)
Expand All @@ -149,7 +165,7 @@ protected virtual async Task<ChatResponse> InternalCompletionsAsync(List<Message
}
}

protected virtual IAsyncEnumerable<ChatResponse> InternalCompletionsStreamAsync(List<Message> messages, float? temperature, float? topP, float? penaltyScore, CancellationToken cancellationToken)
protected virtual IAsyncEnumerable<ChatResponse> InternalCompletionsStreamAsync(List<Message> messages, float? temperature, float? topP, float? penaltyScore, string? system, CancellationToken cancellationToken)
{
try
{
Expand All @@ -159,6 +175,7 @@ protected virtual IAsyncEnumerable<ChatResponse> InternalCompletionsStreamAsync(
Temperature = temperature,
TopP = topP,
PenaltyScore = penaltyScore,
System = system,
}, _modelEndpoint, cancellationToken);
}
catch (ERNIEBotException ex)
Expand Down