Skip to content

Commit

Permalink
chore: Used sqlparser to parse create view statement
Browse files Browse the repository at this point in the history
  • Loading branch information
Weijun-H committed Oct 1, 2024
1 parent c3e6f03 commit b1f0fa5
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 88 deletions.
7 changes: 4 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pgrx = "0.12.5"
serde = "1.0.210"
serde_json = "1.0.128"
signal-hook = "0.3.17"
sqlparser = "0.50.0"
sqlparser = { version = "0.51.0", features = ["visitor"] }
strum = { version = "0.26.3", features = ["derive"] }
supabase-wrappers = { git = "https://github.com/paradedb/wrappers.git", default-features = false, rev = "c32abb7" }
thiserror = "1.0.63"
Expand Down
118 changes: 34 additions & 84 deletions src/hooks/utility.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@

#![allow(clippy::too_many_arguments)]

use std::ffi::{CStr, CString};
use std::{ffi::CString, ops::ControlFlow};

use anyhow::{bail, Result};
use pg_sys::NodeTag;
use pgrx::*;
use sqlparser::{ast::Statement, dialect::PostgreSqlDialect, parser::Parser};
use sqlparser::{
ast::{visit_relations, Statement},
dialect::PostgreSqlDialect,
parser::Parser,
};

use crate::{
duckdb::connection::{execute, view_exists},
Expand Down Expand Up @@ -77,9 +81,7 @@ pub async fn process_utility_hook(
pstmt.utilityStmt as *mut pg_sys::ExplainStmt,
dest.as_ptr(),
)?,
pg_sys::NodeTag::T_ViewStmt => {
view_query(query_string, pstmt.utilityStmt as *mut pg_sys::ViewStmt)?
}
pg_sys::NodeTag::T_ViewStmt => view_query(query_string)?,
_ => bail!("unexpected statement type in utility hook"),
};

Expand All @@ -103,97 +105,45 @@ fn is_support_utility(stmt_type: NodeTag) -> bool {
stmt_type == pg_sys::NodeTag::T_ExplainStmt || stmt_type == pg_sys::NodeTag::T_ViewStmt
}

fn view_query(query_string: &core::ffi::CStr, stmt: *mut pg_sys::ViewStmt) -> Result<bool> {
fn view_query(query_string: &core::ffi::CStr) -> Result<bool> {
// Get the current schema in Postgres
let current_schema = get_postgres_current_schema();
// Set DuckDB search path according search path in Postgres
set_search_path_by_pg()?;

let query = unsafe { (*stmt).query as *mut pg_sys::SelectStmt };
let from_clause = unsafe { (*query).fromClause };
unsafe {
let elements = (*from_clause).elements;
for i in 0..(*from_clause).length {
let element = (*elements.offset(i as isize)).ptr_value as *mut pg_sys::Node;

match (*element).type_ {
pg_sys::NodeTag::T_RangeVar => {
if !analyze_range_var(
element as *mut pg_sys::RangeVar,
current_schema.as_str(),
)? {
return Ok(true);
}
}
pg_sys::NodeTag::T_JoinExpr => {
if !analyze_join_expr(
element as *mut pg_sys::JoinExpr,
current_schema.as_str(),
)? {
return Ok(true);
}
}
_ => {
continue;
}
}
let dialect = PostgreSqlDialect {};
let statements = Parser::parse_sql(&dialect, query_string.to_str()?)?;
// visit statements, capturing relations (table names)
let mut visited = vec![];

visit_relations(&statements, |relation| {
visited.push(relation.clone());
ControlFlow::<()>::Continue(())
});

for relation in visited.iter() {
let (schema_name, relation_name) = if relation.0.len() == 1 {
(current_schema.clone(), relation.0[0].to_string())
} else if relation.0.len() == 2 {
(relation.0[0].to_string(), relation.0[1].to_string())
} else {
error!(
"it is not possible to create a view with more than 2 parts in the relation name"
);
};

if !view_exists(&relation_name, &schema_name)? {
fallback_warning!(format!(
"{schema_name}.{relation_name} does not exist in DuckDB"
));
return Ok(true);
}
}

// Push down the view creation query to DuckDB
execute(query_string.to_str()?, [])?;
Ok(true)
}

/// Analyze the RangeVar to check if the relation exists in DuckDB
fn analyze_range_var(rv: *mut pg_sys::RangeVar, current_schema: &str) -> Result<bool> {
let relation_name = unsafe { CStr::from_ptr((*rv).relname).to_str()? };
let schema_name = unsafe {
if (*rv).schemaname.is_null() {
current_schema
} else {
CStr::from_ptr((*rv).schemaname).to_str()?
}
};

if !view_exists(relation_name, schema_name)? {
fallback_warning!(format!(
"{schema_name}.{relation_name} does not exist in DuckDB"
));
Ok(false)
} else {
Ok(true)
}
}

/// Analyze the join expression to check if the relations in the join expression exist in DuckDB
fn analyze_join_expr(join_expr: *mut pg_sys::JoinExpr, current_schema: &str) -> Result<bool> {
unsafe {
let ltree = (*join_expr).larg;
let rtree = (*join_expr).rarg;

Ok(analyze_tree(ltree, current_schema)? && analyze_tree(rtree, current_schema)?)
}
}

/// Analyze the tree to check if the relations in the tree exist in DuckDB
fn analyze_tree(mut tree: *mut pg_sys::Node, current_schema: &str) -> Result<bool> {
while !tree.is_null() {
unsafe {
match (*tree).type_ {
pg_sys::NodeTag::T_RangeVar => {
return analyze_range_var(tree as *mut pg_sys::RangeVar, current_schema);
}
pg_sys::NodeTag::T_JoinExpr => {
tree = (*(tree as *mut pg_sys::JoinExpr)).larg;
}
_ => break,
}
}
}
Ok(true)
}

fn explain_query(
query_string: &core::ffi::CStr,
stmt: *mut pg_sys::ExplainStmt,
Expand Down

0 comments on commit b1f0fa5

Please sign in to comment.