forked from Azure/azure-sdk-for-rust
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d2f2149
commit 39cc9a5
Showing
11 changed files
with
278 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
use azure_core::{ | ||
auth::Secret, | ||
headers::{HeaderName, HeaderValue, AUTHORIZATION}, | ||
Header, | ||
}; | ||
|
||
pub struct AzureKeyCredential(Secret); | ||
|
||
pub struct OpenAIKeyCredential(Secret); | ||
|
||
impl OpenAIKeyCredential { | ||
pub fn new(access_token: String) -> Self { | ||
Self(Secret::new(access_token)) | ||
} | ||
} | ||
|
||
impl AzureKeyCredential { | ||
pub fn new(api_key: String) -> Self { | ||
Self(Secret::new(api_key)) | ||
} | ||
} | ||
|
||
impl Header for AzureKeyCredential { | ||
fn name(&self) -> HeaderName { | ||
HeaderName::from_static("api-key") | ||
} | ||
|
||
fn value(&self) -> HeaderValue { | ||
HeaderValue::from_cow(format!("{}", self.0.secret())) | ||
} | ||
} | ||
|
||
impl Header for OpenAIKeyCredential { | ||
fn name(&self) -> HeaderName { | ||
AUTHORIZATION | ||
} | ||
|
||
fn value(&self) -> HeaderValue { | ||
HeaderValue::from_cow(format!("Bearer {}", &self.0.secret())) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,72 @@ | ||
use std::sync::Arc; | ||
|
||
use azure_core::auth::{Secret, TokenCredential}; | ||
use azure_core::{self, HttpClient, Result}; | ||
use reqwest::Url; | ||
use crate::auth::AzureKeyCredential; | ||
use crate::models::CreateChatCompletionsRequest; | ||
use crate::CreateChatCompletionsResponse; | ||
use azure_core::Url; | ||
use azure_core::{self, HttpClient, Method, Result}; | ||
|
||
// TODO: Implement using this instead | ||
// typespec_client_core::json_model!(CreateChatCompletionsResponse); | ||
|
||
pub struct AzureOpenAIClient { | ||
http_client: Arc<dyn HttpClient>, | ||
endpoint: Url, | ||
secret: Secret, | ||
key_credential: AzureKeyCredential, | ||
} | ||
|
||
impl AzureOpenAIClient { | ||
pub fn new(endpoint: impl AsRef<str>, secret: String) -> Result<Self> { | ||
let endpoint = Url::parse(endpoint.as_ref())?; | ||
let secret = Secret::from(secret); | ||
let key_credential = AzureKeyCredential::new(secret); | ||
|
||
Ok(AzureOpenAIClient { | ||
http_client: azure_core::new_http_client(), | ||
endpoint, | ||
secret, | ||
key_credential, | ||
}) | ||
} | ||
|
||
pub fn endpoint(&self) -> &Url { | ||
&self.endpoint | ||
} | ||
|
||
pub async fn create_chat_completions( | ||
&self, | ||
deployment_name: &str, | ||
api_version: impl Into<String>, | ||
chat_completions_request: &CreateChatCompletionsRequest, | ||
) -> Result<CreateChatCompletionsResponse> { | ||
let url = Url::parse(&format!( | ||
"{}/openai/deployments/{}/chat/completions?api-version={}", | ||
&self.endpoint, | ||
deployment_name, | ||
api_version.into() | ||
))?; | ||
let request = super::build_request( | ||
&self.key_credential, | ||
url, | ||
Method::Post, | ||
chat_completions_request, | ||
)?; | ||
let response = self.http_client.execute_request(&request).await?; | ||
Ok(response.into_body().json().await?) | ||
} | ||
} | ||
|
||
pub enum AzureServiceVersion { | ||
V2023_09_01Preview, | ||
V2023_12_01Preview, | ||
V2024_07_01Preview, | ||
} | ||
|
||
impl From<AzureServiceVersion> for String { | ||
fn from(version: AzureServiceVersion) -> String { | ||
let as_str = match version { | ||
AzureServiceVersion::V2023_09_01Preview => "2023-09-01-preview", | ||
AzureServiceVersion::V2023_12_01Preview => "2023-12-01-preview", | ||
AzureServiceVersion::V2024_07_01Preview => "2024-07-01-preview", | ||
}; | ||
return String::from(as_str); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,42 @@ | ||
use azure_core::{ | ||
headers::{ACCEPT, CONTENT_TYPE}, | ||
Header, Method, Request, Result, Url, | ||
}; | ||
use serde::Serialize; | ||
|
||
pub mod azure; | ||
pub mod non_azure; | ||
|
||
pub(crate) fn build_request<T>( | ||
key_credential: &impl Header, | ||
url: Url, | ||
method: Method, | ||
data: &T, | ||
) -> Result<Request> | ||
where | ||
T: ?Sized + Serialize, | ||
{ | ||
let mut request = Request::new(url, method); | ||
request.add_mandatory_header(key_credential); | ||
request.insert_header(CONTENT_TYPE, "application/json"); | ||
request.insert_header(ACCEPT, "application/json"); | ||
request.set_json(data)?; | ||
Ok(request) | ||
} | ||
|
||
// pub(crate) fn build_multipart_request<F>( | ||
// key_credential: &impl Header, | ||
// url: Url, | ||
// form_generator: F, | ||
// ) -> Result<Request> | ||
// where | ||
// F: FnOnce() -> Result<MyForm>, | ||
// { | ||
// let mut request = Request::new(url, Method::Post); | ||
// request.add_mandatory_header(key_credential); | ||
// // handled insternally by reqwest | ||
// // request.insert_header(CONTENT_TYPE, "multipart/form-data"); | ||
// // request.insert_header(ACCEPT, "application/json"); | ||
// request.multipart(form_generator()?); | ||
// Ok(request) | ||
// } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,7 @@ | ||
pub mod auth; | ||
mod clients; | ||
mod models; | ||
|
||
pub use clients::azure::*; | ||
pub use clients::non_azure::*; | ||
pub use models::*; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
mod request; | ||
mod response; | ||
|
||
pub use request::*; | ||
pub use response::*; |
66 changes: 66 additions & 0 deletions
66
sdk/openai/inference/src/models/request/chat_completions.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
use serde::Serialize; | ||
|
||
#[derive(Serialize, Debug, Clone, Default)] | ||
pub struct CreateChatCompletionsRequest { | ||
pub messages: Vec<ChatCompletionRequestMessage>, | ||
pub model: String, | ||
pub stream: Option<bool>, | ||
// pub frequency_penalty: f64, | ||
// pub logit_bias: Option<HashMap<String, f64>>, | ||
// pub logprobs: Option<bool>, | ||
// pub top_logprobs: Option<i32>, | ||
// pub max_tokens: Option<i32>, | ||
} | ||
|
||
#[derive(Serialize, Debug, Clone, Default)] | ||
pub struct ChatCompletionRequestMessageBase { | ||
#[serde(skip)] | ||
pub name: Option<String>, | ||
pub content: String, // TODO this should be either a string or ChatCompletionRequestMessageContentPart (a polymorphic type) | ||
} | ||
|
||
#[derive(Serialize, Debug, Clone)] | ||
#[serde(tag = "role")] | ||
pub enum ChatCompletionRequestMessage { | ||
#[serde(rename = "system")] | ||
System(ChatCompletionRequestMessageBase), | ||
#[serde(rename = "user")] | ||
User(ChatCompletionRequestMessageBase), | ||
} | ||
|
||
impl ChatCompletionRequestMessage { | ||
pub fn new_user(content: impl Into<String>) -> Self { | ||
Self::User(ChatCompletionRequestMessageBase { | ||
content: content.into(), | ||
name: None, | ||
}) | ||
} | ||
|
||
pub fn new_system(content: impl Into<String>) -> Self { | ||
Self::System(ChatCompletionRequestMessageBase { | ||
content: content.into(), | ||
name: None, | ||
}) | ||
} | ||
} | ||
impl CreateChatCompletionsRequest { | ||
pub fn new_with_user_message(model: &str, prompt: &str) -> Self { | ||
Self { | ||
model: model.to_string(), | ||
messages: vec![ChatCompletionRequestMessage::new_user(prompt)], | ||
..Default::default() | ||
} | ||
} | ||
|
||
pub fn new_stream_with_user_message( | ||
model: impl Into<String>, | ||
prompt: impl Into<String>, | ||
) -> Self { | ||
Self { | ||
model: model.into(), | ||
messages: vec![ChatCompletionRequestMessage::new_user(prompt)], | ||
stream: Some(true), | ||
..Default::default() | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
mod chat_completions; | ||
|
||
pub use chat_completions::*; |
36 changes: 36 additions & 0 deletions
36
sdk/openai/inference/src/models/response/chat_completions.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
use serde::Deserialize; | ||
|
||
#[derive(Debug, Clone, Deserialize)] | ||
pub struct CreateChatCompletionsResponse { | ||
pub choices: Vec<ChatCompletionChoice>, | ||
} | ||
|
||
#[derive(Debug, Clone, Deserialize)] | ||
pub struct ChatCompletionChoice { | ||
pub message: ChatCompletionResponseMessage, | ||
} | ||
|
||
#[derive(Debug, Clone, Deserialize)] | ||
pub struct ChatCompletionResponseMessage { | ||
pub content: Option<String>, | ||
pub role: String, | ||
} | ||
|
||
// region: --- Streaming | ||
#[derive(Debug, Clone, Deserialize)] | ||
pub struct CreateChatCompletionsStreamResponse { | ||
pub choices: Vec<ChatCompletionStreamChoice>, | ||
} | ||
|
||
#[derive(Debug, Clone, Deserialize)] | ||
pub struct ChatCompletionStreamChoice { | ||
pub delta: Option<ChatCompletionStreamResponseMessage>, | ||
} | ||
|
||
#[derive(Debug, Clone, Deserialize)] | ||
pub struct ChatCompletionStreamResponseMessage { | ||
pub content: Option<String>, | ||
pub role: Option<String>, | ||
} | ||
|
||
// endregion: Streaming |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
mod chat_completions; | ||
|
||
pub use chat_completions::*; |