Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add adaptors for various backends (ollama, tgi, api-inference) #40

Merged
merged 14 commits into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 163 additions & 0 deletions crates/llm-ls/src/adaptors.rs
noahbald marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
use super::{
internal_error, APIError, APIResponse, CompletionParams, Generation, Ide, NAME, VERSION,
};
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tower_lsp::jsonrpc;

fn build_tgi_body(prompt: String, params: CompletionParams) -> Value {
noahbald marked this conversation as resolved.
Show resolved Hide resolved
serde_json::json!({
"inputs": prompt,
"parameters": params.request_params,
})
}

fn build_tgi_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap, jsonrpc::Error> {
let mut headers = HeaderMap::new();
let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}");
headers.insert(
USER_AGENT,
HeaderValue::from_str(&user_agent).map_err(internal_error)?,
);

if let Some(api_token) = api_token {
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {api_token}")).map_err(internal_error)?,
);
}

Ok(headers)
}

fn parse_tgi_text(text: reqwest::Result<String>) -> Result<Vec<Generation>, jsonrpc::Error> {
noahbald marked this conversation as resolved.
Show resolved Hide resolved
let generations =
match serde_json::from_str(&text.unwrap_or_default()).map_err(internal_error)? {
noahbald marked this conversation as resolved.
Show resolved Hide resolved
APIResponse::Generation(gen) => vec![gen],
APIResponse::Generations(_) => {
return Err(internal_error(
"TGI parser unexpectedly encountered api-inference",
))
noahbald marked this conversation as resolved.
Show resolved Hide resolved
}
APIResponse::Error(err) => return Err(internal_error(err)),
};
Ok(generations)
}

fn build_api_body(prompt: String, params: CompletionParams) -> Value {
noahbald marked this conversation as resolved.
Show resolved Hide resolved
build_tgi_body(prompt, params)
}

fn build_api_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap, jsonrpc::Error> {
build_tgi_headers(api_token, ide)
}

fn parse_api_text(text: reqwest::Result<String>) -> Result<Vec<Generation>, jsonrpc::Error> {
noahbald marked this conversation as resolved.
Show resolved Hide resolved
let generations =
match serde_json::from_str(&text.unwrap_or_default()).map_err(internal_error)? {
noahbald marked this conversation as resolved.
Show resolved Hide resolved
APIResponse::Generation(gen) => vec![gen],
APIResponse::Generations(gens) => gens,
APIResponse::Error(err) => return Err(internal_error(err)),
};
Ok(generations)
}

fn build_ollama_body(prompt: String, params: CompletionParams) -> Value {
noahbald marked this conversation as resolved.
Show resolved Hide resolved
let request_body = params.request_body.unwrap_or_default();
let body = serde_json::json!({
"prompt": prompt,
"model": request_body.get("model"),
noahbald marked this conversation as resolved.
Show resolved Hide resolved
});
body
}
fn build_ollama_headers() -> Result<HeaderMap, jsonrpc::Error> {
let headers = HeaderMap::new();
Ok(headers)
}

#[derive(Debug, Serialize, Deserialize)]
struct OllamaGeneration {
response: String,
}

#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum OllamaAPIResponse {
Generation(OllamaGeneration),
Error(APIError),
}

fn parse_ollama_text(
text: Result<String, reqwest::Error>,
noahbald marked this conversation as resolved.
Show resolved Hide resolved
) -> Result<Vec<Generation>, jsonrpc::Error> {
match text {
Ok(text) => {
noahbald marked this conversation as resolved.
Show resolved Hide resolved
let mut gen: Vec<Generation> = Vec::new();
for row in text.split('\n') {
if row.is_empty() {
continue;
}
let chunk = match serde_json::from_str(row) {
Ok(OllamaAPIResponse::Generation(ollama_gen)) => ollama_gen.response,
Ok(OllamaAPIResponse::Error(err)) => return Err(internal_error(err)),
Err(err) => return Err(internal_error(err)),
};
gen.push(Generation {
generated_text: chunk,
})
}
Ok(gen)
noahbald marked this conversation as resolved.
Show resolved Hide resolved
}
Err(err) => Err(internal_error(err)),
}
noahbald marked this conversation as resolved.
Show resolved Hide resolved
}

const TGI: &str = "tgi";
const HUGGING_FACE: &str = "huggingface";
const OLLAMA: &str = "ollama";
const DEFAULT_ADAPTOR: &str = HUGGING_FACE;

fn unknown_adaptor_error(adaptor: String) -> jsonrpc::Error {
internal_error(format!("Unknown adaptor {}", adaptor))
}

pub fn adapt_body(prompt: String, params: CompletionParams) -> Result<Value, jsonrpc::Error> {
noahbald marked this conversation as resolved.
Show resolved Hide resolved
let adaptor = params
.adaptor
.clone()
.unwrap_or(DEFAULT_ADAPTOR.to_string());
match adaptor.as_str() {
McPatate marked this conversation as resolved.
Show resolved Hide resolved
TGI => Ok(build_tgi_body(prompt, params)),
HUGGING_FACE => Ok(build_api_body(prompt, params)),
noahbald marked this conversation as resolved.
Show resolved Hide resolved
OLLAMA => Ok(build_ollama_body(prompt, params)),
_ => Err(unknown_adaptor_error(adaptor)),
}
}

pub fn adapt_headers(
adaptor: Option<String>,
noahbald marked this conversation as resolved.
Show resolved Hide resolved
api_token: Option<&String>,
ide: Ide,
) -> Result<HeaderMap, jsonrpc::Error> {
let adaptor = adaptor.clone().unwrap_or(DEFAULT_ADAPTOR.to_string());
match adaptor.as_str() {
noahbald marked this conversation as resolved.
Show resolved Hide resolved
TGI => build_tgi_headers(api_token, ide),
HUGGING_FACE => build_api_headers(api_token, ide),
OLLAMA => build_ollama_headers(),
_ => Err(internal_error(adaptor)),
}
}

pub fn adapt_text(
noahbald marked this conversation as resolved.
Show resolved Hide resolved
adaptor: Option<String>,
text: Result<String, reqwest::Error>,
) -> jsonrpc::Result<Vec<Generation>> {
let adaptor = adaptor.clone().unwrap_or(DEFAULT_ADAPTOR.to_string());
match adaptor.as_str() {
TGI => parse_tgi_text(text),
HUGGING_FACE => parse_api_text(text),
OLLAMA => parse_ollama_text(text),
_ => Err(unknown_adaptor_error(adaptor)),
}
}
63 changes: 32 additions & 31 deletions crates/llm-ls/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use adaptors::{adapt_body, adapt_headers, adapt_text};
use document::Document;
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT};
use ropey::Rope;
Expand All @@ -18,12 +19,13 @@ use tracing_appender::rolling;
use tracing_subscriber::EnvFilter;
use uuid::Uuid;

mod adaptors;
mod document;
mod language_id;

const MAX_WARNING_REPEAT: Duration = Duration::from_secs(3_600);
const NAME: &str = "llm-ls";
const VERSION: &str = env!("CARGO_PKG_VERSION");
pub const NAME: &str = "llm-ls";
pub const VERSION: &str = env!("CARGO_PKG_VERSION");

fn get_position_idx(rope: &Rope, row: usize, col: usize) -> Result<usize> {
Ok(rope.try_line_to_char(row).map_err(internal_error)?
Expand Down Expand Up @@ -120,7 +122,7 @@ fn should_complete(document: &Document, position: Position) -> Result<Completion
}
}

#[derive(Debug, Deserialize, Serialize)]
#[derive(Debug, Deserialize, Serialize, Clone)]
#[serde(untagged)]
enum TokenizerConfig {
Local { path: PathBuf },
Expand Down Expand Up @@ -178,12 +180,12 @@ struct APIRequest {
}

#[derive(Debug, Serialize, Deserialize)]
struct Generation {
pub struct Generation {
generated_text: String,
}

#[derive(Debug, Deserialize)]
struct APIError {
pub struct APIError {
error: String,
}

Expand All @@ -195,7 +197,7 @@ impl Display for APIError {

#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum APIResponse {
pub enum APIResponse {
Generation(Generation),
Generations(Vec<Generation>),
Error(APIError),
Expand All @@ -219,7 +221,7 @@ struct Completion {

#[derive(Clone, Copy, Debug, Default, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
enum Ide {
pub enum Ide {
Neovim,
VSCode,
JetBrains,
Expand Down Expand Up @@ -261,7 +263,7 @@ struct RejectedCompletion {

#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
struct CompletionParams {
pub struct CompletionParams {
#[serde(flatten)]
text_document_position: TextDocumentPositionParams,
request_params: RequestParams,
Expand All @@ -271,10 +273,12 @@ struct CompletionParams {
fim: FimParams,
api_token: Option<String>,
model: String,
adaptor: Option<String>,
tokens_to_clear: Vec<String>,
tokenizer_config: Option<TokenizerConfig>,
context_window: usize,
tls_skip_verify_insecure: bool,
request_body: Option<serde_json::Map<String, serde_json::Value>>,
}

#[derive(Debug, Deserialize, Serialize)]
Expand All @@ -283,7 +287,7 @@ struct CompletionResult {
completions: Vec<Completion>,
}

fn internal_error<E: Display>(err: E) -> Error {
pub fn internal_error<E: Display>(err: E) -> Error {
let err_msg = err.to_string();
error!(err_msg);
Error {
Expand Down Expand Up @@ -398,37 +402,34 @@ fn build_prompt(

async fn request_completion(
http_client: &reqwest::Client,
ide: Ide,
model: &str,
request_params: RequestParams,
api_token: Option<&String>,
prompt: String,
params: CompletionParams,
) -> Result<Vec<Generation>> {
let t = Instant::now();
let model = params.model.clone();
let adaptor = params.adaptor.clone();
let api_token = params.api_token.clone();
let ide = params.ide;

let json = adapt_body(prompt, params);
let headers = adapt_headers(adaptor.clone(), api_token.as_ref(), ide)?;
let res = http_client
.post(build_url(model))
.json(&APIRequest {
inputs: prompt,
parameters: request_params.into(),
})
.headers(build_headers(api_token, ide)?)
.post(build_url(&model))
.json(&json)
.headers(headers)
.send()
.await
.map_err(internal_error)?;

let generations = match res.json().await.map_err(internal_error)? {
APIResponse::Generation(gen) => vec![gen],
APIResponse::Generations(gens) => gens,
APIResponse::Error(err) => return Err(internal_error(err)),
};
let generations = adapt_text(adaptor, res.text().await);
noahbald marked this conversation as resolved.
Show resolved Hide resolved
noahbald marked this conversation as resolved.
Show resolved Hide resolved
let time = t.elapsed().as_millis();
info!(
model,
compute_generations_ms = time,
generations = serde_json::to_string(&generations).map_err(internal_error)?,
"{model} computed generations in {time} ms"
);
Ok(generations)
generations
}

fn parse_generations(
Expand Down Expand Up @@ -624,10 +625,11 @@ impl Backend {
return Ok(CompletionResult { request_id, completions: vec![]});
}

let tokenizer_config = params.tokenizer_config.clone();
noahbald marked this conversation as resolved.
Show resolved Hide resolved
let tokenizer = get_tokenizer(
&params.model,
&mut *self.tokenizer_map.write().await,
params.tokenizer_config,
tokenizer_config,
&self.http_client,
&self.cache_dir,
params.api_token.as_ref(),
Expand All @@ -648,17 +650,15 @@ impl Backend {
} else {
&self.http_client
};
let tokens_to_clear = params.tokens_to_clear.clone();
noahbald marked this conversation as resolved.
Show resolved Hide resolved
let result = request_completion(
http_client,
params.ide,
&params.model,
params.request_params,
params.api_token.as_ref(),
prompt,
params,
)
.await?;

let completions = parse_generations(result, &params.tokens_to_clear, completion_type);
let completions = parse_generations(result, &tokens_to_clear, completion_type);
Ok(CompletionResult { request_id, completions })
}.instrument(span).await
}
Expand Down Expand Up @@ -849,3 +849,4 @@ async fn main() {

Server::new(stdin, stdout, socket).serve(service).await;
}