diff --git a/Cargo.lock b/Cargo.lock index c10c00b..ef7b2bb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1817,7 +1817,7 @@ dependencies = [ [[package]] name = "pyrevm" -version = "0.3.4" +version = "0.3.5" dependencies = [ "ethers-core", "ethers-providers", diff --git a/Cargo.toml b/Cargo.toml index a397903..639fd19 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pyrevm" -version = "0.3.4" +version = "0.3.5" edition = "2021" [lib] diff --git a/pyrevm.pyi b/pyrevm.pyi index 4700a7e..b210a13 100644 --- a/pyrevm.pyi +++ b/pyrevm.pyi @@ -201,7 +201,7 @@ class EVM: :param info: The account info. """ - def insert_account_storage(self: "EVM", address: str, index: int, value: int) -> None: + def insert_account_storage(self: "EVM", address: str, index: int, value: int) -> int: """ Inserts the provided value for slot of in the database at the specified address :param address: The address of the account. diff --git a/src/database.rs b/src/database.rs index 9d8a80a..507ef56 100644 --- a/src/database.rs +++ b/src/database.rs @@ -54,10 +54,12 @@ impl DB { value: U256, ) -> PyResult<()> { match self { - DB::Memory(db) => - db.insert_account_storage(address, slot, value).map_err(pyerr), - DB::Fork(db) => - db.insert_account_storage(address, slot, value).map_err(pyerr), + DB::Memory(db) => db + .insert_account_storage(address, slot, value) + .map_err(pyerr), + DB::Fork(db) => db + .insert_account_storage(address, slot, value) + .map_err(pyerr), } } diff --git a/src/evm.rs b/src/evm.rs index 90fed84..5e7c2f0 100644 --- a/src/evm.rs +++ b/src/evm.rs @@ -153,9 +153,30 @@ impl EVM { } /// Inserts the provided value for slot of in the database at the specified address - fn insert_account_storage(&mut self, address: &str, index: U256, value: U256) -> PyResult<()> { + fn insert_account_storage( + &mut self, + address: &str, + index: U256, + value: U256, + ) -> PyResult { let target = addr(address)?; - self.context.db.insert_insert_account_storage(target, index, value) + + match self.context.journaled_state.state.get_mut(&target) { + // account is cold, just insert into the DB + None => { + self.context + .db + .insert_insert_account_storage(target, index, value) + .map_err(pyerr)?; + self.context.load_account(target).map_err(pyerr)?; + Ok(U256::ZERO) + } + // just replace old value + Some(_) => { + let store_result = self.context.sstore(target, index, value).map_err(pyerr)?; + Ok(store_result.original_value) + } + } } /// Set the balance of a given address. diff --git a/tests/test_evm.py b/tests/test_evm.py index 4ce25e1..0c3bfe7 100644 --- a/tests/test_evm.py +++ b/tests/test_evm.py @@ -67,6 +67,25 @@ def test_set_into_storage(): value = evm.storage(weth, 0) assert value == 10 +def test_set_into_storage_with_update(): + weth = "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2" + evm = EVM(fork_url=fork_url, fork_block="latest") + evm.insert_account_storage(weth, 0, 10) + value = evm.storage(weth, 0) + assert value == 10 + evm.insert_account_storage(weth, 0, 20) + value = evm.storage(weth, 0) + assert value == 20 + +def test_set_into_storage_old_value(): + weth = "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2" + evm = EVM(fork_url=fork_url, fork_block="latest") + old_value = evm.insert_account_storage(weth, 0, 10) + assert old_value == 0 + old_value = evm.insert_account_storage(weth, 0, 20) + assert old_value == 10 + + def test_deploy(): evm = EVM()