From 6064cbbe5988c7dca310886a8e7dc1ad35e019d2 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Fri, 6 Sep 2024 11:57:59 +0200 Subject: [PATCH] tried implementing custom type as policy and pass in the pipeline --- sdk/openai/inference/Cargo.toml | 1 + .../examples/azure_chat_completions.rs | 2 +- sdk/openai/inference/src/auth/mod.rs | 29 +++++++++++++++++-- sdk/openai/inference/src/clients/azure.rs | 22 +++++++++----- 4 files changed, 43 insertions(+), 11 deletions(-) diff --git a/sdk/openai/inference/Cargo.toml b/sdk/openai/inference/Cargo.toml index cc4ca823e7..f28284a620 100644 --- a/sdk/openai/inference/Cargo.toml +++ b/sdk/openai/inference/Cargo.toml @@ -20,6 +20,7 @@ azure_core = { workspace = true } tokio = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } +async-trait = { workspace = true } [dev-dependencies] reqwest = { workspace = true } diff --git a/sdk/openai/inference/examples/azure_chat_completions.rs b/sdk/openai/inference/examples/azure_chat_completions.rs index b6faeb1763..32f1e609a6 100644 --- a/sdk/openai/inference/examples/azure_chat_completions.rs +++ b/sdk/openai/inference/examples/azure_chat_completions.rs @@ -9,7 +9,7 @@ pub async fn main() -> Result<()> { std::env::var("AZURE_OPENAI_ENDPOINT").expect("Set AZURE_OPENAI_ENDPOINT env variable"); let secret = std::env::var("AZURE_OPENAI_KEY").expect("Set AZURE_OPENAI_KEY env variable"); - let azure_openai_client = AzureOpenAIClient::new(endpoint, secret)?; + let azure_openai_client = AzureOpenAIClient::new(endpoint, secret, None)?; let chat_completions_request = CreateChatCompletionsRequest::new_with_user_message( "gpt-4-1106-preview", diff --git a/sdk/openai/inference/src/auth/mod.rs b/sdk/openai/inference/src/auth/mod.rs index 88ad9f3fe8..fcdd926922 100644 --- a/sdk/openai/inference/src/auth/mod.rs +++ b/sdk/openai/inference/src/auth/mod.rs @@ -1,7 +1,8 @@ +use std::sync::Arc; +use async_trait::async_trait; + use azure_core::{ - auth::Secret, - headers::{HeaderName, HeaderValue, AUTHORIZATION}, - Header, + auth::Secret, headers::{HeaderName, HeaderValue, AUTHORIZATION}, Context, Header, Policy, PolicyResult, Request }; #[derive(Debug, Clone)] @@ -31,6 +32,28 @@ impl Header for AzureKeyCredential { } } +// code lifted from BearerTokenCredentialPolicy +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +impl Policy for AzureKeyCredential { + + async fn send( + &self, + ctx: &Context, + request: &mut Request, + next: &[Arc], + ) -> PolicyResult { + request.insert_header(Header::name(self), Header::value(self)); + next[0].send(ctx, request, &next[1..]).await + } +} + +impl Into>> for AzureKeyCredential { + fn into(self) -> Vec> { + vec![Arc::new(self)] + } +} + impl Header for OpenAIKeyCredential { fn name(&self) -> HeaderName { AUTHORIZATION diff --git a/sdk/openai/inference/src/clients/azure.rs b/sdk/openai/inference/src/clients/azure.rs index ecfa78aa03..65921d6c7c 100644 --- a/sdk/openai/inference/src/clients/azure.rs +++ b/sdk/openai/inference/src/clients/azure.rs @@ -3,31 +3,38 @@ use std::sync::Arc; use crate::auth::AzureKeyCredential; use crate::models::CreateChatCompletionsRequest; use crate::CreateChatCompletionsResponse; -use azure_core::{self, HttpClient, Method, Result}; +use azure_core::{self, ClientOptions, HttpClient, Method, Policy, Result}; use azure_core::{Context, Url}; // TODO: Implement using this instead // typespec_client_core::json_model!(CreateChatCompletionsResponse); -pub struct AzureOpenAIClient<'a> { +#[derive(Clone, Debug, Default)] +pub struct AzureOpenAIClientOptions { + client_options: ClientOptions, +} + +pub struct AzureOpenAIClient <'a> { http_client: Arc, endpoint: Url, key_credential: AzureKeyCredential, context: Context<'a>, pipeline: azure_core::Pipeline, + azure_openai_client_options: AzureOpenAIClientOptions } -impl AzureOpenAIClient<'_> { +impl AzureOpenAIClient <'_> { // TODO: not sure if this should be named `with_key_credential` instead - pub fn new(endpoint: impl AsRef, secret: String) -> Result { + pub fn new(endpoint: impl AsRef, secret: String, client_options: Option) -> Result { let endpoint = Url::parse(endpoint.as_ref())?; let key_credential = AzureKeyCredential::new(secret); - // let auth_header_policy = CustomHeadersPolicy(key_credential.into()); - let mut context = Context::new(); - context.insert(key_credential.clone()); + let context = Context::new(); let pipeline = Self::new_pipeline(); + let mut azure_openai_client_options = client_options.unwrap_or_default(); + let per_call_policies: Vec> = key_credential.clone().into(); + azure_openai_client_options.client_options.set_per_call_policies(per_call_policies); Ok(AzureOpenAIClient { http_client: azure_core::new_http_client(), @@ -35,6 +42,7 @@ impl AzureOpenAIClient<'_> { key_credential, context, pipeline, + azure_openai_client_options }) }