Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Permit child MastNodeIds to exceed the MastNodeIds of their parents #1542

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
- Rename `EqHash` to `MastNodeFingerprint` and make it `pub` (#1539)
- [BREAKING] `DYN` operation now expects a memory address pointing to the procedure hash (#1535)
- [BREAKING] `DYNCALL` operation fixed, and now expects a memory address pointing to the procedure hash (#1535)

- Permit child `MastNodeId`s to exceed the `MastNodeId`s of their parents (#1542)

#### Fixes

Expand Down
37 changes: 28 additions & 9 deletions core/src/mast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -525,22 +525,41 @@ impl MastNodeId {
value: u32,
mast_forest: &MastForest,
) -> Result<Self, DeserializationError> {
if (value as usize) < mast_forest.nodes.len() {
Ok(Self(value))
} else {
Err(DeserializationError::InvalidValue(format!(
"Invalid deserialized MAST node ID '{}', but only {} nodes in the forest",
value,
mast_forest.nodes.len(),
)))
}
Self::from_u32_with_node_count(value, mast_forest.nodes.len())
}

/// Returns a new [`MastNodeId`] from the given `value` without checking its validity.
pub(crate) fn new_unchecked(value: u32) -> Self {
Self(value)
}

/// Returns a new [`MastNodeId`] with the provided `id`, or an error if `id` is greater or equal
/// to `node_count`. The `node_count` is the total number of nodes in the [`MastForest`] for
/// which this ID is being constructed.
///
/// This function can be used when deserializing an id whose corresponding node is not yet in
/// the forest and [`Self::from_u32_safe`] would fail. For instance, when deserializing the ids
/// referenced by the Join node in this forest:
///
/// ```text
/// [Join(1, 2), Block(foo), Block(bar)]
/// ```
///
/// Since it is less safe than [`Self::from_u32_safe`] and usually not needed it is not public.
pub(super) fn from_u32_with_node_count(
id: u32,
node_count: usize,
) -> Result<Self, DeserializationError> {
if (id as usize) < node_count {
Ok(Self(id))
} else {
Err(DeserializationError::InvalidValue(format!(
"Invalid deserialized MAST node ID '{}', but {} is the number of nodes in the forest",
id, node_count,
)))
}
}
PhilippGackstatter marked this conversation as resolved.
Show resolved Hide resolved

pub fn as_usize(&self) -> usize {
self.0 as usize
}
Expand Down
21 changes: 13 additions & 8 deletions core/src/mast/serialization/info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,14 @@ impl MastNodeInfo {
Self { ty, digest: mast_node.digest() }
}

/// Attempts to convert this [`MastNodeInfo`] into a [`MastNode`] for the given `mast_forest`.
///
/// The `node_count` is the total expected number of nodes in the [`MastForest`] **after
/// deserialization**.
pub fn try_into_mast_node(
self,
mast_forest: &mut MastForest,
mast_forest: &MastForest,
node_count: usize,
basic_block_data_decoder: &BasicBlockDataDecoder,
) -> Result<MastNode, DeserializationError> {
bobbinth marked this conversation as resolved.
Show resolved Hide resolved
match self.ty {
Expand All @@ -59,29 +64,29 @@ impl MastNodeInfo {
Ok(MastNode::Block(block))
},
MastNodeType::Join { left_child_id, right_child_id } => {
let left_child = MastNodeId::from_u32_safe(left_child_id, mast_forest)?;
let right_child = MastNodeId::from_u32_safe(right_child_id, mast_forest)?;
let left_child = MastNodeId::from_u32_with_node_count(left_child_id, node_count)?;
let right_child = MastNodeId::from_u32_with_node_count(right_child_id, node_count)?;
let join = JoinNode::new_unsafe([left_child, right_child], self.digest);
Ok(MastNode::Join(join))
},
MastNodeType::Split { if_branch_id, else_branch_id } => {
let if_branch = MastNodeId::from_u32_safe(if_branch_id, mast_forest)?;
let else_branch = MastNodeId::from_u32_safe(else_branch_id, mast_forest)?;
let if_branch = MastNodeId::from_u32_with_node_count(if_branch_id, node_count)?;
let else_branch = MastNodeId::from_u32_with_node_count(else_branch_id, node_count)?;
let split = SplitNode::new_unsafe([if_branch, else_branch], self.digest);
Ok(MastNode::Split(split))
},
MastNodeType::Loop { body_id } => {
let body_id = MastNodeId::from_u32_safe(body_id, mast_forest)?;
let body_id = MastNodeId::from_u32_with_node_count(body_id, node_count)?;
let loop_node = LoopNode::new_unsafe(body_id, self.digest);
Ok(MastNode::Loop(loop_node))
},
MastNodeType::Call { callee_id } => {
let callee_id = MastNodeId::from_u32_safe(callee_id, mast_forest)?;
let callee_id = MastNodeId::from_u32_with_node_count(callee_id, node_count)?;
let call = CallNode::new_unsafe(callee_id, self.digest);
Ok(MastNode::Call(call))
},
MastNodeType::SysCall { callee_id } => {
let callee_id = MastNodeId::from_u32_safe(callee_id, mast_forest)?;
let callee_id = MastNodeId::from_u32_with_node_count(callee_id, node_count)?;
let syscall = CallNode::new_syscall_unsafe(callee_id, self.digest);
Ok(MastNode::Call(syscall))
},
Expand Down
7 changes: 5 additions & 2 deletions core/src/mast/serialization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,11 @@ impl Deserializable for MastForest {
for _ in 0..node_count {
let mast_node_info = MastNodeInfo::read_from(source)?;

let node = mast_node_info
.try_into_mast_node(&mut mast_forest, &basic_block_data_decoder)?;
let node = mast_node_info.try_into_mast_node(
&mast_forest,
node_count,
&basic_block_data_decoder,
)?;

mast_forest.add_node(node).map_err(|e| {
DeserializationError::InvalidValue(format!(
Expand Down
45 changes: 45 additions & 0 deletions core/src/mast/serialization/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,51 @@ fn serialize_deserialize_all_nodes() {
assert_eq!(mast_forest, deserialized_mast_forest);
}

/// Test that a forest with a node whose child ids are larger than its own id serializes and
/// deserializes successfully.
#[test]
fn mast_forest_serialize_deserialize_with_child_ids_exceeding_parent_id() {
let mut forest = MastForest::new();
let deco0 = forest.add_decorator(Decorator::Trace(0)).unwrap();
let deco1 = forest.add_decorator(Decorator::Trace(1)).unwrap();
let zero = forest.add_block(vec![Operation::U32div], None).unwrap();
let first = forest.add_block(vec![Operation::U32add], Some(vec![(0, deco0)])).unwrap();
let second = forest.add_block(vec![Operation::U32and], Some(vec![(1, deco1)])).unwrap();
forest.add_join(first, second).unwrap();

// Move the Join node before its child nodes and remove the temporary zero node.
forest.nodes.swap_remove(zero.as_usize());

MastForest::read_from_bytes(&forest.to_bytes()).unwrap();
}

/// Test that a forest with a node whose referenced index is >= the max number of nodes in
/// the forest returns an error during deserialization.
#[test]
fn mast_forest_serialize_deserialize_with_overflowing_ids_fails() {
let mut overflow_forest = MastForest::new();
let id0 = overflow_forest.add_block(vec![Operation::Eqz], None).unwrap();
overflow_forest.add_block(vec![Operation::Eqz], None).unwrap();
let id2 = overflow_forest.add_block(vec![Operation::Eqz], None).unwrap();
let id_join = overflow_forest.add_join(id0, id2).unwrap();

let join_node = overflow_forest[id_join].clone();

// Add the Join(0, 2) to this forest which does not have a node with index 2.
let mut forest = MastForest::new();
let deco0 = forest.add_decorator(Decorator::Trace(0)).unwrap();
let deco1 = forest.add_decorator(Decorator::Trace(1)).unwrap();
forest
.add_block(vec![Operation::U32add], Some(vec![(0, deco0), (1, deco1)]))
.unwrap();
forest.add_node(join_node).unwrap();

assert_matches!(
MastForest::read_from_bytes(&forest.to_bytes()),
Err(DeserializationError::InvalidValue(msg)) if msg.contains("number of nodes")
);
}

#[test]
fn mast_forest_invalid_node_id() {
// Hydrate a forest smaller than the second
Expand Down
Loading