Skip to content

Commit

Permalink
chore: clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
tohrnii committed Feb 8, 2023
1 parent 54b5b88 commit 07906a1
Show file tree
Hide file tree
Showing 8 changed files with 240 additions and 120 deletions.
2 changes: 1 addition & 1 deletion air-script/tests/list_comprehension/list_comprehension.air
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
def TraceColGroupAir
def ListComprehensionAir

trace_columns:
main: [clk, fmp[2], ctx]
Expand Down
6 changes: 3 additions & 3 deletions air-script/tests/list_comprehension/list_comprehension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,18 @@ impl Serializable for PublicInputs {
}
}

pub struct TraceColGroupAir {
pub struct ListComprehensionAir {
context: AirContext<Felt>,
stack_inputs: [Felt; 16],
}

impl TraceColGroupAir {
impl ListComprehensionAir {
pub fn last_step(&self) -> usize {
self.trace_length() - self.context().num_transition_exemptions()
}
}

impl Air for TraceColGroupAir {
impl Air for ListComprehensionAir {
type BaseField = Felt;
type PublicInputs = PublicInputs;

Expand Down
1 change: 1 addition & 0 deletions air-script/tests/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ fn random_values() {

#[test]
fn list_comprehension() {
// TODO: Improve this test to include more complicated expressions
let generated_air = Test::new("tests/list_comprehension/list_comprehension.air".to_string())
.transpile()
.unwrap();
Expand Down
75 changes: 26 additions & 49 deletions ir/src/constraints/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,10 @@ impl AlgebraicGraph {
let node_index = if let Expression::Const(rhs) = **rhs {
self.insert_op(Operation::Exp(lhs.root_idx(), rhs as usize))
} else {
todo!()
Err(SemanticError::InvalidUsage(
"Non const exponents are only allowed inside list comprehensions"
.to_string(),
))?
};

Ok(ExprDetails::new(
Expand All @@ -172,60 +175,34 @@ impl AlgebraicGraph {
}
Expression::ListFolding(lf_type) => match lf_type {
ListFoldingType::Sum(lc) => {
if let VariableType::Vector(list) = parse_lc(lc, symbol_table)? {
let mut sum = self.insert_expr(
symbol_table,
&list[0],
variable_roots,
default_domain,
let list = parse_lc(lc, symbol_table)?;
let mut sum =
self.insert_expr(symbol_table, &list[0], variable_roots, default_domain)?;
for elem in list.iter().skip(1) {
let expr =
self.insert_expr(symbol_table, elem, variable_roots, default_domain)?;
sum = self.insert_bin_op(
&sum,
&expr,
Operation::Add(sum.root_idx(), expr.root_idx()),
)?;
for elem in list.iter().skip(1) {
let expr = self.insert_expr(
symbol_table,
elem,
variable_roots,
default_domain,
)?;
sum = self.insert_bin_op(
&sum,
&expr,
Operation::Add(sum.root_idx(), expr.root_idx()),
)?;
}
Ok(sum)
} else {
Err(SemanticError::InvalidListFolding(
"List folding not allowed for non vectors".to_string(),
))
}
Ok(sum)
}
ListFoldingType::Prod(lc) => {
if let VariableType::Vector(list) = parse_lc(lc, symbol_table)? {
let mut prod = self.insert_expr(
symbol_table,
&list[0],
variable_roots,
default_domain,
let list = parse_lc(lc, symbol_table)?;
let mut prod =
self.insert_expr(symbol_table, &list[0], variable_roots, default_domain)?;
for elem in list.iter().skip(1) {
let expr =
self.insert_expr(symbol_table, elem, variable_roots, default_domain)?;
prod = self.insert_bin_op(
&prod,
&expr,
Operation::Mul(prod.root_idx(), expr.root_idx()),
)?;
for elem in list.iter().skip(1) {
let expr = self.insert_expr(
symbol_table,
elem,
variable_roots,
default_domain,
)?;
prod = self.insert_bin_op(
&prod,
&expr,
Operation::Mul(prod.root_idx(), expr.root_idx()),
)?;
}
Ok(prod)
} else {
Err(SemanticError::InvalidListFolding(
"List folding not allowed for non vectors".to_string(),
))
}
Ok(prod)
}
},
}
Expand Down
88 changes: 54 additions & 34 deletions ir/src/constraints/list_comprehension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,24 @@ use air_script_core::{
};

/// Parses a list comprehension and returns the corresponding vector of expressions.
///
/// Returns an error if there is an error while parsing any of the expressions in the expanded
/// vector from the list comprehension.
pub fn parse_lc(
lc: &ListComprehension,
symbol_table: &SymbolTable,
) -> Result<VariableType, SemanticError> {
) -> Result<Vec<Expression>, SemanticError> {
let lc_length = lc_length(lc, symbol_table)?;
let mut vector = Vec::new();
for i in 0..lc_length {
vector.push(parse_lc_expr(lc.expression(), lc, symbol_table, i)?);
}
Ok(VariableType::Vector(vector))
let vector = (0..lc_length)
.map(|i| parse_lc_expr(lc.expression(), lc, symbol_table, i))
.collect::<Result<Vec<_>, _>>()?;
Ok(vector)
}

/// Parses a list comprehension expression and creates an expression based on the index of the
/// expression in the list comprehension.
///
/// Returns an error if there is an error while parsing the sub-expression.
fn parse_lc_expr(
expression: &Expression,
lc: &ListComprehension,
Expand All @@ -41,12 +45,12 @@ fn parse_lc_expr(
Expression::Sub(lhs, rhs) => {
let lhs = parse_lc_expr(lhs, lc, symbol_table, i)?;
let rhs = parse_lc_expr(rhs, lc, symbol_table, i)?;
Ok(Expression::Add(Box::new(lhs), Box::new(rhs)))
Ok(Expression::Sub(Box::new(lhs), Box::new(rhs)))
}
Expression::Mul(lhs, rhs) => {
let lhs = parse_lc_expr(lhs, lc, symbol_table, i)?;
let rhs = parse_lc_expr(rhs, lc, symbol_table, i)?;
Ok(Expression::Add(Box::new(lhs), Box::new(rhs)))
Ok(Expression::Mul(Box::new(lhs), Box::new(rhs)))
}
Expression::Exp(lhs, rhs) => {
let lhs = parse_lc_expr(lhs, lc, symbol_table, i)?;
Expand All @@ -61,6 +65,16 @@ fn parse_lc_expr(
}

/// Parses an identifier in a list comprehension expression.
///
/// # Errors
/// - Returns an error if the iterable is an identifier and that identifier does not correspond to
/// a vector.
/// - Returns an error if the iterable is an identifier but is not of a type in set:
/// { TraceColumns, IntegrityVariable, PublicInput, RandomValuesBinding }.
/// - Returns an error if the iterable is a slice and that identifier does not correspond to
/// a vector.
/// - Returns an error if the iterable is an identifier but is not of a type in set:
/// { TraceColumns, IntegrityVariable, PublicInput, RandomValuesBinding }.
fn parse_elem(
ident: &Identifier,
expression: &Expression,
Expand Down Expand Up @@ -165,6 +179,13 @@ fn parse_elem(
}

/// Parses a named trace access in a list comprehension expression.
///
/// # Errors
/// - Returns an error if the iterable is an identifier and that identifier does not correspond to
/// a trace column.
/// - Returns an error if the iterable is a range.
/// - Returns an error if the iterable is a slice and that identifier does not correspond to a
/// trace column.
fn parse_named_trace_access(
named_trace_access: &NamedTraceAccess,
expression: &Expression,
Expand Down Expand Up @@ -227,36 +248,28 @@ fn parse_list_folding(
) -> Result<Expression, SemanticError> {
match lf_type {
ListFoldingType::Sum(lc) => {
if let VariableType::Vector(list) = parse_lc(lc, symbol_table)? {
let mut sum = parse_lc_expr(expression, lc, symbol_table, i)?;
for elem in list.iter().skip(1) {
let expr = parse_lc_expr(elem, lc, symbol_table, i)?;
sum = Expression::Add(Box::new(sum), Box::new(expr));
}
Ok(sum)
} else {
Err(SemanticError::InvalidListFolding(
"List folding not allowed for non vectors".to_string(),
))
let list = parse_lc(lc, symbol_table)?;
let mut sum = parse_lc_expr(expression, lc, symbol_table, i)?;
for elem in list.iter().skip(1) {
let expr = parse_lc_expr(elem, lc, symbol_table, i)?;
sum = Expression::Add(Box::new(sum), Box::new(expr));
}
Ok(sum)
}
ListFoldingType::Prod(lc) => {
if let VariableType::Vector(list) = parse_lc(lc, symbol_table)? {
let mut sum = parse_lc_expr(expression, lc, symbol_table, i)?;
for elem in list.iter().skip(1) {
let expr = parse_lc_expr(elem, lc, symbol_table, i)?;
sum = Expression::Add(Box::new(sum), Box::new(expr));
}
Ok(sum)
} else {
Err(SemanticError::InvalidListFolding(
"List folding not allowed for non vectors".to_string(),
))
let list = parse_lc(lc, symbol_table)?;
let mut prod = parse_lc_expr(expression, lc, symbol_table, i)?;
for elem in list.iter().skip(1) {
let expr = parse_lc_expr(elem, lc, symbol_table, i)?;
prod = Expression::Mul(Box::new(prod), Box::new(expr));
}
Ok(prod)
}
}
}

/// Validates and returns the length of a list comprehension. Checks that the length of all iterables
/// in the list comprehension is the same and returns an error if it's not.
fn lc_length(lc: &ListComprehension, symbol_table: &SymbolTable) -> Result<usize, SemanticError> {
let lc_len = iterable_length(symbol_table, &lc.context()[0].1)?;
for (_, iterable) in lc.context().iter().skip(1) {
Expand All @@ -271,6 +284,12 @@ fn lc_length(lc: &ListComprehension, symbol_table: &SymbolTable) -> Result<usize
}

/// Returns the length of an iterable.
///
/// # Errors
/// - Returns an error if the iterable identifier is anything other than a vector in the symbol
/// table if it's a variable.
/// - Returns an error if the iterable is not of type in set:
/// { IntegrityVariable, PublicInput, TraceColumns }
fn iterable_length(
symbol_table: &SymbolTable,
iterable: &Iterable,
Expand All @@ -296,13 +315,14 @@ fn iterable_length(
}
}

/// Checks if the access index is valid.
/// Checks if the access index is valid. Returns an error if the access index is greater than
/// the size of the vector.
fn validate_access(i: usize, size: usize) -> Result<(), SemanticError> {
if i >= size {
if i < size {
Ok(())
} else {
Err(SemanticError::InvalidListComprehension(
"Invalid access index".to_string(),
))
} else {
Ok(())
}
}
2 changes: 1 addition & 1 deletion ir/src/constraints/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ impl Constraints {
let vector = parse_lc(list_comprehension, symbol_table)?;
symbol_table.insert_integrity_variable(&Variable::new(
Identifier(variable.name().to_string()),
vector,
VariableType::Vector(vector),
))?
} else {
symbol_table.insert_integrity_variable(variable)?
Expand Down
1 change: 0 additions & 1 deletion ir/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ pub enum SemanticError {
InvalidConstraintDomain(String),
InvalidIdentifier(String),
InvalidListComprehension(String),
InvalidListFolding(String),
InvalidPeriodicColumn(String),
InvalidUsage(String),
MissingDeclaration(String),
Expand Down
Loading

0 comments on commit 07906a1

Please sign in to comment.