diff --git a/CHANGELOG.md b/CHANGELOG.md index d37a550c6..95ad26bc1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # Changelog +## 0.12.0 (TBD) + +#### Enhancements +- Added `miden_core::mast::MastForest::advice_map` to load it into the advice provider before the `MastForest` execution (#1574). + ## 0.11.0 (2024-11-04) #### Enhancements diff --git a/core/src/advice/map.rs b/core/src/advice/map.rs index c42a416ef..1e0384eb4 100644 --- a/core/src/advice/map.rs +++ b/core/src/advice/map.rs @@ -41,6 +41,16 @@ impl AdviceMap { pub fn remove(&mut self, key: RpoDigest) -> Option> { self.0.remove(&key) } + + /// Returns the number of key value pairs in the advice map. + pub fn len(&self) -> usize { + self.0.len() + } + + /// Returns true if the advice map is empty. + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } } impl From>> for AdviceMap { diff --git a/core/src/mast/merger/mod.rs b/core/src/mast/merger/mod.rs index ae0273371..3e6009a61 100644 --- a/core/src/mast/merger/mod.rs +++ b/core/src/mast/merger/mod.rs @@ -65,10 +65,11 @@ impl MastForestMerger { /// /// It does this in three steps: /// - /// 1. Merge all decorators, which is a case of deduplication and creating a decorator id + /// 1. Merge all advice maps, checking for key collisions. + /// 2. Merge all decorators, which is a case of deduplication and creating a decorator id /// mapping which contains how existing [`DecoratorId`]s map to [`DecoratorId`]s in the /// merged forest. - /// 2. Merge all nodes of forests. + /// 3. Merge all nodes of forests. /// - Similar to decorators, node indices might move during merging, so the merger keeps a /// node id mapping as it merges nodes. /// - This is a depth-first traversal over all forests to ensure all children are processed @@ -90,10 +91,13 @@ impl MastForestMerger { /// `replacement` node. Now we can simply add a mapping from the external node to the /// `replacement` node in our node id mapping which means all nodes that referenced the /// external node will point to the `replacement` instead. - /// 3. Finally, we merge all roots of all forests. Here we map the existing root indices to + /// 4. Finally, we merge all roots of all forests. Here we map the existing root indices to /// their potentially new indices in the merged forest and add them to the forest, /// deduplicating in the process, too. fn merge_inner(&mut self, forests: Vec<&MastForest>) -> Result<(), MastForestError> { + for other_forest in forests.iter() { + self.merge_advice_map(other_forest)?; + } for other_forest in forests.iter() { self.merge_decorators(other_forest)?; } @@ -163,6 +167,17 @@ impl MastForestMerger { Ok(()) } + fn merge_advice_map(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> { + for (key, value) in other_forest.advice_map.clone().into_iter() { + if self.mast_forest.advice_map().get(&key).is_some() { + return Err(MastForestError::AdviceMapKeyCollisionOnMerge(key)); + } else { + self.mast_forest.advice_map_mut().insert(key, value.clone()); + } + } + Ok(()) + } + fn merge_node( &mut self, forest_idx: usize, diff --git a/core/src/mast/merger/tests.rs b/core/src/mast/merger/tests.rs index b33ae9729..c9100d9f7 100644 --- a/core/src/mast/merger/tests.rs +++ b/core/src/mast/merger/tests.rs @@ -1,4 +1,4 @@ -use miden_crypto::{hash::rpo::RpoDigest, ONE}; +use miden_crypto::{hash::rpo::RpoDigest, Felt, ONE}; use super::*; use crate::{Decorator, Operation}; @@ -794,3 +794,54 @@ fn mast_forest_merge_invalid_decorator_index() { let err = MastForest::merge([&forest_a, &forest_b]).unwrap_err(); assert_matches!(err, MastForestError::DecoratorIdOverflow(_, _)); } + +/// Tests that forest's advice maps are merged correctly. +#[test] +fn mast_forest_merge_advice_maps_merged() { + let mut forest_a = MastForest::new(); + let id_foo = forest_a.add_node(block_foo()).unwrap(); + let id_call_a = forest_a.add_call(id_foo).unwrap(); + forest_a.make_root(id_call_a); + let key_a = RpoDigest::new([Felt::new(1), Felt::new(2), Felt::new(3), Felt::new(4)]); + let value_a = vec![ONE, ONE]; + forest_a.advice_map_mut().insert(key_a, value_a.clone()); + + let mut forest_b = MastForest::new(); + let id_bar = forest_b.add_node(block_bar()).unwrap(); + let id_call_b = forest_b.add_call(id_bar).unwrap(); + forest_b.make_root(id_call_b); + let key_b = RpoDigest::new([Felt::new(1), Felt::new(3), Felt::new(2), Felt::new(1)]); + let value_b = vec![Felt::new(2), Felt::new(2)]; + forest_b.advice_map_mut().insert(key_b, value_b.clone()); + + let (merged, _root_maps) = MastForest::merge([&forest_a, &forest_b]).unwrap(); + + let merged_advice_map = merged.advice_map(); + assert_eq!(merged_advice_map.len(), 2); + assert_eq!(merged_advice_map.get(&key_a).unwrap(), &value_a); + assert_eq!(merged_advice_map.get(&key_b).unwrap(), &value_b); +} + +/// Tests that an error is returned when advice maps have a key collision. +#[test] +fn mast_forest_merge_advice_maps_collision() { + let mut forest_a = MastForest::new(); + let id_foo = forest_a.add_node(block_foo()).unwrap(); + let id_call_a = forest_a.add_call(id_foo).unwrap(); + forest_a.make_root(id_call_a); + let key_a = RpoDigest::new([Felt::new(1), Felt::new(2), Felt::new(3), Felt::new(4)]); + let value_a = vec![ONE, ONE]; + forest_a.advice_map_mut().insert(key_a, value_a.clone()); + + let mut forest_b = MastForest::new(); + let id_bar = forest_b.add_node(block_bar()).unwrap(); + let id_call_b = forest_b.add_call(id_bar).unwrap(); + forest_b.make_root(id_call_b); + // The key collides with key_a in the forest_a. + let key_b = key_a; + let value_b = vec![Felt::new(2), Felt::new(2)]; + forest_b.advice_map_mut().insert(key_b, value_b.clone()); + + let err = MastForest::merge([&forest_a, &forest_b]).unwrap_err(); + assert_matches!(err, MastForestError::AdviceMapKeyCollisionOnMerge(_)); +} diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index 444fedc5c..8e32a9c9d 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -16,7 +16,7 @@ pub use node::{ }; use winter_utils::{ByteWriter, DeserializationError, Serializable}; -use crate::{Decorator, DecoratorList, Operation}; +use crate::{AdviceMap, Decorator, DecoratorList, Operation}; mod serialization; @@ -50,6 +50,9 @@ pub struct MastForest { /// All the decorators included in the MAST forest. decorators: Vec, + + /// Advice map to be loaded into the VM prior to executing procedures from this MAST forest. + advice_map: AdviceMap, } // ------------------------------------------------------------------------------------------------ @@ -463,6 +466,14 @@ impl MastForest { pub fn nodes(&self) -> &[MastNode] { &self.nodes } + + pub fn advice_map(&self) -> &AdviceMap { + &self.advice_map + } + + pub fn advice_map_mut(&mut self) -> &mut AdviceMap { + &mut self.advice_map + } } impl Index for MastForest { @@ -689,4 +700,6 @@ pub enum MastForestError { EmptyBasicBlock, #[error("decorator root of child with node id {0} is missing but required for fingerprint computation")] ChildFingerprintMissing(MastNodeId), + #[error("advice map key already exists when merging forests: {0}")] + AdviceMapKeyCollisionOnMerge(RpoDigest), } diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index e76f775c5..8601d28a9 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -31,6 +31,7 @@ use string_table::{StringTable, StringTableBuilder}; use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; use super::{DecoratorId, MastForest, MastNode, MastNodeId}; +use crate::AdviceMap; mod decorator; @@ -71,7 +72,7 @@ const MAGIC: &[u8; 5] = b"MAST\0"; /// If future modifications are made to this format, the version should be incremented by 1. A /// version of `[255, 255, 255]` is reserved for future extensions that require extending the /// version field itself, but should be considered invalid for now. -const VERSION: [u8; 3] = [0, 0, 0]; +const VERSION: [u8; 3] = [0, 0, 1]; // MAST FOREST SERIALIZATION/DESERIALIZATION // ================================================================================================ @@ -161,6 +162,7 @@ impl Serializable for MastForest { // Write "before enter" and "after exit" decorators before_enter_decorators.write_into(target); after_exit_decorators.write_into(target); + self.advice_map.write_into(target); } } @@ -256,6 +258,7 @@ impl Deserializable for MastForest { let node_id = MastNodeId::from_u32_safe(node_id, &mast_forest)?; mast_forest.set_after_exit(node_id, decorator_ids); } + mast_forest.advice_map = AdviceMap::read_from(source)?; Ok(mast_forest) } diff --git a/core/src/mast/serialization/tests.rs b/core/src/mast/serialization/tests.rs index cb6e9e2c0..76a180b79 100644 --- a/core/src/mast/serialization/tests.rs +++ b/core/src/mast/serialization/tests.rs @@ -1,6 +1,6 @@ use alloc::{string::ToString, sync::Arc}; -use miden_crypto::{hash::rpo::RpoDigest, Felt}; +use miden_crypto::{hash::rpo::RpoDigest, Felt, ONE}; use super::*; use crate::{ @@ -435,3 +435,22 @@ fn mast_forest_invalid_node_id() { // Validate normal operations forest.add_join(first, second).unwrap(); } + +/// Test `MastForest::advice_map` serialization and deserialization. +#[test] +fn mast_forest_serialize_deserialize_advice_map() { + let mut forest = MastForest::new(); + let deco0 = forest.add_decorator(Decorator::Trace(0)).unwrap(); + let deco1 = forest.add_decorator(Decorator::Trace(1)).unwrap(); + let first = forest.add_block(vec![Operation::U32add], Some(vec![(0, deco0)])).unwrap(); + let second = forest.add_block(vec![Operation::U32and], Some(vec![(1, deco1)])).unwrap(); + forest.add_join(first, second).unwrap(); + + let key = RpoDigest::new([ONE, ONE, ONE, ONE]); + let value = vec![ONE, ONE]; + + forest.advice_map_mut().insert(key, value); + + let parsed = MastForest::read_from_bytes(&forest.to_bytes()).unwrap(); + assert_eq!(forest.advice_map, parsed.advice_map); +} diff --git a/miden/benches/program_execution.rs b/miden/benches/program_execution.rs index a8abd7050..e5a98266f 100644 --- a/miden/benches/program_execution.rs +++ b/miden/benches/program_execution.rs @@ -11,7 +11,7 @@ fn program_execution(c: &mut Criterion) { let stdlib = StdLibrary::default(); let mut host = DefaultHost::default(); - host.load_mast_forest(stdlib.as_ref().mast_forest().clone()); + host.load_mast_forest(stdlib.as_ref().mast_forest().clone()).unwrap(); group.bench_function("sha256", |bench| { let source = " diff --git a/miden/src/examples/blake3.rs b/miden/src/examples/blake3.rs index 4d4133b59..4f6fac657 100644 --- a/miden/src/examples/blake3.rs +++ b/miden/src/examples/blake3.rs @@ -22,7 +22,7 @@ pub fn get_example(n: usize) -> Example> { ); let mut host = DefaultHost::default(); - host.load_mast_forest(StdLibrary::default().mast_forest().clone()); + host.load_mast_forest(StdLibrary::default().mast_forest().clone()).unwrap(); let stack_inputs = StackInputs::try_from_ints(INITIAL_HASH_VALUE.iter().map(|&v| v as u64)).unwrap(); diff --git a/miden/src/repl/mod.rs b/miden/src/repl/mod.rs index 692e29df5..e8dd76b4f 100644 --- a/miden/src/repl/mod.rs +++ b/miden/src/repl/mod.rs @@ -318,7 +318,8 @@ fn execute( let stack_inputs = StackInputs::default(); let mut host = DefaultHost::default(); for library in provided_libraries { - host.load_mast_forest(library.mast_forest().clone()); + host.load_mast_forest(library.mast_forest().clone()) + .map_err(|err| format!("{err}"))?; } let state_iter = processor::execute_iter(&program, stack_inputs, host); diff --git a/miden/src/tools/mod.rs b/miden/src/tools/mod.rs index 39d9d9ea5..430403b26 100644 --- a/miden/src/tools/mod.rs +++ b/miden/src/tools/mod.rs @@ -38,7 +38,8 @@ impl Analyze { // fetch the stack and program inputs from the arguments let stack_inputs = input_data.parse_stack_inputs().map_err(Report::msg)?; let mut host = DefaultHost::new(input_data.parse_advice_provider().map_err(Report::msg)?); - host.load_mast_forest(StdLibrary::default().mast_forest().clone()); + host.load_mast_forest(StdLibrary::default().mast_forest().clone()) + .into_diagnostic()?; let execution_details: ExecutionDetails = analyze(program.as_str(), stack_inputs, host) .expect("Could not retrieve execution details"); diff --git a/miden/tests/integration/exec.rs b/miden/tests/integration/exec.rs new file mode 100644 index 000000000..029774217 --- /dev/null +++ b/miden/tests/integration/exec.rs @@ -0,0 +1,53 @@ +use assembly::Assembler; +use miden_vm::DefaultHost; +use processor::{ExecutionOptions, MastForest}; +use prover::{Digest, StackInputs}; +use vm_core::{assert_matches, Program, ONE}; + +#[test] +fn advice_map_loaded_before_execution() { + let source = "\ + begin + push.1.1.1.1 + adv.push_mapval + dropw + end"; + + // compile and execute program + let program_without_advice_map: Program = + Assembler::default().assemble_program(source).unwrap(); + + // Test `processor::execute` fails if no advice map provided with the program + let mut host = DefaultHost::default(); + match processor::execute( + &program_without_advice_map, + StackInputs::default(), + &mut host, + ExecutionOptions::default(), + ) { + Ok(_) => panic!("Expected error"), + Err(e) => { + assert_matches!(e, prover::ExecutionError::AdviceMapKeyNotFound(_)); + }, + } + + // Test `processor::execute` works if advice map provided with the program + let mast_forest: MastForest = (**program_without_advice_map.mast_forest()).clone(); + + let key = Digest::new([ONE, ONE, ONE, ONE]); + let value = vec![ONE, ONE]; + + let mut mast_forest = mast_forest.clone(); + mast_forest.advice_map_mut().insert(key, value); + let program_with_advice_map = + Program::new(mast_forest.into(), program_without_advice_map.entrypoint()); + + let mut host = DefaultHost::default(); + processor::execute( + &program_with_advice_map, + StackInputs::default(), + &mut host, + ExecutionOptions::default(), + ) + .unwrap(); +} diff --git a/miden/tests/integration/main.rs b/miden/tests/integration/main.rs index 10720fdb1..0ee815409 100644 --- a/miden/tests/integration/main.rs +++ b/miden/tests/integration/main.rs @@ -4,6 +4,7 @@ use test_utils::{build_op_test, build_test}; mod air; mod cli; +mod exec; mod exec_iters; mod flow_control; mod operations; diff --git a/processor/src/errors.rs b/processor/src/errors.rs index 230a37a1c..f2aa58b71 100644 --- a/processor/src/errors.rs +++ b/processor/src/errors.rs @@ -23,6 +23,7 @@ use super::{ #[derive(Debug, Clone, PartialEq, Eq)] pub enum ExecutionError { AdviceMapKeyNotFound(Word), + AdviceMapKeyAlreadyPresent(Word), AdviceStackReadFailed(RowIndex), CallerNotInSyscall, CircularExternalNode(Digest), @@ -96,6 +97,10 @@ impl Display for ExecutionError { let hex = to_hex(Felt::elements_as_bytes(key)); write!(f, "Value for key {hex} not present in the advice map") }, + AdviceMapKeyAlreadyPresent(key) => { + let hex = to_hex(Felt::elements_as_bytes(key)); + write!(f, "Value for key {hex} already present in the advice map") + }, AdviceStackReadFailed(step) => write!(f, "Advice stack read failed at step {step}"), CallerNotInSyscall => { write!(f, "Instruction `caller` used outside of kernel context") diff --git a/processor/src/host/mod.rs b/processor/src/host/mod.rs index 56b0529ee..68cb33147 100644 --- a/processor/src/host/mod.rs +++ b/processor/src/host/mod.rs @@ -334,8 +334,17 @@ where } } - pub fn load_mast_forest(&mut self, mast_forest: Arc) { - self.store.insert(mast_forest) + pub fn load_mast_forest(&mut self, mast_forest: Arc) -> Result<(), ExecutionError> { + // Load the MAST's advice data into the advice provider. + for (digest, values) in mast_forest.advice_map().clone().into_iter() { + if self.adv_provider.get_mapped_values(&digest).is_some() { + return Err(ExecutionError::AdviceMapKeyAlreadyPresent(digest.into())); + } else { + self.adv_provider.insert_into_map(digest.into(), values); + } + } + self.store.insert(mast_forest); + Ok(()) } #[cfg(any(test, feature = "testing"))] diff --git a/processor/src/lib.rs b/processor/src/lib.rs index 3cf34efe4..15dc2cac5 100644 --- a/processor/src/lib.rs +++ b/processor/src/lib.rs @@ -253,6 +253,18 @@ where return Err(ExecutionError::ProgramAlreadyExecuted); } + // Load the program's advice data into the advice provider + for (digest, values) in program.mast_forest().advice_map().clone().into_iter() { + if self.host.borrow().advice_provider().get_mapped_values(&digest).is_some() { + return Err(ExecutionError::AdviceMapKeyAlreadyPresent(digest.into())); + } else { + self.host + .borrow_mut() + .advice_provider_mut() + .insert_into_map(digest.into(), values); + } + } + self.execute_mast_node(program.entrypoint(), &program.mast_forest().clone())?; self.stack.build_stack_outputs() diff --git a/stdlib/tests/mem/mod.rs b/stdlib/tests/mem/mod.rs index 61c3836f9..69da86cef 100644 --- a/stdlib/tests/mem/mod.rs +++ b/stdlib/tests/mem/mod.rs @@ -31,7 +31,7 @@ fn test_memcopy() { assembler.assemble_program(source).expect("Failed to compile test source."); let mut host = DefaultHost::default(); - host.load_mast_forest(stdlib.mast_forest().clone()); + host.load_mast_forest(stdlib.mast_forest().clone()).unwrap(); let mut process = Process::new( program.kernel().clone(), diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs index c2515702d..79484d709 100644 --- a/test-utils/src/lib.rs +++ b/test-utils/src/lib.rs @@ -240,10 +240,10 @@ impl Test { let (program, kernel) = self.compile().expect("Failed to compile test source."); let mut host = DefaultHost::new(MemAdviceProvider::from(self.advice_inputs.clone())); if let Some(kernel) = kernel { - host.load_mast_forest(kernel.mast_forest().clone()); + host.load_mast_forest(kernel.mast_forest().clone()).unwrap(); } for library in &self.libraries { - host.load_mast_forest(library.mast_forest().clone()); + host.load_mast_forest(library.mast_forest().clone()).unwrap(); } // execute the test @@ -338,10 +338,10 @@ impl Test { let (program, kernel) = self.compile().expect("Failed to compile test source."); let mut host = DefaultHost::new(MemAdviceProvider::from(self.advice_inputs.clone())); if let Some(kernel) = kernel { - host.load_mast_forest(kernel.mast_forest().clone()); + host.load_mast_forest(kernel.mast_forest().clone()).unwrap(); } for library in &self.libraries { - host.load_mast_forest(library.mast_forest().clone()); + host.load_mast_forest(library.mast_forest().clone()).unwrap(); } processor::execute(&program, self.stack_inputs.clone(), host, ExecutionOptions::default()) } @@ -354,10 +354,10 @@ impl Test { let (program, kernel) = self.compile().expect("Failed to compile test source."); let mut host = DefaultHost::new(MemAdviceProvider::from(self.advice_inputs.clone())); if let Some(kernel) = kernel { - host.load_mast_forest(kernel.mast_forest().clone()); + host.load_mast_forest(kernel.mast_forest().clone()).unwrap(); } for library in &self.libraries { - host.load_mast_forest(library.mast_forest().clone()); + host.load_mast_forest(library.mast_forest().clone()).unwrap(); } let mut process = Process::new( @@ -378,10 +378,10 @@ impl Test { let (program, kernel) = self.compile().expect("Failed to compile test source."); let mut host = DefaultHost::new(MemAdviceProvider::from(self.advice_inputs.clone())); if let Some(kernel) = kernel { - host.load_mast_forest(kernel.mast_forest().clone()); + host.load_mast_forest(kernel.mast_forest().clone()).unwrap(); } for library in &self.libraries { - host.load_mast_forest(library.mast_forest().clone()); + host.load_mast_forest(library.mast_forest().clone()).unwrap(); } let (mut stack_outputs, proof) = prover::prove(&program, stack_inputs.clone(), host, ProvingOptions::default()).unwrap(); @@ -403,10 +403,10 @@ impl Test { let (program, kernel) = self.compile().expect("Failed to compile test source."); let mut host = DefaultHost::new(MemAdviceProvider::from(self.advice_inputs.clone())); if let Some(kernel) = kernel { - host.load_mast_forest(kernel.mast_forest().clone()); + host.load_mast_forest(kernel.mast_forest().clone()).unwrap(); } for library in &self.libraries { - host.load_mast_forest(library.mast_forest().clone()); + host.load_mast_forest(library.mast_forest().clone()).unwrap(); } processor::execute_iter(&program, self.stack_inputs.clone(), host) }