Skip to content

Commit

Permalink
Project compiles and runs, but request errors
Browse files Browse the repository at this point in the history
  • Loading branch information
jpalvarezl committed Sep 13, 2024
1 parent 1e0a8b0 commit 28c2c40
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 48 deletions.
12 changes: 7 additions & 5 deletions sdk/openai/inference/examples/azure_chat_completions.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use azure_core::Result;
use azure_openai_inference::{
request::CreateChatCompletionsRequest, AzureOpenAIClient, AzureOpenAIClientOptions,
AzureServiceVersion,
clients::{AzureOpenAIClient, AzureOpenAIClientMethods, ChatCompletionsClientMethods},
request::CreateChatCompletionsRequest,
AzureOpenAIClientOptions, AzureServiceVersion,
};

#[tokio::main]
Expand All @@ -10,22 +11,23 @@ pub async fn main() -> Result<()> {
std::env::var("AZURE_OPENAI_ENDPOINT").expect("Set AZURE_OPENAI_ENDPOINT env variable");
let secret = std::env::var("AZURE_OPENAI_KEY").expect("Set AZURE_OPENAI_KEY env variable");

let azure_openai_client = AzureOpenAIClient::with_key(
let chat_completions_client = AzureOpenAIClient::with_key(
endpoint,
secret,
Some(
AzureOpenAIClientOptions::builder()
.with_api_version(AzureServiceVersion::V2023_12_01Preview)
.build(),
),
)?;
)?
.chat_completions_client();

let chat_completions_request = CreateChatCompletionsRequest::new_with_user_message(
"gpt-4-1106-preview",
"Tell me a joke about pineapples",
);

let response = azure_openai_client
let response = chat_completions_client
.create_chat_completions(&chat_completions_request.model, &chat_completions_request)
.await;

Expand Down
4 changes: 2 additions & 2 deletions sdk/openai/inference/src/auth/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mod azure_key_credential;
mod openai_key_credential;
// mod openai_key_credential;

pub(crate) use azure_key_credential::*;
pub(crate) use openai_key_credential::*;
// pub(crate) use openai_key_credential::*;
72 changes: 38 additions & 34 deletions sdk/openai/inference/src/clients/azure_openai_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,35 @@ use std::sync::Arc;
use crate::auth::AzureKeyCredential;

use crate::options::AzureOpenAIClientOptions;
use crate::request::CreateChatCompletionsRequest;
use crate::response::CreateChatCompletionsResponse;
use azure_core::{self, Method, Policy, Result};
use azure_core::{Context, Url};
use azure_core::Url;
use azure_core::{self, Policy, Result};

use super::chat_completions_client::ChatCompletionsClient;
use super::BaseOpenAIClientMethods;

pub trait AzureOpenAIClientMethods: BaseOpenAIClientMethods {
fn with_key(
endpoint: impl AsRef<str>,
secret: impl Into<String>,
client_options: Option<AzureOpenAIClientOptions>,
) -> Result<Self>
where
Self: Sized;

fn endpoint(&self) -> &Url;

fn chat_completions_client(&self) -> ChatCompletionsClient;
}

#[derive(Debug, Clone)]
pub struct AzureOpenAIClient {
endpoint: Url,
pipeline: azure_core::Pipeline,
options: AzureOpenAIClientOptions,
}

impl AzureOpenAIClient {
pub fn with_key(
impl AzureOpenAIClientMethods for AzureOpenAIClient {
fn with_key(
endpoint: impl AsRef<str>,
secret: impl Into<String>,
client_options: Option<AzureOpenAIClientOptions>,
Expand All @@ -37,37 +53,25 @@ impl AzureOpenAIClient {
})
}

pub fn endpoint(&self) -> &Url {
fn endpoint(&self) -> &Url {
&self.endpoint
}

pub async fn create_chat_completions(
&self,
deployment_name: &str,
chat_completions_request: &CreateChatCompletionsRequest,
// Should I be using RequestContent ? All the new methods have signatures that would force me to mutate
// the request object into &static str, Vec<u8>, etc.
// chat_completions_request: RequestContent<CreateChatCompletionsRequest>,
) -> Result<CreateChatCompletionsResponse> {
let url = Url::parse(&format!(
"{}/openai/deployments/{}/chat/completions",
&self.endpoint, deployment_name
))?;

let context = Context::new();

let mut request = azure_core::Request::new(url, Method::Post);
// this was replaced by the AzureServiceVersion policy, not sure what is the right approach
// adding the mandatory header shouldn't be necessary if the pipeline was setup correctly (?)
// request.add_mandatory_header(&self.key_credential);

request.set_json(chat_completions_request)?;

let response = self
.pipeline
.send::<CreateChatCompletionsResponse>(&context, &mut request)
.await?;
response.into_body().json().await
fn chat_completions_client(&self) -> ChatCompletionsClient {
ChatCompletionsClient::new(Box::new(self.clone()))
}
}

impl BaseOpenAIClientMethods for AzureOpenAIClient {
fn base_url(&self, deployment_name: Option<&str>) -> azure_core::Result<Url> {
// TODO gracefully handle this
Ok(self
.endpoint()
.join(deployment_name.expect("deployment_name should be provided"))?)
}

fn pipeline(&self) -> &azure_core::Pipeline {
&self.pipeline
}
}

Expand Down
52 changes: 51 additions & 1 deletion sdk/openai/inference/src/clients/chat_completions_client.rs
Original file line number Diff line number Diff line change
@@ -1 +1,51 @@
pub struct ChatCompletionsClient;
use super::BaseOpenAIClientMethods;
use crate::{request::CreateChatCompletionsRequest, response::CreateChatCompletionsResponse};
use azure_core::{Context, Method, Result};

pub trait ChatCompletionsClientMethods {
#[allow(async_fn_in_trait)]
async fn create_chat_completions(
&self,
deployment_name: impl AsRef<str>,
chat_completions_request: &CreateChatCompletionsRequest,
) -> Result<CreateChatCompletionsResponse>;
}

pub struct ChatCompletionsClient {
base_client: Box<dyn BaseOpenAIClientMethods>,
}

impl ChatCompletionsClient {
pub fn new(base_client: Box<dyn BaseOpenAIClientMethods>) -> Self {
Self { base_client }
}
}

impl ChatCompletionsClientMethods for ChatCompletionsClient {
async fn create_chat_completions(
&self,
deployment_name: impl AsRef<str>,
chat_completions_request: &CreateChatCompletionsRequest,
) -> Result<CreateChatCompletionsResponse> {
let base_url = self.base_client.base_url(Some(deployment_name.as_ref()))?;
let request_url = base_url.join("chat/completions")?;

let context = Context::new();

let mut request = azure_core::Request::new(request_url, Method::Post);
// this was replaced by the AzureServiceVersion policy, not sure what is the right approach
// adding the mandatory header shouldn't be necessary if the pipeline was setup correctly (?)
// request.add_mandatory_header(&self.key_credential);

request.set_json(chat_completions_request)?;

let response = self
.base_client
.pipeline()
.send::<CreateChatCompletionsResponse>(&context, &mut request)
.await?;
response.into_body().json().await

// todo!()
}
}
15 changes: 12 additions & 3 deletions sdk/openai/inference/src/clients/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
pub mod azure_openai_client;
pub mod chat_completions_client;
pub mod openai_client;
mod azure_openai_client;
mod chat_completions_client;
mod openai_client;

pub use azure_openai_client::{AzureOpenAIClient, AzureOpenAIClientMethods};
pub use chat_completions_client::{ChatCompletionsClient, ChatCompletionsClientMethods};

pub trait BaseOpenAIClientMethods {
fn base_url(&self, deployment_name: Option<&str>) -> azure_core::Result<azure_core::Url>;

fn pipeline(&self) -> &azure_core::Pipeline;
}
4 changes: 1 addition & 3 deletions sdk/openai/inference/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
mod auth;
mod clients;
pub mod clients;
mod models;
mod options;

pub use clients::azure_openai_client::*;
pub use clients::openai_client::*;
pub use models::*;
pub use options::*;

0 comments on commit 28c2c40

Please sign in to comment.