diff --git a/sdk/openai/inference/examples/azure_chat_completions.rs b/sdk/openai/inference/examples/azure_chat_completions.rs index 78b06f65e3..4e1595d65e 100644 --- a/sdk/openai/inference/examples/azure_chat_completions.rs +++ b/sdk/openai/inference/examples/azure_chat_completions.rs @@ -10,7 +10,7 @@ pub async fn main() -> Result<()> { std::env::var("AZURE_OPENAI_ENDPOINT").expect("Set AZURE_OPENAI_ENDPOINT env variable"); let secret = std::env::var("AZURE_OPENAI_KEY").expect("Set AZURE_OPENAI_KEY env variable"); - let azure_openai_client = AzureOpenAIClient::new( + let azure_openai_client = AzureOpenAIClient::with_key( endpoint, secret, Some( diff --git a/sdk/openai/inference/src/clients/azure.rs b/sdk/openai/inference/src/clients/azure.rs index 01eb19f853..13857c5391 100644 --- a/sdk/openai/inference/src/clients/azure.rs +++ b/sdk/openai/inference/src/clients/azure.rs @@ -8,24 +8,20 @@ use crate::response::CreateChatCompletionsResponse; use azure_core::{self, Method, Policy, Result}; use azure_core::{Context, Url}; -pub struct AzureOpenAIClient<'a> { +pub struct AzureOpenAIClient { endpoint: Url, - context: Context<'a>, pipeline: azure_core::Pipeline, options: AzureOpenAIClientOptions, } -impl AzureOpenAIClient<'_> { - // TODO: not sure if this should be named `with_key_credential` instead - pub fn new( +impl AzureOpenAIClient { + pub fn with_key( endpoint: impl AsRef, - secret: String, + secret: impl Into, client_options: Option, ) -> Result { let endpoint = Url::parse(endpoint.as_ref())?; - let key_credential = AzureKeyCredential::new(secret); - - let context = Context::new(); + let key_credential = AzureKeyCredential::new(secret.into()); let options = client_options.unwrap_or_default(); let per_call_policies: Vec> = key_credential.clone().into(); @@ -34,7 +30,6 @@ impl AzureOpenAIClient<'_> { Ok(AzureOpenAIClient { endpoint, - context, pipeline, options, }) @@ -76,6 +71,8 @@ impl AzureOpenAIClient<'_> { &self.options.api_service_version.to_string(), ))?; + let context = Context::new(); + let mut request = azure_core::Request::new(url, Method::Post); // adding the mandatory header shouldn't be necessary if the pipeline was setup correctly (?) // request.add_mandatory_header(&self.key_credential); @@ -84,7 +81,7 @@ impl AzureOpenAIClient<'_> { let response = self .pipeline - .send::(&self.context, &mut request) + .send::(&context, &mut request) .await?; response.into_body().json().await } diff --git a/sdk/openai/inference/src/clients/chat_completions_client.rs b/sdk/openai/inference/src/clients/chat_completions_client.rs new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/sdk/openai/inference/src/clients/chat_completions_client.rs @@ -0,0 +1 @@ + diff --git a/sdk/openai/inference/src/clients/mod.rs b/sdk/openai/inference/src/clients/mod.rs index 2e8d14cefc..c6dba4c0d3 100644 --- a/sdk/openai/inference/src/clients/mod.rs +++ b/sdk/openai/inference/src/clients/mod.rs @@ -5,6 +5,7 @@ // use serde::Serialize; pub mod azure; +pub mod chat_completions_client; pub mod non_azure; // pub(crate) fn build_request( diff --git a/sdk/openai/inference/src/options/mod.rs b/sdk/openai/inference/src/options/mod.rs index b3e2d9cbe0..c0d2c929d2 100644 --- a/sdk/openai/inference/src/options/mod.rs +++ b/sdk/openai/inference/src/options/mod.rs @@ -4,7 +4,7 @@ use azure_core::ClientOptions; pub use service_version::AzureServiceVersion; -// TODO: I was not able to find ClientOptions as a derive macros +// TODO: I was not able to find ClientOptions as a derive macros #[derive(Clone, Debug, Default)] pub struct AzureOpenAIClientOptions { pub(crate) client_options: ClientOptions,