From b1f0fa57ba929e93336308ff50b37b800ff69b47 Mon Sep 17 00:00:00 2001 From: Weijun-H Date: Tue, 1 Oct 2024 16:38:31 +0800 Subject: [PATCH] chore: Used sqlparser to parse create view statement --- Cargo.lock | 7 +-- Cargo.toml | 2 +- src/hooks/utility.rs | 118 +++++++++++++------------------------------ 3 files changed, 39 insertions(+), 88 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index dd9a35a0..4a8eef78 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4028,7 +4028,7 @@ dependencies = [ "serde_json", "signal-hook", "soa_derive", - "sqlparser 0.50.0", + "sqlparser 0.51.0", "sqlx", "strum 0.26.3", "supabase-wrappers", @@ -5405,11 +5405,12 @@ dependencies = [ [[package]] name = "sqlparser" -version = "0.50.0" +version = "0.51.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2e5b515a2bd5168426033e9efbfd05500114833916f1d5c268f938b4ee130ac" +checksum = "5fe11944a61da0da3f592e19a45ebe5ab92dc14a779907ff1f08fbb797bfefc7" dependencies = [ "log", + "sqlparser_derive", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 020fd677..4aacdad3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ pgrx = "0.12.5" serde = "1.0.210" serde_json = "1.0.128" signal-hook = "0.3.17" -sqlparser = "0.50.0" +sqlparser = { version = "0.51.0", features = ["visitor"] } strum = { version = "0.26.3", features = ["derive"] } supabase-wrappers = { git = "https://github.com/paradedb/wrappers.git", default-features = false, rev = "c32abb7" } thiserror = "1.0.63" diff --git a/src/hooks/utility.rs b/src/hooks/utility.rs index d3f6f04d..8285d844 100644 --- a/src/hooks/utility.rs +++ b/src/hooks/utility.rs @@ -17,12 +17,16 @@ #![allow(clippy::too_many_arguments)] -use std::ffi::{CStr, CString}; +use std::{ffi::CString, ops::ControlFlow}; use anyhow::{bail, Result}; use pg_sys::NodeTag; use pgrx::*; -use sqlparser::{ast::Statement, dialect::PostgreSqlDialect, parser::Parser}; +use sqlparser::{ + ast::{visit_relations, Statement}, + dialect::PostgreSqlDialect, + parser::Parser, +}; use crate::{ duckdb::connection::{execute, view_exists}, @@ -77,9 +81,7 @@ pub async fn process_utility_hook( pstmt.utilityStmt as *mut pg_sys::ExplainStmt, dest.as_ptr(), )?, - pg_sys::NodeTag::T_ViewStmt => { - view_query(query_string, pstmt.utilityStmt as *mut pg_sys::ViewStmt)? - } + pg_sys::NodeTag::T_ViewStmt => view_query(query_string)?, _ => bail!("unexpected statement type in utility hook"), }; @@ -103,97 +105,45 @@ fn is_support_utility(stmt_type: NodeTag) -> bool { stmt_type == pg_sys::NodeTag::T_ExplainStmt || stmt_type == pg_sys::NodeTag::T_ViewStmt } -fn view_query(query_string: &core::ffi::CStr, stmt: *mut pg_sys::ViewStmt) -> Result { +fn view_query(query_string: &core::ffi::CStr) -> Result { // Get the current schema in Postgres let current_schema = get_postgres_current_schema(); // Set DuckDB search path according search path in Postgres set_search_path_by_pg()?; - let query = unsafe { (*stmt).query as *mut pg_sys::SelectStmt }; - let from_clause = unsafe { (*query).fromClause }; - unsafe { - let elements = (*from_clause).elements; - for i in 0..(*from_clause).length { - let element = (*elements.offset(i as isize)).ptr_value as *mut pg_sys::Node; - - match (*element).type_ { - pg_sys::NodeTag::T_RangeVar => { - if !analyze_range_var( - element as *mut pg_sys::RangeVar, - current_schema.as_str(), - )? { - return Ok(true); - } - } - pg_sys::NodeTag::T_JoinExpr => { - if !analyze_join_expr( - element as *mut pg_sys::JoinExpr, - current_schema.as_str(), - )? { - return Ok(true); - } - } - _ => { - continue; - } - } + let dialect = PostgreSqlDialect {}; + let statements = Parser::parse_sql(&dialect, query_string.to_str()?)?; + // visit statements, capturing relations (table names) + let mut visited = vec![]; + + visit_relations(&statements, |relation| { + visited.push(relation.clone()); + ControlFlow::<()>::Continue(()) + }); + + for relation in visited.iter() { + let (schema_name, relation_name) = if relation.0.len() == 1 { + (current_schema.clone(), relation.0[0].to_string()) + } else if relation.0.len() == 2 { + (relation.0[0].to_string(), relation.0[1].to_string()) + } else { + error!( + "it is not possible to create a view with more than 2 parts in the relation name" + ); + }; + + if !view_exists(&relation_name, &schema_name)? { + fallback_warning!(format!( + "{schema_name}.{relation_name} does not exist in DuckDB" + )); + return Ok(true); } } - // Push down the view creation query to DuckDB execute(query_string.to_str()?, [])?; Ok(true) } -/// Analyze the RangeVar to check if the relation exists in DuckDB -fn analyze_range_var(rv: *mut pg_sys::RangeVar, current_schema: &str) -> Result { - let relation_name = unsafe { CStr::from_ptr((*rv).relname).to_str()? }; - let schema_name = unsafe { - if (*rv).schemaname.is_null() { - current_schema - } else { - CStr::from_ptr((*rv).schemaname).to_str()? - } - }; - - if !view_exists(relation_name, schema_name)? { - fallback_warning!(format!( - "{schema_name}.{relation_name} does not exist in DuckDB" - )); - Ok(false) - } else { - Ok(true) - } -} - -/// Analyze the join expression to check if the relations in the join expression exist in DuckDB -fn analyze_join_expr(join_expr: *mut pg_sys::JoinExpr, current_schema: &str) -> Result { - unsafe { - let ltree = (*join_expr).larg; - let rtree = (*join_expr).rarg; - - Ok(analyze_tree(ltree, current_schema)? && analyze_tree(rtree, current_schema)?) - } -} - -/// Analyze the tree to check if the relations in the tree exist in DuckDB -fn analyze_tree(mut tree: *mut pg_sys::Node, current_schema: &str) -> Result { - while !tree.is_null() { - unsafe { - match (*tree).type_ { - pg_sys::NodeTag::T_RangeVar => { - return analyze_range_var(tree as *mut pg_sys::RangeVar, current_schema); - } - pg_sys::NodeTag::T_JoinExpr => { - tree = (*(tree as *mut pg_sys::JoinExpr)).larg; - } - _ => break, - } - } - } - Ok(true) -} - fn explain_query( query_string: &core::ffi::CStr, stmt: *mut pg_sys::ExplainStmt,