diff --git a/Cargo.toml b/Cargo.toml index 54750274..769e554d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ serde = "1.0.201" serde_json = "1.0.120" signal-hook = "0.3.17" strum = { version = "0.26.3", features = ["derive"] } -supabase-wrappers = { git = "https://github.com/paradedb/wrappers.git", default-features = false, rev = "c6f5e79" } +supabase-wrappers = { git = "https://github.com/paradedb/wrappers.git", default-features = false, rev = "719628f" } thiserror = "1.0.59" uuid = "1.9.1" diff --git a/src/duckdb/csv.rs b/src/duckdb/csv.rs index 0e5414fb..fe82de31 100644 --- a/src/duckdb/csv.rs +++ b/src/duckdb/csv.rs @@ -19,6 +19,8 @@ use anyhow::{anyhow, Result}; use std::collections::HashMap; use strum::{AsRefStr, EnumIter}; +use crate::fdw::base::OptionValidator; + use super::utils; #[derive(EnumIter, AsRefStr, PartialEq, Debug)] @@ -93,8 +95,8 @@ pub enum CsvOption { UnionByName, } -impl CsvOption { - pub fn is_required(&self) -> bool { +impl OptionValidator for CsvOption { + fn is_required(&self) -> bool { match self { Self::AllVarchar => false, Self::AllowQuotedNulls => false, diff --git a/src/duckdb/delta.rs b/src/duckdb/delta.rs index 412f1d2b..0d95e65a 100644 --- a/src/duckdb/delta.rs +++ b/src/duckdb/delta.rs @@ -15,6 +15,7 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . +use crate::fdw::base::OptionValidator; use anyhow::{anyhow, Result}; use std::collections::HashMap; use strum::{AsRefStr, EnumIter}; @@ -29,8 +30,8 @@ pub enum DeltaOption { Select, } -impl DeltaOption { - pub fn is_required(&self) -> bool { +impl OptionValidator for DeltaOption { + fn is_required(&self) -> bool { match self { Self::Files => true, Self::PreserveCasing => false, diff --git a/src/duckdb/iceberg.rs b/src/duckdb/iceberg.rs index 851e9fb4..689afc52 100644 --- a/src/duckdb/iceberg.rs +++ b/src/duckdb/iceberg.rs @@ -19,6 +19,8 @@ use anyhow::{anyhow, Result}; use std::collections::HashMap; use strum::{AsRefStr, EnumIter}; +use crate::fdw::base::OptionValidator; + #[derive(EnumIter, AsRefStr, PartialEq, Debug)] pub enum IcebergOption { #[strum(serialize = "allow_moved_paths")] @@ -31,8 +33,8 @@ pub enum IcebergOption { Select, } -impl IcebergOption { - pub fn is_required(&self) -> bool { +impl OptionValidator for IcebergOption { + fn is_required(&self) -> bool { match self { Self::AllowMovedPaths => false, Self::Files => true, diff --git a/src/duckdb/parquet.rs b/src/duckdb/parquet.rs index c79c39aa..96e45ea2 100644 --- a/src/duckdb/parquet.rs +++ b/src/duckdb/parquet.rs @@ -19,6 +19,8 @@ use anyhow::{anyhow, Result}; use std::collections::HashMap; use strum::{AsRefStr, EnumIter}; +use crate::fdw::base::OptionValidator; + use super::utils; #[derive(EnumIter, AsRefStr, PartialEq, Debug)] @@ -46,8 +48,8 @@ pub enum ParquetOption { // TODO: EncryptionConfig } -impl ParquetOption { - pub fn is_required(&self) -> bool { +impl OptionValidator for ParquetOption { + fn is_required(&self) -> bool { match self { Self::BinaryAsString => false, Self::FileName => false, diff --git a/src/duckdb/secret.rs b/src/duckdb/secret.rs index cce3c561..81104152 100644 --- a/src/duckdb/secret.rs +++ b/src/duckdb/secret.rs @@ -19,6 +19,8 @@ use anyhow::{anyhow, bail, Result}; use std::collections::HashMap; use strum::{AsRefStr, EnumIter}; +use crate::fdw::base::OptionValidator; + #[derive(EnumIter, AsRefStr, PartialEq, Debug)] pub enum UserMappingOptions { // Universal @@ -70,9 +72,8 @@ pub enum UserMappingOptions { ProxyPassword, } -impl UserMappingOptions { - #[allow(unused)] - pub fn is_required(&self) -> bool { +impl OptionValidator for UserMappingOptions { + fn is_required(&self) -> bool { match self { Self::Type => true, Self::Provider => false, diff --git a/src/duckdb/spatial.rs b/src/duckdb/spatial.rs index db330f6a..b0d54e24 100644 --- a/src/duckdb/spatial.rs +++ b/src/duckdb/spatial.rs @@ -20,6 +20,8 @@ use std::collections::HashMap; use strum::IntoEnumIterator; use strum::{AsRefStr, EnumIter}; +use crate::fdw::base::OptionValidator; + /// SpatialOption is an enum that represents the options that can be passed to the st_read function. /// Reference https://github.com/duckdb/duckdb_spatial/blob/main/docs/functions.md#st_read #[derive(EnumIter, AsRefStr, PartialEq, Debug)] @@ -44,8 +46,8 @@ pub enum SpatialOption { KeepWkb, } -impl SpatialOption { - pub fn is_required(&self) -> bool { +impl OptionValidator for SpatialOption { + fn is_required(&self) -> bool { match self { Self::Files => true, Self::SequentialLayerScan => false, diff --git a/src/fdw/base.rs b/src/fdw/base.rs index 3f7e64fa..01d10365 100644 --- a/src/fdw/base.rs +++ b/src/fdw/base.rs @@ -19,6 +19,7 @@ use anyhow::{anyhow, bail, Result}; use duckdb::arrow::array::RecordBatch; use pgrx::*; use std::collections::HashMap; +use strum::IntoEnumIterator; use supabase_wrappers::prelude::*; use thiserror::Error; @@ -288,3 +289,22 @@ impl DuckDbFormatter { Self {} } } + +pub(crate) trait OptionValidator { + fn is_required(&self) -> bool; +} + +pub fn validate_mapping_option>( + opt_list: Vec>, +) -> Result<()> { + let valid_options: Vec = T::iter().map(|opt| opt.as_ref().to_string()).collect(); + + validate_options(opt_list.clone(), valid_options)?; + + for opt in T::iter() { + if opt.is_required() { + check_options_contain(&opt_list, opt.as_ref())?; + } + } + Ok(()) +} diff --git a/src/fdw/csv.rs b/src/fdw/csv.rs index 069804fd..24c1d977 100644 --- a/src/fdw/csv.rs +++ b/src/fdw/csv.rs @@ -20,11 +20,10 @@ use async_std::task; use duckdb::arrow::array::RecordBatch; use pgrx::*; use std::collections::HashMap; -use strum::IntoEnumIterator; use supabase_wrappers::prelude::*; use super::base::*; -use crate::duckdb::csv::CsvOption; +use crate::duckdb::{csv::CsvOption, secret::UserMappingOptions}; #[wrappers_fdw( author = "ParadeDB", @@ -111,23 +110,14 @@ impl ForeignDataWrapper for CsvFdw { FOREIGN_DATA_WRAPPER_RELATION_ID => {} FOREIGN_SERVER_RELATION_ID => {} FOREIGN_TABLE_RELATION_ID => { - let valid_options: Vec = CsvOption::iter() - .map(|opt| opt.as_ref().to_string()) - .collect(); - - validate_options(opt_list.clone(), valid_options)?; - - for opt in CsvOption::iter() { - if opt.is_required() { - check_options_contain(&opt_list, opt.as_ref())?; - } - } + validate_mapping_option::(opt_list)?; + } + USER_MAPPING_RELATION_ID => { + validate_mapping_option::(opt_list)?; } - // TODO: Sanitize user mapping options _ => {} } } - Ok(()) } diff --git a/src/fdw/delta.rs b/src/fdw/delta.rs index 3a06e584..a2db0043 100644 --- a/src/fdw/delta.rs +++ b/src/fdw/delta.rs @@ -20,11 +20,10 @@ use async_std::task; use duckdb::arrow::array::RecordBatch; use pgrx::*; use std::collections::HashMap; -use strum::IntoEnumIterator; use supabase_wrappers::prelude::*; use super::base::*; -use crate::duckdb::delta::DeltaOption; +use crate::duckdb::{delta::DeltaOption, secret::UserMappingOptions}; #[wrappers_fdw( author = "ParadeDB", @@ -111,23 +110,14 @@ impl ForeignDataWrapper for DeltaFdw { FOREIGN_DATA_WRAPPER_RELATION_ID => {} FOREIGN_SERVER_RELATION_ID => {} FOREIGN_TABLE_RELATION_ID => { - let valid_options: Vec = DeltaOption::iter() - .map(|opt| opt.as_ref().to_string()) - .collect(); - - validate_options(opt_list.clone(), valid_options)?; - - for opt in DeltaOption::iter() { - if opt.is_required() { - check_options_contain(&opt_list, opt.as_ref())?; - } - } + validate_mapping_option::(opt_list)?; + } + USER_MAPPING_RELATION_ID => { + validate_mapping_option::(opt_list)?; } - // TODO: Sanitize user mapping options _ => {} } } - Ok(()) } diff --git a/src/fdw/iceberg.rs b/src/fdw/iceberg.rs index 33f3be86..6938a6f5 100644 --- a/src/fdw/iceberg.rs +++ b/src/fdw/iceberg.rs @@ -20,11 +20,10 @@ use async_std::task; use duckdb::arrow::array::RecordBatch; use pgrx::*; use std::collections::HashMap; -use strum::IntoEnumIterator; use supabase_wrappers::prelude::*; use super::base::*; -use crate::duckdb::iceberg::IcebergOption; +use crate::duckdb::{iceberg::IcebergOption, secret::UserMappingOptions}; #[wrappers_fdw( author = "ParadeDB", @@ -111,19 +110,11 @@ impl ForeignDataWrapper for IcebergFdw { FOREIGN_DATA_WRAPPER_RELATION_ID => {} FOREIGN_SERVER_RELATION_ID => {} FOREIGN_TABLE_RELATION_ID => { - let valid_options: Vec = IcebergOption::iter() - .map(|opt| opt.as_ref().to_string()) - .collect(); - - validate_options(opt_list.clone(), valid_options)?; - - for opt in IcebergOption::iter() { - if opt.is_required() { - check_options_contain(&opt_list, opt.as_ref())?; - } - } + validate_mapping_option::(opt_list)?; + } + USER_MAPPING_RELATION_ID => { + validate_mapping_option::(opt_list)?; } - // TODO: Sanitize user mapping options _ => {} } } diff --git a/src/fdw/parquet.rs b/src/fdw/parquet.rs index 60fc3d83..3e60406e 100644 --- a/src/fdw/parquet.rs +++ b/src/fdw/parquet.rs @@ -20,11 +20,10 @@ use async_std::task; use duckdb::arrow::array::RecordBatch; use pgrx::*; use std::collections::HashMap; -use strum::IntoEnumIterator; use supabase_wrappers::prelude::*; use super::base::*; -use crate::duckdb::parquet::ParquetOption; +use crate::duckdb::{parquet::ParquetOption, secret::UserMappingOptions}; #[wrappers_fdw( author = "ParadeDB", @@ -111,23 +110,14 @@ impl ForeignDataWrapper for ParquetFdw { FOREIGN_DATA_WRAPPER_RELATION_ID => {} FOREIGN_SERVER_RELATION_ID => {} FOREIGN_TABLE_RELATION_ID => { - let valid_options: Vec = ParquetOption::iter() - .map(|opt| opt.as_ref().to_string()) - .collect(); - - validate_options(opt_list.clone(), valid_options)?; - - for opt in ParquetOption::iter() { - if opt.is_required() { - check_options_contain(&opt_list, opt.as_ref())?; - } - } + validate_mapping_option::(opt_list)?; + } + USER_MAPPING_RELATION_ID => { + validate_mapping_option::(opt_list)?; } - // TODO: Sanitize user mapping options _ => {} } } - Ok(()) } diff --git a/src/fdw/spatial.rs b/src/fdw/spatial.rs index d49b8078..2cb67f97 100644 --- a/src/fdw/spatial.rs +++ b/src/fdw/spatial.rs @@ -20,11 +20,10 @@ use async_std::task; use duckdb::arrow::array::RecordBatch; use pgrx::*; use std::collections::HashMap; -use strum::IntoEnumIterator; use supabase_wrappers::prelude::*; use super::base::*; -use crate::duckdb::spatial::SpatialOption; +use crate::duckdb::{secret::UserMappingOptions, spatial::SpatialOption}; #[wrappers_fdw( author = "ParadeDB", @@ -111,19 +110,11 @@ impl ForeignDataWrapper for SpatialFdw { FOREIGN_DATA_WRAPPER_RELATION_ID => {} FOREIGN_SERVER_RELATION_ID => {} FOREIGN_TABLE_RELATION_ID => { - let valid_options: Vec = SpatialOption::iter() - .map(|opt| opt.as_ref().to_string()) - .collect(); - - validate_options(opt_list.clone(), valid_options)?; - - for opt in SpatialOption::iter() { - if opt.is_required() { - check_options_contain(&opt_list, opt.as_ref())?; - } - } + validate_mapping_option::(opt_list)?; + } + USER_MAPPING_RELATION_ID => { + validate_mapping_option::(opt_list)?; } - // TODO: Sanitize user mapping options _ => {} } } diff --git a/tests/scan.rs b/tests/scan.rs index c9a1ec21..9b2d12ce 100644 --- a/tests/scan.rs +++ b/tests/scan.rs @@ -20,7 +20,8 @@ mod fixtures; use std::fs::File; use crate::fixtures::arrow::{ - delta_primitive_record_batch, primitive_create_table, primitive_record_batch, + delta_primitive_record_batch, primitive_create_foreign_data_wrapper, primitive_create_server, + primitive_create_table, primitive_create_user_mapping_options, primitive_record_batch, primitive_setup_fdw_local_file_delta, primitive_setup_fdw_local_file_listing, primitive_setup_fdw_s3_delta, primitive_setup_fdw_s3_listing, }; @@ -96,6 +97,46 @@ async fn test_arrow_types_s3_listing(#[future(awt)] s3: S3, mut conn: PgConnecti Ok(()) } +#[rstest] +async fn test_wrong_user_mapping_s3_listing( + #[future(awt)] s3: S3, + mut conn: PgConnection, +) -> Result<()> { + let s3_bucket = "test-wrong-user-mapping-s3-listing"; + let s3_key = "test_wrong_user_mapping_s3_listing.parquet"; + let s3_endpoint = s3.url.clone(); + let s3_object_path = format!("s3://{s3_bucket}/{s3_key}"); + + let stored_batch = primitive_record_batch()?; + s3.create_bucket(s3_bucket).await?; + s3.put_batch(s3_bucket, s3_key, &stored_batch).await?; + + let create_foreign_data_wrapper = primitive_create_foreign_data_wrapper( + "parquet_wrapper", + "parquet_fdw_handler", + "parquet_fdw_validator", + ); + let create_user_mapping_options = + primitive_create_user_mapping_options("public", "parquet_server"); + let create_server = primitive_create_server("parquet_server", "parquet_wrapper"); + let create_table = primitive_create_table("parquet_server", "primitive"); + + // this is the wrong user mapping because the type is not provided + let wrong_user_mapping = format!( + r#" + {create_foreign_data_wrapper}; + {create_server}; + {create_user_mapping_options} OPTIONS (region 'us-east-1', endpoint '{s3_endpoint}', use_ssl 'false', url_style 'path'); + {create_table} OPTIONS (files '{s3_object_path}'); + "# + ); + + let result = wrong_user_mapping.execute_result(&mut conn); + assert!(result.is_err()); + + Ok(()) +} + #[rstest] async fn test_arrow_types_s3_delta( #[future(awt)] s3: S3,