From 39cc9a55eefae3fa8c641243e92c6922be6d98c9 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Fri, 30 Aug 2024 16:48:00 +0200 Subject: [PATCH] Somehow hitting the noop client --- sdk/openai/inference/Cargo.toml | 2 + .../examples/azure_chat_completions.rs | 27 +++++++- sdk/openai/inference/src/auth/mod.rs | 41 ++++++++++++ sdk/openai/inference/src/clients/azure.rs | 60 +++++++++++++++-- sdk/openai/inference/src/clients/mod.rs | 40 +++++++++++ sdk/openai/inference/src/lib.rs | 3 + sdk/openai/inference/src/models/mod.rs | 5 ++ .../src/models/request/chat_completions.rs | 66 +++++++++++++++++++ .../inference/src/models/request/mod.rs | 3 + .../src/models/response/chat_completions.rs | 36 ++++++++++ .../inference/src/models/response/mod.rs | 3 + 11 files changed, 278 insertions(+), 8 deletions(-) create mode 100644 sdk/openai/inference/src/auth/mod.rs create mode 100644 sdk/openai/inference/src/models/mod.rs create mode 100644 sdk/openai/inference/src/models/request/chat_completions.rs create mode 100644 sdk/openai/inference/src/models/request/mod.rs create mode 100644 sdk/openai/inference/src/models/response/chat_completions.rs create mode 100644 sdk/openai/inference/src/models/response/mod.rs diff --git a/sdk/openai/inference/Cargo.toml b/sdk/openai/inference/Cargo.toml index 679117e9bd..dcd68edec9 100644 --- a/sdk/openai/inference/Cargo.toml +++ b/sdk/openai/inference/Cargo.toml @@ -18,6 +18,8 @@ workspace = true azure_core = { workspace = true } reqwest = { workspace = true, optional = true } tokio = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } [features] default = [ "reqwest" ] diff --git a/sdk/openai/inference/examples/azure_chat_completions.rs b/sdk/openai/inference/examples/azure_chat_completions.rs index c593a093b0..1b2324c848 100644 --- a/sdk/openai/inference/examples/azure_chat_completions.rs +++ b/sdk/openai/inference/examples/azure_chat_completions.rs @@ -1,5 +1,7 @@ -use azure_core::{auth::TokenCredential, Result}; -use azure_openai_inference::AzureOpenAIClient; +use azure_core::Result; +use azure_openai_inference::{ + AzureOpenAIClient, AzureServiceVersion, CreateChatCompletionsRequest, +}; #[tokio::main] pub async fn main() -> Result<()> { @@ -9,5 +11,26 @@ pub async fn main() -> Result<()> { let azure_openai_client = AzureOpenAIClient::new(endpoint, secret)?; + let chat_completions_request = CreateChatCompletionsRequest::new_with_user_message( + "gpt-4-1106-preview", + "Tell me a joke about pineapples", + ); + + let response = azure_openai_client + .create_chat_completions( + &chat_completions_request.model, + AzureServiceVersion::V2023_12_01Preview, + &chat_completions_request, + ) + .await; + + match response { + Ok(chat_completions) => { + println!("{:#?}", &chat_completions); + } + Err(e) => { + println!("Error: {}", e); + } + }; Ok(()) } diff --git a/sdk/openai/inference/src/auth/mod.rs b/sdk/openai/inference/src/auth/mod.rs new file mode 100644 index 0000000000..e03d485fca --- /dev/null +++ b/sdk/openai/inference/src/auth/mod.rs @@ -0,0 +1,41 @@ +use azure_core::{ + auth::Secret, + headers::{HeaderName, HeaderValue, AUTHORIZATION}, + Header, +}; + +pub struct AzureKeyCredential(Secret); + +pub struct OpenAIKeyCredential(Secret); + +impl OpenAIKeyCredential { + pub fn new(access_token: String) -> Self { + Self(Secret::new(access_token)) + } +} + +impl AzureKeyCredential { + pub fn new(api_key: String) -> Self { + Self(Secret::new(api_key)) + } +} + +impl Header for AzureKeyCredential { + fn name(&self) -> HeaderName { + HeaderName::from_static("api-key") + } + + fn value(&self) -> HeaderValue { + HeaderValue::from_cow(format!("{}", self.0.secret())) + } +} + +impl Header for OpenAIKeyCredential { + fn name(&self) -> HeaderName { + AUTHORIZATION + } + + fn value(&self) -> HeaderValue { + HeaderValue::from_cow(format!("Bearer {}", &self.0.secret())) + } +} diff --git a/sdk/openai/inference/src/clients/azure.rs b/sdk/openai/inference/src/clients/azure.rs index 48fdc1e6c0..ce15841a93 100644 --- a/sdk/openai/inference/src/clients/azure.rs +++ b/sdk/openai/inference/src/clients/azure.rs @@ -1,24 +1,72 @@ use std::sync::Arc; -use azure_core::auth::{Secret, TokenCredential}; -use azure_core::{self, HttpClient, Result}; -use reqwest::Url; +use crate::auth::AzureKeyCredential; +use crate::models::CreateChatCompletionsRequest; +use crate::CreateChatCompletionsResponse; +use azure_core::Url; +use azure_core::{self, HttpClient, Method, Result}; + +// TODO: Implement using this instead +// typespec_client_core::json_model!(CreateChatCompletionsResponse); pub struct AzureOpenAIClient { http_client: Arc, endpoint: Url, - secret: Secret, + key_credential: AzureKeyCredential, } impl AzureOpenAIClient { pub fn new(endpoint: impl AsRef, secret: String) -> Result { let endpoint = Url::parse(endpoint.as_ref())?; - let secret = Secret::from(secret); + let key_credential = AzureKeyCredential::new(secret); Ok(AzureOpenAIClient { http_client: azure_core::new_http_client(), endpoint, - secret, + key_credential, }) } + + pub fn endpoint(&self) -> &Url { + &self.endpoint + } + + pub async fn create_chat_completions( + &self, + deployment_name: &str, + api_version: impl Into, + chat_completions_request: &CreateChatCompletionsRequest, + ) -> Result { + let url = Url::parse(&format!( + "{}/openai/deployments/{}/chat/completions?api-version={}", + &self.endpoint, + deployment_name, + api_version.into() + ))?; + let request = super::build_request( + &self.key_credential, + url, + Method::Post, + chat_completions_request, + )?; + let response = self.http_client.execute_request(&request).await?; + Ok(response.into_body().json().await?) + } +} + +pub enum AzureServiceVersion { + V2023_09_01Preview, + V2023_12_01Preview, + V2024_07_01Preview, +} + +impl From for String { + fn from(version: AzureServiceVersion) -> String { + let as_str = match version { + AzureServiceVersion::V2023_09_01Preview => "2023-09-01-preview", + AzureServiceVersion::V2023_12_01Preview => "2023-12-01-preview", + AzureServiceVersion::V2024_07_01Preview => "2024-07-01-preview", + }; + return String::from(as_str); + } } diff --git a/sdk/openai/inference/src/clients/mod.rs b/sdk/openai/inference/src/clients/mod.rs index 1b3cf68770..8740f354ae 100644 --- a/sdk/openai/inference/src/clients/mod.rs +++ b/sdk/openai/inference/src/clients/mod.rs @@ -1,2 +1,42 @@ +use azure_core::{ + headers::{ACCEPT, CONTENT_TYPE}, + Header, Method, Request, Result, Url, +}; +use serde::Serialize; + pub mod azure; pub mod non_azure; + +pub(crate) fn build_request( + key_credential: &impl Header, + url: Url, + method: Method, + data: &T, +) -> Result +where + T: ?Sized + Serialize, +{ + let mut request = Request::new(url, method); + request.add_mandatory_header(key_credential); + request.insert_header(CONTENT_TYPE, "application/json"); + request.insert_header(ACCEPT, "application/json"); + request.set_json(data)?; + Ok(request) +} + +// pub(crate) fn build_multipart_request( +// key_credential: &impl Header, +// url: Url, +// form_generator: F, +// ) -> Result +// where +// F: FnOnce() -> Result, +// { +// let mut request = Request::new(url, Method::Post); +// request.add_mandatory_header(key_credential); +// // handled insternally by reqwest +// // request.insert_header(CONTENT_TYPE, "multipart/form-data"); +// // request.insert_header(ACCEPT, "application/json"); +// request.multipart(form_generator()?); +// Ok(request) +// } diff --git a/sdk/openai/inference/src/lib.rs b/sdk/openai/inference/src/lib.rs index 228c367b4c..a84c4f69b4 100644 --- a/sdk/openai/inference/src/lib.rs +++ b/sdk/openai/inference/src/lib.rs @@ -1,4 +1,7 @@ +pub mod auth; mod clients; +mod models; pub use clients::azure::*; pub use clients::non_azure::*; +pub use models::*; diff --git a/sdk/openai/inference/src/models/mod.rs b/sdk/openai/inference/src/models/mod.rs new file mode 100644 index 0000000000..b8be6322b5 --- /dev/null +++ b/sdk/openai/inference/src/models/mod.rs @@ -0,0 +1,5 @@ +mod request; +mod response; + +pub use request::*; +pub use response::*; diff --git a/sdk/openai/inference/src/models/request/chat_completions.rs b/sdk/openai/inference/src/models/request/chat_completions.rs new file mode 100644 index 0000000000..3af246183a --- /dev/null +++ b/sdk/openai/inference/src/models/request/chat_completions.rs @@ -0,0 +1,66 @@ +use serde::Serialize; + +#[derive(Serialize, Debug, Clone, Default)] +pub struct CreateChatCompletionsRequest { + pub messages: Vec, + pub model: String, + pub stream: Option, + // pub frequency_penalty: f64, + // pub logit_bias: Option>, + // pub logprobs: Option, + // pub top_logprobs: Option, + // pub max_tokens: Option, +} + +#[derive(Serialize, Debug, Clone, Default)] +pub struct ChatCompletionRequestMessageBase { + #[serde(skip)] + pub name: Option, + pub content: String, // TODO this should be either a string or ChatCompletionRequestMessageContentPart (a polymorphic type) +} + +#[derive(Serialize, Debug, Clone)] +#[serde(tag = "role")] +pub enum ChatCompletionRequestMessage { + #[serde(rename = "system")] + System(ChatCompletionRequestMessageBase), + #[serde(rename = "user")] + User(ChatCompletionRequestMessageBase), +} + +impl ChatCompletionRequestMessage { + pub fn new_user(content: impl Into) -> Self { + Self::User(ChatCompletionRequestMessageBase { + content: content.into(), + name: None, + }) + } + + pub fn new_system(content: impl Into) -> Self { + Self::System(ChatCompletionRequestMessageBase { + content: content.into(), + name: None, + }) + } +} +impl CreateChatCompletionsRequest { + pub fn new_with_user_message(model: &str, prompt: &str) -> Self { + Self { + model: model.to_string(), + messages: vec![ChatCompletionRequestMessage::new_user(prompt)], + ..Default::default() + } + } + + pub fn new_stream_with_user_message( + model: impl Into, + prompt: impl Into, + ) -> Self { + Self { + model: model.into(), + messages: vec![ChatCompletionRequestMessage::new_user(prompt)], + stream: Some(true), + ..Default::default() + } + } +} diff --git a/sdk/openai/inference/src/models/request/mod.rs b/sdk/openai/inference/src/models/request/mod.rs new file mode 100644 index 0000000000..8ccec0e32c --- /dev/null +++ b/sdk/openai/inference/src/models/request/mod.rs @@ -0,0 +1,3 @@ +mod chat_completions; + +pub use chat_completions::*; diff --git a/sdk/openai/inference/src/models/response/chat_completions.rs b/sdk/openai/inference/src/models/response/chat_completions.rs new file mode 100644 index 0000000000..687fa7dca4 --- /dev/null +++ b/sdk/openai/inference/src/models/response/chat_completions.rs @@ -0,0 +1,36 @@ +use serde::Deserialize; + +#[derive(Debug, Clone, Deserialize)] +pub struct CreateChatCompletionsResponse { + pub choices: Vec, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ChatCompletionChoice { + pub message: ChatCompletionResponseMessage, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ChatCompletionResponseMessage { + pub content: Option, + pub role: String, +} + +// region: --- Streaming +#[derive(Debug, Clone, Deserialize)] +pub struct CreateChatCompletionsStreamResponse { + pub choices: Vec, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ChatCompletionStreamChoice { + pub delta: Option, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ChatCompletionStreamResponseMessage { + pub content: Option, + pub role: Option, +} + +// endregion: Streaming diff --git a/sdk/openai/inference/src/models/response/mod.rs b/sdk/openai/inference/src/models/response/mod.rs new file mode 100644 index 0000000000..8ccec0e32c --- /dev/null +++ b/sdk/openai/inference/src/models/response/mod.rs @@ -0,0 +1,3 @@ +mod chat_completions; + +pub use chat_completions::*;