Skip to content

Commit

Permalink
assistant2: Add support for using tools provided by context servers (#…
Browse files Browse the repository at this point in the history
…21418)

This PR adds support to Assistant 2 for using tools provided by context
servers.

As part of this I introduced a new `ThreadStore`.

Release Notes:

- N/A

---------

Co-authored-by: Cole <[email protected]>
  • Loading branch information
maxdeviant and cole-miller authored Dec 2, 2024
1 parent f32ffcf commit b88daae
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 2 deletions.
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions crates/assistant2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@ anyhow.workspace = true
assistant_tool.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
context_server.workspace = true
editor.workspace = true
feature_flags.workspace = true
futures.workspace = true
gpui.workspace = true
language_model.workspace = true
language_model_selector.workspace = true
log.workspace = true
project.workspace = true
proto.workspace = true
settings.workspace = true
serde.workspace = true
Expand Down
1 change: 1 addition & 0 deletions crates/assistant2/src/assistant.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod assistant_panel;
mod message_editor;
mod thread;
mod thread_store;

use command_palette_hooks::CommandPaletteFilter;
use feature_flags::{Assistant2FeatureFlag, FeatureFlagAppExt};
Expand Down
20 changes: 18 additions & 2 deletions crates/assistant2/src/assistant_panel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use workspace::Workspace;

use crate::message_editor::MessageEditor;
use crate::thread::{Message, Thread, ThreadEvent};
use crate::thread_store::ThreadStore;
use crate::{NewThread, ToggleFocus, ToggleModelSelector};

pub fn init(cx: &mut AppContext) {
Expand All @@ -29,6 +30,8 @@ pub fn init(cx: &mut AppContext) {

pub struct AssistantPanel {
workspace: WeakView<Workspace>,
#[allow(unused)]
thread_store: Model<ThreadStore>,
thread: Model<Thread>,
message_editor: View<MessageEditor>,
tools: Arc<ToolWorkingSet>,
Expand All @@ -42,13 +45,25 @@ impl AssistantPanel {
) -> Task<Result<View<Self>>> {
cx.spawn(|mut cx| async move {
let tools = Arc::new(ToolWorkingSet::default());
let thread_store = workspace
.update(&mut cx, |workspace, cx| {
let project = workspace.project().clone();
ThreadStore::new(project, tools.clone(), cx)
})?
.await?;

workspace.update(&mut cx, |workspace, cx| {
cx.new_view(|cx| Self::new(workspace, tools, cx))
cx.new_view(|cx| Self::new(workspace, thread_store, tools, cx))
})
})
}

fn new(workspace: &Workspace, tools: Arc<ToolWorkingSet>, cx: &mut ViewContext<Self>) -> Self {
fn new(
workspace: &Workspace,
thread_store: Model<ThreadStore>,
tools: Arc<ToolWorkingSet>,
cx: &mut ViewContext<Self>,
) -> Self {
let thread = cx.new_model(|cx| Thread::new(tools.clone(), cx));
let subscriptions = vec![
cx.observe(&thread, |_, _, cx| cx.notify()),
Expand All @@ -57,6 +72,7 @@ impl AssistantPanel {

Self {
workspace: workspace.weak_handle(),
thread_store,
thread: thread.clone(),
message_editor: cx.new_view(|cx| MessageEditor::new(thread, cx)),
tools,
Expand Down
114 changes: 114 additions & 0 deletions crates/assistant2/src/thread_store.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
use std::sync::Arc;

use anyhow::Result;
use assistant_tool::{ToolId, ToolWorkingSet};
use collections::HashMap;
use context_server::manager::ContextServerManager;
use context_server::{ContextServerFactoryRegistry, ContextServerTool};
use gpui::{prelude::*, AppContext, Model, ModelContext, Task};
use project::Project;
use util::ResultExt as _;

pub struct ThreadStore {
#[allow(unused)]
project: Model<Project>,
tools: Arc<ToolWorkingSet>,
context_server_manager: Model<ContextServerManager>,
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
}

impl ThreadStore {
pub fn new(
project: Model<Project>,
tools: Arc<ToolWorkingSet>,
cx: &mut AppContext,
) -> Task<Result<Model<Self>>> {
cx.spawn(|mut cx| async move {
let this = cx.new_model(|cx: &mut ModelContext<Self>| {
let context_server_factory_registry =
ContextServerFactoryRegistry::default_global(cx);
let context_server_manager = cx.new_model(|cx| {
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
});

let this = Self {
project,
tools,
context_server_manager,
context_server_tool_ids: HashMap::default(),
};
this.register_context_server_handlers(cx);

this
})?;

Ok(this)
})
}

fn register_context_server_handlers(&self, cx: &mut ModelContext<Self>) {
cx.subscribe(
&self.context_server_manager.clone(),
Self::handle_context_server_event,
)
.detach();
}

fn handle_context_server_event(
&mut self,
context_server_manager: Model<ContextServerManager>,
event: &context_server::manager::Event,
cx: &mut ModelContext<Self>,
) {
let tool_working_set = self.tools.clone();
match event {
context_server::manager::Event::ServerStarted { server_id } => {
if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
let context_server_manager = context_server_manager.clone();
cx.spawn({
let server = server.clone();
let server_id = server_id.clone();
|this, mut cx| async move {
let Some(protocol) = server.client() else {
return;
};

if protocol.capable(context_server::protocol::ServerCapability::Tools) {
if let Some(tools) = protocol.list_tools().await.log_err() {
let tool_ids = tools
.tools
.into_iter()
.map(|tool| {
log::info!(
"registering context server tool: {:?}",
tool.name
);
tool_working_set.insert(Arc::new(
ContextServerTool::new(
context_server_manager.clone(),
server.id(),
tool,
),
))
})
.collect::<Vec<_>>();

this.update(&mut cx, |this, _cx| {
this.context_server_tool_ids.insert(server_id, tool_ids);
})
.log_err();
}
}
}
})
.detach();
}
}
context_server::manager::Event::ServerStopped { server_id } => {
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
tool_working_set.remove(&tool_ids);
}
}
}
}
}

0 comments on commit b88daae

Please sign in to comment.