diff --git a/include/tquic.h b/include/tquic.h index b30e237c4..b5facc81d 100644 --- a/include/tquic.h +++ b/include/tquic.h @@ -108,8 +108,8 @@ typedef struct http3_conn_t http3_conn_t; typedef struct http3_headers_t http3_headers_t; typedef struct quic_tls_config_select_methods_t { - const SSL_CTX *(*get_default)(void *ctx); - const SSL_CTX *(*select)(void *ctx, const uint8_t *server_name, size_t server_name_len); + SSL_CTX *(*get_default)(void *ctx); + SSL_CTX *(*select)(void *ctx, const uint8_t *server_name, size_t server_name_len); } quic_tls_config_select_methods_t; typedef void *quic_tls_config_select_context_t; @@ -416,6 +416,12 @@ void quic_config_set_tls_selector(struct quic_config_t *config, const struct quic_tls_config_select_methods_t *methods, quic_tls_config_select_context_t context); +/** + * Set TLS config. + * The caller is responsible for the memory of SSL_CTX when use this function. + */ +void quic_config_set_tls_config(struct quic_config_t *config, SSL_CTX *ssl_ctx); + /** * Create a QUIC endpoint. * diff --git a/src/endpoint.rs b/src/endpoint.rs index 1448aa2c2..87e56b97e 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -2299,7 +2299,7 @@ mod tests { true, )?; tls_config.set_ticket_key(&vec![0x01; 48])?; - srv_conf.set_tls_config_selector(Arc::new(tls_config)); + srv_conf.set_tls_config(tls_config); let mut case_conf = CaseConf::default(); case_conf.session = Some(TestPair::new_test_session_state()); diff --git a/src/ffi.rs b/src/ffi.rs index ed9234ea0..0c076813f 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -316,6 +316,17 @@ pub extern "C" fn quic_config_set_tls_selector( config.set_tls_config_selector(Arc::new(selector)); } +/// Set TLS config. +/// The caller is responsible for the memory of SSL_CTX when use this function. +#[no_mangle] +pub extern "C" fn quic_config_set_tls_config( + config: &mut Config, + ssl_ctx: *mut crate::tls::SslCtx, +) { + let tls_config = crate::tls::TlsConfig::new_with_ssl_ctx(ssl_ctx); + config.set_tls_config(tls_config); +} + /// Create a QUIC endpoint. /// /// The caller is responsible for the memory of the Endpoint and properly @@ -949,12 +960,12 @@ pub struct TlsConfigSelectorContext(*mut c_void); #[repr(C)] pub struct TlsConfigSelectMethods { - pub get_default: fn(ctx: *mut c_void) -> *const crate::tls::SslCtx, + pub get_default: fn(ctx: *mut c_void) -> *mut crate::tls::SslCtx, pub select: fn( ctx: *mut c_void, server_name: *const u8, server_name_len: size_t, - ) -> *const crate::tls::SslCtx, + ) -> *mut crate::tls::SslCtx, } #[repr(C)] @@ -967,17 +978,17 @@ unsafe impl Send for TlsConfigSelector {} unsafe impl Sync for TlsConfigSelector {} impl crate::tls::TlsConfigSelector for TlsConfigSelector { - fn get_default(&self) -> Option<&crate::tls::TlsConfig> { + fn get_default(&self) -> Option> { let ssl_ctx = unsafe { ((*self.methods).get_default)(self.context.0) }; if ssl_ctx.is_null() { return None; } - let tls_config = unsafe { &(*(ssl_ctx as *const crate::tls::TlsConfig)) }; + let tls_config = Arc::new(crate::tls::TlsConfig::new_with_ssl_ctx(ssl_ctx)); Some(tls_config) } - fn select(&self, server_name: &str) -> Option<&crate::tls::TlsConfig> { + fn select(&self, server_name: &str) -> Option> { let ssl_ctx = unsafe { ((*self.methods).select)( self.context.0, @@ -989,7 +1000,7 @@ impl crate::tls::TlsConfigSelector for TlsConfigSelector { return None; } - let tls_config = unsafe { &(*(ssl_ctx as *const crate::tls::TlsConfig)) }; + let tls_config = Arc::new(crate::tls::TlsConfig::new_with_ssl_ctx(ssl_ctx)); Some(tls_config) } } diff --git a/src/lib.rs b/src/lib.rs index 86fe72a7c..b1e2ab1a2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -584,7 +584,9 @@ impl Config { /// Set TLS config. pub fn set_tls_config(&mut self, tls_config: tls::TlsConfig) { - self.set_tls_config_selector(Arc::new(tls_config)); + self.set_tls_config_selector(Arc::new(tls::DefaultTlsConfigSelector { + tls_config: Arc::new(tls_config), + })); } /// Set TLS config selector. Used for selecting TLS config according to SNI. diff --git a/src/tls/boringssl/tls.rs b/src/tls/boringssl/tls.rs index 7a5a9032b..a3f2dc8d2 100644 --- a/src/tls/boringssl/tls.rs +++ b/src/tls/boringssl/tls.rs @@ -127,16 +127,29 @@ static QUICHE_STREAM_METHOD: SslQuicMethod = SslQuicMethod { /// Rust wrapper of SSL_CTX which holds various configuration and data relevant /// to SSL/TLS session establishment. -#[repr(transparent)] -pub struct Context(*mut SslCtx); +pub struct Context { + ctx_raw: *mut SslCtx, + owned: bool, +} + +impl Drop for Context { + fn drop(&mut self) { + if self.owned { + unsafe { SSL_CTX_free(self.as_mut_ptr()) } + } + } +} -#[cfg(not(feature = "ffi"))] impl Context { + /// Create a new TLS context. pub fn new() -> Result { unsafe { let ctx_raw = SSL_CTX_new(TLS_method()); - let mut ctx = Context(ctx_raw); + let mut ctx = Context { + ctx_raw, + owned: true, + }; ctx.set_session_callback(); ctx.set_default_verify_paths()?; @@ -144,39 +157,26 @@ impl Context { } } - /// Return the mutable pointer of the inner SSL_CTX. - fn as_mut_ptr(&mut self) -> *mut SslCtx { - self.0 - } - - /// Return the const pointer of the inner SSL_CTX. - fn as_ptr(&self) -> *const SslCtx { - self.0 - } -} - -// The caller is responsible for the memory of SSL_CTX when using ffi. -#[cfg(not(feature = "ffi"))] -impl Drop for Context { - fn drop(&mut self) { - unsafe { SSL_CTX_free(self.as_mut_ptr()) } + /// Create a new TLS context with SSL_CTX. + /// The caller is responsible for the memory of SSL_CTX when use this function. + pub fn new_with_ssl_ctx(ssl_ctx: *mut SslCtx) -> Context { + Self { + ctx_raw: ssl_ctx, + owned: false, + } } -} -#[cfg(feature = "ffi")] -impl Context { /// Return the mutable pointer of the inner SSL_CTX. fn as_mut_ptr(&mut self) -> *mut SslCtx { - self as *mut Context as *mut SslCtx + self.ctx_raw } /// Return the const pointer of the inner SSL_CTX. fn as_ptr(&self) -> *const SslCtx { - self as *const Context as *const SslCtx + self.ctx_raw } -} -impl Context { + /// Create a new TLS session. pub fn new_session(&self) -> Result { unsafe { let ssl = SSL_new(self.as_ptr()); diff --git a/src/tls/tls.rs b/src/tls/tls.rs index 8f4f1b2a2..f6c4f4b85 100644 --- a/src/tls/tls.rs +++ b/src/tls/tls.rs @@ -68,13 +68,11 @@ where } } -#[repr(transparent)] pub struct TlsConfig { /// Boringssl SSL context. tls_ctx: boringssl::tls::Context, } -#[cfg(not(feature = "ffi"))] impl TlsConfig { /// Create a new TlsConfig. pub fn new() -> Result { @@ -84,6 +82,14 @@ impl TlsConfig { Ok(TlsConfig { tls_ctx }) } + /// Create a new TlsConfig with SSL_CTX. + /// The caller is responsible for the memory of SSL_CTX when use this function. + pub fn new_with_ssl_ctx(ssl_ctx: *mut boringssl::tls::SslCtx) -> TlsConfig { + Self { + tls_ctx: boringssl::tls::Context::new_with_ssl_ctx(ssl_ctx), + } + } + /// Create a new client side TlsConfig. pub fn new_client_config( application_protos: Vec>, @@ -188,25 +194,29 @@ impl TlsConfig { } } -impl TlsConfigSelector for TlsConfig { +pub(crate) struct DefaultTlsConfigSelector { + pub tls_config: Arc, +} + +impl TlsConfigSelector for DefaultTlsConfigSelector { // TODO: support local and peer address. /// Get default TLS config. - fn get_default(&self) -> Option<&TlsConfig> { - Some(self) + fn get_default(&self) -> Option> { + Some(self.tls_config.clone()) } /// Find TLS config according to server name. - fn select(&self, _server_name: &str) -> Option<&TlsConfig> { - Some(self) + fn select(&self, _server_name: &str) -> Option> { + Some(self.tls_config.clone()) } } pub trait TlsConfigSelector: Send + Sync { /// Get default TLS config. - fn get_default(&self) -> Option<&TlsConfig>; + fn get_default(&self) -> Option>; /// Find TLS config according to server name. - fn select(&self, server_name: &str) -> Option<&TlsConfig>; + fn select(&self, server_name: &str) -> Option>; } #[derive(Default)] @@ -965,7 +975,7 @@ pub(crate) mod tests { } pub struct ServerConfigSelector { - hash_map: HashMap, + hash_map: HashMap>, } impl ServerConfigSelector { @@ -985,7 +995,9 @@ pub(crate) mod tests { cert, keys[index], )?; - cert_manager.hash_map.insert(index.to_string(), tls_config); + cert_manager + .hash_map + .insert(index.to_string(), tls_config.into()); } Ok(cert_manager) @@ -997,12 +1009,12 @@ pub(crate) mod tests { } impl TlsConfigSelector for ServerConfigSelector { - fn get_default(&self) -> Option<&TlsConfig> { - self.hash_map.get("0") + fn get_default(&self) -> Option> { + self.select("0") } - fn select(&self, server_name: &str) -> Option<&TlsConfig> { - self.hash_map.get(server_name) + fn select(&self, server_name: &str) -> Option> { + self.hash_map.get(server_name).cloned() } }