Skip to content

Commit

Permalink
feat(core): Permit child MastNodeIds to exceed parent ids
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilippGackstatter committed Oct 22, 2024
1 parent b2294a1 commit 41b8a82
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 10 deletions.
26 changes: 26 additions & 0 deletions core/src/mast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,32 @@ impl MastNodeId {
}
}

/// Returns a new [`MastNodeId`] with the provided `id`, or an error if `id` is greater or equal
/// to `node_count`.
///
/// This function can be used when deserializing an id whose corresponding node is not yet in
/// the forest and [`Self::from_u32_safe`] would fail. For instance, when deserializing the ids
/// referenced by the Join node in this forest:
///
/// ```text
/// [Join(1, 2), Block(foo), Block(bar)]
/// ```
///
/// Since it is less safe than [`Self::from_u32_safe`] and usually not needed it is not public.
pub(crate) fn from_u32_with_node_count(
id: u32,
node_count: usize,
) -> Result<Self, DeserializationError> {
if (id as usize) < node_count {
Ok(Self(id))
} else {
Err(DeserializationError::InvalidValue(format!(
"Invalid deserialized MAST node ID '{}', but {} is the number of nodes in the forest",
id, node_count,
)))
}
}

pub fn as_usize(&self) -> usize {
self.0 as usize
}
Expand Down
17 changes: 9 additions & 8 deletions core/src/mast/serialization/info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ impl MastNodeInfo {

pub fn try_into_mast_node(
self,
mast_forest: &mut MastForest,
mast_forest: &MastForest,
node_count: usize,
basic_block_data_decoder: &BasicBlockDataDecoder,
) -> Result<MastNode, DeserializationError> {
match self.ty {
Expand All @@ -59,29 +60,29 @@ impl MastNodeInfo {
Ok(MastNode::Block(block))
},
MastNodeType::Join { left_child_id, right_child_id } => {
let left_child = MastNodeId::from_u32_safe(left_child_id, mast_forest)?;
let right_child = MastNodeId::from_u32_safe(right_child_id, mast_forest)?;
let left_child = MastNodeId::from_u32_with_node_count(left_child_id, node_count)?;
let right_child = MastNodeId::from_u32_with_node_count(right_child_id, node_count)?;
let join = JoinNode::new_unsafe([left_child, right_child], self.digest);
Ok(MastNode::Join(join))
},
MastNodeType::Split { if_branch_id, else_branch_id } => {
let if_branch = MastNodeId::from_u32_safe(if_branch_id, mast_forest)?;
let else_branch = MastNodeId::from_u32_safe(else_branch_id, mast_forest)?;
let if_branch = MastNodeId::from_u32_with_node_count(if_branch_id, node_count)?;
let else_branch = MastNodeId::from_u32_with_node_count(else_branch_id, node_count)?;
let split = SplitNode::new_unsafe([if_branch, else_branch], self.digest);
Ok(MastNode::Split(split))
},
MastNodeType::Loop { body_id } => {
let body_id = MastNodeId::from_u32_safe(body_id, mast_forest)?;
let body_id = MastNodeId::from_u32_with_node_count(body_id, node_count)?;
let loop_node = LoopNode::new_unsafe(body_id, self.digest);
Ok(MastNode::Loop(loop_node))
},
MastNodeType::Call { callee_id } => {
let callee_id = MastNodeId::from_u32_safe(callee_id, mast_forest)?;
let callee_id = MastNodeId::from_u32_with_node_count(callee_id, node_count)?;
let call = CallNode::new_unsafe(callee_id, self.digest);
Ok(MastNode::Call(call))
},
MastNodeType::SysCall { callee_id } => {
let callee_id = MastNodeId::from_u32_safe(callee_id, mast_forest)?;
let callee_id = MastNodeId::from_u32_with_node_count(callee_id, node_count)?;
let syscall = CallNode::new_syscall_unsafe(callee_id, self.digest);
Ok(MastNode::Call(syscall))
},
Expand Down
7 changes: 5 additions & 2 deletions core/src/mast/serialization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,11 @@ impl Deserializable for MastForest {
for _ in 0..node_count {
let mast_node_info = MastNodeInfo::read_from(source)?;

let node = mast_node_info
.try_into_mast_node(&mut mast_forest, &basic_block_data_decoder)?;
let node = mast_node_info.try_into_mast_node(
&mast_forest,
node_count,
&basic_block_data_decoder,
)?;

mast_forest.add_node(node).map_err(|e| {
DeserializationError::InvalidValue(format!(
Expand Down
45 changes: 45 additions & 0 deletions core/src/mast/serialization/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,51 @@ fn serialize_deserialize_all_nodes() {
assert_eq!(mast_forest, deserialized_mast_forest);
}

/// Test that a forest with a node whose child ids are larger than its own id serializes and
/// deserializes successfully.
#[test]
fn mast_forest_serialize_deserialize_with_child_ids_exceeding_parent_id() {
let mut forest = MastForest::new();
let deco0 = forest.add_decorator(Decorator::Trace(0)).unwrap();
let deco1 = forest.add_decorator(Decorator::Trace(1)).unwrap();
let zero = forest.add_block(vec![Operation::U32div], None).unwrap();
let first = forest.add_block(vec![Operation::U32add], Some(vec![(0, deco0)])).unwrap();
let second = forest.add_block(vec![Operation::U32and], Some(vec![(1, deco1)])).unwrap();
forest.add_join(first, second).unwrap();

// Move the Join node before its child nodes and remove the temporary zero node.
forest.nodes.swap_remove(zero.as_usize());

MastForest::read_from_bytes(&forest.to_bytes()).unwrap();
}

/// Test that a forest with a node whose referenced index is >= the max number of nodes in
/// the forest returns an error during deserialization.
#[test]
fn mast_forest_serialize_deserialize_with_overflowing_ids_fails() {
let mut overflow_forest = MastForest::new();
let id0 = overflow_forest.add_block(vec![Operation::Eqz], None).unwrap();
overflow_forest.add_block(vec![Operation::Eqz], None).unwrap();
let id2 = overflow_forest.add_block(vec![Operation::Eqz], None).unwrap();
let id_join = overflow_forest.add_join(id0, id2).unwrap();

let join_node = overflow_forest[id_join].clone();

// Add the Join(0, 2) to this forest which does not have a node with index 2.
let mut forest = MastForest::new();
let deco0 = forest.add_decorator(Decorator::Trace(0)).unwrap();
let deco1 = forest.add_decorator(Decorator::Trace(1)).unwrap();
forest
.add_block(vec![Operation::U32add], Some(vec![(0, deco0), (1, deco1)]))
.unwrap();
forest.add_node(join_node).unwrap();

assert_matches!(
MastForest::read_from_bytes(&forest.to_bytes()),
Err(DeserializationError::InvalidValue(msg)) if msg.contains("number of nodes")
);
}

#[test]
fn mast_forest_invalid_node_id() {
// Hydrate a forest smaller than the second
Expand Down

0 comments on commit 41b8a82

Please sign in to comment.