Skip to content

Commit

Permalink
feat: bson2json and json2bson scalar converters (#3104)
Browse files Browse the repository at this point in the history
  • Loading branch information
tychoish authored Jul 26, 2024
1 parent 9e7b551 commit 335fb9c
Show file tree
Hide file tree
Showing 9 changed files with 218 additions and 12 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/sqlbuiltins/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,4 @@ arrow-cast = { version = "50.0.0" } # MUST synchronize sync with the datafusion:
lance-linalg = { git = "https://github.com/GlareDB/lance", branch = "df36" } # omits duckdb submodule
jaq-interpret = "1.5.0"
jaq-parse = "1.0.2"
bson = "2.11.0"
3 changes: 3 additions & 0 deletions crates/sqlbuiltins/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ pub enum BuiltinError {
#[error("serde_json: {0}")]
SerdeJsonError(String),

#[error(transparent)]
BsonSer(#[from] bson::ser::Error),

#[error("jaq: {0}")]
JaqError(String),
}
Expand Down
5 changes: 5 additions & 0 deletions crates/sqlbuiltins/src/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use datafusion::logical_expr::{
};
use once_cell::sync::Lazy;
use protogen::metastore::types::catalog::FunctionType;
use scalars::bson2json::Bson2Json;
use scalars::df_scalars::ArrowCastFunction;
use scalars::hashing::{FnvHash, PartitionResults, SipHash};
use scalars::jaq::{JAQMatches, JAQSelect};
Expand All @@ -41,6 +42,7 @@ use scalars::{ConnectionId, Version};
use table::{BuiltinTableFuncs, TableFunc};

use self::alias_map::AliasMap;
use crate::functions::scalars::bson2json::Json2Bson;
use crate::functions::scalars::df_scalars::{Decode, Encode, IsNan, NullIf};
use crate::functions::scalars::openai::OpenAIEmbed;
use crate::functions::scalars::similarity::CosineSimilarity;
Expand Down Expand Up @@ -246,6 +248,9 @@ impl FunctionRegistry {
// JAQ functions
Arc::new(JAQMatches::new()),
Arc::new(JAQSelect::new()),
// Converters
Arc::new(Bson2Json),
Arc::new(Json2Bson),
// Hashing/Partitioning
Arc::new(SipHash),
Arc::new(FnvHash),
Expand Down
153 changes: 153 additions & 0 deletions crates/sqlbuiltins/src/functions/scalars/bson2json.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
use std::sync::Arc;

use bson::Bson;
use catalog::session_catalog::SessionCatalog;
use datafusion::arrow::datatypes::DataType;
use datafusion::error::{DataFusionError, Result as DataFusionResult};
use datafusion::logical_expr::expr::ScalarFunction;
use datafusion::logical_expr::{
ColumnarValue,
ReturnTypeFunction,
ScalarFunctionImplementation,
ScalarUDF,
Signature,
TypeSignature,
Volatility,
};
use datafusion::prelude::Expr;
use datafusion::scalar::ScalarValue;
use protogen::metastore::types::catalog::FunctionType;

use super::apply_op_to_col_array;
use crate::errors::BuiltinError;
use crate::functions::{BuiltinScalarUDF, ConstBuiltinFunction};

pub struct Bson2Json;

impl ConstBuiltinFunction for Bson2Json {
const NAME: &'static str = "bson2json";
const DESCRIPTION: &'static str = "Converts a bson value to a (relaxed extended) json string";
const EXAMPLE: &'static str = "bson2json(<value>)";
const FUNCTION_TYPE: FunctionType = FunctionType::Scalar;

fn signature(&self) -> Option<Signature> {
Some(Signature::one_of(
vec![TypeSignature::OneOf(vec![
TypeSignature::Exact(vec![]),
TypeSignature::Exact(vec![DataType::Binary]),
TypeSignature::Exact(vec![DataType::LargeBinary]),
])],
Volatility::Immutable,
))
}
}

impl Bson2Json {
fn convert(scalar: &ScalarValue) -> Result<ScalarValue, BuiltinError> {
match scalar {
ScalarValue::Binary(Some(v)) | ScalarValue::LargeBinary(Some(v)) => {
Ok(ScalarValue::new_utf8(
bson::de::from_slice::<Bson>(v)
.map_err(|e| DataFusionError::External(Box::new(e)))?
.into_relaxed_extjson()
.to_string(),
))
}
ScalarValue::Binary(None) | ScalarValue::LargeBinary(None) => {
Ok(ScalarValue::Utf8(None))
}
other => Err(BuiltinError::IncorrectType(
other.data_type(),
DataType::Binary,
)),
}
}
}

impl BuiltinScalarUDF for Bson2Json {
fn try_as_expr(&self, _: &SessionCatalog, args: Vec<Expr>) -> DataFusionResult<Expr> {
let return_type_fn: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Utf8)));
let scalar_fn_impl: ScalarFunctionImplementation = Arc::new(move |input| {
Ok(match input {
[] => ColumnarValue::Scalar(ScalarValue::new_utf8("{}")),
[ColumnarValue::Scalar(scalar)] => ColumnarValue::Scalar(Self::convert(scalar)?),
[ColumnarValue::Array(array)] => {
ColumnarValue::Array(apply_op_to_col_array(array, &Self::convert)?)
}
_ => unreachable!("bson2json expects exactly one argument"),
})
});
Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
Arc::new(ScalarUDF::new(
Self::NAME,
&ConstBuiltinFunction::signature(self).unwrap(),
&return_type_fn,
&scalar_fn_impl,
)),
args,
)))
}
}


pub struct Json2Bson;

impl ConstBuiltinFunction for Json2Bson {
const NAME: &'static str = "json2bson";
const DESCRIPTION: &'static str = "Converts a json string value to Bson";
const EXAMPLE: &'static str = "json2bson(<value>)";
const FUNCTION_TYPE: FunctionType = FunctionType::Scalar;

fn signature(&self) -> Option<Signature> {
Some(Signature::one_of(
vec![TypeSignature::OneOf(vec![
TypeSignature::Exact(vec![]),
TypeSignature::Exact(vec![DataType::Utf8]),
TypeSignature::Exact(vec![DataType::LargeUtf8]),
])],
Volatility::Immutable,
))
}
}

impl Json2Bson {
fn convert(scalar: &ScalarValue) -> Result<ScalarValue, BuiltinError> {
match scalar {
ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) => {
Ok(ScalarValue::Binary(Some(bson::ser::to_vec(
&serde_json::from_str::<serde_json::Value>(v)?,
)?)))
}
ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) => Ok(ScalarValue::Binary(None)),
other => Err(BuiltinError::IncorrectType(
other.data_type(),
DataType::Utf8,
)),
}
}
}

impl BuiltinScalarUDF for Json2Bson {
fn try_as_expr(&self, _: &SessionCatalog, args: Vec<Expr>) -> DataFusionResult<Expr> {
let return_type_fn: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Binary)));
let scalar_fn_impl: ScalarFunctionImplementation = Arc::new(move |input| {
Ok(match input {
[] => ColumnarValue::Scalar(ScalarValue::Binary(Some(Vec::new()))),
[ColumnarValue::Scalar(scalar)] => ColumnarValue::Scalar(Self::convert(scalar)?),
[ColumnarValue::Array(array)] => {
ColumnarValue::Array(apply_op_to_col_array(array, &Self::convert)?)
}
_ => unreachable!("json2bson expects exactly one argument"),
})
});
Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
Arc::new(ScalarUDF::new(
Self::NAME,
&ConstBuiltinFunction::signature(self).unwrap(),
&return_type_fn,
&scalar_fn_impl,
)),
args,
)))
}
}
8 changes: 4 additions & 4 deletions crates/sqlbuiltins/src/functions/scalars/jaq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ impl ScalarUDFImpl for JAQSelect {
get_nth_string_value(
input,
0,
&|value: String| -> Result<ScalarValue, BuiltinError> {
let val: Value = serde_json::from_str(&value)?;
&|value: &String| -> Result<ScalarValue, BuiltinError> {
let val: Value = serde_json::from_str(value)?;
let inputs = RcIter::new(core::iter::empty());

let output = filter
Expand Down Expand Up @@ -197,8 +197,8 @@ impl ScalarUDFImpl for JAQMatches {
get_nth_string_value(
input,
0,
&|value: String| -> Result<ScalarValue, BuiltinError> {
let val: Value = serde_json::from_str(&value)?;
&|value: &String| -> Result<ScalarValue, BuiltinError> {
let val: Value = serde_json::from_str(value)?;
let input = RcIter::new(core::iter::empty());

let output = filter.run((Ctx::new([], &input), Val::from(val)));
Expand Down
4 changes: 2 additions & 2 deletions crates/sqlbuiltins/src/functions/scalars/kdl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ impl ScalarUDFImpl for KDLSelect {
get_nth_string_value(
input,
0,
&|value: String| -> Result<ScalarValue, BuiltinError> {
&|value: &String| -> Result<ScalarValue, BuiltinError> {
let sdoc: kdl::KdlDocument = value.parse().map_err(BuiltinError::KdlError)?;

let out: Vec<&KdlNode> = sdoc
Expand Down Expand Up @@ -200,7 +200,7 @@ impl ScalarUDFImpl for KDLMatches {
get_nth_string_value(
input,
0,
&|value: String| -> Result<ScalarValue, BuiltinError> {
&|value: &String| -> Result<ScalarValue, BuiltinError> {
let doc: kdl::KdlDocument = value.parse().map_err(BuiltinError::KdlError)?;

Ok(ScalarValue::Boolean(Some(
Expand Down
12 changes: 6 additions & 6 deletions crates/sqlbuiltins/src/functions/scalars/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod bson2json;
pub mod df_scalars;
pub mod hashing;
pub mod jaq;
Expand All @@ -19,7 +20,6 @@ use crate::document;
use crate::errors::BuiltinError;
use crate::functions::{BuiltinFunction, BuiltinScalarUDF, ConstBuiltinFunction};


pub struct ConnectionId;

impl ConstBuiltinFunction for ConnectionId {
Expand Down Expand Up @@ -67,11 +67,11 @@ impl BuiltinScalarUDF for Version {
fn get_nth_scalar_value(
input: &[ColumnarValue],
n: usize,
op: &dyn Fn(ScalarValue) -> Result<ScalarValue, BuiltinError>,
op: &dyn Fn(&ScalarValue) -> Result<ScalarValue, BuiltinError>,
) -> Result<ColumnarValue, BuiltinError> {
match input.get(n) {
Some(input) => match input {
ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar(op(scalar.clone())?)),
ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar(op(scalar)?)),
ColumnarValue::Array(arr) => Ok(ColumnarValue::Array(apply_op_to_col_array(arr, op)?)),
},
None => Err(BuiltinError::MissingValueAtIndex(n)),
Expand All @@ -80,7 +80,7 @@ fn get_nth_scalar_value(

fn apply_op_to_col_array(
arr: &dyn Array,
op: &dyn Fn(ScalarValue) -> Result<ScalarValue, BuiltinError>,
op: &dyn Fn(&ScalarValue) -> Result<ScalarValue, BuiltinError>,
) -> Result<Arc<dyn Array>, BuiltinError> {
let mut check_err: Result<(), BuiltinError> = Ok(());

Expand All @@ -104,7 +104,7 @@ fn apply_op_to_col_array(
let iter = (0..arr.len()).filter_map(|idx| {
let scalar_res = ScalarValue::try_from_array(arr, idx).map_err(BuiltinError::from);
let scalar = filter_fn(&mut check_err, scalar_res)?;
filter_fn(&mut check_err, op(scalar))
filter_fn(&mut check_err, op(&scalar))
});

// NB: ScalarValue::iter_to_array accepts an iterator over
Expand Down Expand Up @@ -207,7 +207,7 @@ fn get_nth_string_fn_arg(input: &[ColumnarValue], idx: usize) -> Result<String,
fn get_nth_string_value(
input: &[ColumnarValue],
n: usize,
op: &dyn Fn(String) -> Result<ScalarValue, BuiltinError>,
op: &dyn Fn(&String) -> Result<ScalarValue, BuiltinError>,
) -> Result<ColumnarValue, BuiltinError> {
get_nth_scalar_value(input, n, &|scalar| -> Result<ScalarValue, BuiltinError> {
match scalar {
Expand Down
43 changes: 43 additions & 0 deletions testdata/sqllogictests/functions/bson.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
statement ok
CREATE TEMP TABLE bson_conversions (id int, json text, bson bytea);

statement ok
INSERT INTO bson_conversions
VALUES
(0, '{"a":1}', json2bson('{"a":1}')),
(1, '{"b":2}', json2bson('{"b":2}'));

query
SELECT jaq_select(json, '.a')
FROM bson_conversions
WHERE id = 0;
----
1

query
SELECT jaq_select(bson2json(bson), '.a')
FROM bson_conversions
WHERE id = 0;
----
1

query
SELECT jaq_select(bson2json(bson), '.a')
FROM bson_conversions
WHERE id = 1;
----
NULL

query
SELECT bson2json(bson) = json
FROM bson_conversions
WHERE id = 0
----
t

query
SELECT json2bson(json) = bson
FROM bson_conversions
WHERE id = 0
----
t

0 comments on commit 335fb9c

Please sign in to comment.