Skip to content

Commit

Permalink
Add trace helper methods
Browse files Browse the repository at this point in the history
  • Loading branch information
joshklop authored and omerfirmak committed Dec 26, 2023
1 parent ac88a30 commit 69d18d3
Show file tree
Hide file tree
Showing 2 changed files with 298 additions and 0 deletions.
72 changes: 72 additions & 0 deletions vm/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package vm
import (
"errors"
"fmt"
"slices"

"github.com/NethermindEth/juno/core/felt"
)
Expand Down Expand Up @@ -105,6 +106,61 @@ type TransactionTrace struct {
StateDiff *StateDiff `json:"state_diff,omitempty"`
}

func (t *TransactionTrace) allInvocations() []*FunctionInvocation {
var executeInvocation *FunctionInvocation
if t.ExecuteInvocation != nil {
executeInvocation = t.ExecuteInvocation.FunctionInvocation
}
return slices.DeleteFunc([]*FunctionInvocation{
t.ConstructorInvocation,
t.ValidateInvocation,
t.FeeTransferInvocation,
executeInvocation,
t.FunctionInvocation,
}, func(i *FunctionInvocation) bool { return i == nil })
}

func (t *TransactionTrace) TotalExecutionResources() *ExecutionResources {
total := new(ExecutionResources)
for _, invocation := range t.allInvocations() {
r := invocation.ExecutionResources
total.Pedersen += r.Pedersen
total.RangeCheck += r.RangeCheck
total.Bitwise += r.Bitwise
total.Ecsda += r.Ecsda
total.EcOp += r.EcOp
total.Keccak += r.Keccak
total.Poseidon += r.Poseidon
total.SegmentArena += r.SegmentArena
total.MemoryHoles += r.MemoryHoles
total.Steps += r.Steps
}
return total
}

func (t *TransactionTrace) RevertReason() string {
if t.ExecuteInvocation == nil {
return ""
}
return t.ExecuteInvocation.RevertReason
}

func (t *TransactionTrace) AllEvents() []OrderedEvent {
events := make([]OrderedEvent, 0)
for _, invocation := range t.allInvocations() {
events = append(events, invocation.allEvents()...)
}
return events
}

func (t *TransactionTrace) AllMessages() []OrderedL2toL1Message {
messages := make([]OrderedL2toL1Message, 0)
for _, invocation := range t.allInvocations() {
messages = append(messages, invocation.allMessages()...)
}
return messages
}

type FunctionInvocation struct {
ContractAddress felt.Felt `json:"contract_address"`
EntryPointSelector *felt.Felt `json:"entry_point_selector,omitempty"`
Expand All @@ -120,6 +176,22 @@ type FunctionInvocation struct {
ExecutionResources *ExecutionResources `json:"execution_resources,omitempty"`
}

func (invocation *FunctionInvocation) allEvents() []OrderedEvent {
events := make([]OrderedEvent, 0)
for i := range invocation.Calls {
events = append(events, invocation.Calls[i].allEvents()...)
}
return append(events, invocation.Events...)
}

func (invocation *FunctionInvocation) allMessages() []OrderedL2toL1Message {
messages := make([]OrderedL2toL1Message, 0)
for i := range invocation.Calls {
messages = append(messages, invocation.Calls[i].allMessages()...)
}
return append(messages, invocation.Messages...)
}

type ExecuteInvocation struct {
RevertReason string `json:"revert_reason,omitempty"`
*FunctionInvocation `json:",omitempty"`
Expand Down
226 changes: 226 additions & 0 deletions vm/trace_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
package vm_test

import (
"testing"

"github.com/NethermindEth/juno/vm"
"github.com/stretchr/testify/require"
)

func TestRevertReason(t *testing.T) {
require.Equal(t, (&vm.TransactionTrace{
ExecuteInvocation: &vm.ExecuteInvocation{
RevertReason: "reason",
},
}).RevertReason(), "reason")

require.Empty(t, (&vm.TransactionTrace{
ExecuteInvocation: &vm.ExecuteInvocation{},
}).RevertReason())
}

//nolint:dupl
func TestAllEvents(t *testing.T) {
numEvents := uint64(10)
events := make([]vm.OrderedEvent, 0, numEvents)
for i := uint64(0); i < numEvents; i++ {
events = append(events, vm.OrderedEvent{Order: i})
}
tests := map[string]*vm.TransactionTrace{
"many top-level invocations": {
ValidateInvocation: &vm.FunctionInvocation{
Events: []vm.OrderedEvent{events[0]},
},
FunctionInvocation: &vm.FunctionInvocation{
Events: []vm.OrderedEvent{events[1]},
},
ConstructorInvocation: &vm.FunctionInvocation{
Events: []vm.OrderedEvent{events[2]},
},
ExecuteInvocation: &vm.ExecuteInvocation{
FunctionInvocation: &vm.FunctionInvocation{
Events: []vm.OrderedEvent{events[3]},
},
},
FeeTransferInvocation: &vm.FunctionInvocation{
Events: events[4:],
},
},
"only validate invocation": {
ValidateInvocation: &vm.FunctionInvocation{
Events: events,
},
},
"present in some sub-calls": {
ValidateInvocation: &vm.FunctionInvocation{
Events: []vm.OrderedEvent{events[0]},
Calls: []vm.FunctionInvocation{
{
Events: events[1:5],
},
},
},
FunctionInvocation: &vm.FunctionInvocation{
Events: []vm.OrderedEvent{events[5]},
Calls: []vm.FunctionInvocation{
{
Events: events[6:],
},
},
},
},
}

for description, trace := range tests {
t.Run(description, func(t *testing.T) {
require.ElementsMatch(t, events, trace.AllEvents())
})
}
}

//nolint:dupl
func TestAllMessages(t *testing.T) {
nummessages := uint64(10)
messages := make([]vm.OrderedL2toL1Message, 0, nummessages)
for i := uint64(0); i < nummessages; i++ {
messages = append(messages, vm.OrderedL2toL1Message{Order: i})
}
tests := map[string]*vm.TransactionTrace{
"many top-level invocations": {
ValidateInvocation: &vm.FunctionInvocation{
Messages: []vm.OrderedL2toL1Message{messages[0]},
},
FunctionInvocation: &vm.FunctionInvocation{
Messages: []vm.OrderedL2toL1Message{messages[1]},
},
ConstructorInvocation: &vm.FunctionInvocation{
Messages: []vm.OrderedL2toL1Message{messages[2]},
},
ExecuteInvocation: &vm.ExecuteInvocation{
FunctionInvocation: &vm.FunctionInvocation{
Messages: []vm.OrderedL2toL1Message{messages[3]},
},
},
FeeTransferInvocation: &vm.FunctionInvocation{
Messages: messages[4:],
},
},
"only validate invocation": {
ValidateInvocation: &vm.FunctionInvocation{
Messages: messages,
},
},
"present in some sub-calls": {
ValidateInvocation: &vm.FunctionInvocation{
Messages: []vm.OrderedL2toL1Message{messages[0]},
Calls: []vm.FunctionInvocation{
{
Messages: messages[1:5],
},
},
},
FunctionInvocation: &vm.FunctionInvocation{
Messages: []vm.OrderedL2toL1Message{messages[5]},
Calls: []vm.FunctionInvocation{
{
Messages: messages[6:],
},
},
},
},
}

for description, trace := range tests {
t.Run(description, func(t *testing.T) {
require.ElementsMatch(t, messages, trace.AllMessages())
})
}
}

func TestTotalExecutionResources(t *testing.T) {
resources := &vm.ExecutionResources{
Steps: 1,
MemoryHoles: 2,
Pedersen: 3,
RangeCheck: 4,
Bitwise: 5,
Ecsda: 6,
EcOp: 7,
Keccak: 8,
Poseidon: 9,
SegmentArena: 10,
}
tests := map[string]struct {
multiplier uint64
trace *vm.TransactionTrace
}{
"many top-level invocations": {
multiplier: 5,
trace: &vm.TransactionTrace{
ValidateInvocation: &vm.FunctionInvocation{
ExecutionResources: resources,
},
FunctionInvocation: &vm.FunctionInvocation{
ExecutionResources: resources,
},
ConstructorInvocation: &vm.FunctionInvocation{
ExecutionResources: resources,
},
ExecuteInvocation: &vm.ExecuteInvocation{
FunctionInvocation: &vm.FunctionInvocation{
ExecutionResources: resources,
},
},
FeeTransferInvocation: &vm.FunctionInvocation{
ExecutionResources: resources,
},
},
},
"only validate invocation": {
multiplier: 1,
trace: &vm.TransactionTrace{
ValidateInvocation: &vm.FunctionInvocation{
ExecutionResources: resources,
},
},
},
"present in some sub-calls": {
multiplier: 2,
trace: &vm.TransactionTrace{
ValidateInvocation: &vm.FunctionInvocation{
ExecutionResources: resources,
Calls: []vm.FunctionInvocation{
{
ExecutionResources: resources,
},
},
},
FunctionInvocation: &vm.FunctionInvocation{
ExecutionResources: resources,
Calls: []vm.FunctionInvocation{
{
ExecutionResources: resources,
},
},
},
},
},
}

for description, test := range tests {
t.Run(description, func(t *testing.T) {
require.Equal(t, &vm.ExecutionResources{
Steps: resources.Steps * test.multiplier,
MemoryHoles: resources.MemoryHoles * test.multiplier,
Pedersen: resources.Pedersen * test.multiplier,
RangeCheck: resources.RangeCheck * test.multiplier,
Bitwise: resources.Bitwise * test.multiplier,
Ecsda: resources.Ecsda * test.multiplier,
EcOp: resources.EcOp * test.multiplier,
Keccak: resources.Keccak * test.multiplier,
Poseidon: resources.Poseidon * test.multiplier,
SegmentArena: resources.SegmentArena * test.multiplier,
}, test.trace.TotalExecutionResources())
})
}
}

0 comments on commit 69d18d3

Please sign in to comment.