Skip to content

Commit

Permalink
Add detector_id in response for orchestrator apis (#278)
Browse files Browse the repository at this point in the history
* added detector_id to TokenClassificationResults

Signed-off-by: Shonda-Adena-Witherspoon <[email protected]>

* added detector_id to detection_result model

Signed-off-by: Shonda-Adena-Witherspoon <[email protected]>

* updated detector_id field to be optional

Signed-off-by: Shonda-Adena-Witherspoon <[email protected]>

* added detecor_id to contentanaylysisresponse data model

Signed-off-by: Shonda-Adena-Witherspoon <[email protected]>

* modified relevant unit tests to add detector_id field in expected results

Signed-off-by: Shonda-Adena-Witherspoon <[email protected]>

* updated streaming_content to utilize new detector_id field

Signed-off-by: Shonda-Adena-Witherspoon <[email protected]>

* fixed issues with clippy

Signed-off-by: Shonda-Adena-Witherspoon <[email protected]>

* updated formatting using cargo +nightly fmt

Signed-off-by: Shonda-Adena-Witherspoon <[email protected]>

* updated addition of detector_id in detectorresult instantiations and reverted nlp proto

Signed-off-by: Shonda-Adena-Witherspoon <[email protected]>

* Update src/orchestrator/unary.rs

Update setting the detector_id in the  TokenClassificationResult.

Co-authored-by: Dan Clark <[email protected]>
Signed-off-by: swith004 <[email protected]>

* updated TokenClassificationResult data model and detector_id setting

Signed-off-by: Shonda-Adena-Witherspoon <[email protected]>

* updated detector_id comments for consistency

Signed-off-by: Shonda-Adena-Witherspoon <[email protected]>

* Update comment for detector_id

Co-authored-by: Evaline Ju <[email protected]>
Signed-off-by: swith004 <[email protected]>

* add space in comment for detector_id

Co-authored-by: Evaline Ju <[email protected]>
Signed-off-by: swith004 <[email protected]>

* Update comment for detector_id

Co-authored-by: Evaline Ju <[email protected]>
Signed-off-by: swith004 <[email protected]>

* add detector_id comment to new line

Co-authored-by: Evaline Ju <[email protected]>
Signed-off-by: swith004 <[email protected]>

* update detector_id comment

Co-authored-by: Evaline Ju <[email protected]>
Signed-off-by: swith004 <[email protected]>

---------

Signed-off-by: Shonda-Adena-Witherspoon <[email protected]>
Signed-off-by: swith004 <[email protected]>
Co-authored-by: Shonda-Adena-Witherspoon <[email protected]>
Co-authored-by: Dan Clark <[email protected]>
Co-authored-by: Evaline Ju <[email protected]>
  • Loading branch information
4 people authored Jan 21, 2025
1 parent cf5dd83 commit 1dd51fb
Show file tree
Hide file tree
Showing 12 changed files with 75 additions and 27 deletions.
3 changes: 1 addition & 2 deletions src/clients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ use ginepro::LoadBalancedChannel;
use hyper_timeout::TimeoutConnector;
use hyper_util::rt::TokioExecutor;
use tonic::{metadata::MetadataMap, Request};
use tower::timeout::TimeoutLayer;
use tower::ServiceBuilder;
use tower::{timeout::TimeoutLayer, ServiceBuilder};
use tracing::{debug, instrument, Span};
use tracing_opentelemetry::OpenTelemetrySpanExt;
use url::Url;
Expand Down
3 changes: 3 additions & 0 deletions src/clients/detector/text_contents.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ pub struct ContentAnalysisResponse {
pub detection: String,
/// Detection type or aggregate detection label
pub detection_type: String,
/// Optional, ID of Detector
pub detector_id: Option<String>,
/// Score of detection
pub score: f64,
/// Optional, any applicable evidence for detection
Expand All @@ -147,6 +149,7 @@ impl From<ContentAnalysisResponse> for crate::models::TokenClassificationResult
word: value.text,
entity: value.detection,
entity_group: value.detection_type,
detector_id: value.detector_id,
score: value.score,
token_count: None,
}
Expand Down
3 changes: 1 addition & 2 deletions src/clients/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ use hyper_rustls::HttpsConnector;
use hyper_timeout::TimeoutConnector;
use hyper_util::client::legacy::connect::HttpConnector;
use serde::{de::DeserializeOwned, Serialize};
use tower::timeout::Timeout;
use tower::Service;
use tower::{timeout::Timeout, Service};
use tower_http::{
classify::{
NeverClassifyEos, ServerErrorsAsFailures, ServerErrorsFailureClass, SharedClassifier,
Expand Down
8 changes: 5 additions & 3 deletions src/clients/otel_grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,23 @@
*/

use crate::utils::trace::{current_trace_id, with_traceparent_header};
use http::{Request, Response, StatusCode};
use pin_project_lite::pin_project;
use std::{
error::Error,
future::Future,
pin::Pin,
task::{Context, Poll},
};

use http::{Request, Response, StatusCode};
use pin_project_lite::pin_project;
use tokio::time::Instant;
use tonic::client::GrpcService;
use tower::Layer;
use tracing::{error, info, info_span, Span};
use tracing_opentelemetry::OpenTelemetrySpanExt;

use crate::utils::trace::{current_trace_id, with_traceparent_header};

// Adapted from https://github.com/davidB/tracing-opentelemetry-instrumentation-sdk/tree/main/tonic-tracing-opentelemetry
/// Layer for grpc (tonic client):
///
Expand Down
6 changes: 6 additions & 0 deletions src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,9 @@ pub struct TokenClassificationResult {
/// Aggregate label, if applicable
pub entity_group: String,

/// Optional id of detector (model) responsible for result(s)
pub detector_id: Option<String>,

/// Confidence-like score of this classification prediction in [0, 1]
pub score: f64,

Expand Down Expand Up @@ -894,6 +897,9 @@ pub struct DetectionResult {
// The detection class
pub detection: String,

// Optional id of the detector
pub detector_id: Option<String>,

// The confidence level in the detection class
pub score: f64,

Expand Down
4 changes: 3 additions & 1 deletion src/orchestrator/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,9 @@ async fn detection_task(
.into_iter()
.flat_map(|r| {
r.into_iter().filter_map(|resp| {
let result: TokenClassificationResult = resp.into();
let mut result: TokenClassificationResult = resp.into();
// add detector_id
result.detector_id = Some(detector_id.clone());
(result.score >= threshold).then_some(result)
})
})
Expand Down
6 changes: 4 additions & 2 deletions src/orchestrator/streaming/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -458,13 +458,15 @@ mod tests {
text: &str,
detection: &str,
detection_type: &str,
detector_id: &str,
) -> TokenClassificationResult {
TokenClassificationResult {
start: span.0 as u32,
end: span.1 as u32,
word: text.to_string(),
entity: detection.to_string(),
entity_group: detection_type.to_string(),
detector_id: Some(detector_id.to_string()),
score: 0.99,
token_count: None,
}
Expand Down Expand Up @@ -498,11 +500,11 @@ mod tests {
let partial_span = (chunk_token.start + 2, chunk_token.end - 2);

let (detector_tx1, detector_rx1) = mpsc::channel(1);
let detection = get_detection_obj(whole_span, text, "has_HAP", "HAP");
let detection = get_detection_obj(whole_span, text, "has_HAP", "HAP", "en-hap");
let _ = detector_tx1.send((chunk.clone(), vec![detection])).await;

let (detector_tx2, detector_rx2) = mpsc::channel(1);
let detection = get_detection_obj(partial_span, text, "email_ID", "PII");
let detection = get_detection_obj(partial_span, text, "email_ID", "PII", "en-pii");
let _ = detector_tx2.send((chunk.clone(), vec![detection])).await;

// Push HAP after PII to make sure detection ordering is not coincidental
Expand Down
28 changes: 16 additions & 12 deletions src/orchestrator/streaming_content_detection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,29 @@
///////////////////////////////////////////////////////////////////////////////////
mod aggregator;

use aggregator::Aggregator;
use std::{collections::HashMap, pin::Pin, sync::Arc, time::Duration};

use aggregator::Aggregator;
use futures::{future::try_join_all, stream::Peekable, Stream, StreamExt, TryStreamExt};
use hyper::HeaderMap;
use tokio::sync::{broadcast, mpsc};
use tokio_stream::wrappers::{BroadcastStream, ReceiverStream};
use tracing::{debug, error, info, instrument, warn};

use super::{streaming::Detections, Context, Error, Orchestrator, StreamingContentDetectionTask};
use crate::clients::{
chunker::{tokenize_whole_doc_stream, ChunkerClient, DEFAULT_CHUNKER_ID},
detector::ContentAnalysisRequest,
TextContentsDetectorClient,
};
use crate::models::{
DetectorParams, StreamingContentDetectionRequest, StreamingContentDetectionResponse,
TokenClassificationResult,
use crate::{
clients::{
chunker::{tokenize_whole_doc_stream, ChunkerClient, DEFAULT_CHUNKER_ID},
detector::ContentAnalysisRequest,
TextContentsDetectorClient,
},
models::{
DetectorParams, StreamingContentDetectionRequest, StreamingContentDetectionResponse,
TokenClassificationResult,
},
orchestrator::{get_chunker_ids, streaming::Chunk},
pb::caikit::runtime::chunkers,
};
use crate::orchestrator::{get_chunker_ids, streaming::Chunk};
use crate::pb::caikit::runtime::chunkers;

type ContentInputStream =
Pin<Box<dyn Stream<Item = Result<StreamingContentDetectionRequest, Error>> + Send>>;
Expand Down Expand Up @@ -419,7 +421,9 @@ async fn detection_task(
.into_iter()
.flat_map(|r| {
r.into_iter().filter_map(|resp| {
let result: TokenClassificationResult = resp.into();
let mut result: TokenClassificationResult = resp.into();
// add detector_id
result.detector_id = Some(detector_id.clone());
(result.score >= threshold).then_some(result)
})
})
Expand Down
7 changes: 5 additions & 2 deletions src/orchestrator/streaming_content_detection/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ impl AggregationActor {
text: r.word,
detection: r.entity,
detection_type: r.entity_group,
detector_id: r.detector_id,
score: r.score,
evidence: None,
})
Expand Down Expand Up @@ -205,13 +206,15 @@ mod tests {
text: &str,
detection: &str,
detection_type: &str,
detector_id: &str,
) -> TokenClassificationResult {
TokenClassificationResult {
start: span.0 as u32,
end: span.1 as u32,
word: text.to_string(),
entity: detection.to_string(),
entity_group: detection_type.to_string(),
detector_id: Some(detector_id.to_string()),
score: 0.99,
token_count: None,
}
Expand Down Expand Up @@ -245,11 +248,11 @@ mod tests {
let partial_span = (chunk_token.start + 2, chunk_token.end - 2);

let (detector_tx1, detector_rx1) = mpsc::channel(1);
let detection = get_detection_obj(whole_span, text, "has_HAP", "HAP");
let detection = get_detection_obj(whole_span, text, "has_HAP", "HAP", "en-hap");
let _ = detector_tx1.send((chunk.clone(), vec![detection])).await;

let (detector_tx2, detector_rx2) = mpsc::channel(1);
let detection = get_detection_obj(partial_span, text, "email_ID", "PII");
let detection = get_detection_obj(partial_span, text, "email_ID", "PII", "en-pii");
let _ = detector_tx2.send((chunk.clone(), vec![detection])).await;

// Push HAP after PII to make sure detection ordering is not coincidental
Expand Down
25 changes: 25 additions & 0 deletions src/orchestrator/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,8 @@ pub async fn detect(
.into_iter()
.filter_map(|resp| {
let mut result: TokenClassificationResult = resp.into();
// add detector_id
result.detector_id = Some(detector_id.clone());
result.start += chunk.offset as u32;
result.end += chunk.offset as u32;
(result.score >= threshold).then_some(result)
Expand Down Expand Up @@ -756,6 +758,8 @@ pub async fn detect_content(
.filter_map(|mut resp| {
resp.start += chunk.offset;
resp.end += chunk.offset;
// add detector_id
resp.detector_id = Some(detector_id.clone());
(resp.score >= threshold).then_some(resp)
})
.collect::<Vec<_>>()
Expand Down Expand Up @@ -803,6 +807,11 @@ pub async fn detect_for_generation(
results
.into_iter()
.filter(|detection| detection.score > threshold)
.map(|mut detection| {
// add detector_id
detection.detector_id = Some(detector_id.clone());
detection
})
.collect()
})
.map_err(|error| Error::DetectorRequestFailed {
Expand Down Expand Up @@ -844,6 +853,11 @@ pub async fn detect_for_chat(
results
.into_iter()
.filter(|detection| detection.score > threshold)
.map(|mut detection| {
//add detector_id
detection.detector_id = Some(detector_id.clone());
detection
})
.collect()
})
.map_err(|error| Error::DetectorRequestFailed {
Expand Down Expand Up @@ -899,6 +913,11 @@ pub async fn detect_for_context(
results
.into_iter()
.filter(|detection| detection.score > threshold)
.map(|mut detection| {
//add detector_id
detection.detector_id = Some(detector_id.clone());
detection
})
.collect()
})
.map_err(|error| Error::DetectorRequestFailed {
Expand Down Expand Up @@ -1131,6 +1150,7 @@ mod tests {
word: second_sentence.clone(),
entity: "has_HAP".to_string(),
entity_group: "hap".to_string(),
detector_id: Some(detector_id.to_string()),
score: 0.9,
token_count: None,
}];
Expand All @@ -1151,6 +1171,7 @@ mod tests {
text: first_sentence.clone(),
detection: "has_HAP".to_string(),
detection_type: "hap".to_string(),
detector_id: Some(detector_id.to_string()),
score: 0.1,
evidence: Some(vec![]),
}],
Expand All @@ -1160,6 +1181,7 @@ mod tests {
text: second_sentence.clone(),
detection: "has_HAP".to_string(),
detection_type: "hap".to_string(),
detector_id: Some(detector_id.to_string()),
score: 0.9,
evidence: Some(vec![]),
}],
Expand Down Expand Up @@ -1300,6 +1322,7 @@ mod tests {
let expected_response: Vec<DetectionResult> = vec![DetectionResult {
detection_type: "relevance".to_string(),
detection: "is_relevant".to_string(),
detector_id: Some(detector_id.to_string()),
score: 0.9,
evidence: Some(
[EvidenceObj {
Expand All @@ -1325,6 +1348,7 @@ mod tests {
.then_return(Ok(vec![DetectionResult {
detection_type: "relevance".to_string(),
detection: "is_relevant".to_string(),
detector_id: Some(detector_id.to_string()),
score: 0.9,
evidence: Some(
[EvidenceObj {
Expand Down Expand Up @@ -1393,6 +1417,7 @@ mod tests {
.then_return(Ok(vec![DetectionResult {
detection_type: "relevance".to_string(),
detection: "is_relevant".to_string(),
detector_id: Some(detector_id.to_string()),
score: 0.1,
evidence: None,
}]));
Expand Down
3 changes: 2 additions & 1 deletion src/utils/tls.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::{fs::File, io, path::PathBuf, sync::Arc};

use http_serde::http::StatusCode;
use hyper_rustls::ConfigBuilderExt;
use rustls::{
Expand All @@ -6,7 +8,6 @@ use rustls::{
ClientConfig, DigitallySignedStruct, SignatureScheme,
};
use serde::Deserialize;
use std::{fs::File, io, path::PathBuf, sync::Arc};

use crate::{clients, config::TlsConfig};

Expand Down
6 changes: 4 additions & 2 deletions src/utils/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ use tracing::{error, info, info_span, Span};
use tracing_opentelemetry::{MetricsLayer, OpenTelemetrySpanExt};
use tracing_subscriber::{layer::SubscriberExt, EnvFilter, Layer};

use crate::args::{LogFormat, OtlpProtocol, TracingConfig};
use crate::clients::http::TracedResponse;
use crate::{
args::{LogFormat, OtlpProtocol, TracingConfig},
clients::http::TracedResponse,
};

#[derive(Debug, thiserror::Error)]
pub enum TracingError {
Expand Down

0 comments on commit 1dd51fb

Please sign in to comment.