diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index bc9cc8e3b..f0b5f9c37 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -466,6 +466,32 @@ impl MastNodeId { } } + /// Returns a new [`MastNodeId`] with the provided `id`, or an error if `id` is greater or equal + /// to `node_count`. + /// + /// This function can be used when deserializing an id whose corresponding node is not yet in + /// the forest and [`Self::from_u32_safe`] would fail. For instance, when deserializing the ids + /// referenced by the Join node in this forest: + /// + /// ```text + /// [Join(1, 2), Block(foo), Block(bar)] + /// ``` + /// + /// Since it is less safe than [`Self::from_u32_safe`] and usually not needed it is not public. + pub(crate) fn from_u32_with_node_count( + id: u32, + node_count: usize, + ) -> Result { + if (id as usize) < node_count { + Ok(Self(id)) + } else { + Err(DeserializationError::InvalidValue(format!( + "Invalid deserialized MAST node ID '{}', but {} is the number of nodes in the forest", + id, node_count, + ))) + } + } + pub fn as_usize(&self) -> usize { self.0 as usize } diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index 4a3fa5865..bf9e417d0 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -44,7 +44,8 @@ impl MastNodeInfo { pub fn try_into_mast_node( self, - mast_forest: &mut MastForest, + mast_forest: &MastForest, + node_count: usize, basic_block_data_decoder: &BasicBlockDataDecoder, ) -> Result { match self.ty { @@ -59,29 +60,29 @@ impl MastNodeInfo { Ok(MastNode::Block(block)) }, MastNodeType::Join { left_child_id, right_child_id } => { - let left_child = MastNodeId::from_u32_safe(left_child_id, mast_forest)?; - let right_child = MastNodeId::from_u32_safe(right_child_id, mast_forest)?; + let left_child = MastNodeId::from_u32_with_node_count(left_child_id, node_count)?; + let right_child = MastNodeId::from_u32_with_node_count(right_child_id, node_count)?; let join = JoinNode::new_unsafe([left_child, right_child], self.digest); Ok(MastNode::Join(join)) }, MastNodeType::Split { if_branch_id, else_branch_id } => { - let if_branch = MastNodeId::from_u32_safe(if_branch_id, mast_forest)?; - let else_branch = MastNodeId::from_u32_safe(else_branch_id, mast_forest)?; + let if_branch = MastNodeId::from_u32_with_node_count(if_branch_id, node_count)?; + let else_branch = MastNodeId::from_u32_with_node_count(else_branch_id, node_count)?; let split = SplitNode::new_unsafe([if_branch, else_branch], self.digest); Ok(MastNode::Split(split)) }, MastNodeType::Loop { body_id } => { - let body_id = MastNodeId::from_u32_safe(body_id, mast_forest)?; + let body_id = MastNodeId::from_u32_with_node_count(body_id, node_count)?; let loop_node = LoopNode::new_unsafe(body_id, self.digest); Ok(MastNode::Loop(loop_node)) }, MastNodeType::Call { callee_id } => { - let callee_id = MastNodeId::from_u32_safe(callee_id, mast_forest)?; + let callee_id = MastNodeId::from_u32_with_node_count(callee_id, node_count)?; let call = CallNode::new_unsafe(callee_id, self.digest); Ok(MastNode::Call(call)) }, MastNodeType::SysCall { callee_id } => { - let callee_id = MastNodeId::from_u32_safe(callee_id, mast_forest)?; + let callee_id = MastNodeId::from_u32_with_node_count(callee_id, node_count)?; let syscall = CallNode::new_syscall_unsafe(callee_id, self.digest); Ok(MastNode::Call(syscall)) }, diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index cf67c17a3..e76f775c5 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -209,8 +209,11 @@ impl Deserializable for MastForest { for _ in 0..node_count { let mast_node_info = MastNodeInfo::read_from(source)?; - let node = mast_node_info - .try_into_mast_node(&mut mast_forest, &basic_block_data_decoder)?; + let node = mast_node_info.try_into_mast_node( + &mast_forest, + node_count, + &basic_block_data_decoder, + )?; mast_forest.add_node(node).map_err(|e| { DeserializationError::InvalidValue(format!( diff --git a/core/src/mast/serialization/tests.rs b/core/src/mast/serialization/tests.rs index c76a3ff3a..1e4fae689 100644 --- a/core/src/mast/serialization/tests.rs +++ b/core/src/mast/serialization/tests.rs @@ -347,6 +347,51 @@ fn serialize_deserialize_all_nodes() { assert_eq!(mast_forest, deserialized_mast_forest); } +/// Test that a forest with a node whose child ids are larger than its own id serializes and +/// deserializes successfully. +#[test] +fn mast_forest_serialize_deserialize_with_child_ids_exceeding_parent_id() { + let mut forest = MastForest::new(); + let deco0 = forest.add_decorator(Decorator::Trace(0)).unwrap(); + let deco1 = forest.add_decorator(Decorator::Trace(1)).unwrap(); + let zero = forest.add_block(vec![Operation::U32div], None).unwrap(); + let first = forest.add_block(vec![Operation::U32add], Some(vec![(0, deco0)])).unwrap(); + let second = forest.add_block(vec![Operation::U32and], Some(vec![(1, deco1)])).unwrap(); + forest.add_join(first, second).unwrap(); + + // Move the Join node before its child nodes and remove the temporary zero node. + forest.nodes.swap_remove(zero.as_usize()); + + MastForest::read_from_bytes(&forest.to_bytes()).unwrap(); +} + +/// Test that a forest with a node whose referenced index is >= the max number of nodes in +/// the forest returns an error during deserialization. +#[test] +fn mast_forest_serialize_deserialize_with_overflowing_ids_fails() { + let mut overflow_forest = MastForest::new(); + let id0 = overflow_forest.add_block(vec![Operation::Eqz], None).unwrap(); + overflow_forest.add_block(vec![Operation::Eqz], None).unwrap(); + let id2 = overflow_forest.add_block(vec![Operation::Eqz], None).unwrap(); + let id_join = overflow_forest.add_join(id0, id2).unwrap(); + + let join_node = overflow_forest[id_join].clone(); + + // Add the Join(0, 2) to this forest which does not have a node with index 2. + let mut forest = MastForest::new(); + let deco0 = forest.add_decorator(Decorator::Trace(0)).unwrap(); + let deco1 = forest.add_decorator(Decorator::Trace(1)).unwrap(); + forest + .add_block(vec![Operation::U32add], Some(vec![(0, deco0), (1, deco1)])) + .unwrap(); + forest.add_node(join_node).unwrap(); + + assert_matches!( + MastForest::read_from_bytes(&forest.to_bytes()), + Err(DeserializationError::InvalidValue(msg)) if msg.contains("number of nodes") + ); +} + #[test] fn mast_forest_invalid_node_id() { // Hydrate a forest smaller than the second