diff --git a/Cargo.lock b/Cargo.lock index bb12571684..47873fde4f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -745,12 +745,6 @@ dependencies = [ "rustc-demangle", ] -[[package]] -name = "base64" -version = "0.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" - [[package]] name = "base64" version = "0.21.7" @@ -2070,8 +2064,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi", + "wasm-bindgen", ] [[package]] @@ -2149,6 +2145,7 @@ checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" dependencies = [ "ahash", "allocator-api2", + "serde", ] [[package]] @@ -2166,7 +2163,7 @@ version = "7.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "765c9198f173dd59ce26ff9f95ef0aafd0a0fe01fb9d72841bc5066a4c06511d" dependencies = [ - "base64 0.21.7", + "base64", "byteorder", "flate2", "nom", @@ -2179,7 +2176,7 @@ version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06683b93020a07e3dbcf5f8c0f6d40080d725bea7936fc01ad345c01b97dc270" dependencies = [ - "base64 0.21.7", + "base64", "bytes", "headers-core", "http", @@ -2613,13 +2610,14 @@ dependencies = [ [[package]] name = "jsonwebtoken" -version = "8.3.0" +version = "9.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6971da4d9c3aa03c3d8f3ff0f4155b534aad021292003895a469716b2a230378" +checksum = "5c7ea04a7c5c055c175f189b6dc6ba036fd62306b58c66c9f6389036c503a3f4" dependencies = [ - "base64 0.21.7", + "base64", + "js-sys", "pem", - "ring 0.16.20", + "ring 0.17.8", "serde", "serde_json", "simple_asn1", @@ -2683,7 +2681,7 @@ dependencies = [ "anyhow", "async-stream", "async-trait", - "base64 0.21.7", + "base64", "bincode", "bitflags 2.4.2", "bytes", @@ -2725,7 +2723,7 @@ checksum = "9c7b1c078b4d3d45ba0db91accc23dcb8d2761d67f819efd94293065597b7ac8" dependencies = [ "anyhow", "async-trait", - "base64 0.21.7", + "base64", "num-traits", "reqwest", "serde_json", @@ -2780,7 +2778,7 @@ dependencies = [ "aws-sdk-s3", "axum", "axum-extra", - "base64 0.21.7", + "base64", "bincode", "bottomless", "bytes", @@ -2825,6 +2823,7 @@ dependencies = [ "rand", "regex", "reqwest", + "ring 0.17.8", "rustls 0.21.10", "rustls-pemfile", "s3s", @@ -2858,7 +2857,7 @@ name = "libsql-shell" version = "0.1.1" dependencies = [ "anyhow", - "base64 0.21.7", + "base64", "clap 4.4.18", "home", "libsql-rusqlite", @@ -2891,7 +2890,7 @@ dependencies = [ name = "libsql-sys" version = "0.3.0" dependencies = [ - "base64 0.21.7", + "base64", "bytes", "libsql-ffi", "libsql-rusqlite", @@ -3099,7 +3098,7 @@ version = "0.12.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d4fa7ce7c4862db464a37b0b31d89bca874562f034bd7993895572783d02950" dependencies = [ - "base64 0.21.7", + "base64", "hyper", "indexmap 1.9.3", "ipnet", @@ -3467,11 +3466,12 @@ checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" [[package]] name = "pem" -version = "1.1.1" +version = "3.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8835c273a76a90455d7344889b0964598e3316e2a79ede8e36f16bdcf2228b8" +checksum = "1b8fcc794035347fb64beda2d3b462595dd2753e3f268d89c5aae77e8cf2c310" dependencies = [ - "base64 0.13.1", + "base64", + "serde", ] [[package]] @@ -4038,7 +4038,7 @@ version = "0.11.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c6920094eb85afde5e4a138be3f2de8bbdf28000f0029e72c45025a56b042251" dependencies = [ - "base64 0.21.7", + "base64", "bytes", "encoding_rs", "futures-core", @@ -4099,16 +4099,17 @@ dependencies = [ [[package]] name = "ring" -version = "0.17.7" +version = "0.17.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "688c63d65483050968b2a8937f7995f443e27041a0f7700aa59b0822aedebb74" +checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" dependencies = [ "cc", + "cfg-if", "getrandom", "libc", "spin 0.9.8", "untrusted 0.9.0", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -4180,7 +4181,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba" dependencies = [ "log", - "ring 0.17.7", + "ring 0.17.8", "rustls-webpki", "sct", ] @@ -4203,7 +4204,7 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" dependencies = [ - "base64 0.21.7", + "base64", ] [[package]] @@ -4212,7 +4213,7 @@ version = "0.101.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" dependencies = [ - "ring 0.17.7", + "ring 0.17.8", "untrusted 0.9.0", ] @@ -4384,7 +4385,7 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" dependencies = [ - "ring 0.17.7", + "ring 0.17.8", "untrusted 0.9.0", ] @@ -5057,7 +5058,7 @@ dependencies = [ "async-stream", "async-trait", "axum", - "base64 0.21.7", + "base64", "bytes", "h2", "http", @@ -5099,7 +5100,7 @@ version = "0.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fddb2a37b247e6adcb9f239f4e5cefdcc5ed526141a416b943929f13aea2cce" dependencies = [ - "base64 0.21.7", + "base64", "bytes", "http", "http-body", @@ -5697,7 +5698,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6107809b2d9f5b2fd3ddbaddb3bb92ff8048b62f4030debf1408119ffd38c6cb" dependencies = [ "anyhow", - "base64 0.21.7", + "base64", "bincode", "directories-next", "file-per-thread-logger", @@ -5958,7 +5959,7 @@ version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed63aea5ce73d0ff405984102c42de94fc55a6b75765d621c65262469b3c9b53" dependencies = [ - "ring 0.17.7", + "ring 0.17.8", "untrusted 0.9.0", ] diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index 570215f8f1..14e3b02fa3 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -131315,9 +131315,9 @@ int libsql_try_initialize_wasm_func_table(sqlite3 *db) { sqlite3_finalize(stmt); return rc; } - const char *pName = sqlite3_column_text(stmt, 0); + const unsigned char *pName = sqlite3_column_text(stmt, 0); const void *pBody = body_type == SQLITE_TEXT ? sqlite3_column_text(stmt, 1) : sqlite3_column_blob(stmt, 1); - try_instantiate_wasm_function(db, pName, name_size, pBody, body_size, -1, NULL); + try_instantiate_wasm_function(db, (const char *)pName, name_size, pBody, body_size, -1, NULL); } } sqlite3_finalize(stmt); diff --git a/libsql-server/Cargo.toml b/libsql-server/Cargo.toml index 6710cb594b..67495c3b58 100644 --- a/libsql-server/Cargo.toml +++ b/libsql-server/Cargo.toml @@ -34,7 +34,7 @@ hyper = { version = "0.14.23", features = ["http2"] } hyper-rustls = { git = "https://github.com/rustls/hyper-rustls.git", rev = "163b3f5" } hyper-tungstenite = "0.11" itertools = "0.10.5" -jsonwebtoken = "8.2.0" +jsonwebtoken = "9" libsql = { path = "../libsql/", optional = true } libsql_replication = { path = "../libsql-replication" } metrics = "0.21.1" @@ -80,7 +80,7 @@ uuid = { version = "1.3", features = ["v4", "serde"] } aes = { version = "0.8.3", optional = true } cbc = { version = "0.1.2", optional = true } zerocopy = { version = "0.7.28", features = ["derive", "alloc"] } -hashbrown = "0.14.3" +hashbrown = { version = "0.14.3", features = ["serde"] } [dev-dependencies] arbitrary = { version = "1.3.0", features = ["derive_arbitrary"] } @@ -99,6 +99,7 @@ url = "2.3" metrics-util = "0.15" s3s = "0.8.1" s3s-fs = "0.8.1" +ring = { version = "0.17.8", features = ["std"] } [build-dependencies] vergen = { version = "8", features = ["build", "git", "gitcl"] } diff --git a/libsql-server/src/auth/authenticated.rs b/libsql-server/src/auth/authenticated.rs index e606032e36..62b640cac7 100644 --- a/libsql-server/src/auth/authenticated.rs +++ b/libsql-server/src/auth/authenticated.rs @@ -1,34 +1,23 @@ -use crate::auth::{constants::GRPC_PROXY_AUTH_HEADER, Authorized, Permission}; +use std::sync::Arc; + +use crate::auth::{constants::GRPC_PROXY_AUTH_HEADER, Authorized}; use crate::namespace::NamespaceName; -use libsql_replication::rpc::replication::NAMESPACE_METADATA_KEY; use tonic::Status; +use super::authorized::Scope; +use super::Permission; + /// A witness that the user has been authenticated. #[non_exhaustive] -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Debug, serde::Serialize, serde::Deserialize, Clone)] pub enum Authenticated { Anonymous, - Authorized(Authorized), + Authorized(Arc), + FullAccess, } impl Authenticated { - pub fn from_proxy_grpc_request( - req: &tonic::Request, - disable_namespace: bool, - ) -> Result { - let namespace = if disable_namespace { - None - } else { - req.metadata() - .get_bin(NAMESPACE_METADATA_KEY) - .map(|c| c.to_bytes()) - .transpose() - .map_err(|_| Status::invalid_argument("failed to parse namespace header"))? - .map(NamespaceName::from_bytes) - .transpose() - .map_err(|_| Status::invalid_argument("invalid namespace name"))? - }; - + pub fn from_proxy_grpc_request(req: &tonic::Request) -> Result { let auth = match req .metadata() .get(GRPC_PROXY_AUTH_HEADER) @@ -36,21 +25,7 @@ impl Authenticated { .transpose() .map_err(|_| Status::invalid_argument("missing authorization header"))? { - Some("full_access") => Authenticated::Authorized(Authorized { - namespace, - permission: Permission::FullAccess, - }), - Some("read_only") => Authenticated::Authorized(Authorized { - namespace, - permission: Permission::ReadOnly, - }), - Some("anonymous") => Authenticated::Anonymous, - Some(level) => { - return Err(Status::permission_denied(format!( - "invalid authorization level: {}", - level - ))) - } + Some(s) => serde_json::from_str::(s).unwrap(), None => return Err(Status::invalid_argument("x-proxy-authorization not set")), }; @@ -60,18 +35,7 @@ impl Authenticated { pub fn upgrade_grpc_request(&self, req: &mut tonic::Request) { let key = tonic::metadata::AsciiMetadataKey::from_static(GRPC_PROXY_AUTH_HEADER); - let auth = match self { - Authenticated::Anonymous => "anonymous", - Authenticated::Authorized(Authorized { - permission: Permission::FullAccess, - .. - }) => "full_access", - Authenticated::Authorized(Authorized { - permission: Permission::ReadOnly, - .. - }) => "read_only", - }; - + let auth = serde_json::to_string(self).unwrap(); let value = tonic::metadata::AsciiMetadataValue::try_from(auth).unwrap(); req.metadata_mut().insert(key, value); @@ -80,14 +44,31 @@ impl Authenticated { pub fn is_namespace_authorized(&self, namespace: &NamespaceName) -> bool { match self { Authenticated::Anonymous => false, - Authenticated::Authorized(Authorized { - namespace: Some(ns), - .. - }) => ns == namespace, - // we threat the absence of a specific namespace has a permission to any namespace - Authenticated::Authorized(Authorized { - namespace: None, .. - }) => true, + Authenticated::Authorized(auth) => { + auth.has_right(Scope::Namespace(namespace.clone()), Permission::Read) + } + Authenticated::FullAccess => true, + } + } + + pub(crate) fn has_right( + &self, + namespace: &NamespaceName, + perm: Permission, + ) -> crate::Result<()> { + match self { + Authenticated::Anonymous => Err(crate::Error::NotAuthorized( + "anonymous access not allowed".to_string(), + )), + Authenticated::Authorized(a) => { + if !a.has_right(Scope::Namespace(namespace.clone()), perm) { + Err(crate::Error::NotAuthorized(format!( + "Current session doest not have {perm:?} permission to namespace {namespace}"))) + } else { + Ok(()) + } + } + Authenticated::FullAccess => Ok(()), } } } diff --git a/libsql-server/src/auth/authorized.rs b/libsql-server/src/auth/authorized.rs index 2cae84bc28..398bcb6bed 100644 --- a/libsql-server/src/auth/authorized.rs +++ b/libsql-server/src/auth/authorized.rs @@ -1,8 +1,162 @@ -use crate::auth::Permission; +use hashbrown::HashSet; + use crate::namespace::NamespaceName; -#[derive(Clone, Debug, PartialEq, Eq)] +use super::Permission; + +#[derive(Debug, serde::Deserialize, serde::Serialize, Default)] pub struct Authorized { - pub namespace: Option, - pub permission: Permission, + #[serde(rename = "ro", default)] + pub read_only: Option, + #[serde(rename = "rw", default)] + pub read_write: Option, + #[serde(rename = "roa", default)] + pub read_only_attach: Option, + #[serde(rename = "rwa", default)] + pub read_write_attach: Option, +} + +#[derive(Debug, PartialEq, Eq)] +pub enum Scope { + Namespace(NamespaceName), +} + +impl Authorized { + pub fn has_right(&self, scope: Scope, perm: Permission) -> bool { + match (perm, scope) { + (Permission::Read, Scope::Namespace(ref name)) => self.can_read_ns(name), + (Permission::Write, Scope::Namespace(ref name)) => self.can_write_ns(name), + (Permission::AttachRead, Scope::Namespace(ref name)) => self.can_attach_ns(name), + } + } + + pub fn merge_legacy(&mut self, namespace: NamespaceName, perm: Permission) { + let scope = match perm { + Permission::Read => self.read_only.get_or_insert_with(Default::default), + Permission::Write => self.read_write.get_or_insert_with(Default::default), + Permission::AttachRead => self.read_only_attach.get_or_insert_with(Default::default), + }; + + scope + .namespaces + .get_or_insert_with(Default::default) + .insert(namespace); + } + + fn can_write_ns(&self, name: &NamespaceName) -> bool { + if let Some(Scopes { + namespaces: Some(ref ns), + .. + }) = self.read_write + { + if ns.contains(name) { + return true; + } + } + + if let Some(Scopes { + namespaces: Some(ref ns), + .. + }) = self.read_write_attach + { + if ns.contains(name) { + return true; + } + } + + false + } + + fn can_read_ns(&self, name: &NamespaceName) -> bool { + if self.can_write_ns(name) { + return true; + } + + if let Some(Scopes { + namespaces: Some(ref ns), + .. + }) = self.read_only + { + if ns.contains(name) { + return true; + } + } + + if let Some(Scopes { + namespaces: Some(ref ns), + .. + }) = self.read_only_attach + { + if ns.contains(name) { + return true; + } + } + + false + } + + #[cfg(test)] + pub fn perms_iter(&self) -> impl Iterator + '_ { + macro_rules! perm_iter { + ($field:ident, $perm:expr) => { + self.$field + .as_ref() + .map(|s| s.iter()) + .into_iter() + .flatten() + .zip(std::iter::repeat($perm)) + }; + } + + let ro_iter = perm_iter!(read_only, Permission::Read); + let rw_iter = perm_iter!(read_write, Permission::Write); + let ro_attach_iter = perm_iter!(read_only_attach, Permission::AttachRead); + let rw_attach_iter = perm_iter!(read_write_attach, Permission::AttachRead); + + ro_iter + .chain(rw_iter) + .chain(ro_attach_iter) + .chain(rw_attach_iter) + } + + fn can_attach_ns(&self, name: &NamespaceName) -> bool { + if let Some(Scopes { + namespaces: Some(ref ns), + .. + }) = self.read_only_attach + { + if ns.contains(name) { + return true; + } + } + + if let Some(Scopes { + namespaces: Some(ref ns), + .. + }) = self.read_write_attach + { + if ns.contains(name) { + return true; + } + } + + false + } +} + +#[derive(Debug, serde::Deserialize, serde::Serialize, Default)] +pub struct Scopes { + #[serde(rename = "ns", default)] + pub namespaces: Option>, +} + +impl Scopes { + #[cfg(test)] + fn iter(&self) -> impl Iterator + '_ { + self.namespaces + .as_ref() + .map(|nss| nss.iter().cloned().map(|ns| Scope::Namespace(ns))) + .into_iter() + .flatten() + } } diff --git a/libsql-server/src/auth/parsers.rs b/libsql-server/src/auth/parsers.rs index eb40b9c009..e09d930eae 100644 --- a/libsql-server/src/auth/parsers.rs +++ b/libsql-server/src/auth/parsers.rs @@ -30,7 +30,7 @@ pub fn parse_jwt_key(data: &str) -> Result { bail!("Key is in unsupported PEM format") } else { jsonwebtoken::DecodingKey::from_ed_components(data) - .context("Could not decode Ed25519 public key from base64") + .map_err(|e| anyhow::anyhow!("Could not decode Ed25519 public key from base64: {e}")) } } diff --git a/libsql-server/src/auth/permission.rs b/libsql-server/src/auth/permission.rs index ab9c66c786..ec33865ac7 100644 --- a/libsql-server/src/auth/permission.rs +++ b/libsql-server/src/auth/permission.rs @@ -1,6 +1,10 @@ +#[derive(Debug, Clone, Copy, serde::Deserialize, serde::Serialize, PartialEq, Eq)] #[non_exhaustive] -#[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Permission { - FullAccess, - ReadOnly, + #[serde(rename = "ro")] + Read, + #[serde(rename = "rw")] + Write, + #[serde(rename = "roa")] + AttachRead, } diff --git a/libsql-server/src/auth/user_auth_strategies/disabled.rs b/libsql-server/src/auth/user_auth_strategies/disabled.rs index bc0c5d3dd7..8ffa5e7028 100644 --- a/libsql-server/src/auth/user_auth_strategies/disabled.rs +++ b/libsql-server/src/auth/user_auth_strategies/disabled.rs @@ -1,16 +1,12 @@ use super::{UserAuthContext, UserAuthStrategy}; -use crate::auth::{AuthError, Authenticated, Authorized, Permission}; +use crate::auth::{AuthError, Authenticated}; pub struct Disabled {} impl UserAuthStrategy for Disabled { fn authenticate(&self, _context: UserAuthContext) -> Result { tracing::trace!("executing disabled auth"); - - Ok(Authenticated::Authorized(Authorized { - namespace: None, - permission: Permission::FullAccess, - })) + Ok(Authenticated::FullAccess) } } @@ -22,25 +18,18 @@ impl Disabled { #[cfg(test)] mod tests { - use crate::namespace::NamespaceName; - use super::*; #[test] fn authenticates() { let strategy = Disabled::new(); let context = UserAuthContext { - namespace: NamespaceName::default(), - namespace_credential: None, user_credential: None, }; - assert_eq!( + assert!(matches!( strategy.authenticate(context).unwrap(), - Authenticated::Authorized(Authorized { - namespace: None, - permission: Permission::FullAccess, - }) - ) + Authenticated::FullAccess + )) } } diff --git a/libsql-server/src/auth/user_auth_strategies/http_basic.rs b/libsql-server/src/auth/user_auth_strategies/http_basic.rs index 3bd05c750b..01ab33422a 100644 --- a/libsql-server/src/auth/user_auth_strategies/http_basic.rs +++ b/libsql-server/src/auth/user_auth_strategies/http_basic.rs @@ -1,4 +1,4 @@ -use crate::auth::{parse_http_auth_header, AuthError, Authenticated, Authorized, Permission}; +use crate::auth::{parse_http_auth_header, AuthError, Authenticated}; use super::{UserAuthContext, UserAuthStrategy}; @@ -18,10 +18,7 @@ impl UserAuthStrategy for HttpBasic { let expected_value = self.credential.trim_end_matches('='); if actual_value == expected_value { - return Ok(Authenticated::Authorized(Authorized { - namespace: None, - permission: Permission::FullAccess, - })); + return Ok(Authenticated::FullAccess); } Err(AuthError::BasicRejected) @@ -38,8 +35,6 @@ impl HttpBasic { mod tests { use axum::http::HeaderValue; - use crate::namespace::NamespaceName; - use super::*; const CREDENTIAL: &str = "d29qdGVrOnRoZWJlYXI="; @@ -51,18 +46,13 @@ mod tests { #[test] fn authenticates_with_valid_credential() { let context = UserAuthContext { - namespace: NamespaceName::default(), - namespace_credential: None, user_credential: HeaderValue::from_str(&format!("Basic {CREDENTIAL}")).ok(), }; - assert_eq!( + assert!(matches!( strategy().authenticate(context).unwrap(), - Authenticated::Authorized(Authorized { - namespace: None, - permission: Permission::FullAccess, - }) - ) + Authenticated::FullAccess + )) } #[test] @@ -70,25 +60,18 @@ mod tests { let credential = CREDENTIAL.trim_end_matches('='); let context = UserAuthContext { - namespace: NamespaceName::default(), - namespace_credential: None, user_credential: HeaderValue::from_str(&format!("Basic {credential}")).ok(), }; - assert_eq!( + assert!(matches!( strategy().authenticate(context).unwrap(), - Authenticated::Authorized(Authorized { - namespace: None, - permission: Permission::FullAccess, - }) - ) + Authenticated::FullAccess + )) } #[test] fn errors_when_credentials_do_not_match() { let context = UserAuthContext { - namespace: NamespaceName::default(), - namespace_credential: None, user_credential: HeaderValue::from_str("Basic abc").ok(), }; diff --git a/libsql-server/src/auth/user_auth_strategies/jwt.rs b/libsql-server/src/auth/user_auth_strategies/jwt.rs index 5e37f26418..e85bde43c4 100644 --- a/libsql-server/src/auth/user_auth_strategies/jwt.rs +++ b/libsql-server/src/auth/user_auth_strategies/jwt.rs @@ -1,3 +1,5 @@ +use chrono::{DateTime, Utc}; + use crate::{ auth::{parse_http_auth_header, AuthError, Authenticated, Authorized, Permission}, namespace::NamespaceName, @@ -12,15 +14,8 @@ pub struct Jwt { impl UserAuthStrategy for Jwt { fn authenticate(&self, context: UserAuthContext) -> Result { tracing::trace!("executing jwt auth"); - let param = parse_http_auth_header("bearer", &context.user_credential)?; - - let jwt_key = match context.namespace_credential.as_ref() { - Some(jwt_key) => jwt_key, - None => &self.key, - }; - - validate_jwt(jwt_key, param, context.namespace) + validate_jwt(&self.key, param) } } @@ -30,41 +25,69 @@ impl Jwt { } } +#[derive(serde::Deserialize, serde::Serialize, Debug)] +struct Token { + #[serde(default)] + id: Option, + #[serde(default)] + a: Option, + #[serde(default)] + p: Option, + #[serde(with = "jwt_time", default)] + exp: Option>, +} + +mod jwt_time { + use chrono::{DateTime, Utc}; + use serde::{de::Error, Deserialize, Deserializer, Serializer}; + + pub fn serialize(date: &Option>, serializer: S) -> Result + where + S: Serializer, + { + match date { + Some(date) => serializer.serialize_i64(date.timestamp()), + None => serializer.serialize_none(), + } + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result>, D::Error> + where + D: Deserializer<'de>, + { + Option::::deserialize(deserializer)? + .map(|x| { + DateTime::from_timestamp(x, 0).ok_or_else(|| D::Error::custom("invalid exp claim")) + }) + .transpose() + } +} + fn validate_jwt( jwt_key: &jsonwebtoken::DecodingKey, jwt: &str, - namespace: NamespaceName, ) -> Result { use jsonwebtoken::errors::ErrorKind; let mut validation = jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::EdDSA); validation.required_spec_claims.remove("exp"); - match jsonwebtoken::decode::(jwt, jwt_key, &validation).map(|t| t.claims) { - Ok(serde_json::Value::Object(claims)) => { - tracing::trace!("Claims: {claims:#?}"); - let namespace = if namespace == NamespaceName::default() { - None + match jsonwebtoken::decode::(jwt, jwt_key, &validation).map(|t| t.claims) { + Ok(Token { id, a, p, .. }) => { + // This is legacy: when nothing is specified, then it's full access + if p.is_none() && id.is_none() && a.is_none() { + return Ok(Authenticated::FullAccess); } else { - claims - .get("id") - .and_then(|ns| NamespaceName::from_string(ns.as_str()?.into()).ok()) - }; - - let permission = match claims.get("a").and_then(|s| s.as_str()) { - Some("ro") => Permission::ReadOnly, - Some("rw") => Permission::FullAccess, - Some(_) => return Ok(Authenticated::Anonymous), - // Backward compatibility - no access claim means full access - None => Permission::FullAccess, - }; - - Ok(Authenticated::Authorized(Authorized { - namespace, - permission, - })) + let mut auth = p.unwrap_or_default(); + + // We only allow tokens if they contains a ns and a perm + if let Some((ns, a)) = id.zip(a) { + auth.merge_legacy(ns, a); + } + + Ok(Authenticated::Authorized(auth.into())) + } } - Ok(_) => Err(AuthError::JwtInvalid), Err(error) => Err(match error.kind() { ErrorKind::InvalidToken | ErrorKind::InvalidSignature @@ -81,71 +104,162 @@ fn validate_jwt( #[cfg(test)] mod tests { + use std::time::Duration; + use axum::http::HeaderValue; + use jsonwebtoken::{DecodingKey, EncodingKey}; + use ring::signature::{Ed25519KeyPair, KeyPair}; + use serde::Serialize; - use crate::auth::parse_jwt_key; + use crate::auth::authorized::Scope; use super::*; - const KEY: &str = "zaMv-aFGmB7PXkjM4IrMdF6B5zCYEiEGXW3RgMjNAtc"; + fn strategy(dec: jsonwebtoken::DecodingKey) -> Jwt { + Jwt::new(dec) + } + + fn key_pair() -> (jsonwebtoken::EncodingKey, jsonwebtoken::DecodingKey) { + let doc = Ed25519KeyPair::generate_pkcs8(&ring::rand::SystemRandom::new()).unwrap(); + let encoding_key = EncodingKey::from_ed_der(doc.as_ref()); + let pair = Ed25519KeyPair::from_pkcs8(doc.as_ref()).unwrap(); + let decoding_key = DecodingKey::from_ed_der(pair.public_key().as_ref()); + (encoding_key, decoding_key) + } - fn strategy() -> Jwt { - Jwt::new(parse_jwt_key(KEY).unwrap()) + fn encode(claims: &T, key: &EncodingKey) -> String { + let header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::EdDSA); + jsonwebtoken::encode(&header, &claims, key).unwrap() } #[test] fn authenticates_valid_jwt_token_with_full_access() { - let token = "eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.\ - eyJleHAiOjc5ODg0ODM4Mjd9.\ - MatB2aLnPFusagqH2RMoVExP37o2GFLmaJbmd52OdLtAehRNeqeJZPrefP1t2GBFidApUTLlaBRL6poKq_s3CQ"; + // this is a full access token + let (enc, dec) = key_pair(); + let token = Token { + id: None, + a: None, + p: None, + exp: None, + }; + let token = encode(&token, &enc); let context = UserAuthContext { - namespace: NamespaceName::default(), - namespace_credential: None, user_credential: HeaderValue::from_str(&format!("Bearer {token}")).ok(), }; - assert_eq!( - strategy().authenticate(context).unwrap(), - Authenticated::Authorized(Authorized { - namespace: None, - permission: Permission::FullAccess, - }) - ) + assert!(matches!( + strategy(dec).authenticate(context).unwrap(), + Authenticated::FullAccess + )) } #[test] fn authenticates_valid_jwt_token_with_read_only_access() { - let token = "eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.\ - eyJleHAiOjc5ODg0ODM4MjcsImEiOiJybyJ9.\ - _2ZZiO2HC8b3CbCHSCufXXBmwpl-dLCv5O9Owvpy7LZ9aiQhXODpgV-iCdTsLQJ5FVanWhfn3FtJSnmWHn25DQ"; + let (enc, dec) = key_pair(); + let token = Token { + id: Some(NamespaceName::default()), + a: Some(Permission::Read), + p: None, + exp: None, + }; + let token = encode(&token, &enc); let context = UserAuthContext { - namespace: NamespaceName::default(), - namespace_credential: None, user_credential: HeaderValue::from_str(&format!("Bearer {token}")).ok(), }; + let Authenticated::Authorized(a) = strategy(dec).authenticate(context).unwrap() else { + panic!() + }; + + let mut perms = a.perms_iter(); assert_eq!( - strategy().authenticate(context).unwrap(), - Authenticated::Authorized(Authorized { - namespace: None, - permission: Permission::ReadOnly, - }) - ) + perms.next().unwrap(), + (Scope::Namespace(NamespaceName::default()), Permission::Read) + ); + assert!(perms.next().is_none()); } #[test] fn errors_when_jwt_token_invalid() { + let (_enc, dec) = key_pair(); let context = UserAuthContext { - namespace: NamespaceName::default(), - namespace_credential: None, user_credential: HeaderValue::from_str("Bearer abc").ok(), }; assert_eq!( - strategy().authenticate(context).unwrap_err(), + strategy(dec).authenticate(context).unwrap_err(), AuthError::JwtInvalid ) } + + #[test] + fn expired_token() { + let (enc, dec) = key_pair(); + let token = Token { + id: None, + a: None, + p: None, + exp: Some(Utc::now() - Duration::from_secs(5 * 60)), + }; + + let token = encode(&token, &enc); + + let context = UserAuthContext { + user_credential: HeaderValue::from_str(&format!("Bearer {token}")).ok(), + }; + + assert_eq!( + strategy(dec).authenticate(context).unwrap_err(), + AuthError::JwtExpired + ); + } + + #[test] + fn multi_scopes() { + let (enc, dec) = key_pair(); + let token = serde_json::json!({ + "id": "foobar", + "a": "ro", + "p": { + "rw": { "ns": ["foo"] }, + "roa": { "ns": ["bar"] } + } + }); + + let token = encode(&token, &enc); + + let context = UserAuthContext { + user_credential: HeaderValue::from_str(&format!("Bearer {token}")).ok(), + }; + + let Authenticated::Authorized(a) = strategy(dec).authenticate(context).unwrap() else { + panic!() + }; + + let mut perms = a.perms_iter(); + assert_eq!( + perms.next().unwrap(), + ( + Scope::Namespace(NamespaceName::from_string("foobar".into()).unwrap()), + Permission::Read + ) + ); + assert_eq!( + perms.next().unwrap(), + ( + Scope::Namespace(NamespaceName::from_string("foo".into()).unwrap()), + Permission::Write + ) + ); + assert_eq!( + perms.next().unwrap(), + ( + Scope::Namespace(NamespaceName::from_string("bar".into()).unwrap()), + Permission::AttachRead + ) + ); + assert!(perms.next().is_none()); + } } diff --git a/libsql-server/src/auth/user_auth_strategies/mod.rs b/libsql-server/src/auth/user_auth_strategies/mod.rs index afa19eb55f..f3b8a2e5b2 100644 --- a/libsql-server/src/auth/user_auth_strategies/mod.rs +++ b/libsql-server/src/auth/user_auth_strategies/mod.rs @@ -7,13 +7,9 @@ pub use disabled::*; pub use http_basic::*; pub use jwt::*; -use crate::namespace::NamespaceName; - use super::{AuthError, Authenticated}; pub struct UserAuthContext { - pub namespace: NamespaceName, - pub namespace_credential: Option, pub user_credential: Option, } diff --git a/libsql-server/src/connection/libsql.rs b/libsql-server/src/connection/libsql.rs index df33e28c04..21ee7d99da 100644 --- a/libsql-server/src/connection/libsql.rs +++ b/libsql-server/src/connection/libsql.rs @@ -13,7 +13,7 @@ use rusqlite::{DatabaseName, ErrorCode, OpenFlags, StatementStatus, TransactionS use tokio::sync::{watch, Notify}; use tokio::time::{Duration, Instant}; -use crate::auth::{Authenticated, Authorized, Permission}; +use crate::auth::Permission; use crate::connection::TXN_TIMEOUT; use crate::error::Error; use crate::metrics::{ @@ -29,7 +29,7 @@ use crate::stats::Stats; use crate::Result; use super::program::{Cond, DescribeCol, DescribeParam, DescribeResponse}; -use super::{MakeConnection, Program, Step}; +use super::{MakeConnection, Program, RequestContext, Step}; pub struct MakeLibSqlConn { db_path: PathBuf, @@ -804,17 +804,20 @@ impl Connection { let config = self.config_store.get(); let blocked = match query.stmt.kind { - StmtKind::Read | StmtKind::TxnBegin | StmtKind::Other => config.block_reads, + StmtKind::Read | StmtKind::TxnBegin => config.block_reads, StmtKind::Write => config.block_reads || config.block_writes, StmtKind::DDL => config.block_reads || config.block_writes || config.block_ddl(), - StmtKind::TxnEnd | StmtKind::Release | StmtKind::Savepoint => false, - StmtKind::Attach | StmtKind::Detach => !config.allow_attach, + StmtKind::TxnEnd + | StmtKind::Release + | StmtKind::Savepoint + | StmtKind::Detach + | StmtKind::Attach(_) => false, }; if blocked { return Err(Error::Blocked(config.block_reason.clone())); } - let mut stmt = if matches!(query.stmt.kind, StmtKind::Attach) { + let mut stmt = if matches!(query.stmt.kind, StmtKind::Attach(_)) { match &query.stmt.attach_info { Some((attached, attached_alias)) => { let query = self.prepare_attach_query(attached, attached_alias)?; @@ -1019,42 +1022,37 @@ fn eval_cond(cond: &Cond, results: &[bool], is_autocommit: bool) -> Result }) } -fn check_program_auth(auth: Authenticated, pgm: &Program) -> Result<()> { +fn check_program_auth(ctx: &RequestContext, pgm: &Program) -> Result<()> { for step in pgm.steps() { - let query = &step.query; - match (query.stmt.kind, &auth) { - (_, Authenticated::Anonymous) => { - return Err(Error::NotAuthorized( - "anonymous access not allowed".to_string(), - )); + match step.query.stmt.kind { + StmtKind::TxnBegin + | StmtKind::TxnEnd + | StmtKind::Read + | StmtKind::Savepoint + | StmtKind::Release => { + ctx.auth.has_right(&ctx.namespace, Permission::Read)?; + } + StmtKind::DDL | StmtKind::Write => { + ctx.auth.has_right(&ctx.namespace, Permission::Write)?; } - (StmtKind::Read, Authenticated::Authorized(_)) => (), - (StmtKind::TxnBegin, _) | (StmtKind::TxnEnd, _) => (), - ( - _, - Authenticated::Authorized(Authorized { - permission: Permission::FullAccess, - .. - }), - ) => (), - _ => { - return Err(Error::NotAuthorized(format!( - "Current session is not authorized to run: {}", - query.stmt.stmt - ))); + StmtKind::Attach(ref ns) => { + ctx.auth.has_right(ns, Permission::AttachRead)?; + if !ctx.meta_store.handle(ns.clone()).get().allow_attach { + return Err(Error::NotAuthorized(format!( + "Namespace `{ns}` doesn't allow attach" + ))); + } } + StmtKind::Detach => (), } } + Ok(()) } -fn check_describe_auth(auth: Authenticated) -> Result<()> { - match auth { - Authenticated::Anonymous => { - Err(Error::NotAuthorized("anonymous access not allowed".into())) - } - Authenticated::Authorized(_) => Ok(()), - } +fn check_describe_auth(ctx: RequestContext) -> Result<()> { + ctx.auth().has_right(ctx.namespace(), Permission::Read)?; + Ok(()) } /// We use a different runtime to run the connection, because long running tasks block turmoil @@ -1073,13 +1071,13 @@ where async fn execute_program( &self, pgm: Program, - auth: Authenticated, + ctx: RequestContext, builder: B, _replication_index: Option, ) -> Result { PROGRAM_EXEC_COUNT.increment(1); - check_program_auth(auth, &pgm)?; + check_program_auth(&ctx, &pgm)?; let conn = self.inner.clone(); CONN_RT .spawn_blocking(move || Connection::run(conn, pgm, builder)) @@ -1090,11 +1088,11 @@ where async fn describe( &self, sql: String, - auth: Authenticated, + ctx: RequestContext, _replication_index: Option, ) -> Result> { DESCRIBE_COUNT.increment(1); - check_describe_auth(auth)?; + check_describe_auth(ctx)?; let conn = self.inner.clone(); let res = tokio::task::spawn_blocking(move || conn.lock().describe(&sql)) .await @@ -1142,7 +1140,10 @@ mod test { use tempfile::tempdir; use tokio::task::JoinSet; + use crate::auth::Authenticated; use crate::connection::Connection as _; + use crate::namespace::meta_store::MetaStore; + use crate::namespace::NamespaceName; use crate::query_result_builder::test::{test_driver, TestBuilder}; use crate::query_result_builder::QueryResultBuilder; use crate::DEFAULT_AUTO_CHECKPOINT; @@ -1363,28 +1364,31 @@ mod test { ) .await .unwrap(); - let auth = Authenticated::Authorized(Authorized { - namespace: None, - permission: Permission::FullAccess, - }); let conn = make_conn.make_connection().await.unwrap(); + let ctx = RequestContext::new( + Authenticated::FullAccess, + NamespaceName::default(), + MetaStore::new(Default::default(), tmp.path()) + .await + .unwrap(), + ); conn.execute_program( Program::seq(&["CREATE TABLE test (x)"]), - auth.clone(), + ctx.clone(), TestBuilder::default(), None, ) .await .unwrap(); let run_conn = |maker: Arc>| { - let auth = auth.clone(); + let ctx = ctx.clone(); async move { for _ in 0..1000 { let conn = maker.make_connection().await.unwrap(); let pgm = Program::seq(&["BEGIN IMMEDIATE", "INSERT INTO test VALUES (42)"]); let res = conn - .execute_program(pgm, auth.clone(), TestBuilder::default(), None) + .execute_program(pgm, ctx.clone(), TestBuilder::default(), None) .await .unwrap() .into_ret(); @@ -1395,7 +1399,7 @@ mod test { if rand::thread_rng().gen_range(0..100) > 1 { let pgm = Program::seq(&["INSERT INTO test VALUES (43)", "COMMIT"]); let res = conn - .execute_program(pgm, auth.clone(), TestBuilder::default(), None) + .execute_program(pgm, ctx.clone(), TestBuilder::default(), None) .await .unwrap() .into_ret(); diff --git a/libsql-server/src/connection/mod.rs b/libsql-server/src/connection/mod.rs index fc0835c426..19811a2ceb 100644 --- a/libsql-server/src/connection/mod.rs +++ b/libsql-server/src/connection/mod.rs @@ -1,15 +1,19 @@ +use libsql_replication::rpc::replication::NAMESPACE_METADATA_KEY; use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; use std::sync::Arc; use tokio::time::{Duration, Instant}; use futures::Future; use tokio::{sync::Semaphore, time::timeout}; +use tonic::metadata::BinaryMetadataValue; use crate::auth::Authenticated; use crate::error::Error; use crate::metrics::{ CONCCURENT_CONNECTIONS_COUNT, CONNECTION_ALIVE_DURATION, CONNECTION_CREATE_TIME, }; +use crate::namespace::meta_store::MetaStore; +use crate::namespace::NamespaceName; use crate::query::{Params, Query}; use crate::query_analysis::Statement; use crate::query_result_builder::{IgnoreResult, QueryResultBuilder}; @@ -29,13 +33,47 @@ const TXN_TIMEOUT: Duration = Duration::from_secs(5); #[cfg(test)] const TXN_TIMEOUT: Duration = Duration::from_millis(100); +#[derive(Clone)] +pub struct RequestContext { + /// Authentication for this request + auth: Authenticated, + /// current namespace + namespace: NamespaceName, + meta_store: MetaStore, +} + +impl RequestContext { + pub fn new(auth: Authenticated, namespace: NamespaceName, meta_store: MetaStore) -> Self { + Self { + auth, + namespace, + meta_store, + } + } + + pub fn upgrade_grpc_request(&self, req: &mut tonic::Request) { + let namespace = BinaryMetadataValue::from_bytes(self.namespace.as_slice()); + req.metadata_mut() + .insert_bin(NAMESPACE_METADATA_KEY, namespace); + self.auth.upgrade_grpc_request(req); + } + + pub fn namespace(&self) -> &NamespaceName { + &self.namespace + } + + pub fn auth(&self) -> &Authenticated { + &self.auth + } +} + #[async_trait::async_trait] pub trait Connection: Send + Sync + 'static { /// Executes a query program async fn execute_program( &self, pgm: Program, - auth: Authenticated, + ctx: RequestContext, response_builder: B, replication_index: Option, ) -> Result; @@ -46,7 +84,7 @@ pub trait Connection: Send + Sync + 'static { async fn execute_batch_or_rollback( &self, batch: Vec, - auth: Authenticated, + ctx: RequestContext, result_builder: B, replication_index: Option, ) -> Result { @@ -74,7 +112,7 @@ pub trait Connection: Send + Sync + 'static { // ignore the rollback result let builder = result_builder.take(batch_len); let builder = self - .execute_program(pgm, auth, builder, replication_index) + .execute_program(pgm, ctx, builder, replication_index) .await?; Ok(builder.into_inner()) @@ -85,24 +123,24 @@ pub trait Connection: Send + Sync + 'static { async fn execute_batch( &self, batch: Vec, - auth: Authenticated, + ctx: RequestContext, result_builder: B, replication_index: Option, ) -> Result { let steps = make_batch_program(batch); let pgm = Program::new(steps); - self.execute_program(pgm, auth, result_builder, replication_index) + self.execute_program(pgm, ctx, result_builder, replication_index) .await } - async fn rollback(&self, auth: Authenticated) -> Result<()> { + async fn rollback(&self, ctx: RequestContext) -> Result<()> { self.execute_batch( vec![Query { stmt: Statement::parse("ROLLBACK").next().unwrap().unwrap(), params: Params::empty(), want_rows: false, }], - auth, + ctx, IgnoreResult, None, ) @@ -115,7 +153,7 @@ pub trait Connection: Send + Sync + 'static { async fn describe( &self, sql: String, - auth: Authenticated, + ctx: RequestContext, replication_index: Option, ) -> Result>; @@ -339,13 +377,13 @@ impl Connection for TrackedConnection { async fn execute_program( &self, pgm: Program, - auth: Authenticated, + ctx: RequestContext, builder: B, replication_index: Option, ) -> crate::Result { self.atime.store(now_millis(), Ordering::Relaxed); self.inner - .execute_program(pgm, auth, builder, replication_index) + .execute_program(pgm, ctx, builder, replication_index) .await } @@ -353,11 +391,11 @@ impl Connection for TrackedConnection { async fn describe( &self, sql: String, - auth: Authenticated, + ctx: RequestContext, replication_index: Option, ) -> crate::Result> { self.atime.store(now_millis(), Ordering::Relaxed); - self.inner.describe(sql, auth, replication_index).await + self.inner.describe(sql, ctx, replication_index).await } #[inline] @@ -394,7 +432,7 @@ pub mod test { async fn execute_program( &self, _pgm: Program, - _auth: Authenticated, + _ctx: RequestContext, _builder: B, _replication_index: Option, ) -> crate::Result { @@ -404,7 +442,7 @@ pub mod test { async fn describe( &self, _sql: String, - _auth: Authenticated, + _ctx: RequestContext, _replication_index: Option, ) -> crate::Result> { unreachable!() diff --git a/libsql-server/src/connection/write_proxy.rs b/libsql-server/src/connection/write_proxy.rs index cec74d6764..5139484b6d 100644 --- a/libsql-server/src/connection/write_proxy.rs +++ b/libsql-server/src/connection/write_proxy.rs @@ -7,22 +7,18 @@ use libsql_replication::rpc::proxy::proxy_client::ProxyClient; use libsql_replication::rpc::proxy::{ exec_req, exec_resp, ExecReq, ExecResp, StreamDescribeReq, StreamProgramReq, }; -use libsql_replication::rpc::replication::NAMESPACE_METADATA_KEY; use libsql_sys::wal::{Sqlite3Wal, Sqlite3WalManager}; use libsql_sys::EncryptionConfig; use parking_lot::Mutex as PMutex; use tokio::sync::{mpsc, watch, Mutex}; use tokio_stream::StreamExt; -use tonic::metadata::BinaryMetadataValue; use tonic::transport::Channel; use tonic::{Request, Streaming}; -use crate::auth::Authenticated; use crate::connection::program::{DescribeCol, DescribeParam}; use crate::error::Error; use crate::metrics::{REPLICA_LOCAL_EXEC_MISPREDICT, REPLICA_LOCAL_PROGRAM_EXEC}; use crate::namespace::meta_store::MetaStoreHandle; -use crate::namespace::NamespaceName; use crate::query_analysis::TxnStatus; use crate::query_result_builder::{QueryBuilderConfig, QueryResultBuilder}; use crate::replication::FrameNo; @@ -31,7 +27,7 @@ use crate::{Result, DEFAULT_AUTO_CHECKPOINT}; use super::libsql::{LibSqlConnection, MakeLibSqlConn}; use super::program::DescribeResponse; -use super::Connection; +use super::{Connection, RequestContext}; use super::{MakeConnection, Program}; pub type RpcStream = Streaming; @@ -42,7 +38,6 @@ pub struct MakeWriteProxyConn { applied_frame_no_receiver: watch::Receiver>, max_response_size: u64, max_total_response_size: u64, - namespace: NamespaceName, primary_replication_index: Option, make_read_only_conn: MakeLibSqlConn, encryption_config: Option, @@ -60,7 +55,6 @@ impl MakeWriteProxyConn { applied_frame_no_receiver: watch::Receiver>, max_response_size: u64, max_total_response_size: u64, - namespace: NamespaceName, primary_replication_index: Option, encryption_config: Option, ) -> crate::Result { @@ -85,7 +79,6 @@ impl MakeWriteProxyConn { applied_frame_no_receiver, max_response_size, max_total_response_size, - namespace, make_read_only_conn, primary_replication_index, encryption_config, @@ -107,7 +100,6 @@ impl MakeConnection for MakeWriteProxyConn { auto_checkpoint: DEFAULT_AUTO_CHECKPOINT, encryption_config: self.encryption_config.clone(), }, - self.namespace.clone(), self.primary_replication_index, self.make_read_only_conn.create().await?, )?) @@ -127,7 +119,6 @@ pub struct WriteProxyConnection { applied_frame_no_receiver: watch::Receiver>, builder_config: QueryBuilderConfig, stats: Arc, - namespace: NamespaceName, remote_conn: Mutex>>, /// the primary replication index when the namespace was loaded @@ -141,7 +132,6 @@ impl WriteProxyConnection { stats: Arc, applied_frame_no_receiver: watch::Receiver>, builder_config: QueryBuilderConfig, - namespace: NamespaceName, primary_replication_index: Option, read_conn: LibSqlConnection, ) -> Result { @@ -153,7 +143,6 @@ impl WriteProxyConnection { applied_frame_no_receiver, builder_config, stats, - namespace, remote_conn: Default::default(), primary_replication_index, }) @@ -161,7 +150,7 @@ impl WriteProxyConnection { async fn with_remote_conn( &self, - auth: Authenticated, + ctx: RequestContext, builder_config: QueryBuilderConfig, cb: F, ) -> crate::Result @@ -172,13 +161,8 @@ impl WriteProxyConnection { if remote_conn.is_some() { cb(remote_conn.as_mut().unwrap()).await } else { - let conn = RemoteConnection::connect( - self.write_proxy.clone(), - self.namespace.clone(), - auth, - builder_config, - ) - .await?; + let conn = + RemoteConnection::connect(self.write_proxy.clone(), ctx, builder_config).await?; let conn = remote_conn.insert(conn); cb(conn).await } @@ -188,13 +172,13 @@ impl WriteProxyConnection { &self, pgm: Program, status: &mut TxnStatus, - auth: Authenticated, + ctx: RequestContext, builder: B, ) -> Result { self.stats.inc_write_requests_delegated(); *status = TxnStatus::Invalid; let res = self - .with_remote_conn(auth, self.builder_config.clone(), |conn| { + .with_remote_conn(ctx, self.builder_config.clone(), |conn| { Box::pin(conn.execute(pgm, builder)) }) .await; @@ -278,18 +262,14 @@ struct RemoteConnection> { impl RemoteConnection { async fn connect( mut client: ProxyClient, - namespace: NamespaceName, - auth: Authenticated, + ctx: RequestContext, builder_config: QueryBuilderConfig, ) -> crate::Result { let (request_sender, receiver) = mpsc::channel(1); let stream = tokio_stream::wrappers::ReceiverStream::new(receiver); let mut req = Request::new(stream); - let namespace = BinaryMetadataValue::from_bytes(namespace.as_slice()); - req.metadata_mut() - .insert_bin(NAMESPACE_METADATA_KEY, namespace); - auth.upgrade_grpc_request(&mut req); + ctx.upgrade_grpc_request(&mut req); let response_stream = client.stream_exec(req).await?.into_inner(); Ok(Self { @@ -461,14 +441,14 @@ impl Connection for WriteProxyConnection { async fn execute_program( &self, pgm: Program, - auth: Authenticated, + ctx: RequestContext, builder: B, replication_index: Option, ) -> Result { let mut state = self.state.lock().await; if self.should_proxy() { - self.execute_remote(pgm, &mut state, auth, builder).await + self.execute_remote(pgm, &mut state, ctx, builder).await } else if *state == TxnStatus::Init && pgm.is_read_only() { // set the state to invalid before doing anything, and set it to a valid state after. *state = TxnStatus::Invalid; @@ -478,31 +458,31 @@ impl Connection for WriteProxyConnection { // transaction, so we rollback the replica, and execute again on the primary. let builder = self .read_conn - .execute_program(pgm.clone(), auth.clone(), builder, replication_index) + .execute_program(pgm.clone(), ctx.clone(), builder, replication_index) .await?; let new_state = self.read_conn.txn_status()?; if new_state != TxnStatus::Init { REPLICA_LOCAL_EXEC_MISPREDICT.increment(1); - self.read_conn.rollback(auth.clone()).await?; - self.execute_remote(pgm, &mut state, auth, builder).await + self.read_conn.rollback(ctx.clone()).await?; + self.execute_remote(pgm, &mut state, ctx, builder).await } else { REPLICA_LOCAL_PROGRAM_EXEC.increment(1); *state = new_state; Ok(builder) } } else { - self.execute_remote(pgm, &mut state, auth, builder).await + self.execute_remote(pgm, &mut state, ctx, builder).await } } async fn describe( &self, sql: String, - auth: Authenticated, + ctx: RequestContext, replication_index: Option, ) -> Result> { self.wait_replication_sync(replication_index).await?; - self.read_conn.describe(sql, auth, replication_index).await + self.read_conn.describe(sql, ctx, replication_index).await } async fn is_autocommit(&self) -> Result { diff --git a/libsql-server/src/hrana/batch.rs b/libsql-server/src/hrana/batch.rs index a2ddd9e291..a54547134e 100644 --- a/libsql-server/src/hrana/batch.rs +++ b/libsql-server/src/hrana/batch.rs @@ -2,9 +2,8 @@ use anyhow::{anyhow, bail, Result}; use std::collections::HashMap; use std::sync::Arc; -use crate::auth::Authenticated; use crate::connection::program::{Cond, Program, Step}; -use crate::connection::Connection; +use crate::connection::{Connection, RequestContext}; use crate::error::Error as SqldError; use crate::hrana::stmt::StmtError; use crate::query::{Params, Query}; @@ -105,13 +104,13 @@ pub fn proto_batch_to_program( pub async fn execute_batch( db: &impl Connection, - auth: Authenticated, + ctx: RequestContext, pgm: Program, replication_index: Option, ) -> Result { let batch_builder = HranaBatchProtoBuilder::default(); let builder = db - .execute_program(pgm, auth, batch_builder, replication_index) + .execute_program(pgm, ctx, batch_builder, replication_index) .await .map_err(catch_batch_error)?; @@ -146,13 +145,13 @@ pub fn proto_sequence_to_program(sql: &str) -> Result { pub async fn execute_sequence( db: &impl Connection, - auth: Authenticated, + ctx: RequestContext, pgm: Program, replication_index: Option, ) -> Result<()> { let builder = StepResultsBuilder::default(); let builder = db - .execute_program(pgm, auth, builder, replication_index) + .execute_program(pgm, ctx, builder, replication_index) .await .map_err(catch_batch_error)?; builder diff --git a/libsql-server/src/hrana/cursor.rs b/libsql-server/src/hrana/cursor.rs index 5792eff1ba..97a8c8c191 100644 --- a/libsql-server/src/hrana/cursor.rs +++ b/libsql-server/src/hrana/cursor.rs @@ -5,9 +5,8 @@ use std::sync::Arc; use std::task; use tokio::sync::{mpsc, oneshot}; -use crate::auth::Authenticated; use crate::connection::program::Program; -use crate::connection::Connection; +use crate::connection::{Connection, RequestContext}; use crate::query_result_builder::{ Column, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError, }; @@ -29,7 +28,7 @@ pub struct SizedEntry { struct OpenReq { db: Arc, - auth: Authenticated, + ctx: RequestContext, pgm: Program, replication_index: Option, } @@ -52,14 +51,14 @@ impl CursorHandle { pub fn open( &mut self, db: Arc, - auth: Authenticated, + ctx: RequestContext, pgm: Program, replication_index: Option, ) { let open_tx = self.open_tx.take().unwrap(); let _: Result<_, _> = open_tx.send(OpenReq { db, - auth, + ctx, pgm, replication_index, }); @@ -90,7 +89,7 @@ async fn run_cursor( .db .execute_program( open_req.pgm, - open_req.auth, + open_req.ctx, result_builder, open_req.replication_index, ) diff --git a/libsql-server/src/hrana/http/mod.rs b/libsql-server/src/hrana/http/mod.rs index 0778c3b194..f7ff2e9801 100644 --- a/libsql-server/src/hrana/http/mod.rs +++ b/libsql-server/src/hrana/http/mod.rs @@ -9,8 +9,7 @@ use std::sync::Arc; use std::task; use super::{batch, cursor, Encoding, ProtocolError, Version}; -use crate::auth::Authenticated; -use crate::connection::{Connection, MakeConnection}; +use crate::connection::{Connection, MakeConnection, RequestContext}; use crate::hrana::http::stream::StreamError; mod request; @@ -44,7 +43,7 @@ impl Server { pub async fn handle_request( &self, connection_maker: Arc>, - auth: Authenticated, + ctx: RequestContext, req: hyper::Request, endpoint: Endpoint, version: Version, @@ -53,7 +52,7 @@ impl Server { handle_request( self, connection_maker, - auth, + ctx, req, endpoint, version, @@ -86,7 +85,7 @@ pub(crate) async fn handle_index() -> hyper::Response { async fn handle_request( server: &Server, connection_maker: Arc>, - auth: Authenticated, + ctx: RequestContext, req: hyper::Request, endpoint: Endpoint, version: Version, @@ -94,10 +93,10 @@ async fn handle_request( ) -> Result> { match endpoint { Endpoint::Pipeline => { - handle_pipeline(server, connection_maker, auth, req, version, encoding).await + handle_pipeline(server, connection_maker, ctx, req, version, encoding).await } Endpoint::Cursor => { - handle_cursor(server, connection_maker, auth, req, version, encoding).await + handle_cursor(server, connection_maker, ctx, req, version, encoding).await } } } @@ -105,7 +104,7 @@ async fn handle_request( async fn handle_pipeline( server: &Server, connection_maker: Arc>, - auth: Authenticated, + ctx: RequestContext, req: hyper::Request, version: Version, encoding: Encoding, @@ -117,7 +116,7 @@ async fn handle_pipeline( let mut results = Vec::with_capacity(req_body.requests.len()); for request in req_body.requests.into_iter() { tracing::debug!("pipeline:{{ {:?}, {:?} }}", version, request); - let result = request::handle(&mut stream_guard, auth.clone(), request, version).await?; + let result = request::handle(&mut stream_guard, ctx.clone(), request, version).await?; results.push(result); } @@ -132,7 +131,7 @@ async fn handle_pipeline( async fn handle_cursor( server: &Server, connection_maker: Arc>, - auth: Authenticated, + ctx: RequestContext, req: hyper::Request, version: Version, encoding: Encoding, @@ -145,7 +144,7 @@ async fn handle_cursor( let db = stream_guard.get_db_owned()?; let sqls = stream_guard.sqls(); let pgm = batch::proto_batch_to_program(&req_body.batch, sqls, version)?; - cursor_hnd.open(db, auth, pgm, req_body.batch.replication_index); + cursor_hnd.open(db, ctx, pgm, req_body.batch.replication_index); let resp_body = proto::CursorRespBody { baton: stream_guard.release(), diff --git a/libsql-server/src/hrana/http/request.rs b/libsql-server/src/hrana/http/request.rs index 818c3009a2..c91d807484 100644 --- a/libsql-server/src/hrana/http/request.rs +++ b/libsql-server/src/hrana/http/request.rs @@ -3,8 +3,7 @@ use bytesize::ByteSize; use super::super::{batch, stmt, ProtocolError, Version}; use super::stream; -use crate::auth::Authenticated; -use crate::connection::Connection; +use crate::connection::{Connection, RequestContext}; use libsql_sys::hrana::proto; const MAX_SQL_COUNT: usize = 50; @@ -25,11 +24,11 @@ enum StreamResponseError { pub async fn handle( stream_guard: &mut stream::Guard<'_, D>, - auth: Authenticated, + ctx: RequestContext, request: proto::StreamRequest, version: Version, ) -> Result { - let result = match try_handle(stream_guard, auth, request, version).await { + let result = match try_handle(stream_guard, ctx, request, version).await { Ok(response) => proto::StreamResult::Ok { response }, Err(err) => { let resp_err = err.downcast::()?; @@ -45,7 +44,7 @@ pub async fn handle( async fn try_handle( stream_guard: &mut stream::Guard<'_, D>, - auth: Authenticated, + ctx: RequestContext, request: proto::StreamRequest, version: Version, ) -> Result { @@ -71,7 +70,7 @@ async fn try_handle( let sqls = stream_guard.sqls(); let query = stmt::proto_stmt_to_query(&req.stmt, sqls, version).map_err(catch_stmt_error)?; - let result = stmt::execute_stmt(db, auth, query, req.stmt.replication_index) + let result = stmt::execute_stmt(db, ctx, query, req.stmt.replication_index) .await .map_err(catch_stmt_error)?; proto::StreamResponse::Execute(proto::ExecuteStreamResp { result }) @@ -82,7 +81,7 @@ async fn try_handle( let pgm = batch::proto_batch_to_program(&req.batch, sqls, version) .map_err(catch_stmt_error) .map_err(catch_batch_error)?; - let result = batch::execute_batch(db, auth, pgm, req.batch.replication_index) + let result = batch::execute_batch(db, ctx, pgm, req.batch.replication_index) .await .map_err(catch_batch_error)?; proto::StreamResponse::Batch(proto::BatchStreamResp { result }) @@ -92,7 +91,7 @@ async fn try_handle( let sqls = stream_guard.sqls(); let sql = stmt::proto_sql_to_sql(req.sql.as_deref(), req.sql_id, sqls, version)?; let pgm = batch::proto_sequence_to_program(sql).map_err(catch_stmt_error)?; - batch::execute_sequence(db, auth, pgm, req.replication_index) + batch::execute_sequence(db, ctx, pgm, req.replication_index) .await .map_err(catch_stmt_error) .map_err(catch_batch_error)?; @@ -102,7 +101,7 @@ async fn try_handle( let db = stream_guard.get_db()?; let sqls = stream_guard.sqls(); let sql = stmt::proto_sql_to_sql(req.sql.as_deref(), req.sql_id, sqls, version)?; - let result = stmt::describe_stmt(db, auth, sql.into(), req.replication_index) + let result = stmt::describe_stmt(db, ctx, sql.into(), req.replication_index) .await .map_err(catch_stmt_error)?; proto::StreamResponse::Describe(proto::DescribeStreamResp { result }) diff --git a/libsql-server/src/hrana/stmt.rs b/libsql-server/src/hrana/stmt.rs index b782442f50..4c76c5d7b4 100644 --- a/libsql-server/src/hrana/stmt.rs +++ b/libsql-server/src/hrana/stmt.rs @@ -3,9 +3,8 @@ use std::collections::HashMap; use super::result_builder::SingleStatementBuilder; use super::{proto, ProtocolError, Version}; -use crate::auth::Authenticated; use crate::connection::program::DescribeResponse; -use crate::connection::Connection; +use crate::connection::{Connection, RequestContext}; use crate::error::Error as SqldError; use crate::hrana; use crate::query::{Params, Query, Value}; @@ -53,13 +52,13 @@ pub enum StmtError { pub async fn execute_stmt( db: &impl Connection, - auth: Authenticated, + ctx: RequestContext, query: Query, replication_index: Option, ) -> Result { let builder = SingleStatementBuilder::default(); let stmt_res = db - .execute_batch(vec![query], auth, builder, replication_index) + .execute_batch(vec![query], ctx, builder, replication_index) .await .map_err(catch_stmt_error)?; stmt_res.into_ret().map_err(catch_stmt_error) @@ -67,11 +66,11 @@ pub async fn execute_stmt( pub async fn describe_stmt( db: &impl Connection, - auth: Authenticated, + ctx: RequestContext, sql: String, replication_index: Option, ) -> Result { - match db.describe(sql, auth, replication_index).await? { + match db.describe(sql, ctx, replication_index).await? { Ok(describe_response) => Ok(proto_describe_result_from_describe_response( describe_response, )), diff --git a/libsql-server/src/hrana/ws/session.rs b/libsql-server/src/hrana/ws/session.rs index e0336eeb77..f88541e246 100644 --- a/libsql-server/src/hrana/ws/session.rs +++ b/libsql-server/src/hrana/ws/session.rs @@ -9,14 +9,14 @@ use tokio::sync::{mpsc, oneshot}; use super::super::{batch, cursor, stmt, ProtocolError, Version}; use super::{proto, Server}; use crate::auth::user_auth_strategies::UserAuthContext; -use crate::auth::{AuthError, Authenticated}; -use crate::connection::Connection; +use crate::auth::{Auth, AuthError, Authenticated, Jwt}; +use crate::connection::{Connection, RequestContext}; use crate::database::Database; use crate::namespace::{MakeNamespace, NamespaceName}; /// Session-level state of an authenticated Hrana connection. pub struct Session { - authenticated: Authenticated, + auth: Authenticated, version: Version, streams: HashMap>, sqls: HashMap, @@ -82,17 +82,15 @@ pub(super) async fn handle_initial_hello( .clone() .and_then(|t| HeaderValue::from_str(&format!("Bearer {t}")).ok()); - let authenticated = server - .user_auth_strategy - .authenticate(UserAuthContext { - namespace, - user_credential, - namespace_credential: namespace_jwt_key, - }) + let auth = namespace_jwt_key + .map(Jwt::new) + .map(Auth::new) + .unwrap_or(server.user_auth_strategy.clone()) + .authenticate(UserAuthContext { user_credential }) .map_err(|err| anyhow!(ResponseError::Auth { source: err }))?; Ok(Session { - authenticated, + auth, version, streams: HashMap::new(), sqls: HashMap::new(), @@ -122,13 +120,11 @@ pub(super) async fn handle_repeated_hello( .clone() .and_then(|t| HeaderValue::from_str(&format!("Bearer {t}")).ok()); - session.authenticated = server - .user_auth_strategy - .authenticate(UserAuthContext { - namespace, - user_credential, - namespace_credential: namespace_jwt_key, - }) + session.auth = namespace_jwt_key + .map(Jwt::new) + .map(Auth::new) + .unwrap_or_else(|| server.user_auth_strategy.clone()) + .authenticate(UserAuthContext { user_credential }) .map_err(|err| anyhow!(ResponseError::Auth { source: err }))?; Ok(()) @@ -221,10 +217,10 @@ pub(super) async fn handle_request( ); let namespaces = server.namespaces.clone(); - let authenticated = session.authenticated.clone(); + let auth = session.auth.clone(); stream_respond!(&mut stream_hnd, async move |stream| { let db = namespaces - .with_authenticated(namespace, authenticated, |ns| ns.db.connection_maker()) + .with_authenticated(namespace, auth, |ns| ns.db.connection_maker()) .await? .create() .await?; @@ -253,11 +249,12 @@ pub(super) async fn handle_request( let query = stmt::proto_stmt_to_query(&req.stmt, &session.sqls, session.version) .map_err(catch_stmt_error)?; - let auth = session.authenticated.clone(); + let auth = session.auth.clone(); + let ctx = RequestContext::new(auth, namespace, server.namespaces.meta_store()); stream_respond!(stream_hnd, async move |stream| { let db = get_stream_db!(stream, stream_id); - let result = stmt::execute_stmt(&**db, auth, query, req.replication_index) + let result = stmt::execute_stmt(&**db, ctx, query, req.replication_index) .await .map_err(catch_stmt_error)?; Ok(proto::Response::Execute(proto::ExecuteResp { result })) @@ -269,11 +266,15 @@ pub(super) async fn handle_request( let pgm = batch::proto_batch_to_program(&req.batch, &session.sqls, session.version) .map_err(catch_stmt_error)?; - let auth = session.authenticated.clone(); + let ctx = RequestContext::new( + session.auth.clone(), + namespace, + server.namespaces.meta_store(), + ); stream_respond!(stream_hnd, async move |stream| { let db = get_stream_db!(stream, stream_id); - let result = batch::execute_batch(&**db, auth, pgm, req.batch.replication_index) + let result = batch::execute_batch(&**db, ctx, pgm, req.batch.replication_index) .await .map_err(catch_batch_error)?; Ok(proto::Response::Batch(proto::BatchResp { result })) @@ -291,11 +292,15 @@ pub(super) async fn handle_request( session.version, )?; let pgm = batch::proto_sequence_to_program(sql).map_err(catch_stmt_error)?; - let auth = session.authenticated.clone(); + let ctx = RequestContext::new( + session.auth.clone(), + namespace, + server.namespaces.meta_store(), + ); stream_respond!(stream_hnd, async move |stream| { let db = get_stream_db!(stream, stream_id); - batch::execute_sequence(&**db, auth, pgm, req.replication_index) + batch::execute_sequence(&**db, ctx, pgm, req.replication_index) .await .map_err(catch_stmt_error) .map_err(catch_batch_error)?; @@ -314,11 +319,15 @@ pub(super) async fn handle_request( session.version, )? .into(); - let auth = session.authenticated.clone(); + let ctx = RequestContext::new( + session.auth.clone(), + namespace, + server.namespaces.meta_store(), + ); stream_respond!(stream_hnd, async move |stream| { let db = get_stream_db!(stream, stream_id); - let result = stmt::describe_stmt(&**db, auth, sql, req.replication_index) + let result = stmt::describe_stmt(&**db, ctx, sql, req.replication_index) .await .map_err(catch_stmt_error)?; Ok(proto::Response::Describe(proto::DescribeResp { result })) @@ -359,12 +368,16 @@ pub(super) async fn handle_request( let pgm = batch::proto_batch_to_program(&req.batch, &session.sqls, session.version) .map_err(catch_stmt_error)?; - let auth = session.authenticated.clone(); - + let ctx = RequestContext::new( + session.auth.clone(), + namespace, + server.namespaces.meta_store(), + ); let mut cursor_hnd = cursor::CursorHandle::spawn(join_set); + stream_respond!(stream_hnd, async move |stream| { let db = get_stream_db!(stream, stream_id); - cursor_hnd.open(db.clone(), auth, pgm, req.batch.replication_index); + cursor_hnd.open(db.clone(), ctx, pgm, req.batch.replication_index); stream.cursor_hnd = Some(cursor_hnd); Ok(proto::Response::OpenCursor(proto::OpenCursorResp {})) }); diff --git a/libsql-server/src/http/admin/mod.rs b/libsql-server/src/http/admin/mod.rs index c505613693..837e7c6ebf 100644 --- a/libsql-server/src/http/admin/mod.rs +++ b/libsql-server/src/http/admin/mod.rs @@ -287,6 +287,8 @@ struct CreateNamespaceReq { /// If some, this is a [NamespaceName] reference to a shared schema DB. #[serde(default)] shared_schema_name: Option, + #[serde(default)] + allow_attach: bool, } async fn handle_create_namespace( @@ -335,6 +337,7 @@ async fn handle_create_namespace( config.is_shared_schema = req.shared_schema; config.shared_schema_name = shared_schema_name.as_ref().map(|x| x.to_string()); + config.allow_attach = req.allow_attach; if let Some(max_db_size) = req.max_db_size { config.max_db_pages = max_db_size.as_u64() / LIBSQL_PAGE_SIZE; } diff --git a/libsql-server/src/http/user/extract.rs b/libsql-server/src/http/user/extract.rs new file mode 100644 index 0000000000..14f7138ee3 --- /dev/null +++ b/libsql-server/src/http/user/extract.rs @@ -0,0 +1,46 @@ +use axum::extract::FromRequestParts; + +use crate::{ + auth::{Jwt, UserAuthContext, UserAuthStrategy}, + connection::RequestContext, + namespace::MakeNamespace, +}; + +use super::{db_factory, AppState}; + +#[async_trait::async_trait] +impl FromRequestParts> for RequestContext +where + F: MakeNamespace, +{ + type Rejection = crate::error::Error; + + async fn from_request_parts( + parts: &mut axum::http::request::Parts, + state: &AppState, + ) -> std::result::Result { + let namespace = db_factory::namespace_from_headers( + &parts.headers, + state.disable_default_namespace, + state.disable_namespaces, + )?; + + let namespace_jwt_key = state + .namespaces + .with(namespace.clone(), |ns| ns.jwt_key()) + .await??; + + let auth_header = parts.headers.get(hyper::header::AUTHORIZATION); + + let auth = match namespace_jwt_key { + Some(key) => Jwt::new(key).authenticate(UserAuthContext { + user_credential: auth_header.cloned(), + })?, + None => state.user_auth_strategy.authenticate(UserAuthContext { + user_credential: auth_header.cloned(), + })?, + }; + + Ok(Self::new(auth, namespace, state.namespaces.meta_store())) + } +} diff --git a/libsql-server/src/http/user/hrana_over_http_1.rs b/libsql-server/src/http/user/hrana_over_http_1.rs index b06a6ec9a3..2217c6c511 100644 --- a/libsql-server/src/http/user/hrana_over_http_1.rs +++ b/libsql-server/src/http/user/hrana_over_http_1.rs @@ -4,8 +4,7 @@ use std::collections::HashMap; use std::future::Future; use std::sync::Arc; -use crate::auth::Authenticated; -use crate::connection::{Connection, MakeConnection}; +use crate::connection::{Connection, MakeConnection, RequestContext}; use crate::hrana; use super::db_factory::MakeConnectionExtractor; @@ -26,7 +25,7 @@ pub async fn handle_index() -> hyper::Response { pub(crate) async fn handle_execute( MakeConnectionExtractor(factory): MakeConnectionExtractor, - auth: Authenticated, + ctx: RequestContext, req: hyper::Request, ) -> crate::Result> { #[derive(Debug, Deserialize)] @@ -46,7 +45,7 @@ pub(crate) async fn handle_execute( hrana::Version::Hrana1, ) .map_err(catch_stmt_error)?; - hrana::stmt::execute_stmt(&db, auth, query, req_body.stmt.replication_index) + hrana::stmt::execute_stmt(&db, ctx, query, req_body.stmt.replication_index) .await .map(|result| RespBody { result }) .map_err(catch_stmt_error) @@ -59,7 +58,7 @@ pub(crate) async fn handle_execute( pub(crate) async fn handle_batch( MakeConnectionExtractor(factory): MakeConnectionExtractor, - auth: Authenticated, + ctx: RequestContext, req: hyper::Request, ) -> crate::Result> { #[derive(Debug, Deserialize)] @@ -79,7 +78,7 @@ pub(crate) async fn handle_batch( hrana::Version::Hrana1, ) .map_err(catch_stmt_error)?; - hrana::batch::execute_batch(&db, auth, pgm, req_body.batch.replication_index) + hrana::batch::execute_batch(&db, ctx, pgm, req_body.batch.replication_index) .await .map(|result| RespBody { result }) .context("Could not execute batch") diff --git a/libsql-server/src/http/user/mod.rs b/libsql-server/src/http/user/mod.rs index 23b8e6205f..29b52ad6f5 100644 --- a/libsql-server/src/http/user/mod.rs +++ b/libsql-server/src/http/user/mod.rs @@ -1,5 +1,6 @@ pub mod db_factory; mod dump; +mod extract; mod hrana_over_http_1; mod result_builder; mod trace; @@ -29,8 +30,8 @@ use tonic::transport::Server; use tower_http::{compression::CompressionLayer, cors}; use crate::auth::user_auth_strategies::UserAuthContext; -use crate::auth::{Auth, Authenticated}; -use crate::connection::Connection; +use crate::auth::{Auth, Authenticated, Jwt}; +use crate::connection::{Connection, RequestContext}; use crate::database::Database; use crate::error::Error; use crate::hrana; @@ -126,7 +127,7 @@ fn parse_queries(queries: Vec) -> crate::Result> { } async fn handle_query( - auth: Authenticated, + ctx: RequestContext, MakeConnectionExtractor(connection_maker): MakeConnectionExtractor, Json(query): Json, ) -> Result { @@ -137,7 +138,7 @@ async fn handle_query( let builder = JsonHttpPayloadBuilder::new(); let builder = db - .execute_batch_or_rollback(batch, auth, builder, query.replication_index) + .execute_batch_or_rollback(batch, ctx, builder, query.replication_index) .await?; let res = ( @@ -202,7 +203,7 @@ async fn handle_hrana_pipeline( MakeConnectionExtractorPath(connection_maker): MakeConnectionExtractorPath< ::Connection, >, - auth: Authenticated, + ctx: RequestContext, axum::extract::Path((_, version)): axum::extract::Path<(String, String)>, req: Request, ) -> Result, Error> { @@ -215,7 +216,7 @@ async fn handle_hrana_pipeline( .hrana_http_srv .handle_request( connection_maker, - auth, + ctx, req, hrana::http::Endpoint::Pipeline, hrana_version, @@ -344,14 +345,14 @@ where MakeConnectionExtractor(connection_maker): MakeConnectionExtractor< ::Connection, >, - auth: Authenticated, + ctx: RequestContext, req: Request, ) -> Result, Error> { Ok(state .hrana_http_srv .handle_request( connection_maker, - auth, + ctx, req, $endpoint, $version, @@ -493,11 +494,13 @@ where let auth_header = parts.headers.get(hyper::header::AUTHORIZATION); - let auth = state.user_auth_strategy.authenticate(UserAuthContext { - namespace: ns, - namespace_credential: namespace_jwt_key, - user_credential: auth_header.cloned(), - })?; + let auth = namespace_jwt_key + .map(Jwt::new) + .map(Auth::new) + .unwrap_or_else(|| state.user_auth_strategy.clone()) + .authenticate(UserAuthContext { + user_credential: auth_header.cloned(), + })?; Ok(auth) } diff --git a/libsql-server/src/namespace/meta_store.rs b/libsql-server/src/namespace/meta_store.rs index aef7fecff4..1ea1d0f1a6 100644 --- a/libsql-server/src/namespace/meta_store.rs +++ b/libsql-server/src/namespace/meta_store.rs @@ -29,6 +29,7 @@ type ChangeMsg = (NamespaceName, Arc); type WalManager = WalWrapper, Sqlite3WalManager>; type Connection = libsql_sys::Connection, Sqlite3Wal>>; +#[derive(Clone)] pub struct MetaStore { changes_tx: mpsc::Sender, inner: Arc>, diff --git a/libsql-server/src/namespace/mod.rs b/libsql-server/src/namespace/mod.rs index 5e3df1ff44..732f7d9a86 100644 --- a/libsql-server/src/namespace/mod.rs +++ b/libsql-server/src/namespace/mod.rs @@ -155,6 +155,15 @@ impl<'de> Deserialize<'de> for NamespaceName { } } +impl serde::Serialize for NamespaceName { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(self.as_str()) + } +} + pub enum ResetOp { Reset(NamespaceName), Destroy(NamespaceName), @@ -777,6 +786,10 @@ impl NamespaceStore { ) -> crate::Result { self.with(namespace, |ns| ns.db_config_store.clone()).await } + + pub(crate) fn meta_store(&self) -> MetaStore { + self.inner.metadata.clone() + } } /// A namespace isolates the resources pertaining to a database of type T @@ -978,7 +991,6 @@ impl Namespace { applied_frame_no_receiver, config.max_response_size, config.max_total_response_size, - name.clone(), primary_current_replicatio_index, config.encryption_config.clone(), ) diff --git a/libsql-server/src/query_analysis.rs b/libsql-server/src/query_analysis.rs index 85c059ad59..e544d6a05d 100644 --- a/libsql-server/src/query_analysis.rs +++ b/libsql-server/src/query_analysis.rs @@ -5,6 +5,8 @@ use fallible_iterator::FallibleIterator; use sqlite3_parser::ast::{Cmd, Expr, Id, PragmaBody, QualifiedName, Stmt}; use sqlite3_parser::lexer::sql::{Parser, ParserError}; +use crate::namespace::NamespaceName; + /// A group of statements to be executed together. #[derive(Debug, Clone)] pub struct Statement { @@ -24,7 +26,7 @@ impl Default for Statement { } /// Classify statement in categories of interest. -#[derive(Debug, PartialEq, Clone, Copy)] +#[derive(Debug, PartialEq, Clone)] pub enum StmtKind { /// The beginning of a transaction TxnBegin, @@ -34,10 +36,9 @@ pub enum StmtKind { Write, Savepoint, Release, - Attach, + Attach(NamespaceName), Detach, DDL, - Other, } fn is_temp(name: &QualifiedName) -> bool { @@ -61,8 +62,8 @@ impl StmtKind { fn kind(cmd: &Cmd) -> Option { match cmd { Cmd::Explain(Stmt::Pragma(name, body)) => Self::pragma_kind(name, body.as_ref()), - Cmd::Explain(_) => Some(Self::Other), - Cmd::ExplainQueryPlan(_) => Some(Self::Other), + Cmd::Explain(_) => Some(Self::Read), + Cmd::ExplainQueryPlan(_) => Some(Self::Read), Cmd::Stmt(Stmt::Begin { .. }) => Some(Self::TxnBegin), Cmd::Stmt( Stmt::Commit { .. } @@ -124,7 +125,9 @@ impl StmtKind { savepoint_name: Some(_), .. }) => Some(Self::Release), - Cmd::Stmt(Stmt::Attach { .. }) => Some(Self::Attach), + Cmd::Stmt(Stmt::Attach { db_name, .. }) => Some(Self::Attach( + NamespaceName::from_string(db_name.to_string()).ok()?, + )), Cmd::Stmt(Stmt::Detach(_)) => Some(Self::Detach), _ => None, } @@ -231,13 +234,13 @@ pub enum TxnStatus { } impl TxnStatus { - pub fn step(&mut self, kind: StmtKind) { + pub fn step(&mut self, kind: &StmtKind) { *self = match (*self, kind) { (TxnStatus::Txn, StmtKind::TxnBegin) | (TxnStatus::Init, StmtKind::TxnEnd) => { TxnStatus::Invalid } (TxnStatus::Txn, StmtKind::TxnEnd) => TxnStatus::Init, - (state, StmtKind::Other | StmtKind::Write | StmtKind::Read | StmtKind::DDL) => state, + (state, StmtKind::Write | StmtKind::Read | StmtKind::DDL) => state, (TxnStatus::Invalid, _) => TxnStatus::Invalid, (TxnStatus::Init, StmtKind::TxnBegin) => TxnStatus::Txn, _ => TxnStatus::Invalid, @@ -352,7 +355,7 @@ pub fn predict_final_state<'a>( stmts: impl Iterator, ) -> TxnStatus { for stmt in stmts { - state.step(stmt.kind); + state.step(&stmt.kind); } state } diff --git a/libsql-server/src/rpc/proxy.rs b/libsql-server/src/rpc/proxy.rs index dcdd4527b7..dea8dfb7d6 100644 --- a/libsql-server/src/rpc/proxy.rs +++ b/libsql-server/src/rpc/proxy.rs @@ -17,10 +17,10 @@ use uuid::Uuid; use crate::auth::parsers::parse_grpc_auth_header; use crate::auth::user_auth_strategies::UserAuthContext; -use crate::auth::{Auth, Authenticated}; -use crate::connection::Connection; +use crate::auth::{Auth, Authenticated, Jwt}; +use crate::connection::{Connection, RequestContext}; use crate::database::{Database, PrimaryConnection}; -use crate::namespace::{NamespaceName, NamespaceStore, PrimaryNamespaceMaker}; +use crate::namespace::{NamespaceStore, PrimaryNamespaceMaker}; use crate::query_result_builder::{ Column, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError, }; @@ -305,40 +305,46 @@ impl ProxyService { self.clients.clone() } - async fn auth( + async fn extract_context( &self, req: &mut tonic::Request, - namespace: NamespaceName, - ) -> Result { + ) -> Result { + let namespace = super::extract_namespace(self.disable_namespaces, req)?; + let namespace_jwt_key = self .namespaces .with(namespace.clone(), |ns| ns.jwt_key()) .await; - let namespace_jwt_key = match namespace_jwt_key { - Ok(Ok(jwt_key)) => Ok(jwt_key), + let auth = match namespace_jwt_key { + Ok(Ok(Some(key))) => Some(Auth::new(Jwt::new(key))), + Ok(Ok(None)) => self.user_auth_strategy.clone(), Err(e) => match e.as_ref() { - crate::error::Error::NamespaceDoesntExist(_) => Ok(None), + crate::error::Error::NamespaceDoesntExist(_) => None, _ => Err(tonic::Status::internal(format!( "Error fetching jwt key for a namespace: {}", e - ))), + )))?, }, Ok(Err(e)) => Err(tonic::Status::internal(format!( "Error fetching jwt key for a namespace: {}", e - ))), - }?; + )))?, + }; - Ok(if let Some(auth) = &self.user_auth_strategy { + let auth = if let Some(auth) = auth { auth.authenticate(UserAuthContext { - namespace, - namespace_credential: namespace_jwt_key, user_credential: parse_grpc_auth_header(req.metadata()), })? } else { - Authenticated::from_proxy_grpc_request(req, self.disable_namespaces)? - }) + Authenticated::from_proxy_grpc_request(req)? + }; + + Ok(RequestContext::new( + auth, + namespace, + self.namespaces.meta_store(), + )) } } @@ -531,12 +537,11 @@ impl Proxy for ProxyService { &self, mut req: tonic::Request>, ) -> Result, tonic::Status> { - let namespace = super::extract_namespace(self.disable_namespaces, &req)?; - let auth = self.auth(&mut req, namespace.clone()).await?; + let ctx = self.extract_context(&mut req).await?; let (connection_maker, _new_frame_notifier) = self .namespaces - .with(namespace, |ns| { + .with(ctx.namespace().clone(), |ns| { let connection_maker = ns.db.connection_maker(); let notifier = ns .db @@ -558,7 +563,7 @@ impl Proxy for ProxyService { let conn = connection_maker.create().await.unwrap(); - let stream = make_proxy_stream(conn, auth, req.into_inner()); + let stream = make_proxy_stream(conn, ctx, req.into_inner()); Ok(tonic::Response::new(Box::pin(stream))) } @@ -567,8 +572,7 @@ impl Proxy for ProxyService { &self, mut req: tonic::Request, ) -> Result, tonic::Status> { - let namespace = super::extract_namespace(self.disable_namespaces, &req)?; - let auth = self.auth(&mut req, namespace.clone()).await?; + let ctx = self.extract_context(&mut req).await?; let req = req.into_inner(); let pgm = crate::connection::program::Program::try_from(req.pgm.unwrap()) .map_err(|e| tonic::Status::new(tonic::Code::InvalidArgument, e.to_string()))?; @@ -576,7 +580,7 @@ impl Proxy for ProxyService { let connection_maker = self .namespaces - .with(namespace, |ns| ns.db.connection_maker()) + .with(ctx.namespace().clone(), |ns| ns.db.connection_maker()) .await .map_err(|e| { if let crate::error::Error::NamespaceDoesntExist(_) = e { @@ -607,7 +611,7 @@ impl Proxy for ProxyService { let builder = ExecuteResultsBuilder::default(); let builder = db - .execute_program(pgm, auth, builder, None) + .execute_program(pgm, ctx, builder, None) .await // TODO: this is no necessarily a permission denied error! .map_err(|e| tonic::Status::new(tonic::Code::PermissionDenied, e.to_string()))?; @@ -632,16 +636,15 @@ impl Proxy for ProxyService { async fn describe( &self, - mut msg: tonic::Request, + mut req: tonic::Request, ) -> Result, tonic::Status> { - let namespace = super::extract_namespace(self.disable_namespaces, &msg)?; - let auth = self.auth(&mut msg, namespace.clone()).await?; + let ctx = self.extract_context(&mut req).await?; // FIXME: copypasta from execute(), creatively extract to a helper function let lock = self.clients.upgradable_read().await; let (connection_maker, _new_frame_notifier) = self .namespaces - .with(namespace, |ns| { + .with(ctx.namespace().clone(), |ns| { let connection_maker = ns.db.connection_maker(); let notifier = ns .db @@ -661,7 +664,7 @@ impl Proxy for ProxyService { } })?; - let DescribeRequest { client_id, stmt } = msg.into_inner(); + let DescribeRequest { client_id, stmt } = req.into_inner(); let client_id = Uuid::from_str(&client_id).unwrap(); let db = match lock.get(&client_id) { @@ -681,7 +684,7 @@ impl Proxy for ProxyService { }; let description = db - .describe(stmt, auth, None) + .describe(stmt, ctx, None) .await // TODO: this is no necessarily a permission denied error! // FIXME: the double map_err looks off diff --git a/libsql-server/src/rpc/replica_proxy.rs b/libsql-server/src/rpc/replica_proxy.rs index b05ede4f2e..fc7b99f95b 100644 --- a/libsql-server/src/rpc/replica_proxy.rs +++ b/libsql-server/src/rpc/replica_proxy.rs @@ -7,7 +7,10 @@ use tokio_stream::StreamExt; use tonic::{transport::Channel, Request, Status}; use crate::{ - auth::{parsers::parse_grpc_auth_header, user_auth_strategies::UserAuthContext, Auth}, + auth::{ + parsers::parse_grpc_auth_header, user_auth_strategies::UserAuthContext, Auth, Jwt, + UserAuthStrategy, + }, namespace::{NamespaceStore, ReplicaNamespaceMaker}, }; @@ -46,23 +49,26 @@ impl ReplicaProxyService { let user_credential = parse_grpc_auth_header(req.metadata()); match namespace_jwt_key { - Ok(Ok(jwt_key)) => { - let authenticated = self.user_auth_strategy.authenticate(UserAuthContext { - namespace, - namespace_credential: jwt_key, - user_credential, - })?; + Ok(Ok(Some(key))) => { + let authenticated = + Jwt::new(key).authenticate(UserAuthContext { user_credential })?; + authenticated.upgrade_grpc_request(req); + + Ok(()) + } + Ok(Ok(None)) => { + let authenticated = self + .user_auth_strategy + .authenticate(UserAuthContext { user_credential })?; authenticated.upgrade_grpc_request(req); Ok(()) } Err(e) => match e.as_ref() { crate::error::Error::NamespaceDoesntExist(_) => { - let authenticated = self.user_auth_strategy.authenticate(UserAuthContext { - namespace, - namespace_credential: None, - user_credential, - })?; + let authenticated = self + .user_auth_strategy + .authenticate(UserAuthContext { user_credential })?; authenticated.upgrade_grpc_request(req); Ok(()) diff --git a/libsql-server/src/rpc/replication_log.rs b/libsql-server/src/rpc/replication_log.rs index 32b09ea0b9..3bf447dd89 100644 --- a/libsql-server/src/rpc/replication_log.rs +++ b/libsql-server/src/rpc/replication_log.rs @@ -20,6 +20,7 @@ use tonic::Status; use uuid::Uuid; use crate::auth::user_auth_strategies::UserAuthContext; +use crate::auth::Jwt; use crate::auth::{parsers::parse_grpc_auth_header, Auth}; use crate::connection::config::DatabaseConfig; use crate::namespace::{NamespaceName, NamespaceStore, PrimaryNamespaceMaker}; @@ -78,38 +79,27 @@ impl ReplicationLogService { let user_credential = parse_grpc_auth_header(req.metadata()); - match namespace_jwt_key { - Ok(Ok(jwt_key)) => { - if let Some(auth) = &self.user_auth_strategy { - auth.authenticate(UserAuthContext { - namespace, - namespace_credential: jwt_key, - user_credential, - })?; - } - Ok(()) - } + let auth = match namespace_jwt_key { + Ok(Ok(Some(key))) => Some(Auth::new(Jwt::new(key))), + Ok(Ok(None)) => self.user_auth_strategy.clone(), Err(e) => match e.as_ref() { - crate::error::Error::NamespaceDoesntExist(_) => { - if let Some(auth) = &self.user_auth_strategy { - auth.authenticate(UserAuthContext { - namespace, - namespace_credential: None, - user_credential, - })?; - } - Ok(()) - } + crate::error::Error::NamespaceDoesntExist(_) => self.user_auth_strategy.clone(), _ => Err(Status::internal(format!( "Error fetching jwt key for a namespace: {}", e - ))), + )))?, }, Ok(Err(e)) => Err(Status::internal(format!( "Error fetching jwt key for a namespace: {}", e - ))), + )))?, + }; + + if let Some(auth) = auth { + auth.authenticate(UserAuthContext { user_credential })?; } + + Ok(()) } fn verify_session_token( diff --git a/libsql-server/src/rpc/streaming_exec.rs b/libsql-server/src/rpc/streaming_exec.rs index 802bf06a10..4b852c5da0 100644 --- a/libsql-server/src/rpc/streaming_exec.rs +++ b/libsql-server/src/rpc/streaming_exec.rs @@ -20,8 +20,7 @@ use tokio::sync::mpsc; use tokio_stream::StreamExt; use tonic::{Code, Status}; -use crate::auth::Authenticated; -use crate::connection::Connection; +use crate::connection::{Connection, RequestContext}; use crate::error::Error; use crate::query_result_builder::{ Column, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError, @@ -32,19 +31,19 @@ const MAX_RESPONSE_SIZE: usize = bytesize::ByteSize::kb(100).as_u64() as usize; pub fn make_proxy_stream( conn: C, - auth: Authenticated, + ctx: RequestContext, request_stream: S, ) -> impl Stream> where S: Stream>, C: Connection, { - make_proxy_stream_inner(conn, auth, request_stream, MAX_RESPONSE_SIZE) + make_proxy_stream_inner(conn, ctx, request_stream, MAX_RESPONSE_SIZE) } fn make_proxy_stream_inner( conn: C, - auth: Authenticated, + ctx: RequestContext, request_stream: S, max_program_resp_size: usize, ) -> impl Stream> @@ -92,7 +91,7 @@ where break }; let conn = conn.clone(); - let auth = auth.clone(); + let ctx = ctx.clone(); let sender = snd.clone(); let fut = async move { @@ -104,19 +103,19 @@ where max_program_resp_size, }; - let ret = conn.execute_program(pgm, auth, builder, None).await.map(|_| ()); + let ret = conn.execute_program(pgm, ctx, builder, None).await.map(|_| ()); (ret, request_id) }; current_request_fut = Box::pin(fut); } Some(Request::Describe(StreamDescribeReq { stmt })) => { - let auth = auth.clone(); + let ctx = ctx.clone(); let sender = snd.clone(); let conn = conn.clone(); let fut = async move { let do_describe = || async move { - let ret = conn.describe(stmt, auth, None).await??; + let ret = conn.describe(stmt, ctx, None).await??; Ok(DescribeResp { cols: ret.cols.into_iter().map(|c| DescribeCol { name: c.name, decltype: c.decltype }).collect(), params: ret.params.into_iter().map(|p| DescribeParam { name: p.name }).collect(), @@ -364,9 +363,11 @@ pub mod test { use tempfile::tempdir; use tokio_stream::wrappers::ReceiverStream; - use crate::auth::{Authorized, Permission}; + use crate::auth::Authenticated; use crate::connection::libsql::LibSqlConnection; use crate::connection::program::Program; + use crate::namespace::meta_store::MetaStore; + use crate::namespace::NamespaceName; use crate::query_result_builder::test::{ fsm_builder_driver, random_transition, TestBuilder, ValidateTraceBuilder, }; @@ -388,7 +389,14 @@ pub mod test { let tmp = tempdir().unwrap(); let conn = LibSqlConnection::new_test(tmp.path()); let (snd, rcv) = mpsc::channel(1); - let stream = make_proxy_stream(conn, Authenticated::Anonymous, ReceiverStream::new(rcv)); + let ctx = RequestContext::new( + Authenticated::Anonymous, + NamespaceName::default(), + MetaStore::new(Default::default(), tmp.path()) + .await + .unwrap(), + ); + let stream = make_proxy_stream(conn, ctx, ReceiverStream::new(rcv)); pin!(stream); let req = ExecReq { @@ -406,11 +414,14 @@ pub mod test { let tmp = tempdir().unwrap(); let conn = LibSqlConnection::new_test(tmp.path()); let (snd, rcv) = mpsc::channel(1); - let auth = Authenticated::Authorized(Authorized { - namespace: None, - permission: Permission::FullAccess, - }); - let stream = make_proxy_stream(conn, auth, ReceiverStream::new(rcv)); + let ctx = RequestContext::new( + Authenticated::FullAccess, + NamespaceName::default(), + MetaStore::new(Default::default(), tmp.path()) + .await + .unwrap(), + ); + let stream = make_proxy_stream(conn, ctx, ReceiverStream::new(rcv)); pin!(stream); @@ -424,11 +435,14 @@ pub mod test { let tmp = tempdir().unwrap(); let conn = LibSqlConnection::new_test(tmp.path()); let (snd, rcv) = mpsc::channel(1); - let auth = Authenticated::Authorized(Authorized { - namespace: None, - permission: Permission::FullAccess, - }); - let stream = make_proxy_stream(conn, auth, ReceiverStream::new(rcv)); + let ctx = RequestContext::new( + Authenticated::FullAccess, + NamespaceName::default(), + MetaStore::new(Default::default(), tmp.path()) + .await + .unwrap(), + ); + let stream = make_proxy_stream(conn, ctx, ReceiverStream::new(rcv)); pin!(stream); @@ -444,12 +458,15 @@ pub mod test { let tmp = tempdir().unwrap(); let conn = LibSqlConnection::new_test(tmp.path()); let (snd, rcv) = mpsc::channel(1); - let auth = Authenticated::Authorized(Authorized { - namespace: None, - permission: Permission::FullAccess, - }); + let ctx = RequestContext::new( + Authenticated::FullAccess, + NamespaceName::default(), + MetaStore::new(Default::default(), tmp.path()) + .await + .unwrap(), + ); // limit the size of the response to force a split - let stream = make_proxy_stream_inner(conn, auth, ReceiverStream::new(rcv), 500); + let stream = make_proxy_stream_inner(conn, ctx, ReceiverStream::new(rcv), 500); pin!(stream); @@ -497,11 +514,14 @@ pub mod test { let tmp = tempdir().unwrap(); let conn = LibSqlConnection::new_test(tmp.path()); let (snd, rcv) = mpsc::channel(2); - let auth = Authenticated::Authorized(Authorized { - namespace: None, - permission: Permission::FullAccess, - }); - let stream = make_proxy_stream(conn, auth, ReceiverStream::new(rcv)); + let ctx = RequestContext::new( + Authenticated::FullAccess, + NamespaceName::default(), + MetaStore::new(Default::default(), tmp.path()) + .await + .unwrap(), + ); + let stream = make_proxy_stream(conn, ctx, ReceiverStream::new(rcv)); pin!(stream); @@ -520,11 +540,14 @@ pub mod test { let tmp = tempdir().unwrap(); let conn = LibSqlConnection::new_test(tmp.path()); let (snd, rcv) = mpsc::channel(1); - let auth = Authenticated::Authorized(Authorized { - namespace: None, - permission: Permission::FullAccess, - }); - let stream = make_proxy_stream(conn, auth, ReceiverStream::new(rcv)); + let ctx = RequestContext::new( + Authenticated::FullAccess, + NamespaceName::default(), + MetaStore::new(Default::default(), tmp.path()) + .await + .unwrap(), + ); + let stream = make_proxy_stream(conn, ctx, ReceiverStream::new(rcv)); pin!(stream); @@ -543,11 +566,14 @@ pub mod test { let tmp = tempdir().unwrap(); let conn = LibSqlConnection::new_test(tmp.path()); let (snd, rcv) = mpsc::channel(1); - let auth = Authenticated::Authorized(Authorized { - namespace: None, - permission: Permission::FullAccess, - }); - let stream = make_proxy_stream(conn, auth, ReceiverStream::new(rcv)); + let ctx = RequestContext::new( + Authenticated::FullAccess, + NamespaceName::default(), + MetaStore::new(Default::default(), tmp.path()) + .await + .unwrap(), + ); + let stream = make_proxy_stream(conn, ctx, ReceiverStream::new(rcv)); pin!(stream); @@ -568,11 +594,14 @@ pub mod test { let tmp = tempdir().unwrap(); let conn = LibSqlConnection::new_test(tmp.path()); let (snd, rcv) = mpsc::channel(1); - let auth = Authenticated::Authorized(Authorized { - namespace: None, - permission: Permission::FullAccess, - }); - let stream = make_proxy_stream(conn, auth, ReceiverStream::new(rcv)); + let ctx = RequestContext::new( + Authenticated::FullAccess, + NamespaceName::default(), + MetaStore::new(Default::default(), tmp.path()) + .await + .unwrap(), + ); + let stream = make_proxy_stream(conn, ctx, ReceiverStream::new(rcv)); pin!(stream); diff --git a/libsql-server/tests/standalone/attach.rs b/libsql-server/tests/standalone/attach.rs new file mode 100644 index 0000000000..7bd04bfa3d --- /dev/null +++ b/libsql-server/tests/standalone/attach.rs @@ -0,0 +1,201 @@ +use base64::Engine; +use insta::assert_debug_snapshot; +use jsonwebtoken::EncodingKey; +use libsql::Database; +use ring::signature::{Ed25519KeyPair, KeyPair}; + +use crate::common::{http::Client, net::TurmoilConnector}; + +use super::make_standalone_server; + +fn key_pair() -> (EncodingKey, Ed25519KeyPair) { + let doc = Ed25519KeyPair::generate_pkcs8(&ring::rand::SystemRandom::new()).unwrap(); + let encoding_key = EncodingKey::from_ed_der(doc.as_ref()); + let pair = Ed25519KeyPair::from_pkcs8(doc.as_ref()).unwrap(); + (encoding_key, pair) +} + +fn encode(claims: &T, key: &EncodingKey) -> String { + let header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::EdDSA); + jsonwebtoken::encode(&header, &claims, key).unwrap() +} + +#[test] +fn attach_no_auth() { + let mut sim = turmoil::Builder::new().build(); + + sim.host("primary", make_standalone_server); + + sim.client("test", async { + let client = Client::new(); + + client + .post( + "http://primary:9090/v1/namespaces/foo/create", + serde_json::json!({}), + ) + .await + .unwrap(); + client + .post( + "http://primary:9090/v1/namespaces/bar/create", + serde_json::json!({ "allow_attach": true}), + ) + .await + .unwrap(); + + let foo_db = + Database::open_remote_with_connector("http://foo.primary:8080", "", TurmoilConnector)?; + let foo_conn = foo_db.connect().unwrap(); + foo_conn + .execute("CREATE TABLE foo_table (x)", ()) + .await + .unwrap(); + foo_conn + .execute("insert into foo_table values (42)", ()) + .await + .unwrap(); + + let bar_db = + Database::open_remote_with_connector("http://bar.primary:8080", "", TurmoilConnector)?; + let bar_conn = bar_db.connect().unwrap(); + bar_conn + .execute("CREATE TABLE bar_table (x)", ()) + .await + .unwrap(); + bar_conn + .execute("insert into bar_table values (43)", ()) + .await + .unwrap(); + + // fails: foo doesn't allow attach + assert_debug_snapshot!(bar_conn.execute("ATTACH foo as foo", ()).await.unwrap_err()); + + let txn = foo_conn.transaction().await.unwrap(); + txn.execute("ATTACH DATABASE bar as bar", ()).await.unwrap(); + let mut rows = txn.query("SELECT * FROM bar.bar_table", ()).await.unwrap(); + // succeeds! + assert_debug_snapshot!(rows.next().await); + + Ok(()) + }); + + sim.run().unwrap(); +} + +#[test] +fn attach_auth() { + let mut sim = turmoil::Builder::new().build(); + + sim.host("primary", make_standalone_server); + + sim.client("test", async { + let client = Client::new(); + + let (enc, pair) = key_pair(); + + let jwt_key = base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(pair.public_key().as_ref()); + assert!(client + .post( + "http://primary:9090/v1/namespaces/foo/create", + serde_json::json!({ "jwt_key": jwt_key }) + ) + .await + .unwrap() + .status() + .is_success()); + assert!(client + .post( + "http://primary:9090/v1/namespaces/bar/create", + serde_json::json!({ "allow_attach": true, "jwt_key": jwt_key }) + ) + .await + .unwrap() + .status() + .is_success()); + + let claims = serde_json::json!({ + "p": { + "rw": { + "ns": ["bar", "foo"] + } + } + }); + let token = encode(&claims, &enc); + + let foo_db = Database::open_remote_with_connector( + "http://foo.primary:8080", + &token, + TurmoilConnector, + )?; + let foo_conn = foo_db.connect().unwrap(); + foo_conn + .execute("CREATE TABLE foo_table (x)", ()) + .await + .unwrap(); + foo_conn + .execute("insert into foo_table values (42)", ()) + .await + .unwrap(); + + let bar_db = Database::open_remote_with_connector( + "http://bar.primary:8080", + &token, + TurmoilConnector, + )?; + let bar_conn = bar_db.connect().unwrap(); + bar_conn + .execute("CREATE TABLE bar_table (x)", ()) + .await + .unwrap(); + bar_conn + .execute("insert into bar_table values (43)", ()) + .await + .unwrap(); + + // fails: no perm + assert_debug_snapshot!(bar_conn.execute("ATTACH foo as foo", ()).await.unwrap_err()); + + let txn = foo_conn.transaction().await.unwrap(); + // fails: no perm + assert_debug_snapshot!(txn + .execute("ATTACH DATABASE bar as bar", ()) + .await + .unwrap_err()); + + let claims = serde_json::json!({ + "p": { + "roa": { + "ns": ["bar", "foo"] + } + } + }); + let token = encode(&claims, &enc); + + let foo_db = Database::open_remote_with_connector( + "http://foo.primary:8080", + &token, + TurmoilConnector, + )?; + let foo_conn = foo_db.connect().unwrap(); + let bar_db = Database::open_remote_with_connector( + "http://bar.primary:8080", + &token, + TurmoilConnector, + )?; + let bar_conn = bar_db.connect().unwrap(); + + // fails: namesapce doesn't allow attach + assert_debug_snapshot!(bar_conn.execute("ATTACH foo as foo", ()).await.unwrap_err()); + + let txn = foo_conn.transaction().await.unwrap(); + txn.execute("ATTACH DATABASE bar as bar", ()).await.unwrap(); + let mut rows = txn.query("SELECT * FROM bar.bar_table", ()).await.unwrap(); + // succeeds! + assert_debug_snapshot!(rows.next().await); + + Ok(()) + }); + + sim.run().unwrap(); +} diff --git a/libsql-server/tests/standalone/mod.rs b/libsql-server/tests/standalone/mod.rs index b46156e34e..65f4784e77 100644 --- a/libsql-server/tests/standalone/mod.rs +++ b/libsql-server/tests/standalone/mod.rs @@ -17,6 +17,8 @@ use libsql_server::config::{AdminApiConfig, UserApiConfig}; use common::net::{init_tracing, TestServer, TurmoilConnector}; +mod attach; + async fn make_standalone_server() -> Result<(), Box> { init_tracing(); let tmp = tempdir()?; @@ -26,6 +28,12 @@ async fn make_standalone_server() -> Result<(), Box> { hrana_ws_acceptor: None, ..Default::default() }, + admin_api_config: Some(AdminApiConfig { + acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await.unwrap(), + connector: TurmoilConnector, + disable_metrics: true, + }), + disable_namespaces: false, ..Default::default() }; diff --git a/libsql-server/tests/standalone/snapshots/tests__standalone__attach__attach_auth-2.snap b/libsql-server/tests/standalone/snapshots/tests__standalone__attach__attach_auth-2.snap new file mode 100644 index 0000000000..e10a2b7def --- /dev/null +++ b/libsql-server/tests/standalone/snapshots/tests__standalone__attach__attach_auth-2.snap @@ -0,0 +1,9 @@ +--- +source: libsql-server/tests/standalone/attach.rs +expression: "txn.execute(\"ATTACH DATABASE bar as bar\", ()).await.unwrap_err()" +--- +Hrana( + Api( + "{\"error\":\"Internal Error: `Not authorized to execute query: Current session doest not have AttachRead permission to namespace bar`\"}", + ), +) diff --git a/libsql-server/tests/standalone/snapshots/tests__standalone__attach__attach_auth-3.snap b/libsql-server/tests/standalone/snapshots/tests__standalone__attach__attach_auth-3.snap new file mode 100644 index 0000000000..9eb2631d4f --- /dev/null +++ b/libsql-server/tests/standalone/snapshots/tests__standalone__attach__attach_auth-3.snap @@ -0,0 +1,9 @@ +--- +source: libsql-server/tests/standalone/attach.rs +expression: "bar_conn.execute(\"ATTACH foo as foo\", ()).await.unwrap_err()" +--- +Hrana( + Api( + "{\"error\":\"Internal Error: `Not authorized to execute query: Namespace `foo` doesn't allow attach`\"}", + ), +) diff --git a/libsql-server/tests/standalone/snapshots/tests__standalone__attach__attach_auth-4.snap b/libsql-server/tests/standalone/snapshots/tests__standalone__attach__attach_auth-4.snap new file mode 100644 index 0000000000..57222c38de --- /dev/null +++ b/libsql-server/tests/standalone/snapshots/tests__standalone__attach__attach_auth-4.snap @@ -0,0 +1,23 @@ +--- +source: libsql-server/tests/standalone/attach.rs +expression: rows.next().await +--- +Ok( + Some( + Row { + cols: [ + Col { + name: Some( + "x", + ), + decltype: None, + }, + ], + inner: [ + Integer { + value: 43, + }, + ], + }, + ), +) diff --git a/libsql-server/tests/standalone/snapshots/tests__standalone__attach__attach_auth.snap b/libsql-server/tests/standalone/snapshots/tests__standalone__attach__attach_auth.snap new file mode 100644 index 0000000000..35445b308e --- /dev/null +++ b/libsql-server/tests/standalone/snapshots/tests__standalone__attach__attach_auth.snap @@ -0,0 +1,9 @@ +--- +source: libsql-server/tests/standalone/attach.rs +expression: "bar_conn.execute(\"ATTACH foo as foo\", ()).await.unwrap_err()" +--- +Hrana( + Api( + "{\"error\":\"Internal Error: `Not authorized to execute query: Current session doest not have AttachRead permission to namespace foo`\"}", + ), +) diff --git a/libsql-server/tests/standalone/snapshots/tests__standalone__attach__attach_no_auth-2.snap b/libsql-server/tests/standalone/snapshots/tests__standalone__attach__attach_no_auth-2.snap new file mode 100644 index 0000000000..57222c38de --- /dev/null +++ b/libsql-server/tests/standalone/snapshots/tests__standalone__attach__attach_no_auth-2.snap @@ -0,0 +1,23 @@ +--- +source: libsql-server/tests/standalone/attach.rs +expression: rows.next().await +--- +Ok( + Some( + Row { + cols: [ + Col { + name: Some( + "x", + ), + decltype: None, + }, + ], + inner: [ + Integer { + value: 43, + }, + ], + }, + ), +) diff --git a/libsql-server/tests/standalone/snapshots/tests__standalone__attach__attach_no_auth.snap b/libsql-server/tests/standalone/snapshots/tests__standalone__attach__attach_no_auth.snap new file mode 100644 index 0000000000..9eb2631d4f --- /dev/null +++ b/libsql-server/tests/standalone/snapshots/tests__standalone__attach__attach_no_auth.snap @@ -0,0 +1,9 @@ +--- +source: libsql-server/tests/standalone/attach.rs +expression: "bar_conn.execute(\"ATTACH foo as foo\", ()).await.unwrap_err()" +--- +Hrana( + Api( + "{\"error\":\"Internal Error: `Not authorized to execute query: Namespace `foo` doesn't allow attach`\"}", + ), +) diff --git a/libsql-sqlite3/src/func.c b/libsql-sqlite3/src/func.c index 2cdc22b907..a811671bbb 100644 --- a/libsql-sqlite3/src/func.c +++ b/libsql-sqlite3/src/func.c @@ -2646,9 +2646,9 @@ int libsql_try_initialize_wasm_func_table(sqlite3 *db) { sqlite3_finalize(stmt); return rc; } - const char *pName = sqlite3_column_text(stmt, 0); + const unsigned char *pName = sqlite3_column_text(stmt, 0); const void *pBody = body_type == SQLITE_TEXT ? sqlite3_column_text(stmt, 1) : sqlite3_column_blob(stmt, 1); - try_instantiate_wasm_function(db, pName, name_size, pBody, body_size, -1, NULL); + try_instantiate_wasm_function(db, (const char *)pName, name_size, pBody, body_size, -1, NULL); } } sqlite3_finalize(stmt);