Skip to content

Commit

Permalink
Contract segmentation: Add bytecode_segment_lengths to CasmContractCl…
Browse files Browse the repository at this point in the history
…ass.

commit-id:68464189
  • Loading branch information
liorgold2 committed Dec 6, 2023
1 parent b92f4d5 commit eae1920
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 5 deletions.
22 changes: 20 additions & 2 deletions crates/cairo-lang-starknet/src/casm_contract_class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,15 @@ use starknet_crypto::{poseidon_hash_many, FieldElement};
use thiserror::Error;

use crate::allowed_libfuncs::AllowedLibfuncsError;
use crate::compiler_version::{current_compiler_version_id, current_sierra_version_id, VersionId};
use crate::compiler_version::{
current_compiler_version_id, current_sierra_version_id, VersionId,
CONTRACT_SEGMENTATION_MINOR_VERSION,
};
use crate::contract::starknet_keccak;
use crate::contract_class::{ContractClass, ContractEntryPoint};
use crate::contract_segmentation::{
compute_bytecode_segment_lengths, NestedIntList, SegmentationError,
};
use crate::felt252_serde::{sierra_from_felt252s, Felt252SerdeError};

/// The expected gas cost of an entrypoint.
Expand All @@ -59,6 +65,8 @@ pub enum StarknetSierraCompilationError {
MetadataError(#[from] MetadataError),
#[error(transparent)]
AllowedLibfuncsError(#[from] AllowedLibfuncsError),
#[error(transparent)]
SegmentationError(#[from] SegmentationError),
#[error("Invalid entry point.")]
EntryPointError,
#[error("Missing arguments in the entry point.")]
Expand Down Expand Up @@ -93,6 +101,8 @@ pub struct CasmContractClass {
pub prime: BigUint,
pub compiler_version: String,
pub bytecode: Vec<BigUintAsHex>,
#[serde(skip_serializing_if = "skip_if_none")]
pub bytecode_segment_lengths: Option<NestedIntList>,
pub hints: Vec<(usize, Vec<Hint>)>,

// Optional pythonic hints in a format that can be executed by the python vm.
Expand Down Expand Up @@ -347,7 +357,7 @@ impl CasmContractClass {

let mut bytecode = vec![];
let mut hints = vec![];
for instruction in cairo_program.instructions {
for instruction in &cairo_program.instructions {
if !instruction.hints.is_empty() {
hints.push((bytecode.len(), instruction.hints.clone()))
}
Expand All @@ -360,6 +370,13 @@ impl CasmContractClass {
}))
}

let bytecode_segment_lengths =
if sierra_version.minor >= CONTRACT_SEGMENTATION_MINOR_VERSION {
Some(compute_bytecode_segment_lengths(&program, &cairo_program, bytecode.len())?)
} else {
None
};

let builtin_types = UnorderedHashSet::<GenericTypeId>::from_iter([
RangeCheckType::id(),
BitwiseType::id(),
Expand Down Expand Up @@ -470,6 +487,7 @@ impl CasmContractClass {
prime,
compiler_version,
bytecode,
bytecode_segment_lengths,
hints,
pythonic_hints,
entry_points_by_type: CasmContractEntryPoints {
Expand Down
2 changes: 2 additions & 0 deletions crates/cairo-lang-starknet/src/compiler_version.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ pub struct VersionId {
pub patch: usize,
}

pub const CONTRACT_SEGMENTATION_MINOR_VERSION: usize = 5;

impl std::fmt::Display for VersionId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}.{}.{}", self.major, self.minor, self.patch)
Expand Down
42 changes: 39 additions & 3 deletions crates/cairo-lang-starknet/src/contract_segmentation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,24 @@ mod test;
use cairo_lang_sierra::program::{BranchTarget, Program, Statement, StatementIdx};
use cairo_lang_sierra_to_casm::compiler::CairoProgram;
use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
use serde::{Deserialize, Serialize};
use thiserror::Error;

/// NestedIntList is either a list of NestedIntList or an integer.
/// E.g., `[0, [1, 2], [3, [4]]]`.
///
/// Used to represents the lengths of the segments in a contract, which are in a form of a tree.
///
/// For example, the contract may be segmented by functions, where each function is segmented by
/// its branches. It is also possible to have the inner segmentation only for some of the functions,
/// while others are kept as non-segmented leaves in the tree.
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum NestedIntList {
Leaf(usize),
Node(Vec<NestedIntList>),
}

#[derive(Error, Debug, Eq, PartialEq)]
pub enum SegmentationError {
#[error("Expected a function start at index 0.")]
Expand All @@ -15,9 +31,31 @@ pub enum SegmentationError {
JumpOutsideFunction(StatementIdx),
}

/// Computes the bytecode_segment_length for the given contract.
pub fn compute_bytecode_segment_lengths(
program: &Program,
cairo_program: &CairoProgram,
bytecode_size: usize,
) -> Result<NestedIntList, SegmentationError> {
if bytecode_size == 0 {
return Ok(NestedIntList::Leaf(0));
}
let segment_start_statements = find_segments(program)?;
let segment_start_offsets = statement_ids_to_offsets(cairo_program, &segment_start_statements);
Ok(NestedIntList::Node(
get_segment_lengths(&segment_start_offsets, bytecode_size)
.iter()
.map(|segment| {
NestedIntList::Node(
segment.iter().map(|length| NestedIntList::Leaf(*length)).collect(),
)
})
.collect(),
))
}

/// Returns a vector of vectors, where each inner vector represents a function in the program,
/// and contains the starts (as statement indices) of the segments in the function.
#[allow(dead_code)]
fn find_segments(program: &Program) -> Result<Vec<Vec<usize>>, SegmentationError> {
// Get the set of function entry points.
let function_statement_ids: UnorderedHashSet<usize> =
Expand Down Expand Up @@ -46,7 +84,6 @@ fn find_segments(program: &Program) -> Result<Vec<Vec<usize>>, SegmentationError
}

/// Converts the result of [find_segments] from statement ids to bytecode offsets.
#[allow(dead_code)]
fn statement_ids_to_offsets(
cairo_program: &CairoProgram,
segment_starts_statements: &[Vec<usize>],
Expand All @@ -69,7 +106,6 @@ fn statement_ids_to_offsets(

/// Returns a vector of vectors, where each inner vector represents a function in the program,
/// and contains the lengths of the segments in the function.
#[allow(dead_code)]
fn get_segment_lengths(
segment_starts_offsets: &[Vec<usize>],
bytecode_len: usize,
Expand Down

0 comments on commit eae1920

Please sign in to comment.