From a14415b013b51637149d32e93619f117b056babb Mon Sep 17 00:00:00 2001 From: Siyuan Huang <73871299+kysshsy@users.noreply.github.com> Date: Fri, 3 Jan 2025 19:12:13 +0800 Subject: [PATCH] feat: support DataType::List cast to json and jsonb (#187) * feat: support Datatype::List cast to json and jsonb * feat: fix array type downcast * test: primitive list to json test * test: add struct_list to json test --- src/schema/cell.rs | 125 +++++++++++++++++++ tests/tests/json.rs | 298 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 421 insertions(+), 2 deletions(-) diff --git a/src/schema/cell.rs b/src/schema/cell.rs index 2602bafd..9d0e1bb4 100644 --- a/src/schema/cell.rs +++ b/src/schema/cell.rs @@ -327,6 +327,118 @@ where } } +pub trait GetListValue +where + Self: Array + AsArray, +{ + fn get_list_value(&self, index: usize) -> Result> { + let downcast_array = self.as_list::(); + + if downcast_array.nulls().is_some() && downcast_array.is_null(index) { + return Ok(None); + } + + match downcast_array.value_type() { + DataType::Boolean => { + let list_array: ArrayRef = Arc::new(downcast_array.clone()); + let values = list_array + .get_primitive_list_value::>(index)? + .map_or(vec![], |arr| { + arr.into_iter() + .map(|opt| opt.map_or(Value::Null, Value::from)) + .collect() + }); + Ok(Some(datum::JsonB(Value::Array(values)))) + } + DataType::Int8 => { + let list_array: ArrayRef = Arc::new(downcast_array.clone()); + let values = list_array + .get_primitive_list_value::>(index)? + .map_or(vec![], |arr| { + arr.into_iter() + .map(|opt| opt.map_or(Value::Null, |v| Value::Number(Number::from(v)))) + .collect() + }); + Ok(Some(datum::JsonB(Value::Array(values)))) + } + DataType::Int16 => { + let list_array: ArrayRef = Arc::new(downcast_array.clone()); + let values = list_array + .get_primitive_list_value::>(index)? + .map_or(vec![], |arr| { + arr.into_iter() + .map(|opt| opt.map_or(Value::Null, |v| Value::Number(Number::from(v)))) + .collect() + }); + Ok(Some(datum::JsonB(Value::Array(values)))) + } + DataType::Int32 => { + let list_array: ArrayRef = Arc::new(downcast_array.clone()); + let values = list_array + .get_primitive_list_value::>(index)? + .map_or(vec![], |arr| { + arr.into_iter() + .map(|opt| opt.map_or(Value::Null, |v| Value::Number(Number::from(v)))) + .collect() + }); + Ok(Some(datum::JsonB(Value::Array(values)))) + } + DataType::Int64 => { + let list_array: ArrayRef = Arc::new(downcast_array.clone()); + let values = list_array + .get_primitive_list_value::>(index)? + .map_or(vec![], |arr| { + arr.into_iter() + .map(|opt| opt.map_or(Value::Null, |v| Value::Number(Number::from(v)))) + .collect() + }); + Ok(Some(datum::JsonB(Value::Array(values)))) + } + DataType::Utf8 => { + let list_array: ArrayRef = Arc::new(downcast_array.clone()); + let values = list_array + .get_string_list_value(index)? + .map_or(vec![], |arr| { + arr.into_iter() + .map(|opt| opt.map_or(Value::Null, Value::String)) + .collect() + }); + Ok(Some(datum::JsonB(Value::Array(values)))) + } + DataType::LargeUtf8 => { + let list_array: ArrayRef = Arc::new(downcast_array.clone()); + let mut values = vec![]; + for i in 0..list_array.len() { + let string_value = list_array + .get_primitive_value::(i)? + .map_or(Value::Null, |v| Value::String(v.to_string())); + values.push(string_value); + } + Ok(Some(datum::JsonB(Value::Array(values)))) + } + DataType::Struct(_) => { + let list_array = downcast_array.value(index); + let mut values = vec![]; + for i in 0..list_array.len() { + let struct_value = list_array.get_struct_value(i)?.map_or(Value::Null, |v| v.0); + values.push(struct_value); + } + Ok(Some(datum::JsonB(Value::Array(values)))) + } + DataType::List(_) => { + let list_array = downcast_array.value(index); + let mut values = vec![]; + for i in 0..list_array.len() { + let list_value = list_array.get_list_value(i)?.map_or(Value::Null, |v| v.0); + values.push(list_value); + } + Ok(Some(datum::JsonB(Value::Array(values)))) + } + unsupported => bail!("List with {:?} types are not yet supported", unsupported), + } + } +} + pub trait GetDecimalValue where Self: Array + AsArray, @@ -559,6 +671,7 @@ where + GetIntervalDayTimeValue + GetIntervalMonthDayNanoValue + GetIntervalYearMonthValue + + GetListValue + GetPrimitiveValue + GetPrimitiveListValue + GetStringListValue @@ -1040,6 +1153,13 @@ where None => Ok(None), } } + DataType::List(_) => match self.get_list_value(index)? { + Some(value) => { + let json_value: serde_json::Value = serde_json::to_value(value)?; + Ok(Some(Cell::Json(datum::Json(json_value)))) + } + None => Ok(None), + }, unsupported => Err(DataTypeError::DataTypeMismatch( name.to_string(), unsupported.clone(), @@ -1068,6 +1188,10 @@ where None => Ok(None), } } + DataType::List(_) => match self.get_list_value(index)? { + Some(value) => Ok(Some(Cell::JsonB(value))), + None => Ok(None), + }, unsupported => Err(DataTypeError::DataTypeMismatch( name.to_string(), unsupported.clone(), @@ -1266,6 +1390,7 @@ impl GetDecimalValue for ArrayRef {} impl GetIntervalDayTimeValue for ArrayRef {} impl GetIntervalMonthDayNanoValue for ArrayRef {} impl GetIntervalYearMonthValue for ArrayRef {} +impl GetListValue for ArrayRef {} impl GetPrimitiveValue for ArrayRef {} impl GetPrimitiveListValue for ArrayRef {} impl GetStringListValue for ArrayRef {} diff --git a/tests/tests/json.rs b/tests/tests/json.rs index 82646ef1..7bfeb7d1 100644 --- a/tests/tests/json.rs +++ b/tests/tests/json.rs @@ -18,8 +18,13 @@ mod fixtures; use anyhow::Result; -use datafusion::arrow::array::{LargeStringArray, StringArray}; -use datafusion::arrow::datatypes::{DataType, Field, Schema}; +use datafusion::arrow::array::{ + ArrayBuilder, ArrowPrimitiveType, BooleanBuilder, LargeStringArray, LargeStringBuilder, + ListArray, ListBuilder, PrimitiveBuilder, StringArray, StringBuilder, StructBuilder, +}; +use datafusion::arrow::datatypes::{ + DataType, Field, Fields, Int16Type, Int32Type, Int64Type, Int8Type, Schema, +}; use datafusion::{arrow::record_batch::RecordBatch, parquet::arrow::ArrowWriter}; use rstest::*; use serde_json::json; @@ -52,6 +57,217 @@ pub fn json_string_record_batch() -> Result { )?) } +fn boolean_list_array(boolean_values: Vec>>) -> ListArray { + let boolean_builder = BooleanBuilder::new(); + let mut list_builder = ListBuilder::new(boolean_builder); + + for values in boolean_values { + for value in values { + list_builder.values().append_option(value); + } + list_builder.append(true); + } + + list_builder.finish() +} + +fn primitive_list_array>, V>( + values: Vec>>, +) -> ListArray { + let builder = PrimitiveBuilder::::new(); + let mut list_builder = ListBuilder::new(builder); + + for sublist in values { + for value in sublist { + list_builder + .values() + .append_option(value.map(|v| T::Native::from(v))); + } + list_builder.append(true); + } + + list_builder.finish() +} + +fn string_list_array(values: Vec>>) -> ListArray { + let builder = StringBuilder::new(); + let mut list_builder = ListBuilder::new(builder); + + for sublist in values { + for value in sublist { + list_builder.values().append_option(value); + } + list_builder.append(true); + } + + list_builder.finish() +} + +fn large_string_list_array(values: Vec>>) -> ListArray { + let builder = LargeStringBuilder::new(); + let mut list_builder = ListBuilder::new(builder); + + for sublist in values { + for value in sublist { + list_builder.values().append_option(value); + } + list_builder.append(true); + } + + list_builder.finish() +} + +pub fn json_list_record_batch() -> Result { + let fields = vec![ + Field::new( + "boolean_array", + DataType::List(Arc::new(Field::new("item", DataType::Boolean, true))), + false, + ), + Field::new( + "int8_array", + DataType::List(Arc::new(Field::new("item", DataType::Int8, true))), + false, + ), + Field::new( + "int16_array", + DataType::List(Arc::new(Field::new("item", DataType::Int16, true))), + false, + ), + Field::new( + "int32_array", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + false, + ), + Field::new( + "int64_array", + DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), + false, + ), + Field::new( + "string_array", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), + false, + ), + Field::new( + "large_string_array", + DataType::List(Arc::new(Field::new("item", DataType::LargeUtf8, true))), + false, + ), + ]; + + let schema = Arc::new(Schema::new(fields)); + + let boolean_values = vec![ + vec![None, Some(false), Some(true)], + vec![None, Some(true)], + vec![Some(true), None, Some(false), Some(false)], + ]; + let int_values = vec![ + vec![None, Some(1), Some(2)], + vec![None, Some(3)], + vec![Some(4), Some(5), None, Some(6)], + ]; + let string_values = vec![ + vec![Some("abc"), None, Some("b")], + vec![None, Some("ce")], + vec![Some("d"), Some("e"), None, Some("f")], + ]; + + let boolean_array = Arc::new(boolean_list_array(boolean_values)); + let int8_array = Arc::new(primitive_list_array::(int_values.clone())); + let int16_array = Arc::new(primitive_list_array::(int_values.clone())); + let int32_array = Arc::new(primitive_list_array::(int_values.clone())); + let int64_array = Arc::new(primitive_list_array::(int_values.clone())); + let string_array = Arc::new(string_list_array(string_values.clone())); + let large_string_array = Arc::new(large_string_list_array(string_values.clone())); + + Ok(RecordBatch::try_new( + schema, + vec![ + boolean_array, + int8_array, + int16_array, + int32_array, + int64_array, + string_array, + large_string_array, + ], + )?) +} + +pub fn struct_list_record_batch() -> Result { + let struct_fileds = vec![ + Field::new("name", DataType::Utf8, false), + Field::new("age", DataType::Int32, false), + ]; + let fields = vec![Field::new( + "struct_array", + DataType::List(Arc::new(Field::new( + "item", + DataType::Struct(Fields::from(struct_fileds.clone())), + true, + ))), + false, + )]; + + let schema = Arc::new(Schema::new(fields)); + + let struct_values = vec![ + vec![ + Some(("joe", 12)), + None, + Some(("jane", 13)), + Some(("jim", 14)), + ], + vec![Some(("joe", 12))], + ]; + + let struct_array = { + let mut struct_list_builder = ListBuilder::new(StructBuilder::new( + struct_fileds, + vec![ + Box::new(StringBuilder::new()) as Box, + Box::new(PrimitiveBuilder::::new()) as Box, + ], + )); + + for sublist in struct_values { + for value in sublist { + if let Some((name, age)) = value { + struct_list_builder.values().append(true); + struct_list_builder + .values() + .field_builder::(0) + .unwrap() + .append_value(name); + struct_list_builder + .values() + .field_builder::>(1) + .unwrap() + .append_value(age); + } else { + struct_list_builder.values().append(false); + struct_list_builder + .values() + .field_builder::(0) + .unwrap() + .append_null(); + struct_list_builder + .values() + .field_builder::>(1) + .unwrap() + .append_null(); + } + } + struct_list_builder.append(true); + } + struct_list_builder.finish() + }; + + Ok(RecordBatch::try_new(schema, vec![Arc::new(struct_array)])?) +} + #[rstest] async fn test_json_cast_from_string(mut conn: PgConnection, tempdir: TempDir) -> Result<()> { let stored_batch = json_string_record_batch()?; @@ -97,3 +313,81 @@ async fn test_json_cast_from_string(mut conn: PgConnection, tempdir: TempDir) -> Ok(()) } + +#[rstest] +fn test_json_cast_from_list(mut conn: PgConnection, tempdir: TempDir) -> Result<()> { + let stored_batch = json_list_record_batch()?; + let parquet_path = tempdir.path().join("test_json_cast_from_list.parquet"); + let parquet_file = File::create(&parquet_path)?; + + let mut writer = ArrowWriter::try_new(parquet_file, stored_batch.schema(), None).unwrap(); + writer.write(&stored_batch)?; + writer.close()?; + + primitive_create_foreign_data_wrapper( + "parquet_wrapper", + "parquet_fdw_handler", + "parquet_fdw_validator", + ) + .execute(&mut conn); + primitive_create_server("parquet_server", "parquet_wrapper").execute(&mut conn); + format!( + "CREATE FOREIGN TABLE json_table ( + boolean_array jsonb, + int8_array jsonb, + int16_array jsonb, + int32_array jsonb, + int64_array jsonb, + string_array jsonb, + large_string_array jsonb + ) SERVER parquet_server OPTIONS (files '{}')", + parquet_path.to_str().unwrap() + ) + .execute(&mut conn); + + let r = "SELECT * FROM json_table".execute_result(&mut conn); + assert!(r.is_ok(), "error in query:'{}'", r.unwrap_err()); + + let row: (Json,) = + "SELECT int8_array FROM json_table where int8_array = '[null, 3]'".fetch_one(&mut conn); + assert_eq!(row.0, Json::from(json!([null, 3]))); + + Ok(()) +} + +#[rstest] +fn test_json_cast_from_struct_list(mut conn: PgConnection, tempdir: TempDir) -> Result<()> { + let stored_batch = struct_list_record_batch()?; + let parquet_path = tempdir + .path() + .join("test_json_cast_from_struct_list.parquet"); + let parquet_file = File::create(&parquet_path)?; + + let mut writer = ArrowWriter::try_new(parquet_file, stored_batch.schema(), None).unwrap(); + writer.write(&stored_batch)?; + writer.close()?; + + primitive_create_foreign_data_wrapper( + "parquet_wrapper", + "parquet_fdw_handler", + "parquet_fdw_validator", + ) + .execute(&mut conn); + primitive_create_server("parquet_server", "parquet_wrapper").execute(&mut conn); + format!( + "CREATE FOREIGN TABLE json_table () + SERVER parquet_server OPTIONS (files '{}')", + parquet_path.to_str().unwrap() + ) + .execute(&mut conn); + + let r = "SELECT * FROM json_table".execute_result(&mut conn); + assert!(r.is_ok(), "error in query:'{}'", r.unwrap_err()); + + let row: (Json,) = + "SELECT struct_array FROM json_table where struct_array = '[{\"name\": \"joe\", \"age\": 12}]'" + .fetch_one(&mut conn); + assert_eq!(row.0, Json::from(json!([{"name": "joe", "age": 12}]))); + + Ok(()) +}