Skip to content

Commit

Permalink
Merge pull request #214 from gkumbhat/add_header_passthrough
Browse files Browse the repository at this point in the history
✨ Add header passthrough for NLP and detector clients
  • Loading branch information
gkumbhat authored Oct 3, 2024
2 parents f26d490 + bfdc403 commit 5ae4c2c
Show file tree
Hide file tree
Showing 10 changed files with 318 additions and 62 deletions.
5 changes: 5 additions & 0 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,8 @@ tls:
detector_bundle_no_ca:
cert_path: /path/to/client-bundle.pem
insecure: true

# Following section can be used to configure the allowed headers that orchestrator will pass to
# NLP provider and detectors. Note that, this section takes header keys, not values.
# passthrough_headers:
# - header-key
8 changes: 7 additions & 1 deletion src/clients/detector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

use std::collections::HashMap;

use hyper::StatusCode;
use hyper::{HeaderMap, StatusCode};
use serde::{Deserialize, Serialize};

use super::{create_http_clients, Error, HttpClient};
Expand Down Expand Up @@ -75,11 +75,13 @@ impl DetectorClient {
&self,
model_id: &str,
request: ContentAnalysisRequest,
headers: HeaderMap,
) -> Result<Vec<Vec<ContentAnalysisResponse>>, Error> {
let client = self.client(model_id)?;
let url = client.base_url().as_str();
let response = client
.post(url)
.headers(headers)
.header(DETECTOR_ID_HEADER_NAME, model_id)
.json(&request)
.send()
Expand All @@ -104,11 +106,13 @@ impl DetectorClient {
&self,
model_id: &str,
request: GenerationDetectionRequest,
headers: HeaderMap,
) -> Result<Vec<DetectionResult>, Error> {
let client = self.client(model_id)?;
let url = client.base_url().as_str();
let response = client
.post(url)
.headers(headers)
.header(DETECTOR_ID_HEADER_NAME, model_id)
.json(&request)
.send()
Expand All @@ -133,11 +137,13 @@ impl DetectorClient {
&self,
model_id: &str,
request: ContextDocsDetectionRequest,
headers: HeaderMap,
) -> Result<Vec<DetectionResult>, Error> {
let client = self.client(model_id)?;
let url = client.base_url().as_str();
let response = client
.post(url)
.headers(headers)
.header(DETECTOR_ID_HEADER_NAME, model_id)
.json(&request)
.send()
Expand Down
18 changes: 12 additions & 6 deletions src/clients/generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
use std::collections::HashMap;

use futures::{StreamExt, TryStreamExt};
use hyper::HeaderMap;
use tracing::debug;

use super::{BoxStream, Error, NlpClient, TgisClient};
Expand Down Expand Up @@ -85,6 +86,7 @@ impl GenerationClient {
&self,
model_id: String,
text: String,
headers: HeaderMap,
) -> Result<(u32, Vec<String>), Error> {
match &self.0 {
Some(GenerationClientInner::Tgis(client)) => {
Expand All @@ -96,15 +98,17 @@ impl GenerationClient {
truncate_input_tokens: 0,
};
debug!(%model_id, provider = "tgis", ?request, "sending tokenize request");
let mut response = client.tokenize(request).await?;
let mut response = client.tokenize(request, headers).await?;
debug!(%model_id, provider = "tgis", ?response, "received tokenize response");
let response = response.responses.swap_remove(0);
Ok((response.token_count, response.tokens))
}
Some(GenerationClientInner::Nlp(client)) => {
let request = TokenizationTaskRequest { text };
debug!(%model_id, provider = "nlp", ?request, "sending tokenize request");
let response = client.tokenization_task_predict(&model_id, request).await?;
let response = client
.tokenization_task_predict(&model_id, request, headers)
.await?;
debug!(%model_id, provider = "nlp", ?response, "received tokenize response");
let tokens = response
.results
Expand All @@ -122,6 +126,7 @@ impl GenerationClient {
model_id: String,
text: String,
params: Option<GuardrailsTextGenerationParameters>,
headers: HeaderMap,
) -> Result<ClassifiedGeneratedTextResult, Error> {
match &self.0 {
Some(GenerationClientInner::Tgis(client)) => {
Expand All @@ -133,7 +138,7 @@ impl GenerationClient {
params,
};
debug!(%model_id, provider = "tgis", ?request, "sending generate request");
let response = client.generate(request).await?;
let response = client.generate(request, headers).await?;
debug!(%model_id, provider = "tgis", ?response, "received generate response");
Ok(response.into())
}
Expand Down Expand Up @@ -171,7 +176,7 @@ impl GenerationClient {
};
debug!(%model_id, provider = "nlp", ?request, "sending generate request");
let response = client
.text_generation_task_predict(&model_id, request)
.text_generation_task_predict(&model_id, request, headers)
.await?;
debug!(%model_id, provider = "nlp", ?response, "received generate response");
Ok(response.into())
Expand All @@ -185,6 +190,7 @@ impl GenerationClient {
model_id: String,
text: String,
params: Option<GuardrailsTextGenerationParameters>,
headers: HeaderMap,
) -> Result<BoxStream<Result<ClassifiedGeneratedTextStreamResult, Error>>, Error> {
match &self.0 {
Some(GenerationClientInner::Tgis(client)) => {
Expand All @@ -197,7 +203,7 @@ impl GenerationClient {
};
debug!(%model_id, provider = "tgis", ?request, "sending generate_stream request");
let response_stream = client
.generate_stream(request)
.generate_stream(request, headers)
.await?
.map_ok(Into::into)
.boxed();
Expand Down Expand Up @@ -237,7 +243,7 @@ impl GenerationClient {
};
debug!(%model_id, provider = "nlp", ?request, "sending generate_stream request");
let response_stream = client
.server_streaming_text_generation_task_predict(&model_id, request)
.server_streaming_text_generation_task_predict(&model_id, request, headers)
.await?
.map_ok(Into::into)
.boxed();
Expand Down
20 changes: 13 additions & 7 deletions src/clients/nlp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

use std::collections::HashMap;

use axum::http::{Extensions, HeaderMap};
use futures::{StreamExt, TryStreamExt};
use ginepro::LoadBalancedChannel;
use tonic::Request;
use tonic::{metadata::MetadataMap, Request};

use super::{create_grpc_clients, BoxStream, Error};
use crate::{
Expand Down Expand Up @@ -94,8 +95,9 @@ impl NlpClient {
&self,
model_id: &str,
request: TokenizationTaskRequest,
headers: HeaderMap,
) -> Result<TokenizationResults, Error> {
let request = request_with_model_id(request, model_id);
let request = request_with_model_id(request, model_id, headers);
Ok(self
.client(model_id)?
.tokenization_task_predict(request)
Expand All @@ -107,8 +109,9 @@ impl NlpClient {
&self,
model_id: &str,
request: TokenClassificationTaskRequest,
headers: HeaderMap,
) -> Result<TokenClassificationResults, Error> {
let request = request_with_model_id(request, model_id);
let request = request_with_model_id(request, model_id, headers);
Ok(self
.client(model_id)?
.token_classification_task_predict(request)
Expand All @@ -120,8 +123,9 @@ impl NlpClient {
&self,
model_id: &str,
request: TextGenerationTaskRequest,
headers: HeaderMap,
) -> Result<GeneratedTextResult, Error> {
let request = request_with_model_id(request, model_id);
let request = request_with_model_id(request, model_id, headers);
Ok(self
.client(model_id)?
.text_generation_task_predict(request)
Expand All @@ -133,8 +137,9 @@ impl NlpClient {
&self,
model_id: &str,
request: ServerStreamingTextGenerationTaskRequest,
headers: HeaderMap,
) -> Result<BoxStream<Result<GeneratedTextStreamResult, Error>>, Error> {
let request = request_with_model_id(request, model_id);
let request = request_with_model_id(request, model_id, headers);
let response_stream = self
.client(model_id)?
.server_streaming_text_generation_task_predict(request)
Expand All @@ -146,8 +151,9 @@ impl NlpClient {
}
}

fn request_with_model_id<T>(request: T, model_id: &str) -> Request<T> {
let mut request = Request::new(request);
fn request_with_model_id<T>(request: T, model_id: &str, headers: HeaderMap) -> Request<T> {
let metadata = MetadataMap::from_headers(headers);
let mut request = Request::from_parts(metadata, Extensions::new(), request);
request
.metadata_mut()
.insert(MODEL_ID_HEADER_NAME, model_id.parse().unwrap());
Expand Down
5 changes: 4 additions & 1 deletion src/clients/tgis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
limitations under the License.
*/

use std::collections::HashMap;

use axum::http::HeaderMap;
use futures::{StreamExt, TryStreamExt};
use ginepro::LoadBalancedChannel;
use tonic::Code;
Expand Down Expand Up @@ -99,6 +99,7 @@ impl TgisClient {
pub async fn generate(
&self,
request: BatchedGenerationRequest,
_headers: HeaderMap,
) -> Result<BatchedGenerationResponse, Error> {
let model_id = request.model_id.as_str();
Ok(self.client(model_id)?.generate(request).await?.into_inner())
Expand All @@ -107,6 +108,7 @@ impl TgisClient {
pub async fn generate_stream(
&self,
request: SingleGenerationRequest,
_headers: HeaderMap,
) -> Result<BoxStream<Result<GenerationResponse, Error>>, Error> {
let model_id = request.model_id.as_str();
let response_stream = self
Expand All @@ -122,6 +124,7 @@ impl TgisClient {
pub async fn tokenize(
&self,
request: BatchedTokenizeRequest,
_headers: HeaderMap,
) -> Result<BatchedTokenizeResponse, Error> {
let model_id = request.model_id.as_str();
Ok(self.client(model_id)?.tokenize(request).await?.into_inner())
Expand Down
92 changes: 90 additions & 2 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,18 @@
*/

use std::{
collections::HashMap,
collections::{HashMap, HashSet},
path::{Path, PathBuf},
};

use serde::Deserialize;
use tracing::{debug, error, warn};
use tracing::{debug, error, info, warn};

use crate::clients::chunker::DEFAULT_MODEL_ID;

// Placeholder to add default allowed headers
const DEFAULT_ALLOWED_HEADERS: &[&str] = &[];

#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("failed to read config from `{path}`: {error}")]
Expand Down Expand Up @@ -143,6 +146,9 @@ pub struct OrchestratorConfig {
/// Map of TLS connections, allowing reuse across services
/// that may require the same TLS information
pub tls: Option<HashMap<String, TlsConfig>>,
// List of header keys allowed to be passed to downstream servers
#[serde(default)]
pub passthrough_headers: HashSet<String>,
}

impl OrchestratorConfig {
Expand All @@ -166,6 +172,27 @@ impl OrchestratorConfig {
warn!("no chunker configs provided");
}

if config.passthrough_headers.is_empty() {
info!("No allowed headers specified");
}

// Add default headers to allowed_headers list
debug!(
"Adding default headers: [{}]. ",
DEFAULT_ALLOWED_HEADERS.join(", ")
);

// Lowercase all header for case-insensitive comparison
config.passthrough_headers = config
.passthrough_headers
.into_iter()
.map(|h| h.to_lowercase())
.collect::<HashSet<String>>();

config
.passthrough_headers
.extend(DEFAULT_ALLOWED_HEADERS.iter().map(|h| h.to_lowercase()));

config.apply_named_tls_configs()?;
config.validate()?;

Expand Down Expand Up @@ -521,4 +548,65 @@ tls:
.expect_err("Config should not have been validated");
assert!(matches!(error, Error::DetectorChunkerNotFound { .. }))
}

#[test]
fn test_passthrough_headers_empty_config() -> Result<(), Error> {
let s = r#"
generation:
provider: tgis
service:
hostname: localhost
port: 8000
chunkers:
sentence-en:
type: sentence
service:
hostname: localhost
port: 9000
detectors:
hap:
service:
hostname: localhost
port: 9000
tls: detector
chunker_id: sentence-fr
default_threshold: 0.5
"#;
let config: OrchestratorConfig = serde_yml::from_str(s).unwrap();
assert!(config.passthrough_headers.is_empty());
Ok(())
}
#[test]
fn test_allowed_headers_passthrough_in_config() -> Result<(), Error> {
let s = r#"
generation:
provider: tgis
service:
hostname: localhost
port: 8000
chunkers:
sentence-en:
type: sentence
service:
hostname: localhost
port: 9000
detectors:
hap:
service:
hostname: localhost
port: 9000
tls: detector
chunker_id: sentence-fr
default_threshold: 0.5
passthrough_headers:
- test-header
"#;
let config: OrchestratorConfig = serde_yml::from_str(s).unwrap();
assert_eq!(
config.passthrough_headers,
HashSet::from(["test-header".to_string()])
);
Ok(())
}
}
Loading

0 comments on commit 5ae4c2c

Please sign in to comment.