Skip to content

Commit

Permalink
Allow passing class hash to the VM if it is readily available
Browse files Browse the repository at this point in the history
  • Loading branch information
omerfirmak committed Nov 20, 2023
1 parent c1a7bde commit 7052692
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 19 deletions.
8 changes: 4 additions & 4 deletions mocks/mock_vm.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions node/throttled_vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ func NewThrottledVM(res vm.VM, concurrenyBudget uint, maxQueueLen int32) *Thrott
return (*ThrottledVM)(utils.NewThrottler[vm.VM](concurrenyBudget, &res).WithMaxQueueLen(maxQueueLen))
}

func (tvm *ThrottledVM) Call(contractAddr, selector *felt.Felt, calldata []felt.Felt, blockNumber,
func (tvm *ThrottledVM) Call(contractAddr, classHash, selector *felt.Felt, calldata []felt.Felt, blockNumber,
blockTimestamp uint64, state core.StateReader, network utils.Network,
) ([]*felt.Felt, error) {
var ret []*felt.Felt
throttler := (*utils.Throttler[vm.VM])(tvm)
return ret, throttler.Do(func(vm *vm.VM) error {
var err error
ret, err = (*vm).Call(contractAddr, selector, calldata, blockNumber, blockTimestamp, state, network)
ret, err = (*vm).Call(contractAddr, classHash, selector, calldata, blockNumber, blockTimestamp, state, network)
return err
})
}
Expand Down
5 changes: 3 additions & 2 deletions rpc/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -1129,7 +1129,7 @@ func (h *Handler) Call(call FunctionCall, id BlockID) ([]*felt.Felt, *jsonrpc.Er
return nil, ErrBlockNotFound
}

_, err = state.ContractClassHash(&call.ContractAddress)
classHash, err := state.ContractClassHash(&call.ContractAddress)
if err != nil {
return nil, ErrContractNotFound
}
Expand All @@ -1143,7 +1143,8 @@ func (h *Handler) Call(call FunctionCall, id BlockID) ([]*felt.Felt, *jsonrpc.Er
blockNumber = height + 1
}

res, err := h.vm.Call(&call.ContractAddress, &call.EntryPointSelector, call.Calldata, blockNumber, header.Timestamp, state, h.network)
res, err := h.vm.Call(&call.ContractAddress, classHash, &call.EntryPointSelector,
call.Calldata, blockNumber, header.Timestamp, state, h.network)
if err != nil {
if errors.Is(err, utils.ErrResourceBusy) {
return nil, ErrUnexpectedError.CloneWithData(err.Error())
Expand Down
10 changes: 8 additions & 2 deletions vm/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use starknet_api::{
transaction::Fee,
};
use starknet_api::{
core::{ChainId, ContractAddress, EntryPointSelector},
core::{ChainId, ContractAddress, EntryPointSelector, ClassHash},
hash::StarkHash,
transaction::TransactionVersion,
};
Expand All @@ -58,6 +58,7 @@ const N_STEPS_FEE_WEIGHT: f64 = 0.01;
#[no_mangle]
pub extern "C" fn cairoVMCall(
contract_address: *const c_uchar,
class_hash: *const c_uchar,
entry_point_selector: *const c_uchar,
calldata: *const *const c_uchar,
len_calldata: usize,
Expand All @@ -68,6 +69,11 @@ pub extern "C" fn cairoVMCall(
) {
let reader = JunoStateReader::new(reader_handle);
let contract_addr_felt = ptr_to_felt(contract_address);
let class_hash = if class_hash.is_null() {
None
} else {
Some(ClassHash(ptr_to_felt(class_hash)))
};
let entry_point_selector_felt = ptr_to_felt(entry_point_selector);
let chain_id_str = unsafe { CStr::from_ptr(chain_id) }.to_str().unwrap();

Expand All @@ -86,7 +92,7 @@ pub extern "C" fn cairoVMCall(
calldata: Calldata(calldata_vec.into()),
storage_address: contract_addr_felt.try_into().unwrap(),
call_type: CallType::Call,
class_hash: None,
class_hash: class_hash,
code_address: None,
caller_address: ContractAddress::default(),
initial_gas: INITIAL_GAS_COST,
Expand Down
16 changes: 11 additions & 5 deletions vm/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ package vm
//#include <stdint.h>
//#include <stdlib.h>
//#include <stddef.h>
// extern void cairoVMCall(char* contract_address, char* entry_point_selector, char** calldata, size_t len_calldata,
// uintptr_t readerHandle, unsigned long long block_number, unsigned long long block_timestamp,
// char* chain_id);
// extern void cairoVMCall(char* contract_address, char* class_hash, char* entry_point_selector, char** calldata,
// size_t len_calldata, uintptr_t readerHandle, unsigned long long block_number,
// unsigned long long block_timestamp, char* chain_id);
//
// extern void cairoVMExecute(char* txns_json, char* classes_json, uintptr_t readerHandle, unsigned long long block_number,
// unsigned long long block_timestamp, char* chain_id, char* sequencer_address, char* paid_fees_on_l1_json,
Expand All @@ -28,7 +28,7 @@ import (

//go:generate mockgen -destination=../mocks/mock_vm.go -package=mocks github.com/NethermindEth/juno/vm VM
type VM interface {
Call(contractAddr, selector *felt.Felt, calldata []felt.Felt, blockNumber,
Call(contractAddr, classHash, selector *felt.Felt, calldata []felt.Felt, blockNumber,
blockTimestamp uint64, state core.StateReader, network utils.Network,
) ([]*felt.Felt, error)
Execute(txns []core.Transaction, declaredClasses []core.Class, blockNumber, blockTimestamp uint64,
Expand Down Expand Up @@ -105,7 +105,7 @@ func makePtrFromFelt(val *felt.Felt) unsafe.Pointer {
return C.CBytes(feltBytes[:])
}

func (v *vm) Call(contractAddr, selector *felt.Felt, calldata []felt.Felt, blockNumber,
func (v *vm) Call(contractAddr, classHash, selector *felt.Felt, calldata []felt.Felt, blockNumber,
blockTimestamp uint64, state core.StateReader, network utils.Network,
) ([]*felt.Felt, error) {
context := &callContext{
Expand All @@ -129,8 +129,14 @@ func (v *vm) Call(contractAddr, selector *felt.Felt, calldata []felt.Felt, block
calldataArrPtr = unsafe.Pointer(&calldataPtrs[0])
}

classHashPtr := (*byte)(nil)
if classHash != nil {
classHashBytes := classHash.Bytes()
classHashPtr = &classHashBytes[0]
}
chainID := C.CString(network.ChainIDString())
C.cairoVMCall((*C.char)(unsafe.Pointer(&addrBytes[0])),
(*C.char)(unsafe.Pointer(classHashPtr)),
(*C.char)(unsafe.Pointer(&selectorBytes[0])),
(**C.char)(calldataArrPtr),
C.size_t(len(calldataPtrs)),
Expand Down
8 changes: 4 additions & 4 deletions vm/vm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func TestV0Call(t *testing.T) {
}))

entryPoint := utils.HexToFelt(t, "0x39e11d48192e4333233c7eb19d10ad67c362bb28580c604d67884c85da39695")
ret, err := New(nil).Call(contractAddr, entryPoint, nil, 0, 0, testState, utils.MAINNET)
ret, err := New(nil).Call(contractAddr, classHash, entryPoint, nil, 0, 0, testState, utils.MAINNET)
require.NoError(t, err)
assert.Equal(t, []*felt.Felt{&felt.Zero}, ret)

Expand All @@ -70,7 +70,7 @@ func TestV0Call(t *testing.T) {
},
}, nil))

ret, err = New(nil).Call(contractAddr, entryPoint, nil, 1, 0, testState, utils.MAINNET)
ret, err = New(nil).Call(contractAddr, classHash, entryPoint, nil, 1, 0, testState, utils.MAINNET)
require.NoError(t, err)
assert.Equal(t, []*felt.Felt{new(felt.Felt).SetUint64(1337)}, ret)
}
Expand Down Expand Up @@ -112,7 +112,7 @@ func TestV1Call(t *testing.T) {
// test_storage_read
entryPoint := utils.HexToFelt(t, "0x5df99ae77df976b4f0e5cf28c7dcfe09bd6e81aab787b19ac0c08e03d928cf")
storageLocation := utils.HexToFelt(t, "0x44")
ret, err := New(nil).Call(contractAddr, entryPoint, []felt.Felt{
ret, err := New(nil).Call(contractAddr, nil, entryPoint, []felt.Felt{
*storageLocation,
}, 0, 0, testState, utils.GOERLI)
require.NoError(t, err)
Expand All @@ -133,7 +133,7 @@ func TestV1Call(t *testing.T) {
},
}, nil))

ret, err = New(nil).Call(contractAddr, entryPoint, []felt.Felt{
ret, err = New(nil).Call(contractAddr, nil, entryPoint, []felt.Felt{
*storageLocation,
}, 1, 0, testState, utils.GOERLI)
require.NoError(t, err)
Expand Down

0 comments on commit 7052692

Please sign in to comment.