diff --git a/Cargo.toml b/Cargo.toml index 7b72910f..880050e2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,7 +28,7 @@ serde = "1.0.201" serde_json = "1.0.120" signal-hook = "0.3.17" signal-hook-async-std = "0.2.2" -shared = { git = "https://github.com/paradedb/paradedb.git", rev = "4854652" } +shared = { git = "https://github.com/paradedb/paradedb.git", branch = "add-util-record-batches" } supabase-wrappers = { git = "https://github.com/paradedb/wrappers.git", default-features = false, rev = "6c58451" } thiserror = "1.0.59" uuid = "1.9.1" @@ -42,7 +42,7 @@ futures = "0.3.30" pgrx-tests = "0.11.3" rstest = "0.19.0" serde_arrow = { version = "0.11.3", features = ["arrow-51"] } -shared = { git = "https://github.com/paradedb/paradedb.git", rev = "4854652", features = ["fixtures"] } +shared = { git = "https://github.com/paradedb/paradedb.git", branch = "add-util-record-batches", features = ["fixtures"] } sqlx = { version = "0.7.4", features = [ "postgres", "runtime-async-std", diff --git a/src/schema/cell.rs b/src/schema/cell.rs index 3c806c71..c5e5068b 100644 --- a/src/schema/cell.rs +++ b/src/schema/cell.rs @@ -25,8 +25,9 @@ use duckdb::arrow::array::types::{ }; use duckdb::arrow::array::{ timezone::Tz, Array, ArrayAccessor, ArrayRef, ArrowPrimitiveType, AsArray, BinaryArray, - BooleanArray, Decimal128Array, Float16Array, Float32Array, Float64Array, GenericByteArray, - Int16Array, Int32Array, Int64Array, Int8Array, LargeBinaryArray, StringArray, + BooleanArray, Date32Array, Date64Array, Decimal128Array, Float16Array, Float32Array, + Float64Array, GenericByteArray, Int16Array, Int32Array, Int64Array, Int8Array, + LargeBinaryArray, StringArray, }; use duckdb::arrow::datatypes::{DataType, DecimalType, GenericStringType, IntervalUnit, TimeUnit}; use pgrx::*; @@ -1126,6 +1127,20 @@ where None => Ok(None), } } + DataType::Date32 => match self.get_primitive_value::(index)? { + Some(timestamp_in_days) => { + Ok(arrow_date32_to_postgres_timestamps(timestamp_in_days)? + .map(Cell::Timestamptz)) + } + None => Ok(None), + }, + DataType::Date64 => match self.get_primitive_value::(index)? { + Some(timestamp_in_milliseconds) => Ok(arrow_date64_to_postgres_timestamps( + timestamp_in_milliseconds, + )? + .map(Cell::Timestamptz)), + None => Ok(None), + }, unsupported => Err(DataTypeError::DataTypeMismatch( name.to_string(), unsupported.clone(), diff --git a/src/schema/datetime.rs b/src/schema/datetime.rs index cceef37e..10258a94 100644 --- a/src/schema/datetime.rs +++ b/src/schema/datetime.rs @@ -16,7 +16,7 @@ // along with this program. If not, see . use chrono::{ - DateTime, Datelike, NaiveDate, NaiveDateTime, NaiveTime, TimeDelta, TimeZone, Timelike, + DateTime, Datelike, Days, NaiveDate, NaiveDateTime, NaiveTime, TimeDelta, TimeZone, Timelike, }; use pgrx::*; use std::fmt::Debug; @@ -25,6 +25,47 @@ use std::str::FromStr; const NANOSECONDS_IN_SECOND: u32 = 1_000_000_000; +const MILLISECONDS_IN_SECOND: i64 = 1_000; + +const SECONDS_IN_DAY: i64 = 86_400; + +// Number of days between Apache Arrow / UNIX epoch (1970-01-01) +// and PostgreSQL epoch (2000-01-01). +const POSTGRES_BASE_DATE_OFFSET: Days = Days::new(10_957); + +/// Converts an [`i32`] stored in [`arrow::array::types::Date32Type`] to PostgresSQL TimestampWithTimeZone +/// +/// Takes into account [`arrow::array::types::Date32Type`] stores the number of days +/// elapsed since UNIX epoch (1970-01-01). +/// Postgres [`datum::TimestampWithTimeZone`] type takes timestamp in microseconds +/// with epoch (2000-01-01) +#[inline(always)] +pub(crate) fn arrow_date32_to_postgres_timestamps( + timestamp_in_days: i32, +) -> Result, FromTimeError> { + arrow_date64_to_postgres_timestamps( + timestamp_in_days as i64 * SECONDS_IN_DAY * MILLISECONDS_IN_SECOND, + ) +} + +/// Converts an [`i64`] stored in [`arrow::array::types::Date64Type`] to PostgresSQL TimestampWithTimeZone +/// +/// Takes into account [`arrow::array::types::Date64Type`] stores the number of milliseconds +/// elapsed since UNIX epoch (1970-01-01). +/// Postgres [`datum::TimestampWithTimeZone`] type takes timestamp in microseconds +/// with epoch (2000-01-01) +#[inline(always)] +pub(crate) fn arrow_date64_to_postgres_timestamps( + timestamp_in_milliseconds: i64, +) -> Result, FromTimeError> { + DateTime::from_timestamp_millis(timestamp_in_milliseconds) + .map(|date_time| date_time.naive_utc()) + .and_then(|naive_date_time| naive_date_time.checked_sub_days(POSTGRES_BASE_DATE_OFFSET)) + .map(|shifted_naive_date_time| shifted_naive_date_time.and_utc().timestamp_micros()) + .map(TimestampWithTimeZone::try_from) + .transpose() +} + #[derive(Clone, Debug)] pub struct Date(pub NaiveDate); diff --git a/tests/scan.rs b/tests/scan.rs index bfa1c6f5..d4c61e2b 100644 --- a/tests/scan.rs +++ b/tests/scan.rs @@ -18,8 +18,12 @@ mod fixtures; use std::fs::File; +use std::sync::Arc; use anyhow::Result; +use chrono::{DateTime, Datelike, TimeZone, Utc}; +use datafusion::arrow::array::*; +use datafusion::arrow::datatypes::DataType; use datafusion::parquet::arrow::ArrowWriter; use deltalake::operations::create::CreateBuilder; use deltalake::writer::{DeltaWriter, RecordBatchWriter}; @@ -28,7 +32,7 @@ use rstest::*; use shared::fixtures::arrow::{ delta_primitive_record_batch, primitive_record_batch, primitive_setup_fdw_local_file_delta, primitive_setup_fdw_local_file_listing, primitive_setup_fdw_s3_delta, - primitive_setup_fdw_s3_listing, + primitive_setup_fdw_s3_listing, setup_fdw_local_parquet_file_listing, FieldSpec, }; use shared::fixtures::tempfile::TempDir; use sqlx::postgres::types::PgInterval; @@ -36,11 +40,45 @@ use sqlx::types::{BigDecimal, Json, Uuid}; use sqlx::PgConnection; use std::collections::HashMap; use std::str::FromStr; +use temporal_conversions::SECONDS_IN_DAY; use time::macros::{date, datetime, time}; const S3_TRIPS_BUCKET: &str = "test-trip-setup"; const S3_TRIPS_KEY: &str = "test_trip_setup.parquet"; +fn date_time_record_batch() -> Result<(RecordBatch, FieldSpec, Vec)> { + let field_spec = FieldSpec::from(vec![ + ("date32_col", DataType::Date32, false, "date"), + ("date64_col", DataType::Date64, false, "date"), + ]); + let dates = vec![ + "2023-04-01 21:10:00 +0000".to_string(), "2023-04-01 22:08:00 +0000".to_string(), + "2023-04-02 04:55:00 +0000".to_string(), "2023-04-02 11:45:00 +0000".to_string(), + "2023-04-03 01:20:00 +0000".to_string(), "2023-04-03 12:30:00 +0000".to_string(), + ]; + let (dates_i32, dates_i64): (Vec<_>, Vec<_>) = dates + .iter() + .map(|date_str| { + let dt = date_str.parse::>().unwrap(); + ( + (dt.timestamp() / SECONDS_IN_DAY) as i32, + dt.timestamp_millis(), + ) + }) + .unzip(); + + let schema = Arc::new(field_spec.arrow_schema()); + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(Date32Array::from(dates_i32)), + Arc::new(Date64Array::from(dates_i64)), + ], + )?; + + Ok((batch, field_spec, dates)) +} + #[rstest] async fn test_trip_count(#[future(awt)] s3: S3, mut conn: PgConnection) -> Result<()> { NycTripsTable::setup().execute(&mut conn); @@ -287,3 +325,45 @@ async fn test_create_heap_from_parquet(mut conn: PgConnection, tempdir: TempDir) Ok(()) } + +#[rstest] +async fn test_date_functions_support_with_local_file( + mut conn: PgConnection, + tempdir: TempDir, +) -> Result<()> { + let (stored_batch, field_spec, dates) = date_time_record_batch()?; + let parquet_path = tempdir.path().join("test_date_functions.parquet"); + let parquet_file = File::create(&parquet_path)?; + + let mut writer = ArrowWriter::try_new(parquet_file, stored_batch.schema(), None).unwrap(); + writer.write(&stored_batch)?; + writer.close()?; + + setup_fdw_local_parquet_file_listing( + parquet_path.as_path().to_str().unwrap(), + "dates", + &field_spec.postgres_schema(), + ) + .execute(&mut conn); + + let expected_rows: Vec<(f64, DateTime)> = dates + .iter() + .map(|date_str| { + let dt = date_str.parse::>().unwrap(); + ( + dt.day() as f64, + Utc.with_ymd_and_hms(dt.year(), dt.month(), dt.day(), 0, 0, 0) + .unwrap(), + ) + }) + .collect(); + + let fetched_rows = + "SELECT DATE_PART('day', date32_col), DATE_TRUNC('day', date64_col) FROM dates" + .fetch_result::<(f64, chrono::DateTime)>(&mut conn)?; + + assert_eq!(expected_rows.len(), fetched_rows.len()); + assert_eq!(expected_rows, fetched_rows); + + Ok(()) +}