diff --git a/crates/assistant2/src/assistant2.rs b/crates/assistant2/src/assistant2.rs index 4ae783c062e49..f54dd3b7d102d 100644 --- a/crates/assistant2/src/assistant2.rs +++ b/crates/assistant2/src/assistant2.rs @@ -6,19 +6,14 @@ mod saved_conversation_picker; mod tools; pub mod ui; -use crate::saved_conversation::{SavedConversation, SavedMessage, SavedMessageRole}; -use crate::saved_conversation_picker::SavedConversationPicker; -use crate::{ - attachments::ActiveEditorAttachmentTool, - tools::{CreateBufferTool, ProjectIndexTool}, - ui::UserOrAssistant, -}; +use crate::ui::UserOrAssistant; use ::ui::{div, prelude::*, Color, Tooltip, ViewContext}; use anyhow::{Context, Result}; use assistant_tooling::{ tool_running_placeholder, AttachmentRegistry, ProjectContext, ToolFunctionCall, ToolRegistry, UserAttachment, }; +use attachments::ActiveEditorAttachmentTool; use client::{proto, Client, UserStore}; use collections::HashMap; use completion_provider::*; @@ -33,11 +28,13 @@ use gpui::{ use language::{language_settings::SoftWrap, LanguageRegistry}; use open_ai::{FunctionContent, ToolCall, ToolCallContent}; use rich_text::RichText; +use saved_conversation::{SavedAssistantMessagePart, SavedChatMessage, SavedConversation}; +use saved_conversation_picker::SavedConversationPicker; use semantic_index::{CloudEmbeddingProvider, ProjectIndex, ProjectIndexDebugView, SemanticIndex}; use serde::{Deserialize, Serialize}; use settings::Settings; use std::sync::Arc; -use tools::AnnotationTool; +use tools::{AnnotationTool, CreateBufferTool, ProjectIndexTool}; use ui::{ActiveFileButton, Composer, ProjectIndexButton}; use util::paths::CONVERSATIONS_DIR; use util::{maybe, paths::EMBEDDINGS_DIR, ResultExt}; @@ -506,13 +503,11 @@ impl AssistantChat { while let Some(delta) = stream.next().await { let delta = delta?; this.update(cx, |this, cx| { - if let Some(ChatMessage::Assistant(GroupedAssistantMessage { - messages, - .. - })) = this.messages.last_mut() + if let Some(ChatMessage::Assistant(AssistantMessage { messages, .. })) = + this.messages.last_mut() { if messages.is_empty() { - messages.push(AssistantMessage { + messages.push(AssistantMessagePart { body: RichText::default(), tool_calls: Vec::new(), }) @@ -563,7 +558,7 @@ impl AssistantChat { let mut tool_tasks = Vec::new(); this.update(cx, |this, cx| { - if let Some(ChatMessage::Assistant(GroupedAssistantMessage { + if let Some(ChatMessage::Assistant(AssistantMessage { error: message_error, messages, .. @@ -592,7 +587,7 @@ impl AssistantChat { let tools = tools.into_iter().filter_map(|tool| tool.ok()).collect(); this.update(cx, |this, cx| { - if let Some(ChatMessage::Assistant(GroupedAssistantMessage { messages, .. })) = + if let Some(ChatMessage::Assistant(AssistantMessage { messages, .. })) = this.messages.last_mut() { if let Some(current_message) = messages.last_mut() { @@ -608,19 +603,19 @@ impl AssistantChat { fn push_new_assistant_message(&mut self, cx: &mut ViewContext) { // If the last message is a grouped assistant message, add to the grouped message - if let Some(ChatMessage::Assistant(GroupedAssistantMessage { messages, .. })) = + if let Some(ChatMessage::Assistant(AssistantMessage { messages, .. })) = self.messages.last_mut() { - messages.push(AssistantMessage { + messages.push(AssistantMessagePart { body: RichText::default(), tool_calls: Vec::new(), }); return; } - let message = ChatMessage::Assistant(GroupedAssistantMessage { + let message = ChatMessage::Assistant(AssistantMessage { id: self.next_message_id.post_inc(), - messages: vec![AssistantMessage { + messages: vec![AssistantMessagePart { body: RichText::default(), tool_calls: Vec::new(), }], @@ -669,40 +664,30 @@ impl AssistantChat { *entry = !*entry; } - fn new_conversation(&mut self, cx: &mut ViewContext) { - let messages = self - .messages - .drain(..) - .map(|message| { - let text = match &message { - ChatMessage::User(message) => message.body.read(cx).text(cx), - ChatMessage::Assistant(message) => message - .messages - .iter() - .map(|message| message.body.text.to_string()) - .collect::>() - .join("\n\n"), - }; - - SavedMessage { - id: message.id(), - role: match message { - ChatMessage::User(_) => SavedMessageRole::User, - ChatMessage::Assistant(_) => SavedMessageRole::Assistant, - }, - text, - } - }) - .collect::>(); - - // Reset the chat for the new conversation. + fn reset(&mut self) { + self.messages.clear(); self.list_state.reset(0); self.editing_message.take(); self.collapsed_messages.clear(); + } + + fn new_conversation(&mut self, cx: &mut ViewContext) { + let messages = std::mem::take(&mut self.messages) + .into_iter() + .map(|message| self.serialize_message(message, cx)) + .collect::>(); + + self.reset(); let title = messages .first() - .map(|message| message.text.clone()) + .map(|message| match message { + SavedChatMessage::User { body, .. } => body.clone(), + SavedChatMessage::Assistant { messages, .. } => messages + .first() + .map(|message| message.body.to_string()) + .unwrap_or_default(), + }) .unwrap_or_else(|| "A conversation with the assistant.".to_string()); let saved_conversation = SavedConversation { @@ -836,7 +821,7 @@ impl AssistantChat { } }) .into_any(), - ChatMessage::Assistant(GroupedAssistantMessage { + ChatMessage::Assistant(AssistantMessage { id, messages, error, @@ -917,7 +902,7 @@ impl AssistantChat { content: body.read(cx).text(cx), }); } - ChatMessage::Assistant(GroupedAssistantMessage { messages, .. }) => { + ChatMessage::Assistant(AssistantMessage { messages, .. }) => { for message in messages { let body = message.body.clone(); @@ -971,6 +956,43 @@ impl AssistantChat { Ok(completion_messages) }) } + + fn serialize_message( + &self, + message: ChatMessage, + cx: &mut ViewContext, + ) -> SavedChatMessage { + match message { + ChatMessage::User(message) => SavedChatMessage::User { + id: message.id, + body: message.body.read(cx).text(cx), + attachments: message + .attachments + .iter() + .map(|attachment| { + self.attachment_registry + .serialize_user_attachment(attachment) + }) + .collect(), + }, + ChatMessage::Assistant(message) => SavedChatMessage::Assistant { + id: message.id, + error: message.error, + messages: message + .messages + .iter() + .map(|message| SavedAssistantMessagePart { + body: message.body.text.clone(), + tool_calls: message + .tool_calls + .iter() + .map(|tool_call| self.tool_registry.serialize_tool_call(tool_call)) + .collect(), + }) + .collect(), + }, + } + } } impl Render for AssistantChat { @@ -1053,17 +1075,10 @@ impl MessageId { enum ChatMessage { User(UserMessage), - Assistant(GroupedAssistantMessage), + Assistant(AssistantMessage), } impl ChatMessage { - pub fn id(&self) -> MessageId { - match self { - ChatMessage::User(message) => message.id, - ChatMessage::Assistant(message) => message.id, - } - } - fn focus_handle(&self, cx: &AppContext) -> Option { match self { ChatMessage::User(UserMessage { body, .. }) => Some(body.focus_handle(cx)), @@ -1073,18 +1088,18 @@ impl ChatMessage { } struct UserMessage { - id: MessageId, - body: View, - attachments: Vec, + pub id: MessageId, + pub body: View, + pub attachments: Vec, } -struct AssistantMessage { - body: RichText, - tool_calls: Vec, +struct AssistantMessagePart { + pub body: RichText, + pub tool_calls: Vec, } -struct GroupedAssistantMessage { - id: MessageId, - messages: Vec, - error: Option, +struct AssistantMessage { + pub id: MessageId, + pub messages: Vec, + pub error: Option, } diff --git a/crates/assistant2/src/attachments/active_file.rs b/crates/assistant2/src/attachments/active_file.rs index 54bcee940734d..811eb4219c10d 100644 --- a/crates/assistant2/src/attachments/active_file.rs +++ b/crates/assistant2/src/attachments/active_file.rs @@ -1,64 +1,68 @@ +use std::{path::PathBuf, sync::Arc}; + use anyhow::{anyhow, Result}; use assistant_tooling::{LanguageModelAttachment, ProjectContext, ToolOutput}; use editor::Editor; use gpui::{Render, Task, View, WeakModel, WeakView}; use language::Buffer; use project::ProjectPath; +use serde::{Deserialize, Serialize}; use ui::{prelude::*, ButtonLike, Tooltip, WindowContext}; use util::maybe; use workspace::Workspace; +#[derive(Serialize, Deserialize)] pub struct ActiveEditorAttachment { - buffer: WeakModel, - path: Option, + #[serde(skip)] + buffer: Option>, + path: Option, } pub struct FileAttachmentView { - output: Result, + project_path: Option, + buffer: Option>, + error: Option, } impl Render for FileAttachmentView { fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { - match &self.output { - Ok(attachment) => { - let filename: SharedString = attachment - .path - .as_ref() - .and_then(|p| p.path.file_name()?.to_str()) - .unwrap_or("Untitled") - .to_string() - .into(); - - // todo!(): make the button link to the actual file to open - ButtonLike::new("file-attachment") - .child( - h_flex() - .gap_1() - .bg(cx.theme().colors().editor_background) - .rounded_md() - .child(ui::Icon::new(IconName::File)) - .child(filename.clone()), - ) - .tooltip({ - move |cx| Tooltip::with_meta("File Attached", None, filename.clone(), cx) - }) - .into_any_element() - } - Err(err) => div().child(err.to_string()).into_any_element(), + if let Some(error) = &self.error { + return div().child(error.to_string()).into_any_element(); } + + let filename: SharedString = self + .project_path + .as_ref() + .and_then(|p| p.path.file_name()?.to_str()) + .unwrap_or("Untitled") + .to_string() + .into(); + + ButtonLike::new("file-attachment") + .child( + h_flex() + .gap_1() + .bg(cx.theme().colors().editor_background) + .rounded_md() + .child(ui::Icon::new(IconName::File)) + .child(filename.clone()), + ) + .tooltip(move |cx| Tooltip::with_meta("File Attached", None, filename.clone(), cx)) + .into_any_element() } } impl ToolOutput for FileAttachmentView { fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String { - if let Ok(result) = &self.output { - if let Some(path) = &result.path { - project.add_file(path.clone()); - return format!("current file: {}", path.path.display()); - } else if let Some(buffer) = result.buffer.upgrade() { - return format!("current untitled buffer text:\n{}", buffer.read(cx).text()); - } + if let Some(path) = &self.project_path { + project.add_file(path.clone()); + return format!("current file: {}", path.path.display()); + } + + if let Some(buffer) = self.buffer.as_ref().and_then(|buffer| buffer.upgrade()) { + return format!("current untitled buffer text:\n{}", buffer.read(cx).text()); } + String::new() } } @@ -77,6 +81,10 @@ impl LanguageModelAttachment for ActiveEditorAttachmentTool { type Output = ActiveEditorAttachment; type View = FileAttachmentView; + fn name(&self) -> Arc { + "active-editor-attachment".into() + } + fn run(&self, cx: &mut WindowContext) -> Task> { Task::ready(maybe!({ let active_buffer = self @@ -91,13 +99,10 @@ impl LanguageModelAttachment for ActiveEditorAttachmentTool { let buffer = active_buffer.read(cx); if let Some(buffer) = buffer.as_singleton() { - let path = - project::File::from_dyn(buffer.read(cx).file()).map(|file| ProjectPath { - worktree_id: file.worktree_id(cx), - path: file.path.clone(), - }); + let path = project::File::from_dyn(buffer.read(cx).file()) + .and_then(|file| file.worktree.read(cx).absolutize(&file.path).ok()); return Ok(ActiveEditorAttachment { - buffer: buffer.downgrade(), + buffer: Some(buffer.downgrade()), path, }); } else { @@ -106,7 +111,34 @@ impl LanguageModelAttachment for ActiveEditorAttachmentTool { })) } - fn view(output: Result, cx: &mut WindowContext) -> View { - cx.new_view(|_cx| FileAttachmentView { output }) + fn view( + &self, + output: Result, + cx: &mut WindowContext, + ) -> View { + let error; + let project_path; + let buffer; + match output { + Ok(output) => { + error = None; + let workspace = self.workspace.upgrade().unwrap(); + let project = workspace.read(cx).project(); + project_path = output + .path + .and_then(|path| project.read(cx).project_path_for_absolute_path(&path, cx)); + buffer = output.buffer; + } + Err(err) => { + error = Some(err); + buffer = None; + project_path = None; + } + } + cx.new_view(|_cx| FileAttachmentView { + project_path, + buffer, + error, + }) } } diff --git a/crates/assistant2/src/saved_conversation.rs b/crates/assistant2/src/saved_conversation.rs index 2eb6af0557af7..1434b777f5218 100644 --- a/crates/assistant2/src/saved_conversation.rs +++ b/crates/assistant2/src/saved_conversation.rs @@ -1,3 +1,5 @@ +use assistant_tooling::{SavedToolFunctionCall, SavedUserAttachment}; +use gpui::SharedString; use serde::{Deserialize, Serialize}; use crate::MessageId; @@ -8,21 +10,27 @@ pub struct SavedConversation { pub version: String, /// The title of the conversation, generated by the Assistant. pub title: String, - pub messages: Vec, + pub messages: Vec, } #[derive(Serialize, Deserialize)] -#[serde(tag = "type", rename_all = "snake_case")] -pub enum SavedMessageRole { - User, - Assistant, +pub enum SavedChatMessage { + User { + id: MessageId, + body: String, + attachments: Vec, + }, + Assistant { + id: MessageId, + messages: Vec, + error: Option, + }, } #[derive(Serialize, Deserialize)] -pub struct SavedMessage { - pub id: MessageId, - pub role: SavedMessageRole, - pub text: String, +pub struct SavedAssistantMessagePart { + pub body: SharedString, + pub tool_calls: Vec, } /// Returns a list of placeholder conversations for mocking the UI. diff --git a/crates/assistant2/src/tools/annotate_code.rs b/crates/assistant2/src/tools/annotate_code.rs index 7e9216650d6fc..f2427bd440b68 100644 --- a/crates/assistant2/src/tools/annotate_code.rs +++ b/crates/assistant2/src/tools/annotate_code.rs @@ -6,7 +6,7 @@ use editor::{ }; use gpui::{prelude::*, AnyElement, Model, Task, View, WeakView}; use language::ToPoint; -use project::{Project, ProjectPath}; +use project::{search::SearchQuery, Project, ProjectPath}; use schemars::JsonSchema; use serde::Deserialize; use std::path::Path; @@ -29,17 +29,18 @@ impl AnnotationTool { pub struct AnnotationInput { /// Name for this set of annotations title: String, - annotations: Vec, + /// Excerpts from the file to show to the user. + excerpts: Vec, } #[derive(Debug, Deserialize, JsonSchema, Clone)] -struct Annotation { +struct Excerpt { /// Path to the file path: String, - /// Name of a symbol in the code - symbol_name: String, - /// Text to display near the symbol definition - text: String, + /// A short, distinctive string that appears in the file, used to define a location in the file. + text_passage: String, + /// Text to display above the code excerpt + annotation: String, } impl LanguageModelTool for AnnotationTool { @@ -58,7 +59,7 @@ impl LanguageModelTool for AnnotationTool { fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task> { let workspace = self.workspace.clone(); let project = self.project.clone(); - let excerpts = input.annotations.clone(); + let excerpts = input.excerpts.clone(); let title = input.title.clone(); let worktree_id = project.update(cx, |project, cx| { @@ -74,15 +75,16 @@ impl LanguageModelTool for AnnotationTool { }; let buffer_tasks = project.update(cx, |project, cx| { - let excerpts = excerpts.clone(); excerpts .iter() .map(|excerpt| { - let project_path = ProjectPath { - worktree_id, - path: Path::new(&excerpt.path).into(), - }; - project.open_buffer(project_path.clone(), cx) + project.open_buffer( + ProjectPath { + worktree_id, + path: Path::new(&excerpt.path).into(), + }, + cx, + ) }) .collect::>() }); @@ -99,39 +101,43 @@ impl LanguageModelTool for AnnotationTool { for (excerpt, buffer) in excerpts.iter().zip(buffers.iter()) { let snapshot = buffer.update(&mut cx, |buffer, _cx| buffer.snapshot())?; - if let Some(outline) = snapshot.outline(None) { - let matches = outline - .search(&excerpt.symbol_name, cx.background_executor().clone()) - .await; - if let Some(mat) = matches.first() { - let item = &outline.items[mat.candidate_id]; - let start = item.range.start.to_point(&snapshot); - editor.update(&mut cx, |editor, cx| { - let ranges = editor.buffer().update(cx, |multibuffer, cx| { - multibuffer.push_excerpts_with_context_lines( - buffer.clone(), - vec![start..start], - 5, - cx, - ) - }); - let explanation = SharedString::from(excerpt.text.clone()); - editor.insert_blocks( - [BlockProperties { - position: ranges[0].start, - height: 2, - style: BlockStyle::Fixed, - render: Box::new(move |cx| { - Self::render_note_block(&explanation, cx) - }), - disposition: BlockDisposition::Above, - }], - None, - cx, - ); - })?; - } - } + let query = + SearchQuery::text(&excerpt.text_passage, false, false, false, vec![], vec![])?; + + let matches = query.search(&snapshot, None).await; + let Some(first_match) = matches.first() else { + log::warn!( + "text {:?} does not appear in '{}'", + excerpt.text_passage, + excerpt.path + ); + continue; + }; + let mut start = first_match.start.to_point(&snapshot); + start.column = 0; + + editor.update(&mut cx, |editor, cx| { + let ranges = editor.buffer().update(cx, |multibuffer, cx| { + multibuffer.push_excerpts_with_context_lines( + buffer.clone(), + vec![start..start], + 5, + cx, + ) + }); + let annotation = SharedString::from(excerpt.annotation.clone()); + editor.insert_blocks( + [BlockProperties { + position: ranges[0].start, + height: annotation.split('\n').count() as u8 + 1, + style: BlockStyle::Fixed, + render: Box::new(move |cx| Self::render_note_block(&annotation, cx)), + disposition: BlockDisposition::Above, + }], + None, + cx, + ); + })?; } workspace @@ -144,7 +150,8 @@ impl LanguageModelTool for AnnotationTool { }) } - fn output_view( + fn view( + &self, _: Self::Input, output: Result, cx: &mut WindowContext, diff --git a/crates/assistant2/src/tools/create_buffer.rs b/crates/assistant2/src/tools/create_buffer.rs index 13e4cc7081471..5563615bdd1b0 100644 --- a/crates/assistant2/src/tools/create_buffer.rs +++ b/crates/assistant2/src/tools/create_buffer.rs @@ -86,7 +86,8 @@ impl LanguageModelTool for CreateBufferTool { }) } - fn output_view( + fn view( + &self, input: Self::Input, output: Result, cx: &mut WindowContext, diff --git a/crates/assistant2/src/tools/project_index.rs b/crates/assistant2/src/tools/project_index.rs index 32a56fe9edad7..0c43ef09b96eb 100644 --- a/crates/assistant2/src/tools/project_index.rs +++ b/crates/assistant2/src/tools/project_index.rs @@ -1,13 +1,13 @@ -use anyhow::Result; +use anyhow::{anyhow, Result}; use assistant_tooling::{LanguageModelTool, ToolOutput}; use collections::BTreeMap; use gpui::{prelude::*, Model, Task}; use project::ProjectPath; use schemars::JsonSchema; use semantic_index::{ProjectIndex, Status}; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use serde_json::Value; -use std::{fmt::Write as _, ops::Range}; +use std::{fmt::Write as _, ops::Range, path::Path, sync::Arc}; use ui::{div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, WindowContext}; const DEFAULT_SEARCH_LIMIT: usize = 20; @@ -29,28 +29,24 @@ pub struct CodebaseQuery { pub struct ProjectIndexView { input: CodebaseQuery, - output: Result, + status: Status, + excerpts: Result>>>, element_id: ElementId, expanded_header: bool, } +#[derive(Serialize, Deserialize)] pub struct ProjectIndexOutput { status: Status, - excerpts: BTreeMap>>, + worktrees: BTreeMap, WorktreeIndexOutput>, } -impl ProjectIndexView { - fn new(input: CodebaseQuery, output: Result) -> Self { - let element_id = ElementId::Name(nanoid::nanoid!().into()); - - Self { - input, - output, - element_id, - expanded_header: false, - } - } +#[derive(Serialize, Deserialize)] +struct WorktreeIndexOutput { + excerpts: BTreeMap, Vec>>, +} +impl ProjectIndexView { fn toggle_header(&mut self, cx: &mut ViewContext) { self.expanded_header = !self.expanded_header; cx.notify(); @@ -60,18 +56,14 @@ impl ProjectIndexView { impl Render for ProjectIndexView { fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { let query = self.input.query.clone(); - - let result = &self.output; - - let output = match result { + let excerpts = match &self.excerpts { Err(err) => { return div().child(Label::new(format!("Error: {}", err)).color(Color::Error)); } - Ok(output) => output, + Ok(excerpts) => excerpts, }; - let file_count = output.excerpts.len(); - + let file_count = excerpts.len(); let header = h_flex() .gap_2() .child(Icon::new(IconName::File)) @@ -97,16 +89,12 @@ impl Render for ProjectIndexView { .child(Icon::new(IconName::MagnifyingGlass)) .child(Label::new(format!("`{}`", query)).color(Color::Muted)), ) - .child( - v_flex() - .gap_2() - .children(output.excerpts.keys().map(|path| { - h_flex().gap_2().child(Icon::new(IconName::File)).child( - Label::new(path.path.to_string_lossy().to_string()) - .color(Color::Muted), - ) - })), - ), + .child(v_flex().gap_2().children(excerpts.keys().map(|path| { + h_flex().gap_2().child(Icon::new(IconName::File)).child( + Label::new(path.path.to_string_lossy().to_string()) + .color(Color::Muted), + ) + }))), ), ) } @@ -118,16 +106,16 @@ impl ToolOutput for ProjectIndexView { context: &mut assistant_tooling::ProjectContext, _: &mut WindowContext, ) -> String { - match &self.output { - Ok(output) => { + match &self.excerpts { + Ok(excerpts) => { let mut body = "found results in the following paths:\n".to_string(); - for (project_path, ranges) in &output.excerpts { + for (project_path, ranges) in excerpts { context.add_excerpts(project_path.clone(), ranges); writeln!(&mut body, "* {}", &project_path.path.display()).unwrap(); } - if output.status != Status::Idle { + if self.status != Status::Idle { body.push_str("Still indexing. Results may be incomplete.\n"); } @@ -172,16 +160,20 @@ impl LanguageModelTool for ProjectIndexTool { cx.update(|cx| { let mut output = ProjectIndexOutput { status, - excerpts: Default::default(), + worktrees: Default::default(), }; for search_result in search_results { - let path = ProjectPath { - worktree_id: search_result.worktree.read(cx).id(), - path: search_result.path.clone(), - }; - - let excerpts_for_path = output.excerpts.entry(path).or_default(); + let worktree_path = search_result.worktree.read(cx).abs_path(); + let excerpts = &mut output + .worktrees + .entry(worktree_path) + .or_insert(WorktreeIndexOutput { + excerpts: Default::default(), + }) + .excerpts; + + let excerpts_for_path = excerpts.entry(search_result.path).or_default(); let ix = match excerpts_for_path .binary_search_by_key(&search_result.range.start, |r| r.start) { @@ -195,12 +187,57 @@ impl LanguageModelTool for ProjectIndexTool { }) } - fn output_view( + fn view( + &self, input: Self::Input, output: Result, cx: &mut WindowContext, ) -> gpui::View { - cx.new_view(|_cx| ProjectIndexView::new(input, output)) + cx.new_view(|cx| { + let status; + let excerpts; + match output { + Ok(output) => { + status = output.status; + let project_index = self.project_index.read(cx); + if let Some(project) = project_index.project().upgrade() { + let project = project.read(cx); + excerpts = Ok(output + .worktrees + .into_iter() + .filter_map(|(abs_path, output)| { + for worktree in project.worktrees() { + let worktree = worktree.read(cx); + if worktree.abs_path() == abs_path { + return Some((worktree.id(), output.excerpts)); + } + } + None + }) + .flat_map(|(worktree_id, excerpts)| { + excerpts.into_iter().map(move |(path, ranges)| { + (ProjectPath { worktree_id, path }, ranges) + }) + }) + .collect::>()); + } else { + excerpts = Err(anyhow!("project was dropped")); + } + } + Err(err) => { + status = Status::Idle; + excerpts = Err(err); + } + }; + + ProjectIndexView { + input, + status, + excerpts, + element_id: ElementId::Name(nanoid::nanoid!().into()), + expanded_header: false, + } + }) } fn render_running(arguments: &Option, _: &mut WindowContext) -> impl IntoElement { diff --git a/crates/assistant_tooling/src/assistant_tooling.rs b/crates/assistant_tooling/src/assistant_tooling.rs index 39dabf08305c0..e5aff01edf3c0 100644 --- a/crates/assistant_tooling/src/assistant_tooling.rs +++ b/crates/assistant_tooling/src/assistant_tooling.rs @@ -2,9 +2,12 @@ mod attachment_registry; mod project_context; mod tool_registry; -pub use attachment_registry::{AttachmentRegistry, LanguageModelAttachment, UserAttachment}; +pub use attachment_registry::{ + AttachmentRegistry, LanguageModelAttachment, SavedUserAttachment, UserAttachment, +}; pub use project_context::ProjectContext; pub use tool_registry::{ - tool_running_placeholder, LanguageModelTool, ToolFunctionCall, ToolFunctionDefinition, + tool_running_placeholder, LanguageModelTool, SavedToolFunctionCall, + SavedToolFunctionCallResult, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition, ToolOutput, ToolRegistry, }; diff --git a/crates/assistant_tooling/src/attachment_registry.rs b/crates/assistant_tooling/src/attachment_registry.rs index 8c0ae347a05c5..8c82099f4d8f4 100644 --- a/crates/assistant_tooling/src/attachment_registry.rs +++ b/crates/assistant_tooling/src/attachment_registry.rs @@ -3,6 +3,8 @@ use anyhow::{anyhow, Result}; use collections::HashMap; use futures::future::join_all; use gpui::{AnyView, Render, Task, View, WindowContext}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use serde_json::value::RawValue; use std::{ any::TypeId, sync::{ @@ -17,24 +19,34 @@ pub struct AttachmentRegistry { } pub trait LanguageModelAttachment { - type Output: 'static; + type Output: DeserializeOwned + Serialize + 'static; type View: Render + ToolOutput; + fn name(&self) -> Arc; fn run(&self, cx: &mut WindowContext) -> Task>; - - fn view(output: Result, cx: &mut WindowContext) -> View; + fn view(&self, output: Result, cx: &mut WindowContext) -> View; } /// A collected attachment from running an attachment tool pub struct UserAttachment { pub view: AnyView, + name: Arc, + serialized_output: Result, String>, generate_fn: fn(AnyView, &mut ProjectContext, cx: &mut WindowContext) -> String, } +#[derive(Serialize, Deserialize)] +pub struct SavedUserAttachment { + name: Arc, + serialized_output: Result, String>, +} + /// Internal representation of an attachment tool to allow us to treat them dynamically struct RegisteredAttachment { + name: Arc, enabled: AtomicBool, call: Box Task>>, + deserialize: Box Result>, } impl AttachmentRegistry { @@ -45,24 +57,65 @@ impl AttachmentRegistry { } pub fn register(&mut self, attachment: A) { - let call = Box::new(move |cx: &mut WindowContext| { - let result = attachment.run(cx); + let attachment = Arc::new(attachment); + + let call = Box::new({ + let attachment = attachment.clone(); + move |cx: &mut WindowContext| { + let result = attachment.run(cx); + let attachment = attachment.clone(); + cx.spawn(move |mut cx| async move { + let result: Result = result.await; + let serialized_output = + result + .as_ref() + .map_err(ToString::to_string) + .and_then(|output| { + Ok(RawValue::from_string( + serde_json::to_string(output).map_err(|e| e.to_string())?, + ) + .unwrap()) + }); + + let view = cx.update(|cx| attachment.view(result, cx))?; + + Ok(UserAttachment { + name: attachment.name(), + view: view.into(), + generate_fn: generate::, + serialized_output, + }) + }) + } + }); - cx.spawn(move |mut cx| async move { - let result: Result = result.await; - let view = cx.update(|cx| A::view(result, cx))?; + let deserialize = Box::new({ + let attachment = attachment.clone(); + move |saved_attachment: &SavedUserAttachment, cx: &mut WindowContext| { + let serialized_output = saved_attachment.serialized_output.clone(); + let output = match &serialized_output { + Ok(serialized_output) => { + Ok(serde_json::from_str::(serialized_output.get())?) + } + Err(error) => Err(anyhow!("{error}")), + }; + let view = attachment.view(output, cx).into(); Ok(UserAttachment { - view: view.into(), + name: saved_attachment.name.clone(), + view, + serialized_output, generate_fn: generate::, }) - }) + } }); self.registered_attachments.insert( TypeId::of::(), RegisteredAttachment { + name: attachment.name(), call, + deserialize, enabled: AtomicBool::new(true), }, ); @@ -134,6 +187,35 @@ impl AttachmentRegistry { .collect()) }) } + + pub fn serialize_user_attachment( + &self, + user_attachment: &UserAttachment, + ) -> SavedUserAttachment { + SavedUserAttachment { + name: user_attachment.name.clone(), + serialized_output: user_attachment.serialized_output.clone(), + } + } + + pub fn deserialize_user_attachment( + &self, + saved_user_attachment: SavedUserAttachment, + cx: &mut WindowContext, + ) -> Result { + if let Some(registered_attachment) = self + .registered_attachments + .values() + .find(|attachment| attachment.name == saved_user_attachment.name) + { + (registered_attachment.deserialize)(&saved_user_attachment, cx) + } else { + Err(anyhow!( + "no attachment tool for name {}", + saved_user_attachment.name + )) + } + } } impl UserAttachment { diff --git a/crates/assistant_tooling/src/tool_registry.rs b/crates/assistant_tooling/src/tool_registry.rs index d32f756e5f1d1..d1a14c4c9df5f 100644 --- a/crates/assistant_tooling/src/tool_registry.rs +++ b/crates/assistant_tooling/src/tool_registry.rs @@ -1,41 +1,60 @@ +use crate::ProjectContext; use anyhow::{anyhow, Result}; use gpui::{ div, AnyElement, AnyView, IntoElement, ParentElement, Render, Styled, Task, View, WindowContext, }; use schemars::{schema::RootSchema, schema_for, JsonSchema}; -use serde::Deserialize; -use serde_json::Value; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use serde_json::{value::RawValue, Value}; use std::{ any::TypeId, collections::HashMap, fmt::Display, - sync::atomic::{AtomicBool, Ordering::SeqCst}, + sync::{ + atomic::{AtomicBool, Ordering::SeqCst}, + Arc, + }, }; -use crate::ProjectContext; - pub struct ToolRegistry { registered_tools: HashMap, } -#[derive(Default, Deserialize)] +#[derive(Default)] pub struct ToolFunctionCall { pub id: String, pub name: String, pub arguments: String, - #[serde(skip)] pub result: Option, } +#[derive(Default, Serialize, Deserialize)] +pub struct SavedToolFunctionCall { + pub id: String, + pub name: String, + pub arguments: String, + pub result: Option, +} + pub enum ToolFunctionCallResult { NoSuchTool, ParsingFailed, Finished { view: AnyView, + serialized_output: Result, String>, generate_fn: fn(AnyView, &mut ProjectContext, &mut WindowContext) -> String, }, } +#[derive(Serialize, Deserialize)] +pub enum SavedToolFunctionCallResult { + NoSuchTool, + ParsingFailed, + Finished { + serialized_output: Result, String>, + }, +} + #[derive(Clone)] pub struct ToolFunctionDefinition { pub name: String, @@ -46,10 +65,10 @@ pub struct ToolFunctionDefinition { pub trait LanguageModelTool { /// The input type that will be passed in to `execute` when the tool is called /// by the language model. - type Input: for<'de> Deserialize<'de> + JsonSchema; + type Input: DeserializeOwned + JsonSchema; /// The output returned by executing the tool. - type Output: 'static; + type Output: DeserializeOwned + Serialize + 'static; type View: Render + ToolOutput; @@ -80,7 +99,8 @@ pub trait LanguageModelTool { fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task>; /// A view of the output of running the tool, for displaying to the user. - fn output_view( + fn view( + &self, input: Self::Input, output: Result, cx: &mut WindowContext, @@ -102,7 +122,8 @@ pub trait ToolOutput: Sized { struct RegisteredTool { enabled: AtomicBool, type_id: TypeId, - call: Box Task>>, + execute: Box Task>>, + deserialize: Box ToolFunctionCall>, render_running: fn(&ToolFunctionCall, &mut WindowContext) -> gpui::AnyElement, definition: ToolFunctionDefinition, } @@ -162,23 +183,125 @@ impl ToolRegistry { } } + pub fn serialize_tool_call(&self, call: &ToolFunctionCall) -> SavedToolFunctionCall { + SavedToolFunctionCall { + id: call.id.clone(), + name: call.name.clone(), + arguments: call.arguments.clone(), + result: call.result.as_ref().map(|result| match result { + ToolFunctionCallResult::NoSuchTool => SavedToolFunctionCallResult::NoSuchTool, + ToolFunctionCallResult::ParsingFailed => SavedToolFunctionCallResult::ParsingFailed, + ToolFunctionCallResult::Finished { + serialized_output, .. + } => SavedToolFunctionCallResult::Finished { + serialized_output: match serialized_output { + Ok(value) => Ok(value.clone()), + Err(e) => Err(e.to_string()), + }, + }, + }), + } + } + + pub fn deserialize_tool_call( + &self, + call: &SavedToolFunctionCall, + cx: &mut WindowContext, + ) -> ToolFunctionCall { + if let Some(tool) = &self.registered_tools.get(&call.name) { + (tool.deserialize)(call, cx) + } else { + ToolFunctionCall { + id: call.id.clone(), + name: call.name.clone(), + arguments: call.arguments.clone(), + result: Some(ToolFunctionCallResult::NoSuchTool), + } + } + } + pub fn register( &mut self, tool: T, _cx: &mut WindowContext, ) -> Result<()> { let name = tool.name(); + let tool = Arc::new(tool); let registered_tool = RegisteredTool { type_id: TypeId::of::(), definition: tool.definition(), enabled: AtomicBool::new(true), - call: Box::new( - move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| { + deserialize: Box::new({ + let tool = tool.clone(); + move |tool_call: &SavedToolFunctionCall, cx: &mut WindowContext| { + let id = tool_call.id.clone(); let name = tool_call.name.clone(); let arguments = tool_call.arguments.clone(); + + let Ok(input) = serde_json::from_str::(&tool_call.arguments) else { + return ToolFunctionCall { + id, + name: name.clone(), + arguments, + result: Some(ToolFunctionCallResult::ParsingFailed), + }; + }; + + let result = match &tool_call.result { + Some(result) => match result { + SavedToolFunctionCallResult::NoSuchTool => { + Some(ToolFunctionCallResult::NoSuchTool) + } + SavedToolFunctionCallResult::ParsingFailed => { + Some(ToolFunctionCallResult::ParsingFailed) + } + SavedToolFunctionCallResult::Finished { serialized_output } => { + let output = match serialized_output { + Ok(value) => { + match serde_json::from_str::(value.get()) { + Ok(value) => Ok(value), + Err(_) => { + return ToolFunctionCall { + id, + name: name.clone(), + arguments, + result: Some( + ToolFunctionCallResult::ParsingFailed, + ), + }; + } + } + } + Err(e) => Err(anyhow!("{e}")), + }; + + let view = tool.view(input, output, cx).into(); + Some(ToolFunctionCallResult::Finished { + serialized_output: serialized_output.clone(), + generate_fn: generate::, + view, + }) + } + }, + None => None, + }; + + ToolFunctionCall { + id: tool_call.id.clone(), + name: name.clone(), + arguments: tool_call.arguments.clone(), + result, + } + } + }), + execute: Box::new({ + let tool = tool.clone(); + move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| { let id = tool_call.id.clone(); + let name = tool_call.name.clone(); + let arguments = tool_call.arguments.clone(); - let Ok(input) = serde_json::from_str::(arguments.as_str()) else { + let Ok(input) = serde_json::from_str::(&arguments) else { return Task::ready(Ok(ToolFunctionCall { id, name: name.clone(), @@ -188,23 +311,33 @@ impl ToolRegistry { }; let result = tool.execute(&input, cx); - + let tool = tool.clone(); cx.spawn(move |mut cx| async move { - let result: Result = result.await; - let view = cx.update(|cx| T::output_view(input, result, cx))?; + let result = result.await; + let serialized_output = result + .as_ref() + .map_err(ToString::to_string) + .and_then(|output| { + Ok(RawValue::from_string( + serde_json::to_string(output).map_err(|e| e.to_string())?, + ) + .unwrap()) + }); + let view = cx.update(|cx| tool.view(input, result, cx))?; Ok(ToolFunctionCall { id, name: name.clone(), arguments, result: Some(ToolFunctionCallResult::Finished { + serialized_output, view: view.into(), generate_fn: generate::, }), }) }) - }, - ), + } + }), render_running: render_running::, }; @@ -259,7 +392,7 @@ impl ToolRegistry { } }; - (tool.call)(tool_call, cx) + (tool.execute)(tool_call, cx) } } @@ -275,9 +408,9 @@ impl ToolFunctionCallResult { ToolFunctionCallResult::ParsingFailed => { format!("Unable to parse arguments for {name}") } - ToolFunctionCallResult::Finished { generate_fn, view } => { - (generate_fn)(view.clone(), project, cx) - } + ToolFunctionCallResult::Finished { + generate_fn, view, .. + } => (generate_fn)(view.clone(), project, cx), } } @@ -373,7 +506,8 @@ mod test { Task::ready(Ok(weather)) } - fn output_view( + fn view( + &self, _input: Self::Input, result: Result, cx: &mut WindowContext, diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index e162f8c3250cf..42fce2d648b93 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -7864,6 +7864,18 @@ impl Project { }) } + pub fn project_path_for_absolute_path( + &self, + abs_path: &Path, + cx: &AppContext, + ) -> Option { + self.find_local_worktree(abs_path, cx) + .map(|(worktree, relative_path)| ProjectPath { + worktree_id: worktree.read(cx).id(), + path: relative_path.into(), + }) + } + pub fn get_workspace_root( &self, project_path: &ProjectPath, diff --git a/crates/project/src/search.rs b/crates/project/src/search.rs index 5df9748f97aad..11b708c8fec84 100644 --- a/crates/project/src/search.rs +++ b/crates/project/src/search.rs @@ -250,6 +250,7 @@ impl SearchQuery { } } } + pub async fn search( &self, buffer: &BufferSnapshot, diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 98ca2f25c7f1a..94a2e222216c0 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -450,7 +450,7 @@ pub struct WorktreeSearchResult { pub score: f32, } -#[derive(Copy, Clone, Debug, Eq, PartialEq)] +#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] pub enum Status { Idle, Loading,