Skip to content

Commit

Permalink
Somehow hitting the noop client
Browse files Browse the repository at this point in the history
  • Loading branch information
jpalvarezl committed Aug 30, 2024
1 parent d2f2149 commit 39cc9a5
Show file tree
Hide file tree
Showing 11 changed files with 278 additions and 8 deletions.
2 changes: 2 additions & 0 deletions sdk/openai/inference/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ workspace = true
azure_core = { workspace = true }
reqwest = { workspace = true, optional = true }
tokio = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }

[features]
default = [ "reqwest" ]
Expand Down
27 changes: 25 additions & 2 deletions sdk/openai/inference/examples/azure_chat_completions.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use azure_core::{auth::TokenCredential, Result};
use azure_openai_inference::AzureOpenAIClient;
use azure_core::Result;
use azure_openai_inference::{
AzureOpenAIClient, AzureServiceVersion, CreateChatCompletionsRequest,
};

#[tokio::main]
pub async fn main() -> Result<()> {
Expand All @@ -9,5 +11,26 @@ pub async fn main() -> Result<()> {

let azure_openai_client = AzureOpenAIClient::new(endpoint, secret)?;

let chat_completions_request = CreateChatCompletionsRequest::new_with_user_message(
"gpt-4-1106-preview",
"Tell me a joke about pineapples",
);

let response = azure_openai_client
.create_chat_completions(
&chat_completions_request.model,
AzureServiceVersion::V2023_12_01Preview,
&chat_completions_request,
)
.await;

match response {
Ok(chat_completions) => {
println!("{:#?}", &chat_completions);
}
Err(e) => {
println!("Error: {}", e);
}
};
Ok(())
}
41 changes: 41 additions & 0 deletions sdk/openai/inference/src/auth/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use azure_core::{
auth::Secret,
headers::{HeaderName, HeaderValue, AUTHORIZATION},
Header,
};

pub struct AzureKeyCredential(Secret);

pub struct OpenAIKeyCredential(Secret);

impl OpenAIKeyCredential {
pub fn new(access_token: String) -> Self {
Self(Secret::new(access_token))
}
}

impl AzureKeyCredential {
pub fn new(api_key: String) -> Self {
Self(Secret::new(api_key))
}
}

impl Header for AzureKeyCredential {
fn name(&self) -> HeaderName {
HeaderName::from_static("api-key")
}

fn value(&self) -> HeaderValue {
HeaderValue::from_cow(format!("{}", self.0.secret()))
}
}

impl Header for OpenAIKeyCredential {
fn name(&self) -> HeaderName {
AUTHORIZATION
}

fn value(&self) -> HeaderValue {
HeaderValue::from_cow(format!("Bearer {}", &self.0.secret()))
}
}
60 changes: 54 additions & 6 deletions sdk/openai/inference/src/clients/azure.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,72 @@
use std::sync::Arc;

use azure_core::auth::{Secret, TokenCredential};
use azure_core::{self, HttpClient, Result};
use reqwest::Url;
use crate::auth::AzureKeyCredential;
use crate::models::CreateChatCompletionsRequest;
use crate::CreateChatCompletionsResponse;
use azure_core::Url;
use azure_core::{self, HttpClient, Method, Result};

// TODO: Implement using this instead
// typespec_client_core::json_model!(CreateChatCompletionsResponse);

pub struct AzureOpenAIClient {
http_client: Arc<dyn HttpClient>,
endpoint: Url,
secret: Secret,
key_credential: AzureKeyCredential,
}

impl AzureOpenAIClient {
pub fn new(endpoint: impl AsRef<str>, secret: String) -> Result<Self> {
let endpoint = Url::parse(endpoint.as_ref())?;
let secret = Secret::from(secret);
let key_credential = AzureKeyCredential::new(secret);

Ok(AzureOpenAIClient {
http_client: azure_core::new_http_client(),
endpoint,
secret,
key_credential,
})
}

pub fn endpoint(&self) -> &Url {
&self.endpoint
}

pub async fn create_chat_completions(
&self,
deployment_name: &str,
api_version: impl Into<String>,
chat_completions_request: &CreateChatCompletionsRequest,
) -> Result<CreateChatCompletionsResponse> {
let url = Url::parse(&format!(
"{}/openai/deployments/{}/chat/completions?api-version={}",
&self.endpoint,
deployment_name,
api_version.into()
))?;
let request = super::build_request(
&self.key_credential,
url,
Method::Post,
chat_completions_request,
)?;
let response = self.http_client.execute_request(&request).await?;
Ok(response.into_body().json().await?)
}
}

pub enum AzureServiceVersion {
V2023_09_01Preview,
V2023_12_01Preview,
V2024_07_01Preview,
}

impl From<AzureServiceVersion> for String {
fn from(version: AzureServiceVersion) -> String {
let as_str = match version {
AzureServiceVersion::V2023_09_01Preview => "2023-09-01-preview",
AzureServiceVersion::V2023_12_01Preview => "2023-12-01-preview",
AzureServiceVersion::V2024_07_01Preview => "2024-07-01-preview",
};
return String::from(as_str);
}
}
40 changes: 40 additions & 0 deletions sdk/openai/inference/src/clients/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,42 @@
use azure_core::{
headers::{ACCEPT, CONTENT_TYPE},
Header, Method, Request, Result, Url,
};
use serde::Serialize;

pub mod azure;
pub mod non_azure;

pub(crate) fn build_request<T>(
key_credential: &impl Header,
url: Url,
method: Method,
data: &T,
) -> Result<Request>
where
T: ?Sized + Serialize,
{
let mut request = Request::new(url, method);
request.add_mandatory_header(key_credential);
request.insert_header(CONTENT_TYPE, "application/json");
request.insert_header(ACCEPT, "application/json");
request.set_json(data)?;
Ok(request)
}

// pub(crate) fn build_multipart_request<F>(
// key_credential: &impl Header,
// url: Url,
// form_generator: F,
// ) -> Result<Request>
// where
// F: FnOnce() -> Result<MyForm>,
// {
// let mut request = Request::new(url, Method::Post);
// request.add_mandatory_header(key_credential);
// // handled insternally by reqwest
// // request.insert_header(CONTENT_TYPE, "multipart/form-data");
// // request.insert_header(ACCEPT, "application/json");
// request.multipart(form_generator()?);
// Ok(request)
// }
3 changes: 3 additions & 0 deletions sdk/openai/inference/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
pub mod auth;
mod clients;
mod models;

pub use clients::azure::*;
pub use clients::non_azure::*;
pub use models::*;
5 changes: 5 additions & 0 deletions sdk/openai/inference/src/models/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
mod request;
mod response;

pub use request::*;
pub use response::*;
66 changes: 66 additions & 0 deletions sdk/openai/inference/src/models/request/chat_completions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
use serde::Serialize;

#[derive(Serialize, Debug, Clone, Default)]
pub struct CreateChatCompletionsRequest {
pub messages: Vec<ChatCompletionRequestMessage>,
pub model: String,
pub stream: Option<bool>,
// pub frequency_penalty: f64,
// pub logit_bias: Option<HashMap<String, f64>>,
// pub logprobs: Option<bool>,
// pub top_logprobs: Option<i32>,
// pub max_tokens: Option<i32>,
}

#[derive(Serialize, Debug, Clone, Default)]
pub struct ChatCompletionRequestMessageBase {
#[serde(skip)]
pub name: Option<String>,
pub content: String, // TODO this should be either a string or ChatCompletionRequestMessageContentPart (a polymorphic type)
}

#[derive(Serialize, Debug, Clone)]
#[serde(tag = "role")]
pub enum ChatCompletionRequestMessage {
#[serde(rename = "system")]
System(ChatCompletionRequestMessageBase),
#[serde(rename = "user")]
User(ChatCompletionRequestMessageBase),
}

impl ChatCompletionRequestMessage {
pub fn new_user(content: impl Into<String>) -> Self {
Self::User(ChatCompletionRequestMessageBase {
content: content.into(),
name: None,
})
}

pub fn new_system(content: impl Into<String>) -> Self {
Self::System(ChatCompletionRequestMessageBase {
content: content.into(),
name: None,
})
}
}
impl CreateChatCompletionsRequest {
pub fn new_with_user_message(model: &str, prompt: &str) -> Self {
Self {
model: model.to_string(),
messages: vec![ChatCompletionRequestMessage::new_user(prompt)],
..Default::default()
}
}

pub fn new_stream_with_user_message(
model: impl Into<String>,
prompt: impl Into<String>,
) -> Self {
Self {
model: model.into(),
messages: vec![ChatCompletionRequestMessage::new_user(prompt)],
stream: Some(true),
..Default::default()
}
}
}
3 changes: 3 additions & 0 deletions sdk/openai/inference/src/models/request/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
mod chat_completions;

pub use chat_completions::*;
36 changes: 36 additions & 0 deletions sdk/openai/inference/src/models/response/chat_completions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use serde::Deserialize;

#[derive(Debug, Clone, Deserialize)]
pub struct CreateChatCompletionsResponse {
pub choices: Vec<ChatCompletionChoice>,
}

#[derive(Debug, Clone, Deserialize)]
pub struct ChatCompletionChoice {
pub message: ChatCompletionResponseMessage,
}

#[derive(Debug, Clone, Deserialize)]
pub struct ChatCompletionResponseMessage {
pub content: Option<String>,
pub role: String,
}

// region: --- Streaming
#[derive(Debug, Clone, Deserialize)]
pub struct CreateChatCompletionsStreamResponse {
pub choices: Vec<ChatCompletionStreamChoice>,
}

#[derive(Debug, Clone, Deserialize)]
pub struct ChatCompletionStreamChoice {
pub delta: Option<ChatCompletionStreamResponseMessage>,
}

#[derive(Debug, Clone, Deserialize)]
pub struct ChatCompletionStreamResponseMessage {
pub content: Option<String>,
pub role: Option<String>,
}

// endregion: Streaming
3 changes: 3 additions & 0 deletions sdk/openai/inference/src/models/response/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
mod chat_completions;

pub use chat_completions::*;

0 comments on commit 39cc9a5

Please sign in to comment.