diff --git a/sdk/openai/inference/examples/azure_chat_completions.rs b/sdk/openai/inference/examples/azure_chat_completions.rs index 4e1595d65e..81f68e1b0f 100644 --- a/sdk/openai/inference/examples/azure_chat_completions.rs +++ b/sdk/openai/inference/examples/azure_chat_completions.rs @@ -1,7 +1,8 @@ use azure_core::Result; use azure_openai_inference::{ - request::CreateChatCompletionsRequest, AzureOpenAIClient, AzureOpenAIClientOptions, - AzureServiceVersion, + clients::{AzureOpenAIClient, AzureOpenAIClientMethods, ChatCompletionsClientMethods}, + request::CreateChatCompletionsRequest, + AzureOpenAIClientOptions, AzureServiceVersion, }; #[tokio::main] @@ -10,7 +11,7 @@ pub async fn main() -> Result<()> { std::env::var("AZURE_OPENAI_ENDPOINT").expect("Set AZURE_OPENAI_ENDPOINT env variable"); let secret = std::env::var("AZURE_OPENAI_KEY").expect("Set AZURE_OPENAI_KEY env variable"); - let azure_openai_client = AzureOpenAIClient::with_key( + let chat_completions_client = AzureOpenAIClient::with_key( endpoint, secret, Some( @@ -18,14 +19,15 @@ pub async fn main() -> Result<()> { .with_api_version(AzureServiceVersion::V2023_12_01Preview) .build(), ), - )?; + )? + .chat_completions_client(); let chat_completions_request = CreateChatCompletionsRequest::new_with_user_message( "gpt-4-1106-preview", "Tell me a joke about pineapples", ); - let response = azure_openai_client + let response = chat_completions_client .create_chat_completions(&chat_completions_request.model, &chat_completions_request) .await; diff --git a/sdk/openai/inference/src/auth/mod.rs b/sdk/openai/inference/src/auth/mod.rs index 63db568285..02dcf59f6f 100644 --- a/sdk/openai/inference/src/auth/mod.rs +++ b/sdk/openai/inference/src/auth/mod.rs @@ -1,5 +1,5 @@ mod azure_key_credential; -mod openai_key_credential; +// mod openai_key_credential; pub(crate) use azure_key_credential::*; -pub(crate) use openai_key_credential::*; +// pub(crate) use openai_key_credential::*; diff --git a/sdk/openai/inference/src/clients/azure_openai_client.rs b/sdk/openai/inference/src/clients/azure_openai_client.rs index 7abcff36c6..8e7156d1d1 100644 --- a/sdk/openai/inference/src/clients/azure_openai_client.rs +++ b/sdk/openai/inference/src/clients/azure_openai_client.rs @@ -3,19 +3,35 @@ use std::sync::Arc; use crate::auth::AzureKeyCredential; use crate::options::AzureOpenAIClientOptions; -use crate::request::CreateChatCompletionsRequest; -use crate::response::CreateChatCompletionsResponse; -use azure_core::{self, Method, Policy, Result}; -use azure_core::{Context, Url}; +use azure_core::Url; +use azure_core::{self, Policy, Result}; +use super::chat_completions_client::ChatCompletionsClient; +use super::BaseOpenAIClientMethods; + +pub trait AzureOpenAIClientMethods: BaseOpenAIClientMethods { + fn with_key( + endpoint: impl AsRef, + secret: impl Into, + client_options: Option, + ) -> Result + where + Self: Sized; + + fn endpoint(&self) -> &Url; + + fn chat_completions_client(&self) -> ChatCompletionsClient; +} + +#[derive(Debug, Clone)] pub struct AzureOpenAIClient { endpoint: Url, pipeline: azure_core::Pipeline, options: AzureOpenAIClientOptions, } -impl AzureOpenAIClient { - pub fn with_key( +impl AzureOpenAIClientMethods for AzureOpenAIClient { + fn with_key( endpoint: impl AsRef, secret: impl Into, client_options: Option, @@ -37,37 +53,25 @@ impl AzureOpenAIClient { }) } - pub fn endpoint(&self) -> &Url { + fn endpoint(&self) -> &Url { &self.endpoint } - pub async fn create_chat_completions( - &self, - deployment_name: &str, - chat_completions_request: &CreateChatCompletionsRequest, - // Should I be using RequestContent ? All the new methods have signatures that would force me to mutate - // the request object into &static str, Vec, etc. - // chat_completions_request: RequestContent, - ) -> Result { - let url = Url::parse(&format!( - "{}/openai/deployments/{}/chat/completions", - &self.endpoint, deployment_name - ))?; - - let context = Context::new(); - - let mut request = azure_core::Request::new(url, Method::Post); - // this was replaced by the AzureServiceVersion policy, not sure what is the right approach - // adding the mandatory header shouldn't be necessary if the pipeline was setup correctly (?) - // request.add_mandatory_header(&self.key_credential); - - request.set_json(chat_completions_request)?; - - let response = self - .pipeline - .send::(&context, &mut request) - .await?; - response.into_body().json().await + fn chat_completions_client(&self) -> ChatCompletionsClient { + ChatCompletionsClient::new(Box::new(self.clone())) + } +} + +impl BaseOpenAIClientMethods for AzureOpenAIClient { + fn base_url(&self, deployment_name: Option<&str>) -> azure_core::Result { + // TODO gracefully handle this + Ok(self + .endpoint() + .join(deployment_name.expect("deployment_name should be provided"))?) + } + + fn pipeline(&self) -> &azure_core::Pipeline { + &self.pipeline } } diff --git a/sdk/openai/inference/src/clients/chat_completions_client.rs b/sdk/openai/inference/src/clients/chat_completions_client.rs index 955e32de51..54a4f85e4b 100644 --- a/sdk/openai/inference/src/clients/chat_completions_client.rs +++ b/sdk/openai/inference/src/clients/chat_completions_client.rs @@ -1 +1,51 @@ -pub struct ChatCompletionsClient; +use super::BaseOpenAIClientMethods; +use crate::{request::CreateChatCompletionsRequest, response::CreateChatCompletionsResponse}; +use azure_core::{Context, Method, Result}; + +pub trait ChatCompletionsClientMethods { + #[allow(async_fn_in_trait)] + async fn create_chat_completions( + &self, + deployment_name: impl AsRef, + chat_completions_request: &CreateChatCompletionsRequest, + ) -> Result; +} + +pub struct ChatCompletionsClient { + base_client: Box, +} + +impl ChatCompletionsClient { + pub fn new(base_client: Box) -> Self { + Self { base_client } + } +} + +impl ChatCompletionsClientMethods for ChatCompletionsClient { + async fn create_chat_completions( + &self, + deployment_name: impl AsRef, + chat_completions_request: &CreateChatCompletionsRequest, + ) -> Result { + let base_url = self.base_client.base_url(Some(deployment_name.as_ref()))?; + let request_url = base_url.join("chat/completions")?; + + let context = Context::new(); + + let mut request = azure_core::Request::new(request_url, Method::Post); + // this was replaced by the AzureServiceVersion policy, not sure what is the right approach + // adding the mandatory header shouldn't be necessary if the pipeline was setup correctly (?) + // request.add_mandatory_header(&self.key_credential); + + request.set_json(chat_completions_request)?; + + let response = self + .base_client + .pipeline() + .send::(&context, &mut request) + .await?; + response.into_body().json().await + + // todo!() + } +} diff --git a/sdk/openai/inference/src/clients/mod.rs b/sdk/openai/inference/src/clients/mod.rs index d40414f7f0..2c7e718436 100644 --- a/sdk/openai/inference/src/clients/mod.rs +++ b/sdk/openai/inference/src/clients/mod.rs @@ -1,3 +1,12 @@ -pub mod azure_openai_client; -pub mod chat_completions_client; -pub mod openai_client; +mod azure_openai_client; +mod chat_completions_client; +mod openai_client; + +pub use azure_openai_client::{AzureOpenAIClient, AzureOpenAIClientMethods}; +pub use chat_completions_client::{ChatCompletionsClient, ChatCompletionsClientMethods}; + +pub trait BaseOpenAIClientMethods { + fn base_url(&self, deployment_name: Option<&str>) -> azure_core::Result; + + fn pipeline(&self) -> &azure_core::Pipeline; +} diff --git a/sdk/openai/inference/src/lib.rs b/sdk/openai/inference/src/lib.rs index 0b5dcb8c2e..8ade855521 100644 --- a/sdk/openai/inference/src/lib.rs +++ b/sdk/openai/inference/src/lib.rs @@ -1,9 +1,7 @@ mod auth; -mod clients; +pub mod clients; mod models; mod options; -pub use clients::azure_openai_client::*; -pub use clients::openai_client::*; pub use models::*; pub use options::*;