Skip to content

Commit

Permalink
Implement serialization of assistant conversations, including tool ca…
Browse files Browse the repository at this point in the history
…lls and attachments (#11577)

Release Notes:

- N/A

---------

Co-authored-by: Kyle <[email protected]>
Co-authored-by: Marshall <[email protected]>
  • Loading branch information
3 people authored May 8, 2024
1 parent 24ffa0f commit a7aa257
Show file tree
Hide file tree
Showing 12 changed files with 584 additions and 252 deletions.
149 changes: 82 additions & 67 deletions crates/assistant2/src/assistant2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,14 @@ mod saved_conversation_picker;
mod tools;
pub mod ui;

use crate::saved_conversation::{SavedConversation, SavedMessage, SavedMessageRole};
use crate::saved_conversation_picker::SavedConversationPicker;
use crate::{
attachments::ActiveEditorAttachmentTool,
tools::{CreateBufferTool, ProjectIndexTool},
ui::UserOrAssistant,
};
use crate::ui::UserOrAssistant;
use ::ui::{div, prelude::*, Color, Tooltip, ViewContext};
use anyhow::{Context, Result};
use assistant_tooling::{
tool_running_placeholder, AttachmentRegistry, ProjectContext, ToolFunctionCall, ToolRegistry,
UserAttachment,
};
use attachments::ActiveEditorAttachmentTool;
use client::{proto, Client, UserStore};
use collections::HashMap;
use completion_provider::*;
Expand All @@ -33,11 +28,13 @@ use gpui::{
use language::{language_settings::SoftWrap, LanguageRegistry};
use open_ai::{FunctionContent, ToolCall, ToolCallContent};
use rich_text::RichText;
use saved_conversation::{SavedAssistantMessagePart, SavedChatMessage, SavedConversation};
use saved_conversation_picker::SavedConversationPicker;
use semantic_index::{CloudEmbeddingProvider, ProjectIndex, ProjectIndexDebugView, SemanticIndex};
use serde::{Deserialize, Serialize};
use settings::Settings;
use std::sync::Arc;
use tools::AnnotationTool;
use tools::{AnnotationTool, CreateBufferTool, ProjectIndexTool};
use ui::{ActiveFileButton, Composer, ProjectIndexButton};
use util::paths::CONVERSATIONS_DIR;
use util::{maybe, paths::EMBEDDINGS_DIR, ResultExt};
Expand Down Expand Up @@ -506,13 +503,11 @@ impl AssistantChat {
while let Some(delta) = stream.next().await {
let delta = delta?;
this.update(cx, |this, cx| {
if let Some(ChatMessage::Assistant(GroupedAssistantMessage {
messages,
..
})) = this.messages.last_mut()
if let Some(ChatMessage::Assistant(AssistantMessage { messages, .. })) =
this.messages.last_mut()
{
if messages.is_empty() {
messages.push(AssistantMessage {
messages.push(AssistantMessagePart {
body: RichText::default(),
tool_calls: Vec::new(),
})
Expand Down Expand Up @@ -563,7 +558,7 @@ impl AssistantChat {

let mut tool_tasks = Vec::new();
this.update(cx, |this, cx| {
if let Some(ChatMessage::Assistant(GroupedAssistantMessage {
if let Some(ChatMessage::Assistant(AssistantMessage {
error: message_error,
messages,
..
Expand Down Expand Up @@ -592,7 +587,7 @@ impl AssistantChat {
let tools = tools.into_iter().filter_map(|tool| tool.ok()).collect();

this.update(cx, |this, cx| {
if let Some(ChatMessage::Assistant(GroupedAssistantMessage { messages, .. })) =
if let Some(ChatMessage::Assistant(AssistantMessage { messages, .. })) =
this.messages.last_mut()
{
if let Some(current_message) = messages.last_mut() {
Expand All @@ -608,19 +603,19 @@ impl AssistantChat {

fn push_new_assistant_message(&mut self, cx: &mut ViewContext<Self>) {
// If the last message is a grouped assistant message, add to the grouped message
if let Some(ChatMessage::Assistant(GroupedAssistantMessage { messages, .. })) =
if let Some(ChatMessage::Assistant(AssistantMessage { messages, .. })) =
self.messages.last_mut()
{
messages.push(AssistantMessage {
messages.push(AssistantMessagePart {
body: RichText::default(),
tool_calls: Vec::new(),
});
return;
}

let message = ChatMessage::Assistant(GroupedAssistantMessage {
let message = ChatMessage::Assistant(AssistantMessage {
id: self.next_message_id.post_inc(),
messages: vec![AssistantMessage {
messages: vec![AssistantMessagePart {
body: RichText::default(),
tool_calls: Vec::new(),
}],
Expand Down Expand Up @@ -669,40 +664,30 @@ impl AssistantChat {
*entry = !*entry;
}

fn new_conversation(&mut self, cx: &mut ViewContext<Self>) {
let messages = self
.messages
.drain(..)
.map(|message| {
let text = match &message {
ChatMessage::User(message) => message.body.read(cx).text(cx),
ChatMessage::Assistant(message) => message
.messages
.iter()
.map(|message| message.body.text.to_string())
.collect::<Vec<_>>()
.join("\n\n"),
};

SavedMessage {
id: message.id(),
role: match message {
ChatMessage::User(_) => SavedMessageRole::User,
ChatMessage::Assistant(_) => SavedMessageRole::Assistant,
},
text,
}
})
.collect::<Vec<_>>();

// Reset the chat for the new conversation.
fn reset(&mut self) {
self.messages.clear();
self.list_state.reset(0);
self.editing_message.take();
self.collapsed_messages.clear();
}

fn new_conversation(&mut self, cx: &mut ViewContext<Self>) {
let messages = std::mem::take(&mut self.messages)
.into_iter()
.map(|message| self.serialize_message(message, cx))
.collect::<Vec<_>>();

self.reset();

let title = messages
.first()
.map(|message| message.text.clone())
.map(|message| match message {
SavedChatMessage::User { body, .. } => body.clone(),
SavedChatMessage::Assistant { messages, .. } => messages
.first()
.map(|message| message.body.to_string())
.unwrap_or_default(),
})
.unwrap_or_else(|| "A conversation with the assistant.".to_string());

let saved_conversation = SavedConversation {
Expand Down Expand Up @@ -836,7 +821,7 @@ impl AssistantChat {
}
})
.into_any(),
ChatMessage::Assistant(GroupedAssistantMessage {
ChatMessage::Assistant(AssistantMessage {
id,
messages,
error,
Expand Down Expand Up @@ -917,7 +902,7 @@ impl AssistantChat {
content: body.read(cx).text(cx),
});
}
ChatMessage::Assistant(GroupedAssistantMessage { messages, .. }) => {
ChatMessage::Assistant(AssistantMessage { messages, .. }) => {
for message in messages {
let body = message.body.clone();

Expand Down Expand Up @@ -971,6 +956,43 @@ impl AssistantChat {
Ok(completion_messages)
})
}

fn serialize_message(
&self,
message: ChatMessage,
cx: &mut ViewContext<AssistantChat>,
) -> SavedChatMessage {
match message {
ChatMessage::User(message) => SavedChatMessage::User {
id: message.id,
body: message.body.read(cx).text(cx),
attachments: message
.attachments
.iter()
.map(|attachment| {
self.attachment_registry
.serialize_user_attachment(attachment)
})
.collect(),
},
ChatMessage::Assistant(message) => SavedChatMessage::Assistant {
id: message.id,
error: message.error,
messages: message
.messages
.iter()
.map(|message| SavedAssistantMessagePart {
body: message.body.text.clone(),
tool_calls: message
.tool_calls
.iter()
.map(|tool_call| self.tool_registry.serialize_tool_call(tool_call))
.collect(),
})
.collect(),
},
}
}
}

impl Render for AssistantChat {
Expand Down Expand Up @@ -1053,17 +1075,10 @@ impl MessageId {

enum ChatMessage {
User(UserMessage),
Assistant(GroupedAssistantMessage),
Assistant(AssistantMessage),
}

impl ChatMessage {
pub fn id(&self) -> MessageId {
match self {
ChatMessage::User(message) => message.id,
ChatMessage::Assistant(message) => message.id,
}
}

fn focus_handle(&self, cx: &AppContext) -> Option<FocusHandle> {
match self {
ChatMessage::User(UserMessage { body, .. }) => Some(body.focus_handle(cx)),
Expand All @@ -1073,18 +1088,18 @@ impl ChatMessage {
}

struct UserMessage {
id: MessageId,
body: View<Editor>,
attachments: Vec<UserAttachment>,
pub id: MessageId,
pub body: View<Editor>,
pub attachments: Vec<UserAttachment>,
}

struct AssistantMessage {
body: RichText,
tool_calls: Vec<ToolFunctionCall>,
struct AssistantMessagePart {
pub body: RichText,
pub tool_calls: Vec<ToolFunctionCall>,
}

struct GroupedAssistantMessage {
id: MessageId,
messages: Vec<AssistantMessage>,
error: Option<SharedString>,
struct AssistantMessage {
pub id: MessageId,
pub messages: Vec<AssistantMessagePart>,
pub error: Option<SharedString>,
}
Loading

0 comments on commit a7aa257

Please sign in to comment.