diff --git a/crates/feature_flags/src/feature_flags.rs b/crates/feature_flags/src/feature_flags.rs index 48e3cc95b2dda7..f4bebca4d87b24 100644 --- a/crates/feature_flags/src/feature_flags.rs +++ b/crates/feature_flags/src/feature_flags.rs @@ -59,6 +59,11 @@ impl FeatureFlag for ToolUseFeatureFlag { } } +pub struct ZetaFeatureFlag; +impl FeatureFlag for ZetaFeatureFlag { + const NAME: &'static str = "zeta"; +} + pub struct Remoting {} impl FeatureFlag for Remoting { const NAME: &'static str = "remoting"; diff --git a/crates/inline_completion_button/src/inline_completion_button.rs b/crates/inline_completion_button/src/inline_completion_button.rs index a18c250875b6ac..06664f403c1c65 100644 --- a/crates/inline_completion_button/src/inline_completion_button.rs +++ b/crates/inline_completion_button/src/inline_completion_button.rs @@ -1,7 +1,7 @@ use anyhow::Result; use copilot::{Copilot, Status}; use editor::{scroll::Autoscroll, Editor}; -use feature_flags::FeatureFlagAppExt; +use feature_flags::{FeatureFlagAppExt, ZetaFeatureFlag}; use fs::Fs; use gpui::{ div, Action, AnchorCorner, AppContext, AsyncWindowContext, Entity, IntoElement, ParentElement, @@ -199,7 +199,7 @@ impl Render for InlineCompletionButton { } InlineCompletionProvider::Zeta => { - if !cx.is_staff() { + if !cx.has_flag::() { return div(); } diff --git a/crates/zed/src/zed/inline_completion_registry.rs b/crates/zed/src/zed/inline_completion_registry.rs index 2b9e300273d51e..34bd4d67b9dfd9 100644 --- a/crates/zed/src/zed/inline_completion_registry.rs +++ b/crates/zed/src/zed/inline_completion_registry.rs @@ -4,9 +4,9 @@ use client::Client; use collections::HashMap; use copilot::{Copilot, CopilotCompletionProvider}; use editor::{Editor, EditorMode}; -use feature_flags::FeatureFlagAppExt; +use feature_flags::{FeatureFlagAppExt, ZetaFeatureFlag}; use gpui::{AnyWindowHandle, AppContext, Context, ViewContext, WeakView}; -use language::language_settings::all_language_settings; +use language::language_settings::{all_language_settings, InlineCompletionProvider}; use settings::SettingsStore; use supermaven::{Supermaven, SupermavenCompletionProvider}; @@ -49,22 +49,45 @@ pub fn init(client: Arc, cx: &mut AppContext) { }); } - cx.observe_global::(move |cx| { - let new_provider = all_language_settings(None, cx).inline_completions.provider; - if new_provider != provider { - provider = new_provider; - for (editor, window) in editors.borrow().iter() { - _ = window.update(cx, |_window, cx| { - _ = editor.update(cx, |editor, cx| { - assign_inline_completion_provider(editor, provider, &client, cx); - }) - }); + cx.observe_flag::({ + let editors = editors.clone(); + let client = client.clone(); + move |_flag, cx| { + let provider = all_language_settings(None, cx).inline_completions.provider; + assign_inline_completion_providers(&editors, provider, &client, cx) + } + }) + .detach(); + + cx.observe_global::({ + let editors = editors.clone(); + let client = client.clone(); + move |cx| { + let new_provider = all_language_settings(None, cx).inline_completions.provider; + if new_provider != provider { + provider = new_provider; + assign_inline_completion_providers(&editors, provider, &client, cx) } } }) .detach(); } +fn assign_inline_completion_providers( + editors: &Rc, AnyWindowHandle>>>, + provider: InlineCompletionProvider, + client: &Arc, + cx: &mut AppContext, +) { + for (editor, window) in editors.borrow().iter() { + _ = window.update(cx, |_window, cx| { + _ = editor.update(cx, |editor, cx| { + assign_inline_completion_provider(editor, provider, &client, cx); + }) + }); + } +} + fn register_backward_compatible_actions(editor: &mut Editor, cx: &ViewContext) { // We renamed some of these actions to not be copilot-specific, but that // would have not been backwards-compatible. So here we are re-registering @@ -129,7 +152,7 @@ fn assign_inline_completion_provider( } } language::language_settings::InlineCompletionProvider::Zeta => { - if cx.is_staff() { + if cx.has_flag::() { let zeta = zeta::Zeta::register(client.clone(), cx); if let Some(buffer) = editor.buffer().read(cx).as_singleton() { if buffer.read(cx).file().is_some() { diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index dea15b0b08dbd1..a1b171304e1793 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -13,7 +13,7 @@ use language::{ Point, ToOffset, ToPoint, }; use language_models::LlmApiToken; -use rpc::{PredictEditsParams, PredictEditsResponse}; +use rpc::{PredictEditsParams, PredictEditsResponse, EXPIRED_LLM_TOKEN_HEADER_NAME}; use std::{ borrow::Cow, cmp, @@ -269,8 +269,6 @@ impl Zeta { cx.spawn(|this, mut cx| async move { let start = std::time::Instant::now(); - let token = llm_token.acquire(&client).await?; - let mut input_events = String::new(); for event in events { if !input_events.is_empty() { @@ -283,141 +281,192 @@ impl Zeta { log::debug!("Events:\n{}\nExcerpt:\n{}", input_events, input_excerpt); - let http_client = client.http_client(); let body = PredictEditsParams { input_events: input_events.clone(), input_excerpt: input_excerpt.clone(), }; + + let response = Self::perform_predict_edits(&client, llm_token, body).await?; + + let output_excerpt = response.output_excerpt; + log::debug!("prediction took: {:?}", start.elapsed()); + log::debug!("completion response: {}", output_excerpt); + + let inline_completion = Self::process_completion_response( + output_excerpt, + &snapshot, + excerpt_range, + path, + input_events, + input_excerpt, + )?; + + this.update(&mut cx, |this, cx| { + this.recent_completions + .push_front(inline_completion.clone()); + if this.recent_completions.len() > 50 { + this.recent_completions.pop_back(); + } + cx.notify(); + })?; + + Ok(inline_completion) + }) + } + + async fn perform_predict_edits( + client: &Arc, + llm_token: LlmApiToken, + body: PredictEditsParams, + ) -> Result { + let http_client = client.http_client(); + let mut token = llm_token.acquire(client).await?; + let mut did_retry = false; + + loop { let request_builder = http_client::Request::builder(); let request = request_builder .method(Method::POST) .uri( - client - .http_client() + http_client .build_zed_llm_url("/predict_edits", &[])? .as_ref(), ) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", token)) .body(serde_json::to_string(&body)?.into())?; + let mut response = http_client.send(request).await?; - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - if !response.status().is_success() { + + if response.status().is_success() { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + return Ok(serde_json::from_str(&body)?); + } else if !did_retry + && response + .headers() + .get(EXPIRED_LLM_TOKEN_HEADER_NAME) + .is_some() + { + did_retry = true; + token = llm_token.refresh(client).await?; + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; return Err(anyhow!( "error predicting edits.\nStatus: {:?}\nBody: {}", response.status(), body )); } + } + } - let response = serde_json::from_str::(&body)?; - let output_excerpt = response.output_excerpt; - log::debug!("prediction took: {:?}", start.elapsed()); - log::debug!("completion response: {}", output_excerpt); + fn process_completion_response( + output_excerpt: String, + snapshot: &BufferSnapshot, + excerpt_range: Range, + path: Arc, + input_events: String, + input_excerpt: String, + ) -> Result { + let content = output_excerpt.replace(CURSOR_MARKER, ""); - let content = output_excerpt.replace(CURSOR_MARKER, ""); - let mut new_text = content.as_str(); + let codefence_start = content + .find(EDITABLE_REGION_START_MARKER) + .context("could not find start marker")?; + let content = &content[codefence_start..]; - let codefence_start = new_text - .find(EDITABLE_REGION_START_MARKER) - .context("could not find start marker")?; - new_text = &new_text[codefence_start..]; + let newline_ix = content.find('\n').context("could not find newline")?; + let content = &content[newline_ix + 1..]; - let newline_ix = new_text.find('\n').context("could not find newline")?; - new_text = &new_text[newline_ix + 1..]; + let codefence_end = content + .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}")) + .context("could not find end marker")?; + let new_text = &content[..codefence_end]; - let codefence_end = new_text - .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}")) - .context("could not find end marker")?; - new_text = &new_text[..codefence_end]; - log::debug!("sanitized completion response: {}", new_text); + let old_text = snapshot + .text_for_range(excerpt_range.clone()) + .collect::(); - let old_text = snapshot - .text_for_range(excerpt_range.clone()) - .collect::(); + let edits = Self::compute_edits(old_text, new_text, excerpt_range.start, snapshot); - let diff = similar::TextDiff::from_chars(old_text.as_str(), new_text); + Ok(InlineCompletion { + id: InlineCompletionId::new(), + path, + excerpt_range, + edits: edits.into(), + snapshot: snapshot.clone(), + input_events: input_events.into(), + input_excerpt: input_excerpt.into(), + output_excerpt: output_excerpt.into(), + }) + } - let mut edits: Vec<(Range, String)> = Vec::new(); - let mut old_start = excerpt_range.start; - for change in diff.iter_all_changes() { - let value = change.value(); - match change.tag() { - similar::ChangeTag::Equal => { - old_start += value.len(); - } - similar::ChangeTag::Delete => { - let old_end = old_start + value.len(); - if let Some((last_old_range, _)) = edits.last_mut() { - if last_old_range.end == old_start { - last_old_range.end = old_end; - } else { - edits.push((old_start..old_end, String::new())); - } + fn compute_edits( + old_text: String, + new_text: &str, + offset: usize, + snapshot: &BufferSnapshot, + ) -> Vec<(Range, String)> { + let diff = similar::TextDiff::from_chars(old_text.as_str(), new_text); + + let mut edits: Vec<(Range, String)> = Vec::new(); + let mut old_start = offset; + for change in diff.iter_all_changes() { + let value = change.value(); + match change.tag() { + similar::ChangeTag::Equal => { + old_start += value.len(); + } + similar::ChangeTag::Delete => { + let old_end = old_start + value.len(); + if let Some((last_old_range, _)) = edits.last_mut() { + if last_old_range.end == old_start { + last_old_range.end = old_end; } else { edits.push((old_start..old_end, String::new())); } - - old_start = old_end; + } else { + edits.push((old_start..old_end, String::new())); } - similar::ChangeTag::Insert => { - if let Some((last_old_range, last_new_text)) = edits.last_mut() { - if last_old_range.end == old_start { - last_new_text.push_str(value); - } else { - edits.push((old_start..old_start, value.into())); - } + old_start = old_end; + } + similar::ChangeTag::Insert => { + if let Some((last_old_range, last_new_text)) = edits.last_mut() { + if last_old_range.end == old_start { + last_new_text.push_str(value); } else { edits.push((old_start..old_start, value.into())); } + } else { + edits.push((old_start..old_start, value.into())); } } } + } - let edits = edits - .into_iter() - .map(|(mut old_range, new_text)| { - let prefix_len = common_prefix( - snapshot.chars_for_range(old_range.clone()), - new_text.chars(), - ); - old_range.start += prefix_len; - let suffix_len = common_prefix( - snapshot.reversed_chars_for_range(old_range.clone()), - new_text[prefix_len..].chars().rev(), - ); - old_range.end = old_range.end.saturating_sub(suffix_len); - - let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string(); - ( - snapshot.anchor_after(old_range.start) - ..snapshot.anchor_before(old_range.end), - new_text, - ) - }) - .collect(); - let inline_completion = InlineCompletion { - id: InlineCompletionId::new(), - path, - excerpt_range, - edits, - snapshot, - input_events: input_events.into(), - input_excerpt: input_excerpt.into(), - output_excerpt: output_excerpt.into(), - }; - this.update(&mut cx, |this, cx| { - this.recent_completions - .push_front(inline_completion.clone()); - if this.recent_completions.len() > 50 { - this.recent_completions.pop_back(); - } - cx.notify(); - })?; - - Ok(inline_completion) - }) + edits + .into_iter() + .map(|(mut old_range, new_text)| { + let prefix_len = common_prefix( + snapshot.chars_for_range(old_range.clone()), + new_text.chars(), + ); + old_range.start += prefix_len; + let suffix_len = common_prefix( + snapshot.reversed_chars_for_range(old_range.clone()), + new_text[prefix_len..].chars().rev(), + ); + old_range.end = old_range.end.saturating_sub(suffix_len); + + let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string(); + ( + snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end), + new_text, + ) + }) + .collect() } pub fn is_completion_rated(&self, completion_id: InlineCompletionId) -> bool {