Skip to content

Commit

Permalink
make global connection not exclusive
Browse files Browse the repository at this point in the history
  • Loading branch information
rebasedming committed Aug 5, 2024
1 parent 1784935 commit 7ef6b0f
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 22 deletions.
2 changes: 1 addition & 1 deletion src/api/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ fn sniff_csv_impl(files: &str, sample_size: Option<i64>) -> Result<Vec<SniffCsvR
.flatten()
.collect::<Vec<String>>()
.join(", ");
let conn = get_global_connection();
let conn = get_global_connection()?;
let conn = conn.lock().unwrap();
let query = format!("SELECT * FROM sniff_csv({schema_str})");
let mut stmt = conn.prepare(&query)?;
Expand Down
2 changes: 1 addition & 1 deletion src/api/duckdb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub fn duckdb_settings() -> iter::TableIterator<

#[inline]
fn duckdb_settings_impl() -> Result<Vec<DuckdbSettingsRow>> {
let conn = get_global_connection();
let conn = get_global_connection()?;
let conn = conn.lock().unwrap();
let mut stmt = conn.prepare("SELECT * FROM duckdb_settings()")?;

Expand Down
4 changes: 2 additions & 2 deletions src/api/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ pub fn parquet_schema(
#[inline]
fn parquet_schema_impl(files: &str) -> Result<Vec<ParquetSchemaRow>> {
let schema_str = utils::format_csv(files);
let conn = get_global_connection();
let conn = get_global_connection()?;
let conn = conn.lock().unwrap();
let query = format!("SELECT * FROM parquet_schema({schema_str})");
let mut stmt = conn.prepare(&query)?;
Expand Down Expand Up @@ -116,7 +116,7 @@ fn parquet_schema_impl(files: &str) -> Result<Vec<ParquetSchemaRow>> {
#[inline]
fn parquet_describe_impl(files: &str) -> Result<Vec<ParquetDescribeRow>> {
let schema_str = utils::format_csv(files);
let conn = get_global_connection();
let conn = get_global_connection()?;
let conn = conn.lock().unwrap();
let query = format!("DESCRIBE SELECT * FROM {schema_str}");
let mut stmt = conn.prepare(&query)?;
Expand Down
10 changes: 5 additions & 5 deletions src/duckdb/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ fn init_globals() {
let mut signals =
Signals::new([SIGTERM, SIGINT, SIGQUIT]).expect("error registering signal listener");
for _ in signals.forever() {
let conn = get_global_connection();
let conn = get_global_connection().expect("failed to get connection");
let conn = conn.lock().unwrap();
conn.interrupt();
}
});
}

fn iceberg_loaded() -> Result<bool> {
let conn = get_global_connection();
let conn = get_global_connection()?;
let conn = conn.lock().unwrap();
let mut statement = conn.prepare("SELECT * FROM duckdb_extensions() WHERE extension_name = 'iceberg' AND installed = true AND loaded = true")?;
match statement.query([])?.next() {
Expand Down Expand Up @@ -126,7 +126,7 @@ pub fn create_parquet_relation(

pub fn create_arrow(sql: &str) -> Result<bool> {
unsafe {
let conn = get_global_connection();
let conn = get_global_connection()?;
let conn = conn.lock().unwrap();
let statement = conn.prepare(sql)?;
let static_statement: Statement<'static> = std::mem::transmute(statement);
Expand Down Expand Up @@ -179,13 +179,13 @@ pub fn get_batches() -> Result<Vec<RecordBatch>> {
}

pub fn execute<P: Params>(sql: &str, params: P) -> Result<usize> {
let conn = get_global_connection();
let conn = get_global_connection()?;
let conn = conn.lock().unwrap();
conn.execute(sql, params).map_err(|err| anyhow!("{err}"))
}

pub fn drop_relation(table_name: &str, schema_name: &str) -> Result<()> {
let conn = get_global_connection();
let conn = get_global_connection()?;
let conn = conn.lock().unwrap();
let mut statement = conn.prepare(format!("SELECT table_type from information_schema.tables WHERE table_schema = '{schema_name}' AND table_name = '{table_name}' LIMIT 1").as_str())?;
if let Ok(Some(row)) = statement.query([])?.next() {
Expand Down
32 changes: 22 additions & 10 deletions src/env.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
use anyhow::{anyhow, Result};
use duckdb::Connection;
use pgrx::*;
use std::ffi::CStr;
use std::path::PathBuf;
use std::sync::{Arc, Mutex};

#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct DuckdbConnection(pub Arc<Mutex<Connection>>);
unsafe impl PGRXSharedMemory for DuckdbConnection {}

// One connection per database, so 128 databases can have a DuckDB connection
const MAX_CONNECTIONS: usize = 128;
pub static DUCKDB_CONNECTION: PgLwLock<
pub static DUCKDB_CONNECTION_CACHE: PgLwLock<
heapless::FnvIndexMap<u32, DuckdbConnection, MAX_CONNECTIONS>,
> = PgLwLock::new();

Expand All @@ -32,15 +33,26 @@ impl Default for DuckdbConnection {
}
}

pub fn get_global_connection() -> Arc<Mutex<Connection>> {
match DUCKDB_CONNECTION.exclusive().entry(postgres_database_oid()) {
heapless::Entry::Occupied(entry) => entry.get().0.clone(),
heapless::Entry::Vacant(entry) => {
let conn = DuckdbConnection::default();
let _ = entry.insert(conn.clone());
conn.0.clone()
}
pub fn get_global_connection() -> Result<Arc<Mutex<Connection>>> {
let database_id = postgres_database_oid();
let connection_cached = DUCKDB_CONNECTION_CACHE.share().contains_key(&database_id);

if !connection_cached {
let conn = DuckdbConnection::default();
return Ok(DUCKDB_CONNECTION_CACHE
.exclusive()
.insert(database_id, conn)
.expect("failed to cache connection")
.unwrap()
.0);
}

Ok(DUCKDB_CONNECTION_CACHE
.share()
.get(&database_id)
.ok_or_else(|| anyhow!("connection not found"))?
.0
.clone())
}

pub fn postgres_data_dir_path() -> PathBuf {
Expand Down
2 changes: 1 addition & 1 deletion src/fdw/trigger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ unsafe fn auto_create_schema_impl(fcinfo: pg_sys::FunctionCallInfo) -> Result<()
pg_sys::RelationClose(relation);

// Get DuckDB schema
let conn = get_global_connection();
let conn = get_global_connection()?;
let conn = conn.lock().unwrap();
let query = format!("DESCRIBE {schema_name}.{table_name}");
let mut stmt = conn.prepare(&query)?;
Expand Down
3 changes: 1 addition & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ mod fdw;
mod hooks;
mod schema;

use env::DUCKDB_CONNECTION;
use hooks::ExtensionHook;
use pgrx::*;
use shared::{
Expand All @@ -45,7 +44,7 @@ pub extern "C" fn _PG_init() {
};

GUCS.init("pg_analytics");
pg_shmem_init!(DUCKDB_CONNECTION);
pg_shmem_init!(env::DUCKDB_CONNECTION_CACHE);

// TODO: Change to ParadeExtension::PgAnalytics
setup_telemetry_background_worker(ParadeExtension::PgLakehouse);
Expand Down

0 comments on commit 7ef6b0f

Please sign in to comment.