Skip to content

Commit

Permalink
tried implementing custom type as policy and pass in the pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
jpalvarezl committed Sep 6, 2024
1 parent f9b03a3 commit 6064cbb
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 11 deletions.
1 change: 1 addition & 0 deletions sdk/openai/inference/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ azure_core = { workspace = true }
tokio = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
async-trait = { workspace = true }

[dev-dependencies]
reqwest = { workspace = true }
Expand Down
2 changes: 1 addition & 1 deletion sdk/openai/inference/examples/azure_chat_completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub async fn main() -> Result<()> {
std::env::var("AZURE_OPENAI_ENDPOINT").expect("Set AZURE_OPENAI_ENDPOINT env variable");
let secret = std::env::var("AZURE_OPENAI_KEY").expect("Set AZURE_OPENAI_KEY env variable");

let azure_openai_client = AzureOpenAIClient::new(endpoint, secret)?;
let azure_openai_client = AzureOpenAIClient::new(endpoint, secret, None)?;

let chat_completions_request = CreateChatCompletionsRequest::new_with_user_message(
"gpt-4-1106-preview",
Expand Down
29 changes: 26 additions & 3 deletions sdk/openai/inference/src/auth/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use std::sync::Arc;
use async_trait::async_trait;

use azure_core::{
auth::Secret,
headers::{HeaderName, HeaderValue, AUTHORIZATION},
Header,
auth::Secret, headers::{HeaderName, HeaderValue, AUTHORIZATION}, Context, Header, Policy, PolicyResult, Request
};

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -31,6 +32,28 @@ impl Header for AzureKeyCredential {
}
}

// code lifted from BearerTokenCredentialPolicy
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl Policy for AzureKeyCredential {

async fn send(
&self,
ctx: &Context,
request: &mut Request,
next: &[Arc<dyn Policy>],
) -> PolicyResult {
request.insert_header(Header::name(self), Header::value(self));
next[0].send(ctx, request, &next[1..]).await
}
}

impl Into<Vec<Arc<dyn Policy>>> for AzureKeyCredential {
fn into(self) -> Vec<Arc<dyn Policy>> {
vec![Arc::new(self)]
}
}

impl Header for OpenAIKeyCredential {
fn name(&self) -> HeaderName {
AUTHORIZATION
Expand Down
22 changes: 15 additions & 7 deletions sdk/openai/inference/src/clients/azure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,46 @@ use std::sync::Arc;
use crate::auth::AzureKeyCredential;
use crate::models::CreateChatCompletionsRequest;
use crate::CreateChatCompletionsResponse;
use azure_core::{self, HttpClient, Method, Result};
use azure_core::{self, ClientOptions, HttpClient, Method, Policy, Result};
use azure_core::{Context, Url};

// TODO: Implement using this instead
// typespec_client_core::json_model!(CreateChatCompletionsResponse);

pub struct AzureOpenAIClient<'a> {
#[derive(Clone, Debug, Default)]
pub struct AzureOpenAIClientOptions {
client_options: ClientOptions,
}

pub struct AzureOpenAIClient <'a> {
http_client: Arc<dyn HttpClient>,
endpoint: Url,
key_credential: AzureKeyCredential,
context: Context<'a>,
pipeline: azure_core::Pipeline,
azure_openai_client_options: AzureOpenAIClientOptions
}

impl AzureOpenAIClient<'_> {
impl AzureOpenAIClient <'_> {
// TODO: not sure if this should be named `with_key_credential` instead
pub fn new(endpoint: impl AsRef<str>, secret: String) -> Result<Self> {
pub fn new(endpoint: impl AsRef<str>, secret: String, client_options: Option<AzureOpenAIClientOptions>) -> Result<Self> {
let endpoint = Url::parse(endpoint.as_ref())?;
let key_credential = AzureKeyCredential::new(secret);
// let auth_header_policy = CustomHeadersPolicy(key_credential.into());

let mut context = Context::new();
context.insert(key_credential.clone());
let context = Context::new();

let pipeline = Self::new_pipeline();
let mut azure_openai_client_options = client_options.unwrap_or_default();
let per_call_policies: Vec<Arc<dyn Policy>> = key_credential.clone().into();
azure_openai_client_options.client_options.set_per_call_policies(per_call_policies);

Ok(AzureOpenAIClient {
http_client: azure_core::new_http_client(),
endpoint,
key_credential,
context,
pipeline,
azure_openai_client_options
})
}

Expand Down

0 comments on commit 6064cbb

Please sign in to comment.