Skip to content

Commit

Permalink
support AWS_MSK_IAM authentication (#789)
Browse files Browse the repository at this point in the history
  • Loading branch information
emef authored Nov 22, 2024
1 parent bd3acb3 commit 4d57354
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 96 deletions.
42 changes: 42 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion crates/arroyo-connectors/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ regex = "1"
##########################

# Kafka
aws-sdk-kafka = { version = "1.44" }
aws-msk-iam-sasl-signer = "1.0.0"
rdkafka = { version = "0.36", features = ["cmake-build", "tracing", "sasl", "ssl-vendored"] }
rdkafka-sys = "4.5.0"
rdkafka-sys = "4.7.0"
sasl2-sys = { version = "0.1.6", features = ["vendored"] }

# SSE
Expand Down
159 changes: 108 additions & 51 deletions crates/arroyo-connectors/src/kafka/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,28 @@ use arroyo_rpc::schema_resolver::{
};
use arroyo_rpc::{schema_resolver, var_str::VarStr, OperatorConfig};
use arroyo_types::string_to_map;
use aws_config::Region;
use aws_msk_iam_sasl_signer::generate_auth_token;
use futures::TryFutureExt;
use rdkafka::{
consumer::{BaseConsumer, Consumer},
ClientConfig, Message, Offset, TopicPartitionList,
client::OAuthToken,
consumer::{Consumer, ConsumerContext},
producer::ProducerContext,
ClientConfig, ClientContext, Message, Offset, TopicPartitionList,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::borrow::Cow;
use std::collections::HashMap;
use std::num::NonZeroU32;
use std::sync::Arc;
use std::thread;
use std::time::{Duration, Instant, SystemTime};
use tokio::runtime::Handle;
use tokio::sync::mpsc::Sender;
use tokio::sync::oneshot;
use tokio::sync::oneshot::Receiver;
use tokio::time::timeout;
use tonic::Status;
use tracing::{error, info, warn};
use typify::import_types;
Expand Down Expand Up @@ -77,6 +84,9 @@ impl KafkaConnector {
username: VarStr::new(pull_opt("auth.username", options)?),
password: VarStr::new(pull_opt("auth.password", options)?),
},
Some("aws_msk_iam") => KafkaConfigAuthentication::AwsMskIam {
region: pull_opt("auth.region", options)?,
},
Some(other) => bail!("unknown auth type '{}'", other),
};

Expand Down Expand Up @@ -362,7 +372,7 @@ impl Connector for KafkaConnector {
read_mode,
group_id_prefix,
} => {
let mut client_configs = client_configs(&profile, &table);
let mut client_configs = client_configs(&profile, Some(table.clone()))?;
if let Some(ReadMode::ReadCommitted) = read_mode {
client_configs
.insert("isolation.level".to_string(), "read_committed".to_string());
Expand Down Expand Up @@ -399,6 +409,7 @@ impl Connector for KafkaConnector {
schema_resolver,
bad_data: config.bad_data,
client_configs,
context: Context::new(Some(profile.clone())),
messages_per_second: NonZeroU32::new(
config
.rate_limit
Expand All @@ -422,7 +433,8 @@ impl Connector for KafkaConnector {
key_field: key_field.clone(),
key_col: None,
write_futures: vec![],
client_config: client_configs(&profile, &table),
client_config: client_configs(&profile, Some(table.clone()))?,
context: Context::new(Some(profile.clone())),
topic: table.topic,
serializer: ArrowSerializer::new(
config.format.expect("Format must be defined for KafkaSink"),
Expand Down Expand Up @@ -467,38 +479,24 @@ impl KafkaTester {
},
);

// TODO: merge this with client_configs()
match &self.connection.authentication {
KafkaConfigAuthentication::None {} => {}
KafkaConfigAuthentication::Sasl {
mechanism,
password,
protocol,
username,
} => {
client_config.set("sasl.mechanism", mechanism);
client_config.set("security.protocol", protocol);
client_config.set(
"sasl.username",
username.sub_env_vars().map_err(|e| e.to_string())?,
);
client_config.set(
"sasl.password",
password.sub_env_vars().map_err(|e| e.to_string())?,
);
}
};

if let Some(table) = table {
for (k, v) in table.client_configs {
client_config.set(k, v);
}
for (k, v) in client_configs(&self.connection, table)
.map_err(|e| e.to_string())?
.into_iter()
{
client_config.set(k, v);
}

let context = Context::new(Some(self.connection.clone()));
let client: BaseConsumer = client_config
.create()
.create_with_context(context)
.map_err(|e| format!("invalid kafka config: {:?}", e))?;

// NOTE: this is required to trigger an oauth token refresh (when using
// OAUTHBEARER auth).
if client.poll(Duration::from_secs(0)).is_some() {
return Err("unexpected poll event from new consumer".to_string());
}

tokio::task::spawn_blocking(move || {
client
.fetch_metadata(None, Duration::from_secs(10))
Expand Down Expand Up @@ -903,7 +901,10 @@ impl SourceOffset {
}
}

pub fn client_configs(connection: &KafkaConfig, table: &KafkaTable) -> HashMap<String, String> {
pub fn client_configs(
connection: &KafkaConfig,
table: Option<KafkaTable>,
) -> anyhow::Result<HashMap<String, String>> {
let mut client_configs: HashMap<String, String> = HashMap::new();

match &connection.authentication {
Expand All @@ -916,27 +917,83 @@ pub fn client_configs(connection: &KafkaConfig, table: &KafkaTable) -> HashMap<S
} => {
client_configs.insert("sasl.mechanism".to_string(), mechanism.to_string());
client_configs.insert("security.protocol".to_string(), protocol.to_string());
client_configs.insert(
"sasl.username".to_string(),
username
.sub_env_vars()
.expect("Missing env-vars for Kafka username"),
);
client_configs.insert(
"sasl.password".to_string(),
password
.sub_env_vars()
.expect("Missing env-vars for Kafka password"),
);
client_configs.insert("sasl.username".to_string(), username.sub_env_vars()?);
client_configs.insert("sasl.password".to_string(), password.sub_env_vars()?);
}
KafkaConfigAuthentication::AwsMskIam { region: _ } => {
client_configs.insert("sasl.mechanism".to_string(), "OAUTHBEARER".to_string());
client_configs.insert("security.protocol".to_string(), "SASL_SSL".to_string());
}
};

client_configs.extend(
table
.client_configs
.iter()
.map(|(k, v)| (k.to_string(), v.to_string())),
);
if let Some(table) = table {
client_configs.extend(
table
.client_configs
.iter()
.map(|(k, v)| (k.to_string(), v.to_string())),
);
}

Ok(client_configs)
}

type BaseConsumer = rdkafka::consumer::BaseConsumer<Context>;
type FutureProducer = rdkafka::producer::FutureProducer<Context>;
type StreamConsumer = rdkafka::consumer::StreamConsumer<Context>;

#[derive(Clone)]
pub struct Context {
config: Option<KafkaConfig>,
}

impl Context {
pub fn new(config: Option<KafkaConfig>) -> Self {
Self { config }
}
}

impl ConsumerContext for Context {}

impl ProducerContext for Context {
type DeliveryOpaque = ();
fn delivery(
&self,
_delivery_result: &rdkafka::message::DeliveryResult<'_>,
_delivery_opaque: Self::DeliveryOpaque,
) {
}
}

impl ClientContext for Context {
const ENABLE_REFRESH_OAUTH_TOKEN: bool = true;

client_configs
fn generate_oauth_token(
&self,
_oauthbearer_config: Option<&str>,
) -> Result<OAuthToken, Box<dyn std::error::Error>> {
if let Some(KafkaConfigAuthentication::AwsMskIam { region }) =
self.config.as_ref().map(|c| &c.authentication)
{
let region = Region::new(region.clone());
let rt = Handle::current();

let (token, expiration_time_ms) = {
let handle = thread::spawn(move || {
rt.block_on(async {
timeout(Duration::from_secs(10), generate_auth_token(region.clone())).await
})
});
handle.join().unwrap()??
};

Ok(OAuthToken {
token,
principal_name: "".to_string(),
lifetime_ms: expiration_time_ms,
})
} else {
Err(anyhow!("only AWS_MSK_IAM is supported for sasl oauth").into())
}
}
}
Loading

0 comments on commit 4d57354

Please sign in to comment.