Skip to content

Commit

Permalink
feat(torii-grpc): IN operator for comparison (#2812)
Browse files Browse the repository at this point in the history
* feat: in operator grpc finish

* feat: clean up grpc

* fmt

* fmt
  • Loading branch information
Larkooo authored Dec 17, 2024
1 parent 28ec8e8 commit fcf9e26
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 27 deletions.
7 changes: 7 additions & 0 deletions crates/torii/grpc/proto/types.proto
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,14 @@ message MemberValue {
oneof value_type {
Primitive primitive = 1;
string string = 2;
MemberValueList list = 3;
}
}

message MemberValueList {
repeated MemberValue values = 1;
}

message MemberClause {
string model = 2;
string member = 3;
Expand Down Expand Up @@ -152,6 +157,8 @@ enum ComparisonOperator {
GTE = 3;
LT = 4;
LTE = 5;
IN = 6;
NOT_IN = 7;
}

message Token {
Expand Down
99 changes: 72 additions & 27 deletions crates/torii/grpc/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -773,15 +773,32 @@ impl DojoWorld {
let comparison_operator = ComparisonOperator::from_repr(member_clause.operator as usize)
.expect("invalid comparison operator");

let comparison_value =
match member_clause.value.ok_or(QueryError::MissingParam("value".into()))?.value_type {
Some(ValueType::String(value)) => value,
fn prepare_comparison(
value: &proto::types::MemberValue,
bind_values: &mut Vec<String>,
) -> Result<String, Error> {
match &value.value_type {
Some(ValueType::String(value)) => {
bind_values.push(value.to_string());
Ok("?".to_string())
}
Some(ValueType::Primitive(value)) => {
let primitive: Primitive = value.try_into()?;
primitive.to_sql_value()
let primitive: Primitive = (value.clone()).try_into()?;
bind_values.push(primitive.to_sql_value());
Ok("?".to_string())
}
None => return Err(QueryError::MissingParam("value_type".into()).into()),
};
Some(ValueType::List(values)) => Ok(format!(
"({})",
values
.values
.iter()
.map(|v| prepare_comparison(v, bind_values))
.collect::<Result<Vec<String>, Error>>()?
.join(", ")
)),
None => Err(QueryError::MissingParam("value_type".into()).into()),
}
}

let (namespace, model) = member_clause
.model
Expand Down Expand Up @@ -822,8 +839,15 @@ impl DojoWorld {
self.model_cache.models(&model_ids).await?.into_iter().map(|m| m.schema).collect();

// Use the member name directly as the column name since it's already flattened
let mut where_clause =
format!("[{}].[{}] {comparison_operator} ?", member_clause.model, member_clause.member);
let mut bind_values = Vec::new();
let value = prepare_comparison(
&member_clause.value.clone().ok_or(QueryError::MissingParam("value".into()))?,
&mut bind_values,
)?;
let mut where_clause = format!(
"[{}].[{}] {comparison_operator} {value}",
member_clause.model, member_clause.member
);
if entity_updated_after.is_some() {
where_clause += &format!(" AND {table}.updated_at >= ?");
}
Expand All @@ -837,15 +861,19 @@ impl DojoWorld {
limit,
offset,
)?;
let mut count_query = sqlx::query_scalar(&count_query);
for value in &bind_values {
count_query = count_query.bind(value);
}
if let Some(entity_updated_after) = entity_updated_after.clone() {
count_query = count_query.bind(entity_updated_after);
}
let total_count = count_query.fetch_optional(&self.pool).await?.unwrap_or(0);

let total_count = sqlx::query_scalar(&count_query)
.bind(comparison_value.clone())
.bind(entity_updated_after.clone())
.fetch_optional(&self.pool)
.await?
.unwrap_or(0);

let mut query = sqlx::query(&entity_query).bind(comparison_value);
let mut query = sqlx::query(&entity_query);
for value in &bind_values {
query = query.bind(value);
}
if let Some(entity_updated_after) = entity_updated_after.clone() {
query = query.bind(entity_updated_after);
}
Expand Down Expand Up @@ -1356,17 +1384,34 @@ fn build_composite_clause(
ClauseType::Member(member) => {
let comparison_operator = ComparisonOperator::from_repr(member.operator as usize)
.expect("invalid comparison operator");
let value = member.value.clone();
let comparison_value =
match value.ok_or(QueryError::MissingParam("value".into()))?.value_type {
Some(ValueType::String(value)) => value,
let value = member.value.clone().ok_or(QueryError::MissingParam("value".into()))?;
fn prepare_comparison(
value: &proto::types::MemberValue,
bind_values: &mut Vec<String>,
) -> Result<String, Error> {
match &value.value_type {
Some(ValueType::String(value)) => {
bind_values.push(value.to_string());
Ok("?".to_string())
}
Some(ValueType::Primitive(value)) => {
let primitive: Primitive = value.try_into()?;
primitive.to_sql_value()
let primitive: Primitive = (value.clone()).try_into()?;
bind_values.push(primitive.to_sql_value());
Ok("?".to_string())
}
None => return Err(QueryError::MissingParam("value_type".into()).into()),
};
bind_values.push(comparison_value);
Some(ValueType::List(values)) => Ok(format!(
"({})",
values
.values
.iter()
.map(|v| prepare_comparison(v, bind_values))
.collect::<Result<Vec<String>, Error>>()?
.join(", ")
)),
None => Err(QueryError::MissingParam("value_type".into()).into()),
}
}
let value = prepare_comparison(&value, &mut bind_values)?;

let model = member.model.clone();
// Get or create unique alias for this model
Expand Down Expand Up @@ -1394,7 +1439,7 @@ fn build_composite_clause(

// Use the column name directly since it's already flattened
where_clauses
.push(format!("([{alias}].[{}] {comparison_operator} ?)", member.member));
.push(format!("([{alias}].[{}] {comparison_operator} {value})", member.member));
}
ClauseType::Composite(nested) => {
// Handle nested composite by recursively building the clause
Expand Down
6 changes: 6 additions & 0 deletions crates/torii/grpc/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ pub enum ComparisonOperator {
Gte,
Lt,
Lte,
In,
NotIn,
}

impl fmt::Display for ComparisonOperator {
Expand All @@ -198,6 +200,8 @@ impl fmt::Display for ComparisonOperator {
ComparisonOperator::Lte => write!(f, "<="),
ComparisonOperator::Neq => write!(f, "!="),
ComparisonOperator::Eq => write!(f, "="),
ComparisonOperator::In => write!(f, "IN"),
ComparisonOperator::NotIn => write!(f, "NOT IN"),
}
}
}
Expand All @@ -211,6 +215,8 @@ impl From<proto::types::ComparisonOperator> for ComparisonOperator {
proto::types::ComparisonOperator::Lt => ComparisonOperator::Lt,
proto::types::ComparisonOperator::Lte => ComparisonOperator::Lte,
proto::types::ComparisonOperator::Neq => ComparisonOperator::Neq,
proto::types::ComparisonOperator::In => ComparisonOperator::In,
proto::types::ComparisonOperator::NotIn => ComparisonOperator::NotIn,
}
}
}
Expand Down

0 comments on commit fcf9e26

Please sign in to comment.