Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add adaptors for various backends (ollama, tgi, api-inference) #40

Merged
merged 14 commits into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 142 additions & 0 deletions crates/llm-ls/src/adaptors.rs
noahbald marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
use super::{
internal_error, APIError, APIResponse, CompletionParams, Generation, Ide, NAME, VERSION,
};
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tower_lsp::jsonrpc;

struct AdaptHuggingFaceRequest;
impl AdaptHuggingFaceRequest {
fn adapt_body(&self, prompt: String, params: CompletionParams) -> Value {
return serde_json::json!({
"inputs": prompt,
"parameters": params.request_params,
});
}
fn adapt_headers(
&self,
api_token: Option<&String>,
ide: Ide,
) -> Result<HeaderMap, jsonrpc::Error> {
let mut headers = HeaderMap::new();
let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}");
headers.insert(
USER_AGENT,
HeaderValue::from_str(&user_agent).map_err(internal_error)?,
);

if let Some(api_token) = api_token {
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {api_token}")).map_err(internal_error)?,
);
}

Ok(headers)
}
}

struct AdaptHuggingFaceResponse;
noahbald marked this conversation as resolved.
Show resolved Hide resolved
impl AdaptHuggingFaceResponse {
fn adapt_blob(&self, text: reqwest::Result<String>) -> Result<Vec<Generation>, jsonrpc::Error> {
let generations =
match serde_json::from_str(&text.unwrap_or_default()).map_err(internal_error)? {
APIResponse::Generation(gen) => vec![gen],
APIResponse::Generations(gens) => gens,
APIResponse::Error(err) => return Err(internal_error(err)),
};
Ok(generations)
}
}

struct AdaptOllamaRequest;
noahbald marked this conversation as resolved.
Show resolved Hide resolved
impl AdaptOllamaRequest {
fn adapt_body(&self, prompt: String, params: CompletionParams) -> Value {
let request_body = params.request_body.unwrap_or_default();
let body = serde_json::json!({
"prompt": prompt,
"model": request_body.get("model"),
});
body
}
fn adapt_headers(&self) -> Result<HeaderMap, jsonrpc::Error> {
let headers = HeaderMap::new();
Ok(headers)
}
}

#[derive(Debug, Serialize, Deserialize)]
pub struct OllamaGeneration {
response: String,
}

#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub enum OllamaAPIResponse {
Generation(OllamaGeneration),
Error(APIError),
}

struct AdaptOllamaResponse;
impl AdaptOllamaResponse {
fn adapt_blob(
&self,
text: Result<String, reqwest::Error>,
) -> Result<Vec<Generation>, jsonrpc::Error> {
match text {
Ok(text) => {
let mut gen: Vec<Generation> = Vec::new();
for row in text.split("\n") {
if row.is_empty() {
continue;
}
let chunk = match serde_json::from_str(row) {
Ok(OllamaAPIResponse::Generation(ollama_gen)) => ollama_gen.response,
Ok(OllamaAPIResponse::Error(err)) => return Err(internal_error(err)),
Err(err) => return Err(internal_error(err)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Ok(OllamaAPIResponse::Error(err)) => return Err(internal_error(err)),
Err(err) => return Err(internal_error(err)),
Ok(OllamaAPIResponse::Error(err)) | Err(err) => return Err(internal_error(err)),

does this work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think not, since Ok(OllamaAPIResponse::Error(err)) and Err(err) have conflicting types

};
gen.push(Generation {
generated_text: chunk,
})
}
Ok(gen)
}
Err(err) => Err(internal_error(err)),
}
}
}

const HUGGING_FACE_ADAPTOR: &str = "huggingface";

pub struct Adaptors;
noahbald marked this conversation as resolved.
Show resolved Hide resolved
impl Adaptors {
pub fn adapt_body(&self, prompt: String, params: CompletionParams) -> Value {
let adaptor = params.adaptor.clone();
match adaptor.unwrap_or(HUGGING_FACE_ADAPTOR.to_string()).as_str() {
"ollama" => AdaptOllamaRequest.adapt_body(prompt, params),
_ => AdaptHuggingFaceRequest.adapt_body(prompt, params),
}
}
pub fn adapt_headers(
&self,
adaptor: Option<String>,
api_token: Option<&String>,
ide: Ide,
) -> Result<HeaderMap, jsonrpc::Error> {
match adaptor.unwrap_or(HUGGING_FACE_ADAPTOR.to_string()).as_str() {
"ollama" => AdaptOllamaRequest.adapt_headers(),
_ => AdaptHuggingFaceRequest.adapt_headers(api_token, ide),
}
}
pub fn adapt_blob(
noahbald marked this conversation as resolved.
Show resolved Hide resolved
&self,
adaptor: Option<String>,
text: Result<String, reqwest::Error>,
) -> Result<Vec<Generation>, jsonrpc::Error> {
match adaptor.unwrap_or(HUGGING_FACE_ADAPTOR.to_string()).as_str() {
"ollama" => AdaptOllamaResponse.adapt_blob(text),
_ => AdaptHuggingFaceResponse.adapt_blob(text),
noahbald marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
63 changes: 32 additions & 31 deletions crates/llm-ls/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use adaptors::Adaptors;
use document::Document;
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT};
use ropey::Rope;
Expand All @@ -18,12 +19,13 @@ use tracing_appender::rolling;
use tracing_subscriber::EnvFilter;
use uuid::Uuid;

mod adaptors;
mod document;
mod language_id;

const MAX_WARNING_REPEAT: Duration = Duration::from_secs(3_600);
const NAME: &str = "llm-ls";
const VERSION: &str = env!("CARGO_PKG_VERSION");
pub const NAME: &str = "llm-ls";
pub const VERSION: &str = env!("CARGO_PKG_VERSION");

fn get_position_idx(rope: &Rope, row: usize, col: usize) -> Result<usize> {
Ok(rope.try_line_to_char(row).map_err(internal_error)?
Expand Down Expand Up @@ -120,7 +122,7 @@ fn should_complete(document: &Document, position: Position) -> Result<Completion
}
}

#[derive(Debug, Deserialize, Serialize)]
#[derive(Debug, Deserialize, Serialize, Clone)]
#[serde(untagged)]
enum TokenizerConfig {
Local { path: PathBuf },
Expand Down Expand Up @@ -178,12 +180,12 @@ struct APIRequest {
}

#[derive(Debug, Serialize, Deserialize)]
struct Generation {
pub struct Generation {
generated_text: String,
}

#[derive(Debug, Deserialize)]
struct APIError {
pub struct APIError {
error: String,
}

Expand All @@ -195,7 +197,7 @@ impl Display for APIError {

#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum APIResponse {
pub enum APIResponse {
Generation(Generation),
Generations(Vec<Generation>),
Error(APIError),
Expand All @@ -219,7 +221,7 @@ struct Completion {

#[derive(Clone, Copy, Debug, Default, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
enum Ide {
pub enum Ide {
Neovim,
VSCode,
JetBrains,
Expand Down Expand Up @@ -261,7 +263,7 @@ struct RejectedCompletion {

#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
struct CompletionParams {
pub struct CompletionParams {
#[serde(flatten)]
text_document_position: TextDocumentPositionParams,
request_params: RequestParams,
Expand All @@ -271,10 +273,12 @@ struct CompletionParams {
fim: FimParams,
api_token: Option<String>,
model: String,
adaptor: Option<String>,
tokens_to_clear: Vec<String>,
tokenizer_config: Option<TokenizerConfig>,
context_window: usize,
tls_skip_verify_insecure: bool,
request_body: Option<serde_json::Map<String, serde_json::Value>>,
}

#[derive(Debug, Deserialize, Serialize)]
Expand All @@ -283,7 +287,7 @@ struct CompletionResult {
completions: Vec<Completion>,
}

fn internal_error<E: Display>(err: E) -> Error {
pub fn internal_error<E: Display>(err: E) -> Error {
let err_msg = err.to_string();
error!(err_msg);
Error {
Expand Down Expand Up @@ -398,37 +402,34 @@ fn build_prompt(

async fn request_completion(
http_client: &reqwest::Client,
ide: Ide,
model: &str,
request_params: RequestParams,
api_token: Option<&String>,
prompt: String,
params: CompletionParams,
) -> Result<Vec<Generation>> {
let t = Instant::now();
let model = params.model.clone();
let adaptor = params.adaptor.clone();
let api_token = params.api_token.clone();
let ide = params.ide.clone();

let json = Adaptors.adapt_body(prompt, params);
let headers = Adaptors.adapt_headers(adaptor.clone(), api_token.as_ref(), ide)?;
let res = http_client
.post(build_url(model))
.json(&APIRequest {
inputs: prompt,
parameters: request_params.into(),
})
.headers(build_headers(api_token, ide)?)
.post(build_url(&model))
.json(&json)
.headers(headers)
.send()
.await
.map_err(internal_error)?;

let generations = match res.json().await.map_err(internal_error)? {
APIResponse::Generation(gen) => vec![gen],
APIResponse::Generations(gens) => gens,
APIResponse::Error(err) => return Err(internal_error(err)),
};
let generations = Adaptors.adapt_blob(adaptor, res.text().await);
noahbald marked this conversation as resolved.
Show resolved Hide resolved
let time = t.elapsed().as_millis();
info!(
model,
compute_generations_ms = time,
generations = serde_json::to_string(&generations).map_err(internal_error)?,
"{model} computed generations in {time} ms"
);
Ok(generations)
generations
}

fn parse_generations(
Expand Down Expand Up @@ -624,10 +625,11 @@ impl Backend {
return Ok(CompletionResult { request_id, completions: vec![]});
}

let tokenizer_config = params.tokenizer_config.clone();
noahbald marked this conversation as resolved.
Show resolved Hide resolved
let tokenizer = get_tokenizer(
&params.model,
&mut *self.tokenizer_map.write().await,
params.tokenizer_config,
tokenizer_config,
&self.http_client,
&self.cache_dir,
params.api_token.as_ref(),
Expand All @@ -648,17 +650,15 @@ impl Backend {
} else {
&self.http_client
};
let tokens_to_clear = params.tokens_to_clear.clone();
noahbald marked this conversation as resolved.
Show resolved Hide resolved
let result = request_completion(
http_client,
params.ide,
&params.model,
params.request_params,
params.api_token.as_ref(),
prompt,
params,
)
.await?;

let completions = parse_generations(result, &params.tokens_to_clear, completion_type);
let completions = parse_generations(result, &tokens_to_clear, completion_type);
Ok(CompletionResult { request_id, completions })
}.instrument(span).await
}
Expand Down Expand Up @@ -849,3 +849,4 @@ async fn main() {

Server::new(stdin, stdout, socket).serve(service).await;
}