Skip to content

Commit

Permalink
feat: types.HeaderHooks for RLP overrides (#89)
Browse files Browse the repository at this point in the history
## Why this should be merged

The `types.Header` fields of both
[`coreth`](https://pkg.go.dev/github.com/ava-labs/coreth/core/types#Header)
and
[`subnet-evm`](https://pkg.go.dev/github.com/ava-labs/subnet-evm/core/types#Header)
have been modified such that their RLP encodings (i.e. block hashes)
aren't compatible with vanilla `geth` nor each other. This PR adds
support for arbitrary RLP encoding coupled with type-safe extra
payloads.

## How this works

Equivalent to #1 (`params`) and #44 (`types.StateAccount`) registration
of pseudo-generic payloads. The only major difference is the guarantee
of a non-nil payload pointer, which means that the payload hooks are
never called on nil pointers as this would make it difficult to decode
RLP into them.

## How this was tested

Round-trip RLP {en,de}coding via a registered stub hook.

---------

Signed-off-by: Arran Schlosberg <[email protected]>
Co-authored-by: Quentin McGaw <[email protected]>
  • Loading branch information
ARR4N and qdm12 authored Dec 17, 2024
1 parent bd44839 commit dc7e27a
Show file tree
Hide file tree
Showing 11 changed files with 294 additions and 55 deletions.
25 changes: 13 additions & 12 deletions core/state/state.libevm.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,46 +19,47 @@ package state
import (
"github.com/ava-labs/libevm/common"
"github.com/ava-labs/libevm/core/types"
"github.com/ava-labs/libevm/libevm/pseudo"
)

// GetExtra returns the extra payload from the [types.StateAccount] associated
// with the address, or a zero-value `SA` if not found. The
// [types.ExtraPayloads] MUST be sourced from [types.RegisterExtras].
func GetExtra[SA any](s *StateDB, p types.ExtraPayloads[SA], addr common.Address) SA {
// with the address, or a zero-value `SA` if not found. The [pseudo.Accessor]
// MUST be sourced from [types.RegisterExtras].
func GetExtra[SA any](s *StateDB, a pseudo.Accessor[types.StateOrSlimAccount, SA], addr common.Address) SA {
stateObject := s.getStateObject(addr)
if stateObject != nil {
return p.StateAccount.Get(&stateObject.data)
return a.Get(&stateObject.data)
}
var zero SA
return zero
}

// SetExtra sets the extra payload for the address. See [GetExtra] for details.
func SetExtra[SA any](s *StateDB, p types.ExtraPayloads[SA], addr common.Address, extra SA) {
func SetExtra[SA any](s *StateDB, a pseudo.Accessor[types.StateOrSlimAccount, SA], addr common.Address, extra SA) {
stateObject := s.getOrNewStateObject(addr)
if stateObject != nil {
setExtraOnObject(stateObject, p, addr, extra)
setExtraOnObject(stateObject, a, addr, extra)
}
}

func setExtraOnObject[SA any](s *stateObject, p types.ExtraPayloads[SA], addr common.Address, extra SA) {
func setExtraOnObject[SA any](s *stateObject, a pseudo.Accessor[types.StateOrSlimAccount, SA], addr common.Address, extra SA) {
s.db.journal.append(extraChange[SA]{
payloads: p,
accessor: a,
account: &addr,
prev: p.StateAccount.Get(&s.data),
prev: a.Get(&s.data),
})
p.StateAccount.Set(&s.data, extra)
a.Set(&s.data, extra)
}

// extraChange is a [journalEntry] for [SetExtra] / [setExtraOnObject].
type extraChange[SA any] struct {
payloads types.ExtraPayloads[SA]
accessor pseudo.Accessor[types.StateOrSlimAccount, SA]
account *common.Address
prev SA
}

func (e extraChange[SA]) dirtied() *common.Address { return e.account }

func (e extraChange[SA]) revert(s *StateDB) {
e.payloads.StateAccount.Set(&s.getStateObject(*e.account).data, e.prev)
e.accessor.Set(&s.getStateObject(*e.account).data, e.prev)
}
4 changes: 2 additions & 2 deletions core/state/state.libevm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func TestGetSetExtra(t *testing.T) {
t.Cleanup(types.TestOnlyClearRegisteredExtras)
// Just as its Data field is a pointer, the registered type is a pointer to
// test deep copying.
payloads := types.RegisterExtras[*accountExtra]()
payloads := types.RegisterExtras[types.NOOPHeaderHooks, *types.NOOPHeaderHooks, *accountExtra]().StateAccount

rng := ethtest.NewPseudoRand(42)
addr := rng.Address()
Expand Down Expand Up @@ -87,7 +87,7 @@ func TestGetSetExtra(t *testing.T) {
Root: types.EmptyRootHash,
CodeHash: types.EmptyCodeHash[:],
}
payloads.StateAccount.Set(want, extra)
payloads.Set(want, extra)

if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("types.FullAccount(%T.Account()) diff (-want +got):\n%s", iter, diff)
Expand Down
6 changes: 3 additions & 3 deletions core/state/state_object.libevm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,21 @@ func TestStateObjectEmpty(t *testing.T) {
{
name: "explicit false bool",
registerAndSet: func(acc *types.StateAccount) {
types.RegisterExtras[bool]().StateAccount.Set(acc, false)
types.RegisterExtras[types.NOOPHeaderHooks, *types.NOOPHeaderHooks, bool]().StateAccount.Set(acc, false)
},
wantEmpty: true,
},
{
name: "implicit false bool",
registerAndSet: func(*types.StateAccount) {
types.RegisterExtras[bool]()
types.RegisterExtras[types.NOOPHeaderHooks, *types.NOOPHeaderHooks, bool]()
},
wantEmpty: true,
},
{
name: "true bool",
registerAndSet: func(acc *types.StateAccount) {
types.RegisterExtras[bool]().StateAccount.Set(acc, true)
types.RegisterExtras[types.NOOPHeaderHooks, *types.NOOPHeaderHooks, bool]().StateAccount.Set(acc, true)
},
wantEmpty: false,
},
Expand Down
3 changes: 3 additions & 0 deletions core/types/block.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (

"github.com/ava-labs/libevm/common"
"github.com/ava-labs/libevm/common/hexutil"
"github.com/ava-labs/libevm/libevm/pseudo"
"github.com/ava-labs/libevm/rlp"
)

Expand Down Expand Up @@ -93,6 +94,8 @@ type Header struct {

// ParentBeaconRoot was added by EIP-4788 and is ignored in legacy headers.
ParentBeaconRoot *common.Hash `json:"parentBeaconBlockRoot" rlp:"optional"`

extra *pseudo.Type // See RegisterExtras()
}

// field type overrides for gencodec
Expand Down
63 changes: 62 additions & 1 deletion core/types/block.libevm.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,74 @@
package types

import (
"fmt"
"io"

"github.com/ava-labs/libevm/libevm/pseudo"
"github.com/ava-labs/libevm/rlp"
)

// HeaderHooks are required for all types registered with [RegisterExtras] for
// [Header] payloads.
type HeaderHooks interface {
EncodeRLP(*Header, io.Writer) error
DecodeRLP(*Header, *rlp.Stream) error
}

var _ interface {
rlp.Encoder
rlp.Decoder
} = (*Header)(nil)

// EncodeRLP implements the [rlp.Encoder] interface.
func (h *Header) EncodeRLP(w io.Writer) error {
if r := registeredExtras; r.Registered() {
return r.Get().hooks.hooksFromHeader(h).EncodeRLP(h, w)
}
return h.encodeRLP(w)
}

var _ rlp.Encoder = (*Header)(nil)
// decodeHeaderRLPDirectly bypasses the [Header.DecodeRLP] method to avoid
// infinite recursion.
func decodeHeaderRLPDirectly(h *Header, s *rlp.Stream) error {
type withoutMethods Header
return s.Decode((*withoutMethods)(h))
}

// DecodeRLP implements the [rlp.Decoder] interface.
func (h *Header) DecodeRLP(s *rlp.Stream) error {
if r := registeredExtras; r.Registered() {
return r.Get().hooks.hooksFromHeader(h).DecodeRLP(h, s)
}
return decodeHeaderRLPDirectly(h, s)
}

func (e ExtraPayloads[HPtr, SA]) hooksFromHeader(h *Header) HeaderHooks {
return e.Header.Get(h)
}

func (h *Header) extraPayload() *pseudo.Type {
r := registeredExtras
if !r.Registered() {
// See params.ChainConfig.extraPayload() for panic rationale.
panic(fmt.Sprintf("%T.extraPayload() called before RegisterExtras()", r))
}
if h.extra == nil {
h.extra = r.Get().newHeader()
}
return h.extra
}

// NOOPHeaderHooks implements [HeaderHooks] such that they are equivalent to
// no type having been registered.
type NOOPHeaderHooks struct{}

var _ HeaderHooks = (*NOOPHeaderHooks)(nil)

func (*NOOPHeaderHooks) EncodeRLP(h *Header, w io.Writer) error {
return h.encodeRLP(w)
}

func (*NOOPHeaderHooks) DecodeRLP(h *Header, s *rlp.Stream) error {
return decodeHeaderRLPDirectly(h, s)
}
113 changes: 113 additions & 0 deletions core/types/block.libevm_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright 2024 the libevm authors.
//
// The libevm additions to go-ethereum are free software: you can redistribute
// them and/or modify them under the terms of the GNU Lesser General Public License
// as published by the Free Software Foundation, either version 3 of the License,
// or (at your option) any later version.
//
// The libevm additions are distributed in the hope that they will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser
// General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see
// <http://www.gnu.org/licenses/>.

package types_test

import (
"errors"
"io"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

. "github.com/ava-labs/libevm/core/types"
"github.com/ava-labs/libevm/crypto"
"github.com/ava-labs/libevm/libevm/ethtest"
"github.com/ava-labs/libevm/rlp"
)

type stubHeaderHooks struct {
rlpSuffix []byte
gotRawRLPToDecode []byte
setHeaderToOnDecode Header

errEncode, errDecode error
}

func fakeHeaderRLP(h *Header, suffix []byte) []byte {
return append(crypto.Keccak256(h.ParentHash[:]), suffix...)
}

func (hh *stubHeaderHooks) EncodeRLP(h *Header, w io.Writer) error {
if _, err := w.Write(fakeHeaderRLP(h, hh.rlpSuffix)); err != nil {
return err
}
return hh.errEncode
}

func (hh *stubHeaderHooks) DecodeRLP(h *Header, s *rlp.Stream) error {
r, err := s.Raw()
if err != nil {
return err
}
hh.gotRawRLPToDecode = r
*h = hh.setHeaderToOnDecode
return hh.errDecode
}

func TestHeaderHooks(t *testing.T) {
TestOnlyClearRegisteredExtras()
defer TestOnlyClearRegisteredExtras()

extras := RegisterExtras[stubHeaderHooks, *stubHeaderHooks, struct{}]()
rng := ethtest.NewPseudoRand(13579)

t.Run("EncodeRLP", func(t *testing.T) {
suffix := rng.Bytes(8)

hdr := &Header{
ParentHash: rng.Hash(),
}
extras.Header.Get(hdr).rlpSuffix = append([]byte{}, suffix...)

got, err := rlp.EncodeToBytes(hdr)
require.NoError(t, err, "rlp.EncodeToBytes(%T)", hdr)
assert.Equal(t, fakeHeaderRLP(hdr, suffix), got)
})

t.Run("DecodeRLP", func(t *testing.T) {
input, err := rlp.EncodeToBytes(rng.Bytes(8))
require.NoError(t, err)

hdr := new(Header)
stub := &stubHeaderHooks{
setHeaderToOnDecode: Header{
Extra: []byte("arr4n was here"),
},
}
extras.Header.Set(hdr, stub)
err = rlp.DecodeBytes(input, hdr)
require.NoErrorf(t, err, "rlp.DecodeBytes(%#x)", input)

assert.Equal(t, input, stub.gotRawRLPToDecode, "raw RLP received by hooks")
assert.Equalf(t, &stub.setHeaderToOnDecode, hdr, "%T after RLP decoding with hook", hdr)
})

t.Run("error_propagation", func(t *testing.T) {
errEncode := errors.New("uh oh")
errDecode := errors.New("something bad happened")

hdr := new(Header)
extras.Header.Set(hdr, &stubHeaderHooks{
errEncode: errEncode,
errDecode: errDecode,
})

assert.Equal(t, errEncode, rlp.Encode(io.Discard, hdr), "via rlp.Encode()")
assert.Equal(t, errDecode, rlp.DecodeBytes([]byte{0}, hdr), "via rlp.DecodeBytes()")
})
}
45 changes: 41 additions & 4 deletions core/types/rlp_backwards_compat.libevm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,34 @@ import (
)

func TestHeaderRLPBackwardsCompatibility(t *testing.T) {
tests := []struct {
name string
register func()
}{
{
name: "no registered extras",
register: func() {},
},
{
name: "no-op header hooks",
register: func() {
RegisterExtras[NOOPHeaderHooks, *NOOPHeaderHooks, struct{}]()
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
TestOnlyClearRegisteredExtras()
defer TestOnlyClearRegisteredExtras()
tt.register()
testHeaderRLPBackwardsCompatibility(t)
})
}
}

//nolint:thelper
func testHeaderRLPBackwardsCompatibility(t *testing.T) {
// This is a deliberate change-detector test that locks in backwards
// compatibility of RLP encoding.
rng := ethtest.NewPseudoRand(42)
Expand Down Expand Up @@ -62,10 +90,19 @@ func TestHeaderRLPBackwardsCompatibility(t *testing.T) {
// WARNING: changing this hex might break backwards compatibility of RLP
// encoding (i.e. block hashes might change)!
const wantHex = `f9029aa01a571e7e4d774caf46053201cfe0001b3c355ffcc93f510e671e8809741f0eeda0756095410506ec72a2c287fe83ebf68efb0be177e61acec1c985277e90e52087941bfc3bc193012ba58912c01fb35a3454831a8971a00bc9f064144eb5965c5e5d1020f9f90392e7e06ded9225966abc7c754b410e61a0d942eab201424f4320ec1e1ffa9390baf941629b9349977b5d48e0502dbb9386a035d9d550a9c113f78689b4c161c4605609bb57b83061914c42ad244daa7fc38eb901004b31d39ae246d689f23176d679a62ff328f530407cbafd0146f45b2ed635282e2812f2705bfffe52576a6fb31df817f29efac71fa56b8e133334079f8e2a8fd2055451571021506f27190adb52a1313f6d28c77d66ae1aa3d3d6757a762476f4c8a2b7b2a37079a4b6a15d1bc44161190c82d5e1c8b55e05c7354f1e5f6512924c941fb3d93667dc3a8c304a3c164e6525dfc99b5f474110c5059485732153e20300c3482832d07b65f97958360da414cb438ce252aec6c2718d155798390a6c6782181d1bac1dd64cd956332b008412ddc735f2994e297c8a088c6bb4c637542295ba3cbc3cd399c8127076f4d834d74d5b11a36b6d02e2fe3a583216aa4ccea0f052df9a96e7a454256bebabdfc38c429079f25913e0f1d7416b2f056c4a115f88b85f0e9fd6d25717881f03d9985060087c88a2c54269dfd07ca388eb8f974b42a412da90c757012bf5479896165caf573cf82fb3a0aa10f6ebf6b62bef8ed36b8ea3d4b1ddb80c99afafa37cb8f3393eb6d802f5bc886c8cd6bcd168a7e0886d5b1345d948b818a0061a7182ff228a4e66bade4717e6f4d318ac98fca12a053af6f98805a764fb5d8890ed9cab2c5229908891c7e2f71857c77ca0523cb6f654ef3fc7294c7768cddd9ccf4bcda3066d382675f37dd1a18507b5fb`
want, err := hex.DecodeString(wantHex)
wantRLP, err := hex.DecodeString(wantHex)
require.NoError(t, err, "hex.DecodeString()")

got, err := rlp.EncodeToBytes(hdr)
require.NoErrorf(t, err, "rlp.EncodeToBytes(%T)", hdr)
assert.Equalf(t, want, got, "rlp.EncodeToBytes(%T)", hdr)
t.Run("Encode", func(t *testing.T) {
got, err := rlp.EncodeToBytes(hdr)
require.NoErrorf(t, err, "rlp.EncodeToBytes(%T)", hdr)
assert.Equalf(t, wantRLP, got, "rlp.EncodeToBytes(%T)", hdr)
})

t.Run("Decode", func(t *testing.T) {
got := new(Header)
err := rlp.DecodeBytes(wantRLP, got)
require.NoErrorf(t, err, "rlp.DecodeBytes(..., %T)", hdr)
assert.Equal(t, hdr, got)
})
}
Loading

0 comments on commit dc7e27a

Please sign in to comment.