From c9489b79ca106661523e3c14921decf28cb2aefd Mon Sep 17 00:00:00 2001 From: "Jason (Yeli) Zhu" Date: Wed, 25 Oct 2023 22:41:23 +0000 Subject: [PATCH] fix: Fix an issue preventing NNS Governance from rolling back to an older version --- .../src/memory_manager_upgrade_storage.rs | 136 ++++++++++++++++-- 1 file changed, 128 insertions(+), 8 deletions(-) diff --git a/rs/nervous_system/common/src/memory_manager_upgrade_storage.rs b/rs/nervous_system/common/src/memory_manager_upgrade_storage.rs index ecfc22d5b79..e2e92b82d43 100644 --- a/rs/nervous_system/common/src/memory_manager_upgrade_storage.rs +++ b/rs/nervous_system/common/src/memory_manager_upgrade_storage.rs @@ -303,6 +303,16 @@ impl<'a, M: Memory> SizeAwareReader<'a, M> { } } +fn checked_div_mod(dividend: usize, divisor: usize) -> (usize, usize) { + let quotient = dividend + .checked_div(divisor) + .expect("Failed to calculate quotient"); + let remainder = dividend + .checked_rem(divisor) + .expect("Failed to calculate remainder"); + (quotient, remainder) +} + impl<'a, M: Memory> Buf for SizeAwareReader<'a, M> { fn remaining(&self) -> usize { // Our implementation only reads from stable memory up until the size indicated by size bytes @@ -326,18 +336,55 @@ impl<'a, M: Memory> Buf for SizeAwareReader<'a, M> { } fn advance(&mut self, cnt: usize) { - self.buffer_offset = self - .buffer_offset - .checked_add(cnt) - .expect("Tried to advance buffer beyond maximum offset"); + let remaining = self.remaining(); assert!( - self.buffer_offset <= self.buffer.len(), - "Buffer offset was greater than buffer length" + cnt <= remaining, + "Trying to advance {} bytes while only {} bytes remaining", + cnt, + remaining ); - if self.buffer_offset == self.buffer.len() { + // Why below is correct: + // + // Definition 1: the absolute address of the cursor is the address of + // `self.buffer[self.buffer_offset]` within the entire `Buf`, which can also be expressed as + // `self.stable_mem_offset - self.buffer.len() + self.buffer_offset`. + // + // Definition 2: the absolute address of the buffer start is the address of `self.buffer[0]` + // within the entire `Buf`. + // + // The intended effect of this method is to change the state(`self.buffer`, + // `self.buffer_offset`, `self.stable_mem_offset`) so that the absolute address it + // represents is `cnt` larger than its current state, while maintaining the invariant that + // `self.buffer_offset < self.buffer.len()`. + // + // Without considering the invariant, we can simply increment `self.buffer_offset` by `cnt`. + // However, to maintain the invariant: + // + // Every `read()` increases the absolute address of the buffer start by `buffer_size` (note + // that it does not always increase the buffer end by `buffer_size` because the last + // `read()` could read fewer than `buffer_size`). At the same time, if we decrease + // `self.buffer_offset` by `buffer_size`, the combined effect is that the absolute address + // of the cursor would be unchanged. Therefore, if we do the 2 things any number of times, + // we can still keep the absolute address of the cursor unchanged. + // + // Given the `checked_div_mod` arithmetic, we know that `self.buffer_offset + cnt = + // num_buffers_to_advance * buffer_size + new_buffer_offset`. + // + // Doing the above-mentioned 2 things `num_buffers_to_advance` times will result in: (1) + // calling read() `num_buffers_to_advance` times (2) set `self.buffer_offset = + // self.buffer_offset + cnt - num_buffers_to_advance * buffer_size = new_buffer_offset`. + let (num_buffers_to_advance, new_buffer_offset) = checked_div_mod( + self.buffer_offset + .checked_add(cnt) + .expect("Tried to advance buffer beyond maximum offset"), + self.buffer.capacity(), + ); + + for _ in 0..num_buffers_to_advance { self.read(); } + self.buffer_offset = new_buffer_offset; } } @@ -351,7 +398,7 @@ mod test { STORAGE_ENCODING_BYTES_RESERVED, }, }; - use bytes::BufMut; + use bytes::{Buf, BufMut}; use ic_nns_governance::pb::v1::{Governance, NetworkEconomics, Neuron}; use ic_stable_structures::{vec_mem::VectorMemory, Memory}; use prost::Message; @@ -441,6 +488,45 @@ mod test { assert_eq!(gov, decoded); } + #[test] + fn test_size_aware_reader_advance_as_buf() { + // We should be able to call `Buf::advance(cnt)` as long as `cnt < self.remaining()`. More + // specifically, we try to advance past one buffer size (100). + let memory = VectorMemory::default(); + + let mut vec: Vec<_> = (0u8..=255).collect(); + let mut size = vec.len().to_le_bytes().to_vec(); + + memory.borrow_mut().append(&mut size); + memory.borrow_mut().append(&mut vec); + + // There will be 3 pages (256 bytes with 100 per page). + let mut reader = SizeAwareReader::new(&memory, 100, 0); + + // Advancing 36 times will get to byte 252. + for i in 1..=36 { + // Advance in a way that cannot align with the buffer size 100, and advance() should not panic. + reader.advance(7); + assert_eq!(reader.remaining(), (256 - 7 * i) as usize); + assert_eq!(reader.chunk()[0], (7 * i) as u8); + } + } + + #[test] + #[should_panic] + fn test_size_aware_reader_should_panic_when_advancing_past_end() { + let memory = VectorMemory::default(); + + let mut vec = [1u8; 1000].to_vec(); + let mut size = vec.len().to_le_bytes().to_vec(); + + memory.borrow_mut().append(&mut size); + memory.borrow_mut().append(&mut vec); + + let mut reader = SizeAwareReader::new(&memory, 100, 0); + reader.advance(1001); + } + #[test] fn tiny_buffer_value() { let memory = VectorMemory::default(); @@ -492,6 +578,40 @@ mod test { assert_eq!(gov, decoded); } + #[derive(::prost::Message)] + pub struct TestMessageWithoutSubMessage { + #[prost(fixed32, repeated, tag = "1")] + pub x: ::prost::alloc::vec::Vec, + } + #[derive(::prost::Message)] + pub struct TestMessageWithSubMessage { + #[prost(fixed32, repeated, tag = "1")] + pub x: ::prost::alloc::vec::Vec, + #[prost(message, optional, tag = "2")] + pub sub: ::core::option::Option, + } + #[derive(::prost::Message)] + pub struct TestSubMessage { + #[prost(fixed32, repeated, tag = "1")] + pub y: ::prost::alloc::vec::Vec, + } + + #[test] + fn test_store_and_load_protobuf_with_missing_field() { + // The 'missing field' `sub` needs to be larger than 64KB, and 20000 * 4B > 64KB. + let m2 = TestMessageWithSubMessage { + x: (0..1000).collect(), + sub: Some(TestSubMessage { + y: (0..20000).collect(), + }), + }; + let memory = VectorMemory::default(); + + store_protobuf(&memory, &m2).expect("Storing failed in test"); + let _: TestMessageWithoutSubMessage = + load_protobuf(&memory).expect("Loading failed in test"); + } + #[test] fn test_multiple_writes_results_in_safe_read() { let gov1 = allocate_governance(3);