Skip to content

Commit

Permalink
Adds support for label cheatcode (#545)
Browse files Browse the repository at this point in the history
Adds support for the vm.label cheatcode. The label cheatcode will append the label to the TestChain. These labels can then be used by the execution tracer.

---------

Co-authored-by: Anish Naik <[email protected]>
  • Loading branch information
priyankabose and anishnaik authored Feb 1, 2025
1 parent 42a22ca commit ed1ea36
Show file tree
Hide file tree
Showing 11 changed files with 198 additions and 32 deletions.
11 changes: 11 additions & 0 deletions chain/standard_cheat_code_contract.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,17 @@ func getStandardCheatCodeContract(tracer *cheatCodeTracer) (*CheatCodeContract,
},
)

// Label: Sets a label for an address.
contract.addMethod(
"label", abi.Arguments{{Type: typeAddress}, {Type: typeString}}, abi.Arguments{},
func(tracer *cheatCodeTracer, inputs []any) ([]any, *cheatCodeRawReturnData) {
addr := inputs[0].(common.Address)
label := inputs[1].(string)
tracer.chain.Labels[addr] = label
return nil, nil
},
)

// Load: Loads a storage slot value from a given account.
contract.addMethod(
"load", abi.Arguments{{Type: typeAddress}, {Type: typeBytes32}}, abi.Arguments{{Type: typeBytes32}},
Expand Down
4 changes: 4 additions & 0 deletions chain/test_chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ type TestChain struct {
// This is constructed over the kvstore.
db ethdb.Database

// Labels maps an address to its label if one exists. This is useful for execution tracing.
Labels map[common.Address]string

// callTracerRouter forwards tracers.Tracer and TestChainTracer calls to any instances added to it. This
// router is used for non-state changing calls.
callTracerRouter *TestChainTracerRouter
Expand Down Expand Up @@ -187,6 +190,7 @@ func NewTestChain(genesisAlloc types.GenesisAlloc, testChainConfig *config.TestC
db: db,
state: nil,
stateDatabase: stateDatabase,
Labels: make(map[common.Address]string),
transactionTracerRouter: transactionTracerRouter,
callTracerRouter: callTracerRouter,
testChainConfig: testChainConfig,
Expand Down
4 changes: 2 additions & 2 deletions fuzzing/calls/call_sequence.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ func (cse *CallSequenceElement) String() string {
args, err := method.Inputs.Unpack(cse.Call.Data[4:])
argsText := "<unable to unpack args>"
if err == nil {
argsText, err = valuegeneration.EncodeABIArgumentsToString(method.Inputs, args)
argsText, err = valuegeneration.EncodeABIArgumentsToString(method.Inputs, args, nil)
if err != nil {
argsText = "<unresolved args>"
}
Expand All @@ -286,7 +286,7 @@ func (cse *CallSequenceElement) String() string {
cse.Call.GasLimit,
cse.Call.GasPrice.String(),
cse.Call.Value.String(),
utils.TrimLeadingZeroesFromAddress(cse.Call.From.String()),
cse.Call.From.String(),
)
}

Expand Down
2 changes: 1 addition & 1 deletion fuzzing/calls/call_sequence_execution.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func ExecuteCallSequence(chain *chain.TestChain, callSequence CallSequence) (Cal
// ExecuteCallSequenceWithExecutionTracer attaches an executiontracer.ExecutionTracer to ExecuteCallSequenceIteratively and attaches execution traces to the call sequence elements.
func ExecuteCallSequenceWithExecutionTracer(testChain *chain.TestChain, contractDefinitions contracts.Contracts, callSequence CallSequence, verboseTracing bool) (CallSequence, error) {
// Create a new execution tracer
executionTracer := executiontracer.NewExecutionTracer(contractDefinitions, testChain.CheatCodeContracts())
executionTracer := executiontracer.NewExecutionTracer(contractDefinitions, testChain)
defer executionTracer.Close()

// Execute our sequence with a simple fetch operation provided to obtain each element.
Expand Down
40 changes: 29 additions & 11 deletions fuzzing/executiontracer/execution_trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"encoding/hex"
"errors"
"fmt"
"github.com/crytic/medusa/utils"
"github.com/ethereum/go-ethereum/common"
"regexp"
"strings"

Expand All @@ -13,7 +15,6 @@ import (
"github.com/crytic/medusa/fuzzing/valuegeneration"
"github.com/crytic/medusa/logging"
"github.com/crytic/medusa/logging/colors"
"github.com/crytic/medusa/utils"
"github.com/ethereum/go-ethereum/accounts/abi"
coreTypes "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
Expand All @@ -29,13 +30,17 @@ type ExecutionTrace struct {
// contractDefinitions represents the known contract definitions at the time of tracing. This is used to help
// obtain any additional information regarding execution.
contractDefinitions contracts.Contracts

// labels is a mapping that maps an address to its string representation for cleaner execution traces
labels map[common.Address]string
}

// newExecutionTrace creates and returns a new ExecutionTrace, to be used by the ExecutionTracer.
func newExecutionTrace(contracts contracts.Contracts) *ExecutionTrace {
func newExecutionTrace(contracts contracts.Contracts, labels map[common.Address]string) *ExecutionTrace {
return &ExecutionTrace{
TopLevelCallFrame: nil,
contractDefinitions: contracts,
labels: labels,
}
}

Expand Down Expand Up @@ -70,10 +75,18 @@ func (t *ExecutionTrace) generateCallFrameEnterElements(callFrame *CallFrame) ([

// Resolve our contract names, as well as our method and its name from the code contract.
if callFrame.ToContractAbi != nil {
// Check to see if there is a label for the proxy address
proxyContractName = callFrame.ToContractName
if label, ok := t.labels[callFrame.ToAddress]; ok {
proxyContractName = label
}
}
if callFrame.CodeContractAbi != nil {
// Check to see if there is a label for the code address
codeContractName = callFrame.CodeContractName
if label, ok := t.labels[callFrame.CodeAddress]; ok {
codeContractName = label
}
if callFrame.IsContractCreation() {
methodName = "constructor"
method = &callFrame.CodeContractAbi.Constructor
Expand Down Expand Up @@ -102,8 +115,8 @@ func (t *ExecutionTrace) generateCallFrameEnterElements(callFrame *CallFrame) ([
// Unpack our input values and obtain a string to represent them
inputValues, err := method.Inputs.Unpack(abiDataInputBuffer)
if err == nil {
// Encode the ABI arguments into strings
encodedInputString, err := valuegeneration.EncodeABIArgumentsToString(method.Inputs, inputValues)
// Encode the ABI arguments into strings and provide the label overrides
encodedInputString, err := valuegeneration.EncodeABIArgumentsToString(method.Inputs, inputValues, t.labels)
if err == nil {
inputArgumentsDisplayText = &encodedInputString
}
Expand Down Expand Up @@ -137,24 +150,29 @@ func (t *ExecutionTrace) generateCallFrameEnterElements(callFrame *CallFrame) ([
inputArgumentsDisplayText = &temp
}

// Handle all label overrides
toAddress := utils.AttachLabelToAddress(callFrame.ToAddress, t.labels[callFrame.ToAddress])
senderAddress := utils.AttachLabelToAddress(callFrame.SenderAddress, t.labels[callFrame.SenderAddress])
codeAddress := utils.AttachLabelToAddress(callFrame.CodeAddress, t.labels[callFrame.CodeAddress])

// Generate the message we wish to output finally, using all these display string components.
// If we executed code, attach additional context such as the contract name, method, etc.
var callInfo string
if callFrame.IsProxyCall() {
if callFrame.ExecutedCode {
callInfo = fmt.Sprintf("%v -> %v.%v(%v) (addr=%v, code=%v, value=%v, sender=%v)", proxyContractName, codeContractName, methodName, *inputArgumentsDisplayText, utils.TrimLeadingZeroesFromAddress(callFrame.ToAddress.String()), utils.TrimLeadingZeroesFromAddress(callFrame.CodeAddress.String()), callFrame.CallValue, utils.TrimLeadingZeroesFromAddress(callFrame.SenderAddress.String()))
callInfo = fmt.Sprintf("%v -> %v.%v(%v) (addr=%v, code=%v, value=%v, sender=%v)", proxyContractName, codeContractName, methodName, *inputArgumentsDisplayText, toAddress, codeAddress, callFrame.CallValue, senderAddress)
} else {
callInfo = fmt.Sprintf("(addr=%v, value=%v, sender=%v)", utils.TrimLeadingZeroesFromAddress(callFrame.ToAddress.String()), callFrame.CallValue, utils.TrimLeadingZeroesFromAddress(callFrame.SenderAddress.String()))
callInfo = fmt.Sprintf("(addr=%v, value=%v, sender=%v)", toAddress, callFrame.CallValue, senderAddress)
}
} else {
if callFrame.ExecutedCode {
if callFrame.ToAddress == chain.ConsoleLogContractAddress {
callInfo = fmt.Sprintf("%v.%v(%v)", codeContractName, methodName, *inputArgumentsDisplayText)
} else {
callInfo = fmt.Sprintf("%v.%v(%v) (addr=%v, value=%v, sender=%v)", codeContractName, methodName, *inputArgumentsDisplayText, utils.TrimLeadingZeroesFromAddress(callFrame.ToAddress.String()), callFrame.CallValue, utils.TrimLeadingZeroesFromAddress(callFrame.SenderAddress.String()))
callInfo = fmt.Sprintf("%v.%v(%v) (addr=%v, value=%v, sender=%v)", codeContractName, methodName, *inputArgumentsDisplayText, toAddress, callFrame.CallValue, senderAddress)
}
} else {
callInfo = fmt.Sprintf("(addr=%v, value=%v, sender=%v)", utils.TrimLeadingZeroesFromAddress(callFrame.ToAddress.String()), callFrame.CallValue, utils.TrimLeadingZeroesFromAddress(callFrame.SenderAddress.String()))
callInfo = fmt.Sprintf("(addr=%v, value=%v, sender=%v)", toAddress, callFrame.CallValue, senderAddress)
}
}

Expand Down Expand Up @@ -189,7 +207,7 @@ func (t *ExecutionTrace) generateCallFrameExitElements(callFrame *CallFrame) []a
if callFrame.ReturnError == nil {
outputValues, err := method.Outputs.Unpack(callFrame.ReturnData)
if err == nil {
encodedOutputString, err := valuegeneration.EncodeABIArgumentsToString(method.Outputs, outputValues)
encodedOutputString, err := valuegeneration.EncodeABIArgumentsToString(method.Outputs, outputValues, t.labels)
if err == nil {
outputArgumentsDisplayText = &encodedOutputString
}
Expand Down Expand Up @@ -232,7 +250,7 @@ func (t *ExecutionTrace) generateCallFrameExitElements(callFrame *CallFrame) []a
// Try to unpack a custom Solidity error from the return values.
matchedCustomError, unpackedCustomErrorArgs := abiutils.GetSolidityCustomRevertError(callFrame.CodeContractAbi, callFrame.ReturnError, callFrame.ReturnData)
if matchedCustomError != nil {
customErrorArgsDisplayText, err := valuegeneration.EncodeABIArgumentsToString(matchedCustomError.Inputs, unpackedCustomErrorArgs)
customErrorArgsDisplayText, err := valuegeneration.EncodeABIArgumentsToString(matchedCustomError.Inputs, unpackedCustomErrorArgs, t.labels)
if err == nil {
elements = append(elements, colors.RedBold, fmt.Sprintf("[revert (error: %v(%v))]", matchedCustomError.Name, customErrorArgsDisplayText), colors.Reset, "\n")
return elements
Expand Down Expand Up @@ -276,7 +294,7 @@ func (t *ExecutionTrace) generateEventEmittedElements(callFrame *CallFrame, even
// If we resolved an event definition and unpacked data.
if event != nil {
// Format the values as a comma-separated string
encodedEventValuesString, err := valuegeneration.EncodeABIArgumentsToString(event.Inputs, eventInputValues)
encodedEventValuesString, err := valuegeneration.EncodeABIArgumentsToString(event.Inputs, eventInputValues, t.labels)
if err == nil {
// Format our event display text finally, with the event name.
temp := fmt.Sprintf("%v(%v)", event.Name, encodedEventValuesString)
Expand Down
21 changes: 12 additions & 9 deletions fuzzing/executiontracer/execution_tracer.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (
// Returns the ExecutionTrace for the call or an error if one occurs.
func CallWithExecutionTrace(testChain *chain.TestChain, contractDefinitions contracts.Contracts, msg *core.Message, state *state.StateDB) (*core.ExecutionResult, *ExecutionTrace, error) {
// Create an execution tracer
executionTracer := NewExecutionTracer(contractDefinitions, testChain.CheatCodeContracts())
executionTracer := NewExecutionTracer(contractDefinitions, testChain)
defer executionTracer.Close()

// Call the contract on our chain with the provided state.
Expand All @@ -48,6 +48,11 @@ type ExecutionTracer struct {
// trace represents the current execution trace captured by this tracer.
trace *ExecutionTrace

// testChain represents the underlying chain that the execution tracer runs on
testChain *chain.TestChain

// traceMap describes a mapping that allows someone to retrieve the execution trace for a common transaction
// hash.
traceMap map[common.Hash]*ExecutionTrace

// currentCallFrame references the current call frame being traced.
Expand All @@ -56,23 +61,21 @@ type ExecutionTracer struct {
// contractDefinitions represents the contract definitions to match for execution traces.
contractDefinitions contracts.Contracts

// cheatCodeContracts represents the cheat code contract definitions to match for execution traces.
cheatCodeContracts map[common.Address]*chain.CheatCodeContract

// onNextCaptureState refers to methods which should be executed the next time OnOpcode executes.
// OnOpcode is called prior to execution of an instruction. This allows actions to be performed
// after some state is captured, on the next state capture (e.g. detecting a log instruction, but
// using this structure to execute code later once the log is committed).
onNextCaptureState []func()

// nativeTracer is the underlying tracer interface that the execution tracer follows
nativeTracer *chain.TestChainTracer
}

// NewExecutionTracer creates a ExecutionTracer and returns it.
func NewExecutionTracer(contractDefinitions contracts.Contracts, cheatCodeContracts map[common.Address]*chain.CheatCodeContract) *ExecutionTracer {
func NewExecutionTracer(contractDefinitions contracts.Contracts, testChain *chain.TestChain) *ExecutionTracer {
tracer := &ExecutionTracer{
contractDefinitions: contractDefinitions,
cheatCodeContracts: cheatCodeContracts,
testChain: testChain,
traceMap: make(map[common.Hash]*ExecutionTrace),
}
innerTracer := &tracers.Tracer{
Expand Down Expand Up @@ -122,7 +125,7 @@ func (t *ExecutionTracer) OnTxEnd(receipt *coretypes.Receipt, err error) {
// OnTxStart is called upon the start of transaction execution, as defined by tracers.Tracer.
func (t *ExecutionTracer) OnTxStart(vm *tracing.VMContext, tx *coretypes.Transaction, from common.Address) {
// Reset our capture state
t.trace = newExecutionTrace(t.contractDefinitions)
t.trace = newExecutionTrace(t.contractDefinitions, t.testChain.Labels)
t.currentCallFrame = nil
t.onNextCaptureState = nil
t.traceMap = make(map[common.Hash]*ExecutionTrace)
Expand Down Expand Up @@ -151,7 +154,7 @@ func (t *ExecutionTracer) resolveCallFrameContractDefinitions(callFrame *CallFra
// Try to resolve contract definitions for "to" address
if callFrame.ToContractAbi == nil {
// Try to resolve definitions from cheat code contracts
if cheatCodeContract, ok := t.cheatCodeContracts[callFrame.ToAddress]; ok {
if cheatCodeContract, ok := t.testChain.CheatCodeContracts()[callFrame.ToAddress]; ok {
callFrame.ToContractName = cheatCodeContract.Name()
callFrame.ToContractAbi = cheatCodeContract.Abi()
callFrame.ExecutedCode = true
Expand All @@ -175,7 +178,7 @@ func (t *ExecutionTracer) resolveCallFrameContractDefinitions(callFrame *CallFra
// Try to resolve contract definitions for "code" address
if callFrame.CodeContractAbi == nil {
// Try to resolve definitions from cheat code contracts
if cheatCodeContract, ok := t.cheatCodeContracts[callFrame.CodeAddress]; ok {
if cheatCodeContract, ok := t.testChain.CheatCodeContracts()[callFrame.CodeAddress]; ok {
callFrame.CodeContractName = cheatCodeContract.Name()
callFrame.CodeContractAbi = cheatCodeContract.Abi()
callFrame.ExecutedCode = true
Expand Down
50 changes: 50 additions & 0 deletions fuzzing/fuzzer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,56 @@ func TestExecutionTraces(t *testing.T) {
}
}

// TestLabelCheatCode tests the vm.label cheatcode.
func TestLabelCheatCode(t *testing.T) {
// These are the expected messages in the execution trace
expectedTraceMessages := []string{
"ProxyContract.testVMLabel()()",
"addr=ProxyContract [0xA647ff3c36cFab592509E13860ab8c4F28781a66]",
"sender=MySender [0x10000]",
"ProxyContract -> ImplementationContract.emitEvent(address)(ProxyContract [0xA647ff3c36cFab592509E13860ab8c4F28781a66])",
"code=ImplementationContract [0x54919A19522Ce7c842E25735a9cFEcef1c0a06dA]",
"[event] TestEvent(RandomAddress [0x20000])",
"[return (ProxyContract [0xA647ff3c36cFab592509E13860ab8c4F28781a66])]",
}
runFuzzerTest(t, &fuzzerSolcFileTest{
filePath: "testdata/contracts/cheat_codes/utils/label.sol",
configUpdates: func(config *config.ProjectConfig) {
config.Fuzzing.TargetContracts = []string{"TestContract"}
// Only allow for one sender for proper testing of this unit test
config.Fuzzing.SenderAddresses = []string{"0x10000"}
config.Fuzzing.Testing.PropertyTesting.Enabled = false
config.Fuzzing.Testing.OptimizationTesting.Enabled = false
config.Slither.UseSlither = false
},
method: func(f *fuzzerTestContext) {
// Start the fuzzer
err := f.fuzzer.Start()
assert.NoError(t, err)

// Check for failed assertion tests.
failedTestCase := f.fuzzer.TestCasesWithStatus(TestCaseStatusFailed)
assert.NotEmpty(t, failedTestCase, "expected to have failed test cases")

// Obtain our first failed test case, get the message, and verify it contains our assertion failed.
failingSequence := *failedTestCase[0].CallSequence()
assert.NotEmpty(t, failingSequence, "expected to have calls in the call sequence failing an assertion test")

// Obtain the last call
lastCall := failingSequence[len(failingSequence)-1]
assert.NotNilf(t, lastCall.ExecutionTrace, "expected to have an execution trace attached to call sequence for this test")

// Get the execution trace message
executionTraceMsg := lastCall.ExecutionTrace.Log().String()

// Verify it contains all expected strings
for _, expectedTraceMessage := range expectedTraceMessages {
assert.Contains(t, executionTraceMsg, expectedTraceMessage)
}
},
})
}

// TestTestingScope runs tests to ensure dynamically deployed contracts are tested when the "test all contracts"
// config option is specified. It also runs the fuzzer without the option enabled to ensure they are not tested.
func TestTestingScope(t *testing.T) {
Expand Down
14 changes: 9 additions & 5 deletions fuzzing/fuzzer_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -570,11 +570,6 @@ func (fw *FuzzerWorker) run(baseTestChain *chain.TestChain) (bool, error) {
initializedChain.Events.ContractDeploymentAddedEventEmitter.Subscribe(fw.onChainContractDeploymentAddedEvent)
initializedChain.Events.ContractDeploymentRemovedEventEmitter.Subscribe(fw.onChainContractDeploymentRemovedEvent)

// Emit an event indicating the worker has created its chain.
err = fw.Events.FuzzerWorkerChainCreated.Publish(FuzzerWorkerChainCreatedEvent{
Worker: fw,
Chain: initializedChain,
})
if err != nil {
return fmt.Errorf("error returned by an event handler when emitting a worker chain created event: %v", err)
}
Expand All @@ -584,6 +579,15 @@ func (fw *FuzzerWorker) run(baseTestChain *chain.TestChain) (bool, error) {
fw.coverageTracer = coverage.NewCoverageTracer()
initializedChain.AddTracer(fw.coverageTracer.NativeTracer(), true, false)
}

// Copy the labels from the base chain to the worker's chain
initializedChain.Labels = maps.Clone(baseTestChain.Labels)

// Emit an event indicating the worker has created its chain.
err = fw.Events.FuzzerWorkerChainCreated.Publish(FuzzerWorkerChainCreatedEvent{
Worker: fw,
Chain: initializedChain,
})
return nil
})

Expand Down
Loading

0 comments on commit ed1ea36

Please sign in to comment.