Skip to content

Commit

Permalink
chore: clean up boilerplate on scalar functions (#3105)
Browse files Browse the repository at this point in the history
  • Loading branch information
tychoish authored Jul 26, 2024
1 parent 335fb9c commit cff3020
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 418 deletions.
14 changes: 9 additions & 5 deletions crates/sqlbuiltins/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::error::ArrowError;
use datafusion::error::DataFusionError;
use datafusion_ext::errors::ExtensionError;
use datasources::json::jaq::JaqError;

#[derive(Clone, Debug, thiserror::Error)]
pub enum BuiltinError {
Expand Down Expand Up @@ -39,13 +40,16 @@ pub enum BuiltinError {
DataFusionExtension(String),

#[error("serde_json: {0}")]
SerdeJsonError(String),
SerdeJson(String),

#[error(transparent)]
BsonSer(#[from] bson::ser::Error),

#[error("jaq: {0}")]
JaqError(String),
#[error("jaq_internal: {0}")]
JaqInternal(String),

#[error(transparent)]
Jaq(#[from] JaqError),
}

pub type Result<T, E = BuiltinError> = std::result::Result<T, E>;
Expand Down Expand Up @@ -82,12 +86,12 @@ impl From<ArrowError> for BuiltinError {

impl From<serde_json::Error> for BuiltinError {
fn from(e: serde_json::Error) -> Self {
BuiltinError::SerdeJsonError(e.to_string())
BuiltinError::SerdeJson(e.to_string())
}
}

impl From<jaq_interpret::Error> for BuiltinError {
fn from(e: jaq_interpret::Error) -> Self {
BuiltinError::JaqError(e.to_string())
BuiltinError::JaqInternal(e.to_string())
}
}
10 changes: 5 additions & 5 deletions crates/sqlbuiltins/src/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,11 @@ impl FunctionRegistry {
Arc::new(ConnectionId),
Arc::new(Version),
// KDL functions
Arc::new(KDLMatches::new()),
Arc::new(KDLSelect::new()),
Arc::new(KDLMatches),
Arc::new(KDLSelect),
// JAQ functions
Arc::new(JAQMatches::new()),
Arc::new(JAQSelect::new()),
Arc::new(JAQMatches),
Arc::new(JAQSelect),
// Converters
Arc::new(Bson2Json),
Arc::new(Json2Bson),
Expand All @@ -258,7 +258,7 @@ impl FunctionRegistry {
// OpenAI
Arc::new(OpenAIEmbed),
// Similarity
Arc::new(CosineSimilarity::new()),
Arc::new(CosineSimilarity),
];
let udfs = udfs
.into_iter()
Expand Down
1 change: 1 addition & 0 deletions crates/sqlbuiltins/src/functions/scalars/bson2json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use super::apply_op_to_col_array;
use crate::errors::BuiltinError;
use crate::functions::{BuiltinScalarUDF, ConstBuiltinFunction};

#[derive(Debug)]
pub struct Bson2Json;

impl ConstBuiltinFunction for Bson2Json {
Expand Down
240 changes: 77 additions & 163 deletions crates/sqlbuiltins/src/functions/scalars/jaq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,15 @@ use datafusion::arrow::datatypes::DataType;
use datafusion::error::{DataFusionError, Result as DataFusionResult};
use datafusion::logical_expr::expr::ScalarFunction;
use datafusion::logical_expr::{
ColumnarValue,
ReturnTypeFunction,
ScalarFunctionImplementation,
ScalarUDF,
ScalarUDFImpl,
Signature,
TypeSignature,
Volatility,
};
use datafusion::prelude::Expr;
use datafusion::scalar::ScalarValue;
use datasources::json::errors::JsonError;
use datasources::json::jaq::compile_jaq_query;
use jaq_interpret::{Ctx, FilterT, RcIter, Val};
use protogen::metastore::types::catalog::FunctionType;
Expand All @@ -26,10 +23,8 @@ use super::{get_nth_string_fn_arg, get_nth_string_value};
use crate::errors::BuiltinError;
use crate::functions::{BuiltinScalarUDF, ConstBuiltinFunction};

#[derive(Debug)]
pub struct JAQSelect {
signature: Signature,
}
#[derive(Default, Debug)]
pub struct JAQSelect;

impl ConstBuiltinFunction for JAQSelect {
const NAME: &'static str = "jaq_select";
Expand All @@ -38,84 +33,48 @@ impl ConstBuiltinFunction for JAQSelect {
const FUNCTION_TYPE: FunctionType = FunctionType::Scalar;

fn signature(&self) -> Option<Signature> {
Some(self.signature.clone())
}
}

impl Default for JAQSelect {
fn default() -> Self {
Self::new()
}
}

impl JAQSelect {
pub fn new() -> Self {
Self {
signature: Signature::new(
// args: <FIELD>, <QUERY>
TypeSignature::OneOf(vec![
TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]),
TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]),
]),
Volatility::Immutable,
),
}
}
}

impl ScalarUDFImpl for JAQSelect {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
Self::NAME
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _: &[DataType]) -> datafusion::error::Result<DataType> {
Ok(DataType::Utf8)
}

fn invoke(&self, input: &[ColumnarValue]) -> datafusion::error::Result<ColumnarValue> {
let filter =
compile_jaq_query(get_nth_string_fn_arg(input, 1)?).map_err(JsonError::from)?;

get_nth_string_value(
input,
0,
&|value: &String| -> Result<ScalarValue, BuiltinError> {
let val: Value = serde_json::from_str(value)?;
let inputs = RcIter::new(core::iter::empty());

let output = filter
.run((Ctx::new([], &inputs), Val::from(val)))
.map(|res| res.map(|v| jaq_to_scalar_string(&v)))
.collect::<Result<Vec<_>, _>>()?;

Ok(match output.len() {
0 => ScalarValue::Utf8(None),
1 => output.first().unwrap().to_owned(),
_ => ScalarValue::List(ScalarValue::new_list(&output, &DataType::Utf8)),
})
},
)
.map_err(DataFusionError::from)
Some(Signature::one_of(
vec![
TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]),
TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]),
],
Volatility::Immutable,
))
}
}


impl BuiltinScalarUDF for JAQSelect {
fn try_as_expr(&self, _: &SessionCatalog, args: Vec<Expr>) -> DataFusionResult<Expr> {
let return_type_fn: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Utf8)));

let scalar_fn_impl: ScalarFunctionImplementation =
Arc::new(move |input| ScalarUDFImpl::invoke(&Self::new(), input));
let scalar_fn_impl: ScalarFunctionImplementation = Arc::new(move |input| {
let filter = compile_jaq_query(get_nth_string_fn_arg(input, 1)?)
.map_err(|e| DataFusionError::from(BuiltinError::from(e)))?;

get_nth_string_value(
input,
0,
&|value: &String| -> Result<ScalarValue, BuiltinError> {
let val: Value = serde_json::from_str(value)?;
let inputs = RcIter::new(core::iter::empty());

let output = filter
.run((Ctx::new([], &inputs), Val::from(val)))
.map(|res| res.map(|v| jaq_to_scalar_string(&v)))
.collect::<Result<Vec<_>, _>>()?;

Ok(match output.len() {
0 => ScalarValue::Utf8(None),
1 => output.first().unwrap().to_owned(),
_ => ScalarValue::List(ScalarValue::new_list(&output, &DataType::Utf8)),
})
},
)
.map_err(DataFusionError::from)
});


Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
Arc::new(ScalarUDF::new(
Expand All @@ -127,39 +86,10 @@ impl BuiltinScalarUDF for JAQSelect {
args,
)))
}

fn try_into_scalar_udf(self: Arc<Self>) -> datafusion::error::Result<ScalarUDF> {
Ok(Self::new().into())
}
}

#[derive(Debug)]
pub struct JAQMatches {
signature: Signature,
}

impl Default for JAQMatches {
fn default() -> Self {
Self::new()
}
}

impl JAQMatches {
pub fn new() -> Self {
Self {
signature: Signature::new(
// args: <FIELD>, <QUERY>
TypeSignature::OneOf(vec![
TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]),
TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]),
]),
Volatility::Immutable,
),
}
}
}
#[derive(Debug, Default)]
pub struct JAQMatches;

impl ConstBuiltinFunction for JAQMatches {
const NAME: &'static str = "jaq_matches";
Expand All @@ -169,63 +99,51 @@ impl ConstBuiltinFunction for JAQMatches {
const FUNCTION_TYPE: FunctionType = FunctionType::Scalar;

fn signature(&self) -> Option<Signature> {
Some(self.signature.clone())
}
}

impl ScalarUDFImpl for JAQMatches {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
Self::NAME
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _: &[DataType]) -> datafusion::error::Result<DataType> {
Ok(DataType::Boolean)
}

fn invoke(&self, input: &[ColumnarValue]) -> datafusion::error::Result<ColumnarValue> {
let filter =
compile_jaq_query(get_nth_string_fn_arg(input, 1)?).map_err(JsonError::from)?;

get_nth_string_value(
input,
0,
&|value: &String| -> Result<ScalarValue, BuiltinError> {
let val: Value = serde_json::from_str(value)?;
let input = RcIter::new(core::iter::empty());

let output = filter.run((Ctx::new([], &input), Val::from(val)));

for res in output {
match res? {
Val::Null => continue,
Val::Str(s) if s.is_empty() => continue,
Val::Str(_) => return Ok(ScalarValue::Boolean(Some(true))),
other if other.to_string().is_empty() => continue,
_ => return Ok(ScalarValue::Boolean(Some(true))),
}
}

Ok(ScalarValue::Boolean(Some(false)))
},
)
.map_err(DataFusionError::from)
Some(Signature::one_of(
vec![
TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]),
TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]),
],
Volatility::Immutable,
))
}
}

impl BuiltinScalarUDF for JAQMatches {
fn try_as_expr(&self, _: &SessionCatalog, args: Vec<Expr>) -> DataFusionResult<Expr> {
let return_type_fn: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Boolean)));

let scalar_fn_impl: ScalarFunctionImplementation =
Arc::new(move |input| ScalarUDFImpl::invoke(&Self::new(), input));
let scalar_fn_impl: ScalarFunctionImplementation = Arc::new(move |input| {
let filter = compile_jaq_query(get_nth_string_fn_arg(input, 1)?)
.map_err(|e| DataFusionError::from(BuiltinError::from(e)))?;

get_nth_string_value(
input,
0,
&|value: &String| -> Result<ScalarValue, BuiltinError> {
let val: Value = serde_json::from_str(value)?;
let input = RcIter::new(core::iter::empty());

let output = filter.run((Ctx::new([], &input), Val::from(val)));

for res in output {
match res? {
Val::Null => continue,
Val::Str(s) if s.is_empty() => continue,
Val::Str(_) => return Ok(ScalarValue::Boolean(Some(true))),
other if other.to_string().is_empty() => continue,
_ => return Ok(ScalarValue::Boolean(Some(true))),
}
}

Ok(ScalarValue::Boolean(Some(false)))
},
)
.map_err(DataFusionError::from)
});


Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
Arc::new(ScalarUDF::new(
Expand All @@ -237,10 +155,6 @@ impl BuiltinScalarUDF for JAQMatches {
args,
)))
}

fn try_into_scalar_udf(self: Arc<Self>) -> datafusion::error::Result<ScalarUDF> {
Ok(Self::new().into())
}
}


Expand Down
Loading

0 comments on commit cff3020

Please sign in to comment.