Skip to content

Commit

Permalink
Optimize tls ffi implementation, remove ffi feature cfg.
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaofei0800 committed Dec 8, 2023
1 parent 64ba69d commit 4830988
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 52 deletions.
10 changes: 8 additions & 2 deletions include/tquic.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ typedef struct http3_conn_t http3_conn_t;
typedef struct http3_headers_t http3_headers_t;

typedef struct quic_tls_config_select_methods_t {
const SSL_CTX *(*get_default)(void *ctx);
const SSL_CTX *(*select)(void *ctx, const uint8_t *server_name, size_t server_name_len);
SSL_CTX *(*get_default)(void *ctx);
SSL_CTX *(*select)(void *ctx, const uint8_t *server_name, size_t server_name_len);
} quic_tls_config_select_methods_t;

typedef void *quic_tls_config_select_context_t;
Expand Down Expand Up @@ -416,6 +416,12 @@ void quic_config_set_tls_selector(struct quic_config_t *config,
const struct quic_tls_config_select_methods_t *methods,
quic_tls_config_select_context_t context);

/**
* Set TLS config.
* The caller is responsible for the memory of SSL_CTX when use this function.
*/
void quic_config_set_tls_config(struct quic_config_t *config, SSL_CTX *ssl_ctx);

/**
* Create a QUIC endpoint.
*
Expand Down
2 changes: 1 addition & 1 deletion src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2299,7 +2299,7 @@ mod tests {
true,
)?;
tls_config.set_ticket_key(&vec![0x01; 48])?;
srv_conf.set_tls_config_selector(Arc::new(tls_config));
srv_conf.set_tls_config(tls_config);

let mut case_conf = CaseConf::default();
case_conf.session = Some(TestPair::new_test_session_state());
Expand Down
23 changes: 17 additions & 6 deletions src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,17 @@ pub extern "C" fn quic_config_set_tls_selector(
config.set_tls_config_selector(Arc::new(selector));
}

/// Set TLS config.
/// The caller is responsible for the memory of SSL_CTX when use this function.
#[no_mangle]
pub extern "C" fn quic_config_set_tls_config(
config: &mut Config,
ssl_ctx: *mut crate::tls::SslCtx,
) {
let tls_config = crate::tls::TlsConfig::new_with_ssl_ctx(ssl_ctx);
config.set_tls_config(tls_config);
}

/// Create a QUIC endpoint.
///
/// The caller is responsible for the memory of the Endpoint and properly
Expand Down Expand Up @@ -949,12 +960,12 @@ pub struct TlsConfigSelectorContext(*mut c_void);

#[repr(C)]
pub struct TlsConfigSelectMethods {
pub get_default: fn(ctx: *mut c_void) -> *const crate::tls::SslCtx,
pub get_default: fn(ctx: *mut c_void) -> *mut crate::tls::SslCtx,
pub select: fn(
ctx: *mut c_void,
server_name: *const u8,
server_name_len: size_t,
) -> *const crate::tls::SslCtx,
) -> *mut crate::tls::SslCtx,
}

#[repr(C)]
Expand All @@ -967,17 +978,17 @@ unsafe impl Send for TlsConfigSelector {}
unsafe impl Sync for TlsConfigSelector {}

impl crate::tls::TlsConfigSelector for TlsConfigSelector {
fn get_default(&self) -> Option<&crate::tls::TlsConfig> {
fn get_default(&self) -> Option<Arc<crate::tls::TlsConfig>> {
let ssl_ctx = unsafe { ((*self.methods).get_default)(self.context.0) };
if ssl_ctx.is_null() {
return None;
}

let tls_config = unsafe { &(*(ssl_ctx as *const crate::tls::TlsConfig)) };
let tls_config = Arc::new(crate::tls::TlsConfig::new_with_ssl_ctx(ssl_ctx));
Some(tls_config)
}

fn select(&self, server_name: &str) -> Option<&crate::tls::TlsConfig> {
fn select(&self, server_name: &str) -> Option<Arc<crate::tls::TlsConfig>> {
let ssl_ctx = unsafe {
((*self.methods).select)(
self.context.0,
Expand All @@ -989,7 +1000,7 @@ impl crate::tls::TlsConfigSelector for TlsConfigSelector {
return None;
}

let tls_config = unsafe { &(*(ssl_ctx as *const crate::tls::TlsConfig)) };
let tls_config = Arc::new(crate::tls::TlsConfig::new_with_ssl_ctx(ssl_ctx));
Some(tls_config)
}
}
Expand Down
4 changes: 3 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,9 @@ impl Config {

/// Set TLS config.
pub fn set_tls_config(&mut self, tls_config: tls::TlsConfig) {
self.set_tls_config_selector(Arc::new(tls_config));
self.set_tls_config_selector(Arc::new(tls::DefaultTlsConfigSelector {
tls_config: Arc::new(tls_config),
}));
}

/// Set TLS config selector. Used for selecting TLS config according to SNI.
Expand Down
54 changes: 27 additions & 27 deletions src/tls/boringssl/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,56 +127,56 @@ static QUICHE_STREAM_METHOD: SslQuicMethod = SslQuicMethod {

/// Rust wrapper of SSL_CTX which holds various configuration and data relevant
/// to SSL/TLS session establishment.
#[repr(transparent)]
pub struct Context(*mut SslCtx);
pub struct Context {
ctx_raw: *mut SslCtx,
owned: bool,
}

impl Drop for Context {
fn drop(&mut self) {
if self.owned {
unsafe { SSL_CTX_free(self.as_mut_ptr()) }
}
}
}

#[cfg(not(feature = "ffi"))]
impl Context {
/// Create a new TLS context.
pub fn new() -> Result<Context> {
unsafe {
let ctx_raw = SSL_CTX_new(TLS_method());

let mut ctx = Context(ctx_raw);
let mut ctx = Context {
ctx_raw,
owned: true,
};

ctx.set_session_callback();
ctx.set_default_verify_paths()?;
Ok(ctx)
}
}

/// Return the mutable pointer of the inner SSL_CTX.
fn as_mut_ptr(&mut self) -> *mut SslCtx {
self.0
}

/// Return the const pointer of the inner SSL_CTX.
fn as_ptr(&self) -> *const SslCtx {
self.0
}
}

// The caller is responsible for the memory of SSL_CTX when using ffi.
#[cfg(not(feature = "ffi"))]
impl Drop for Context {
fn drop(&mut self) {
unsafe { SSL_CTX_free(self.as_mut_ptr()) }
/// Create a new TLS context with SSL_CTX.
/// The caller is responsible for the memory of SSL_CTX when use this function.
pub fn new_with_ssl_ctx(ssl_ctx: *mut SslCtx) -> Context {
Self {
ctx_raw: ssl_ctx,
owned: false,
}
}
}

#[cfg(feature = "ffi")]
impl Context {
/// Return the mutable pointer of the inner SSL_CTX.
fn as_mut_ptr(&mut self) -> *mut SslCtx {
self as *mut Context as *mut SslCtx
self.ctx_raw
}

/// Return the const pointer of the inner SSL_CTX.
fn as_ptr(&self) -> *const SslCtx {
self as *const Context as *const SslCtx
self.ctx_raw
}
}

impl Context {
/// Create a new TLS session.
pub fn new_session(&self) -> Result<Session> {
unsafe {
let ssl = SSL_new(self.as_ptr());
Expand Down
42 changes: 27 additions & 15 deletions src/tls/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,11 @@ where
}
}

#[repr(transparent)]
pub struct TlsConfig {
/// Boringssl SSL context.
tls_ctx: boringssl::tls::Context,
}

#[cfg(not(feature = "ffi"))]
impl TlsConfig {
/// Create a new TlsConfig.
pub fn new() -> Result<TlsConfig> {
Expand All @@ -84,6 +82,14 @@ impl TlsConfig {
Ok(TlsConfig { tls_ctx })
}

/// Create a new TlsConfig with SSL_CTX.
/// The caller is responsible for the memory of SSL_CTX when use this function.
pub fn new_with_ssl_ctx(ssl_ctx: *mut boringssl::tls::SslCtx) -> TlsConfig {
Self {
tls_ctx: boringssl::tls::Context::new_with_ssl_ctx(ssl_ctx),
}
}

/// Create a new client side TlsConfig.
pub fn new_client_config(
application_protos: Vec<Vec<u8>>,
Expand Down Expand Up @@ -188,25 +194,29 @@ impl TlsConfig {
}
}

impl TlsConfigSelector for TlsConfig {
pub(crate) struct DefaultTlsConfigSelector {
pub tls_config: Arc<TlsConfig>,
}

impl TlsConfigSelector for DefaultTlsConfigSelector {
// TODO: support local and peer address.
/// Get default TLS config.
fn get_default(&self) -> Option<&TlsConfig> {
Some(self)
fn get_default(&self) -> Option<Arc<TlsConfig>> {
Some(self.tls_config.clone())
}

/// Find TLS config according to server name.
fn select(&self, _server_name: &str) -> Option<&TlsConfig> {
Some(self)
fn select(&self, _server_name: &str) -> Option<Arc<TlsConfig>> {
Some(self.tls_config.clone())
}
}

pub trait TlsConfigSelector: Send + Sync {
/// Get default TLS config.
fn get_default(&self) -> Option<&TlsConfig>;
fn get_default(&self) -> Option<Arc<TlsConfig>>;

/// Find TLS config according to server name.
fn select(&self, server_name: &str) -> Option<&TlsConfig>;
fn select(&self, server_name: &str) -> Option<Arc<TlsConfig>>;
}

#[derive(Default)]
Expand Down Expand Up @@ -965,7 +975,7 @@ pub(crate) mod tests {
}

pub struct ServerConfigSelector {
hash_map: HashMap<String, TlsConfig>,
hash_map: HashMap<String, Arc<TlsConfig>>,
}

impl ServerConfigSelector {
Expand All @@ -985,7 +995,9 @@ pub(crate) mod tests {
cert,
keys[index],
)?;
cert_manager.hash_map.insert(index.to_string(), tls_config);
cert_manager
.hash_map
.insert(index.to_string(), tls_config.into());
}

Ok(cert_manager)
Expand All @@ -997,12 +1009,12 @@ pub(crate) mod tests {
}

impl TlsConfigSelector for ServerConfigSelector {
fn get_default(&self) -> Option<&TlsConfig> {
self.hash_map.get("0")
fn get_default(&self) -> Option<Arc<TlsConfig>> {
self.select("0")
}

fn select(&self, server_name: &str) -> Option<&TlsConfig> {
self.hash_map.get(server_name)
fn select(&self, server_name: &str) -> Option<Arc<TlsConfig>> {
self.hash_map.get(server_name).cloned()
}
}

Expand Down

0 comments on commit 4830988

Please sign in to comment.