diff --git a/src/api/csv.rs b/src/api/csv.rs index 8b4fb1eb..3e2faf72 100644 --- a/src/api/csv.rs +++ b/src/api/csv.rs @@ -70,7 +70,7 @@ fn sniff_csv_impl(files: &str, sample_size: Option) -> Result>() .join(", "); - let conn = get_global_connection(); + let conn = get_global_connection()?; let conn = conn.lock().unwrap(); let query = format!("SELECT * FROM sniff_csv({schema_str})"); let mut stmt = conn.prepare(&query)?; diff --git a/src/api/duckdb.rs b/src/api/duckdb.rs index 42f77b34..354329c1 100644 --- a/src/api/duckdb.rs +++ b/src/api/duckdb.rs @@ -37,7 +37,7 @@ pub fn duckdb_settings() -> iter::TableIterator< #[inline] fn duckdb_settings_impl() -> Result> { - let conn = get_global_connection(); + let conn = get_global_connection()?; let conn = conn.lock().unwrap(); let mut stmt = conn.prepare("SELECT * FROM duckdb_settings()")?; diff --git a/src/api/parquet.rs b/src/api/parquet.rs index e6fada24..be222698 100644 --- a/src/api/parquet.rs +++ b/src/api/parquet.rs @@ -88,7 +88,7 @@ pub fn parquet_schema( #[inline] fn parquet_schema_impl(files: &str) -> Result> { let schema_str = utils::format_csv(files); - let conn = get_global_connection(); + let conn = get_global_connection()?; let conn = conn.lock().unwrap(); let query = format!("SELECT * FROM parquet_schema({schema_str})"); let mut stmt = conn.prepare(&query)?; @@ -116,7 +116,7 @@ fn parquet_schema_impl(files: &str) -> Result> { #[inline] fn parquet_describe_impl(files: &str) -> Result> { let schema_str = utils::format_csv(files); - let conn = get_global_connection(); + let conn = get_global_connection()?; let conn = conn.lock().unwrap(); let query = format!("DESCRIBE SELECT * FROM {schema_str}"); let mut stmt = conn.prepare(&query)?; diff --git a/src/duckdb/connection.rs b/src/duckdb/connection.rs index 79c7550d..dab72486 100644 --- a/src/duckdb/connection.rs +++ b/src/duckdb/connection.rs @@ -48,7 +48,7 @@ fn init_globals() { let mut signals = Signals::new([SIGTERM, SIGINT, SIGQUIT]).expect("error registering signal listener"); for _ in signals.forever() { - let conn = get_global_connection(); + let conn = get_global_connection().expect("failed to get connection"); let conn = conn.lock().unwrap(); conn.interrupt(); } @@ -56,7 +56,7 @@ fn init_globals() { } fn iceberg_loaded() -> Result { - let conn = get_global_connection(); + let conn = get_global_connection()?; let conn = conn.lock().unwrap(); let mut statement = conn.prepare("SELECT * FROM duckdb_extensions() WHERE extension_name = 'iceberg' AND installed = true AND loaded = true")?; match statement.query([])?.next() { @@ -126,7 +126,7 @@ pub fn create_parquet_relation( pub fn create_arrow(sql: &str) -> Result { unsafe { - let conn = get_global_connection(); + let conn = get_global_connection()?; let conn = conn.lock().unwrap(); let statement = conn.prepare(sql)?; let static_statement: Statement<'static> = std::mem::transmute(statement); @@ -179,13 +179,13 @@ pub fn get_batches() -> Result> { } pub fn execute(sql: &str, params: P) -> Result { - let conn = get_global_connection(); + let conn = get_global_connection()?; let conn = conn.lock().unwrap(); conn.execute(sql, params).map_err(|err| anyhow!("{err}")) } pub fn drop_relation(table_name: &str, schema_name: &str) -> Result<()> { - let conn = get_global_connection(); + let conn = get_global_connection()?; let conn = conn.lock().unwrap(); let mut statement = conn.prepare(format!("SELECT table_type from information_schema.tables WHERE table_schema = '{schema_name}' AND table_name = '{table_name}' LIMIT 1").as_str())?; if let Ok(Some(row)) = statement.query([])?.next() { diff --git a/src/env.rs b/src/env.rs index 23f98c13..b6daa2b6 100644 --- a/src/env.rs +++ b/src/env.rs @@ -1,16 +1,17 @@ +use anyhow::{anyhow, Result}; use duckdb::Connection; use pgrx::*; use std::ffi::CStr; use std::path::PathBuf; use std::sync::{Arc, Mutex}; -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct DuckdbConnection(pub Arc>); unsafe impl PGRXSharedMemory for DuckdbConnection {} // One connection per database, so 128 databases can have a DuckDB connection const MAX_CONNECTIONS: usize = 128; -pub static DUCKDB_CONNECTION: PgLwLock< +pub static DUCKDB_CONNECTION_CACHE: PgLwLock< heapless::FnvIndexMap, > = PgLwLock::new(); @@ -32,15 +33,26 @@ impl Default for DuckdbConnection { } } -pub fn get_global_connection() -> Arc> { - match DUCKDB_CONNECTION.exclusive().entry(postgres_database_oid()) { - heapless::Entry::Occupied(entry) => entry.get().0.clone(), - heapless::Entry::Vacant(entry) => { - let conn = DuckdbConnection::default(); - let _ = entry.insert(conn.clone()); - conn.0.clone() - } +pub fn get_global_connection() -> Result>> { + let database_id = postgres_database_oid(); + let connection_cached = DUCKDB_CONNECTION_CACHE.share().contains_key(&database_id); + + if !connection_cached { + let conn = DuckdbConnection::default(); + return Ok(DUCKDB_CONNECTION_CACHE + .exclusive() + .insert(database_id, conn) + .expect("failed to cache connection") + .unwrap() + .0); } + + Ok(DUCKDB_CONNECTION_CACHE + .share() + .get(&database_id) + .ok_or_else(|| anyhow!("connection not found"))? + .0 + .clone()) } pub fn postgres_data_dir_path() -> PathBuf { diff --git a/src/fdw/trigger.rs b/src/fdw/trigger.rs index 699cc592..260c5b5c 100644 --- a/src/fdw/trigger.rs +++ b/src/fdw/trigger.rs @@ -144,7 +144,7 @@ unsafe fn auto_create_schema_impl(fcinfo: pg_sys::FunctionCallInfo) -> Result<() pg_sys::RelationClose(relation); // Get DuckDB schema - let conn = get_global_connection(); + let conn = get_global_connection()?; let conn = conn.lock().unwrap(); let query = format!("DESCRIBE {schema_name}.{table_name}"); let mut stmt = conn.prepare(&query)?; diff --git a/src/lib.rs b/src/lib.rs index 96d68497..18f78b99 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,7 +22,6 @@ mod fdw; mod hooks; mod schema; -use env::DUCKDB_CONNECTION; use hooks::ExtensionHook; use pgrx::*; use shared::{ @@ -45,7 +44,7 @@ pub extern "C" fn _PG_init() { }; GUCS.init("pg_analytics"); - pg_shmem_init!(DUCKDB_CONNECTION); + pg_shmem_init!(env::DUCKDB_CONNECTION_CACHE); // TODO: Change to ParadeExtension::PgAnalytics setup_telemetry_background_worker(ParadeExtension::PgLakehouse);