Skip to content

Commit

Permalink
Move rpc.TransactionTrace to vm package (#1578)
Browse files Browse the repository at this point in the history
Co-authored-by: Ömer Faruk IRMAK <[email protected]>
  • Loading branch information
joshklop and omerfirmak authored Dec 20, 2023
1 parent 84a3ffe commit 800daf0
Show file tree
Hide file tree
Showing 8 changed files with 280 additions and 111 deletions.
6 changes: 3 additions & 3 deletions mocks/mock_vm.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 2 additions & 4 deletions node/throttled_vm.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package node

import (
"encoding/json"

"github.com/NethermindEth/juno/core"
"github.com/NethermindEth/juno/core/felt"
"github.com/NethermindEth/juno/utils"
Expand Down Expand Up @@ -32,9 +30,9 @@ func (tvm *ThrottledVM) Call(contractAddr, classHash, selector *felt.Felt, calld
func (tvm *ThrottledVM) Execute(txns []core.Transaction, declaredClasses []core.Class, blockNumber, blockTimestamp uint64,
sequencerAddress *felt.Felt, state core.StateReader, network utils.Network, paidFeesOnL1 []*felt.Felt,
skipChargeFee, skipValidate, errOnRevert bool, gasPriceWEI *felt.Felt, gasPriceSTRK *felt.Felt, legacyTraceJSON bool,
) ([]*felt.Felt, []json.RawMessage, error) {
) ([]*felt.Felt, []vm.TransactionTrace, error) {
var ret []*felt.Felt
var traces []json.RawMessage
var traces []vm.TransactionTrace
throttler := (*utils.Throttler[vm.VM])(tvm)
return ret, traces, throttler.Do(func(vm *vm.VM) error {
var err error
Expand Down
14 changes: 7 additions & 7 deletions rpc/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -1368,23 +1368,23 @@ func (h *Handler) LegacyEstimateMessageFee(msg MsgFromL1, id BlockID) (*FeeEstim
//
// It follows the specification defined here:
// https://github.com/starkware-libs/starknet-specs/blob/1ae810e0137cc5d175ace4554892a4f43052be56/api/starknet_trace_api_openrpc.json#L11
func (h *Handler) TraceTransaction(ctx context.Context, hash felt.Felt) (json.RawMessage, *jsonrpc.Error) {
func (h *Handler) TraceTransaction(ctx context.Context, hash felt.Felt) (*vm.TransactionTrace, *jsonrpc.Error) {
return h.traceTransaction(ctx, &hash, false)
}

// LegacyTraceTransaction returns the trace for a given executed transaction, including internal calls
//
// It follows the specification defined here:
// https://github.com/starkware-libs/starknet-specs/blob/1ae810e0137cc5d175ace4554892a4f43052be56/api/starknet_trace_api_openrpc.json#L11
func (h *Handler) LegacyTraceTransaction(ctx context.Context, hash felt.Felt) (json.RawMessage, *jsonrpc.Error) {
func (h *Handler) LegacyTraceTransaction(ctx context.Context, hash felt.Felt) (*vm.TransactionTrace, *jsonrpc.Error) {
trace, err := h.traceTransaction(ctx, &hash, true)
if err != nil && err.Code == ErrTxnHashNotFound.Code {
err = ErrInvalidTxHash
}
return trace, err
}

func (h *Handler) traceTransaction(ctx context.Context, hash *felt.Felt, legacyTraceJSON bool) (json.RawMessage, *jsonrpc.Error) {
func (h *Handler) traceTransaction(ctx context.Context, hash *felt.Felt, legacyTraceJSON bool) (*vm.TransactionTrace, *jsonrpc.Error) {
_, _, blockNumber, err := h.bcReader.Receipt(hash)
if err != nil {
return nil, ErrTxnHashNotFound
Expand Down Expand Up @@ -1510,7 +1510,7 @@ func (h *Handler) simulateTransactions(id BlockID, transactions []BroadcastedTra
estimate.Unit = utils.Ptr(feeUnit)
}
result = append(result, SimulatedTransaction{
TransactionTrace: traces[i],
TransactionTrace: &traces[i],
FeeEstimation: estimate,
})
}
Expand Down Expand Up @@ -1638,10 +1638,10 @@ func (h *Handler) traceBlockTransactions(ctx context.Context, block *core.Block,
}

var result []TracedBlockTransaction
for i, trace := range traces {
for index := range traces {
result = append(result, TracedBlockTransaction{
TraceRoot: trace,
TransactionHash: block.Transactions[i].Hash(),
TraceRoot: &traces[index],
TransactionHash: block.Transactions[index].Hash(),
})
}

Expand Down
100 changes: 80 additions & 20 deletions rpc/handlers_test.go

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions rpc/simulation.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package rpc

import (
"encoding/json"
"fmt"

"github.com/NethermindEth/juno/core/felt"
"github.com/NethermindEth/juno/vm"
)

type SimulationFlag int
Expand All @@ -28,11 +28,11 @@ func (s *SimulationFlag) UnmarshalJSON(bytes []byte) (err error) {
}

type SimulatedTransaction struct {
TransactionTrace json.RawMessage `json:"transaction_trace,omitempty"`
FeeEstimation FeeEstimate `json:"fee_estimation,omitempty"`
TransactionTrace *vm.TransactionTrace `json:"transaction_trace,omitempty"`
FeeEstimation FeeEstimate `json:"fee_estimation,omitempty"`
}

type TracedBlockTransaction struct {
TraceRoot json.RawMessage `json:"trace_root,omitempty"`
TransactionHash *felt.Felt `json:"transaction_hash,omitempty"`
TraceRoot *vm.TransactionTrace `json:"trace_root,omitempty"`
TransactionHash *felt.Felt `json:"transaction_hash,omitempty"`
}
88 changes: 19 additions & 69 deletions rpc/trace.go
Original file line number Diff line number Diff line change
@@ -1,55 +1,14 @@
package rpc

import (
"encoding/json"
"errors"

"github.com/NethermindEth/juno/core/felt"
"github.com/NethermindEth/juno/starknet"
"github.com/NethermindEth/juno/utils"
"github.com/ethereum/go-ethereum/common"
"github.com/NethermindEth/juno/vm"
)

type TransactionTrace struct {
Type TransactionType `json:"type,omitempty"`
ValidateInvocation *FunctionInvocation `json:"validate_invocation,omitempty"`
ExecuteInvocation *ExecuteInvocation `json:"execute_invocation,omitempty"`
FeeTransferInvocation *FunctionInvocation `json:"fee_transfer_invocation,omitempty"`
ConstructorInvocation *FunctionInvocation `json:"constructor_invocation,omitempty"`
FunctionInvocation *FunctionInvocation `json:"function_invocation,omitempty"`
StateDiff *StateDiff `json:"state_diff,omitempty"`
}

type FunctionInvocation struct {
ContractAddress felt.Felt `json:"contract_address"`
EntryPointSelector *felt.Felt `json:"entry_point_selector,omitempty"`
Calldata []felt.Felt `json:"calldata"`
CallerAddress felt.Felt `json:"caller_address"`
ClassHash *felt.Felt `json:"class_hash,omitempty"`
EntryPointType string `json:"entry_point_type,omitempty"`
CallType string `json:"call_type,omitempty"`
Result []felt.Felt `json:"result"`
Calls []FunctionInvocation `json:"calls"`
Events []OrderedEvent `json:"events"`
Messages []OrderedL2toL1Message `json:"messages"`
ExecutionResources *ExecutionResources `json:"execution_resources,omitempty"`
}

type ExecuteInvocation struct {
RevertReason string `json:"revert_reason,omitempty"`
*FunctionInvocation `json:",omitempty"`
}

type OrderedEvent struct {
Order uint64 `json:"order"`
Event
}

type OrderedL2toL1Message struct {
Order uint64 `json:"order"`
MsgToL1
}

func adaptBlockTrace(block *BlockWithTxs, blockTrace *starknet.BlockTrace, legacyJSON bool) ([]TracedBlockTransaction, error) {
if blockTrace == nil {
return nil, nil
Expand All @@ -60,8 +19,8 @@ func adaptBlockTrace(block *BlockWithTxs, blockTrace *starknet.BlockTrace, legac
traces := make([]TracedBlockTransaction, 0, len(blockTrace.Traces))
for index := range blockTrace.Traces {
feederTrace := &blockTrace.Traces[index]
trace := TransactionTrace{}
trace.Type = block.Transactions[index].Type
trace := vm.TransactionTrace{}
trace.Type = vm.TransactionType(block.Transactions[index].Type)

trace.FeeTransferInvocation = adaptFunctionInvocation(feederTrace.FeeTransferInvocation, legacyJSON)
trace.ValidateInvocation = adaptFunctionInvocation(feederTrace.ValidateInvocation, legacyJSON)
Expand All @@ -73,7 +32,7 @@ func adaptBlockTrace(block *BlockWithTxs, blockTrace *starknet.BlockTrace, legac
case TxnDeployAccount:
trace.ConstructorInvocation = fnInvocation
case TxnInvoke:
trace.ExecuteInvocation = new(ExecuteInvocation)
trace.ExecuteInvocation = new(vm.ExecuteInvocation)
if feederTrace.RevertError != "" {
trace.ExecuteInvocation.RevertReason = feederTrace.RevertError
} else {
Expand All @@ -83,25 +42,20 @@ func adaptBlockTrace(block *BlockWithTxs, blockTrace *starknet.BlockTrace, legac
trace.FunctionInvocation = fnInvocation
}

traceJSON, err := json.Marshal(trace)
if err != nil {
return nil, err
}

traces = append(traces, TracedBlockTransaction{
TransactionHash: &feederTrace.TransactionHash,
TraceRoot: traceJSON,
TraceRoot: &trace,
})
}
return traces, nil
}

func adaptFunctionInvocation(snFnInvocation *starknet.FunctionInvocation, legacyJSON bool) *FunctionInvocation {
func adaptFunctionInvocation(snFnInvocation *starknet.FunctionInvocation, legacyJSON bool) *vm.FunctionInvocation {
if snFnInvocation == nil {
return nil
}

fnInvocation := FunctionInvocation{
fnInvocation := vm.FunctionInvocation{
ContractAddress: snFnInvocation.ContractAddress,
EntryPointSelector: snFnInvocation.Selector,
Calldata: snFnInvocation.Calldata,
Expand All @@ -110,9 +64,9 @@ func adaptFunctionInvocation(snFnInvocation *starknet.FunctionInvocation, legacy
EntryPointType: snFnInvocation.EntryPointType,
CallType: snFnInvocation.CallType,
Result: snFnInvocation.Result,
Calls: make([]FunctionInvocation, 0, len(snFnInvocation.InternalCalls)),
Events: make([]OrderedEvent, 0, len(snFnInvocation.Events)),
Messages: make([]OrderedL2toL1Message, 0, len(snFnInvocation.Messages)),
Calls: make([]vm.FunctionInvocation, 0, len(snFnInvocation.InternalCalls)),
Events: make([]vm.OrderedEvent, 0, len(snFnInvocation.Events)),
Messages: make([]vm.OrderedL2toL1Message, 0, len(snFnInvocation.Messages)),
ExecutionResources: adaptFeederExecutionResources(&snFnInvocation.ExecutionResources),
}

Expand All @@ -125,31 +79,27 @@ func adaptFunctionInvocation(snFnInvocation *starknet.FunctionInvocation, legacy
}
for index := range snFnInvocation.Events {
snEvent := &snFnInvocation.Events[index]
fnInvocation.Events = append(fnInvocation.Events, OrderedEvent{
fnInvocation.Events = append(fnInvocation.Events, vm.OrderedEvent{
Order: snEvent.Order,
Event: Event{
Keys: utils.Map(snEvent.Keys, utils.Ptr[felt.Felt]),
Data: utils.Map(snEvent.Data, utils.Ptr[felt.Felt]),
},
Keys: utils.Map(snEvent.Keys, utils.Ptr[felt.Felt]),
Data: utils.Map(snEvent.Data, utils.Ptr[felt.Felt]),
})
}
for index := range snFnInvocation.Messages {
snMessage := &snFnInvocation.Messages[index]
fnInvocation.Messages = append(fnInvocation.Messages, OrderedL2toL1Message{
Order: snMessage.Order,
MsgToL1: MsgToL1{
Payload: utils.Map(snMessage.Payload, utils.Ptr[felt.Felt]),
To: common.HexToAddress(snMessage.ToAddr),
},
fnInvocation.Messages = append(fnInvocation.Messages, vm.OrderedL2toL1Message{
Order: snMessage.Order,
Payload: utils.Map(snMessage.Payload, utils.Ptr[felt.Felt]),
To: snMessage.ToAddr,
})
}

return &fnInvocation
}

func adaptFeederExecutionResources(resources *starknet.ExecutionResources) *ExecutionResources {
func adaptFeederExecutionResources(resources *starknet.ExecutionResources) *vm.ExecutionResources {
builtins := &resources.BuiltinInstanceCounter
return &ExecutionResources{
return &vm.ExecutionResources{
Steps: resources.Steps,
MemoryHoles: resources.MemoryHoles,
Pedersen: builtins.Pedersen,
Expand Down
Loading

0 comments on commit 800daf0

Please sign in to comment.