Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: params extra types are zero values not nil pointers by default #13

Merged
merged 5 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libevm/hookstest/stub.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type Stub struct {
// Register is a convenience wrapper for registering s as both the
// [params.ChainConfigHooks] and [params.RulesHooks] via [Register].
func (s *Stub) Register(tb testing.TB) {
Register(tb, params.Extras[Stub, Stub]{
Register(tb, params.Extras[*Stub, *Stub]{
NewRules: func(_ *params.ChainConfig, _ *params.Rules, _ *Stub, blockNum *big.Int, isMerge bool, timestamp uint64) *Stub {
return s
},
Expand Down
21 changes: 21 additions & 0 deletions libevm/pseudo/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,27 @@ func Zero[T any]() *Pseudo[T] {
return From[T](x)
}

// PointerTo is equivalent to [From] called with a pointer to the payload
// carried by `t`. It first confirms that the payload is of type `T`.
func PointerTo[T any](t *Type) (*Pseudo[*T], error) {
c, ok := t.val.(*concrete[T])
if !ok {
var want *T
return nil, fmt.Errorf("cannot create *Pseudo[%T] from *Type carrying %T", want, t.val.get())
}
return From(&c.val), nil
}

// MustPointerTo is equivalent to [PointerTo] except that it panics instead of
// returning an error.
func MustPointerTo[T any](t *Type) *Pseudo[*T] {
p, err := PointerTo[T](t)
if err != nil {
panic(err)
}
return p
}

// Interface returns the wrapped value as an `any`, equivalent to
// [reflect.Value.Interface]. Prefer [Value.Get].
func (t *Type) Interface() any { return t.val.get() }
Expand Down
23 changes: 23 additions & 0 deletions libevm/pseudo/type_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,26 @@ func ExamplePseudo_TypeAndValue() {
_ = typ
_ = val
}

func TestPointer(t *testing.T) {
type carrier struct {
payload int
}

typ, val := From(carrier{42}).TypeAndValue()

t.Run("invalid type", func(t *testing.T) {
_, err := PointerTo[int](typ)
require.Errorf(t, err, "PointerTo[int](%T)", carrier{})
})

t.Run("valid type", func(t *testing.T) {
ptrVal := MustPointerTo[carrier](typ).Value

assert.Equal(t, 42, val.Get().payload, "before setting via pointer")
var ptr *carrier = ptrVal.Get()
ptr.payload = 314159
assert.Equal(t, 314159, val.Get().payload, "after setting via pointer")
})

}
95 changes: 54 additions & 41 deletions params/config.libevm.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,27 @@ type Extras[C ChainConfigHooks, R RulesHooks] struct {
// NewRules, if non-nil is called at the end of [ChainConfig.Rules] with the
// newly created [Rules] and other context from the method call. Its
// returned value will be the extra payload of the [Rules]. If NewRules is
// nil then so too will the [Rules] extra payload be a nil `*R`.
// nil then so too will the [Rules] extra payload be a zero-value `R`.
//
// NewRules MAY modify the [Rules] but MUST NOT modify the [ChainConfig].
NewRules func(_ *ChainConfig, _ *Rules, _ *C, blockNum *big.Int, isMerge bool, timestamp uint64) *R
// TODO(arr4n): add the [Rules] to the return signature to make it clearer
// that the caller can modify the generated Rules.
NewRules func(_ *ChainConfig, _ *Rules, _ C, blockNum *big.Int, isMerge bool, timestamp uint64) R
}

// RegisterExtras registers the types `C` and `R` such that they are carried as
// extra payloads in [ChainConfig] and [Rules] structs, respectively. It is
// expected to be called in an `init()` function and MUST NOT be called more
// than once. Both `C` and `R` MUST be structs.
// than once. Both `C` and `R` MUST be structs or pointers to structs.
//
// After registration, JSON unmarshalling of a [ChainConfig] will create a new
// `*C` and unmarshal the JSON key "extra" into it. Conversely, JSON marshalling
// will populate the "extra" key with the contents of the `*C`. Both the
// `C` and unmarshal the JSON key "extra" into it. Conversely, JSON marshalling
// will populate the "extra" key with the contents of the `C`. Both the
// [json.Marshaler] and [json.Unmarshaler] interfaces are honoured if
// implemented by `C` and/or `R.`
//
// Calls to [ChainConfig.Rules] will call the `NewRules` function of the
// registered [Extras] to create a new `*R`.
// registered [Extras] to create a new `R`.
//
// The payloads can be accessed via the [ExtraPayloadGetter.FromChainConfig] and
// [ExtraPayloadGetter.FromRules] methods of the getter returned by
Expand All @@ -54,16 +56,16 @@ func RegisterExtras[C ChainConfigHooks, R RulesHooks](e Extras[C, R]) ExtraPaylo
if registeredExtras != nil {
panic("re-registration of Extras")
}
mustBeStruct[C]()
mustBeStruct[R]()
mustBeStructOrPointerToOne[C]()
mustBeStructOrPointerToOne[R]()

getter := e.getter()
registeredExtras = &extraConstructors{
chainConfig: pseudo.NewConstructor[C](),
rules: pseudo.NewConstructor[R](),
reuseJSONRoot: e.ReuseJSONRoot,
newForRules: e.newForRules,
getter: getter,
newChainConfig: pseudo.NewConstructor[C]().Zero,
newRules: pseudo.NewConstructor[R]().Zero,
reuseJSONRoot: e.ReuseJSONRoot,
newForRules: e.newForRules,
getter: getter,
}
return getter
}
Expand Down Expand Up @@ -95,9 +97,9 @@ func TestOnlyClearRegisteredExtras() {
var registeredExtras *extraConstructors

type extraConstructors struct {
chainConfig, rules pseudo.Constructor
reuseJSONRoot bool
newForRules func(_ *ChainConfig, _ *Rules, blockNum *big.Int, isMerge bool, timestamp uint64) *pseudo.Type
newChainConfig, newRules func() *pseudo.Type
reuseJSONRoot bool
newForRules func(_ *ChainConfig, _ *Rules, blockNum *big.Int, isMerge bool, timestamp uint64) *pseudo.Type
// use top-level hooksFrom<X>() functions instead of these as they handle
// instances where no [Extras] were registered.
getter interface {
Expand All @@ -108,27 +110,34 @@ type extraConstructors struct {

func (e *Extras[C, R]) newForRules(c *ChainConfig, r *Rules, blockNum *big.Int, isMerge bool, timestamp uint64) *pseudo.Type {
if e.NewRules == nil {
return registeredExtras.rules.NilPointer()
return registeredExtras.newRules()
}
rExtra := e.NewRules(c, r, e.getter().FromChainConfig(c), blockNum, isMerge, timestamp)
return pseudo.From(rExtra).Type
}

func (*Extras[C, R]) getter() (g ExtraPayloadGetter[C, R]) { return }

// mustBeStruct panics if `T` isn't a struct.
func mustBeStruct[T any]() {
// mustBeStructOrPointerToOne panics if `T` isn't a struct or a *struct.
func mustBeStructOrPointerToOne[T any]() {
var x T
if k := reflect.TypeOf(x).Kind(); k != reflect.Struct {
panic(notStructMessage[T]())
switch t := reflect.TypeOf(x); t.Kind() {
case reflect.Struct:
return
case reflect.Pointer:
if t.Elem().Kind() == reflect.Struct {
return
}
}
panic(notStructMessage[T]())
}

// notStructMessage returns the message with which [mustBeStruct] might panic.
// It exists to avoid change-detector tests should the message contents change.
// notStructMessage returns the message with which [mustBeStructOrPointerToOne]
// might panic. It exists to avoid change-detector tests should the message
// contents change.
func notStructMessage[T any]() string {
var x T
return fmt.Sprintf("%T is not a struct", x)
return fmt.Sprintf("%T is not a struct nor a pointer to a struct", x)
}

// An ExtraPayloadGettter provides strongly typed access to the extra payloads
Expand All @@ -139,33 +148,37 @@ type ExtraPayloadGetter[C ChainConfigHooks, R RulesHooks] struct {
}

// FromChainConfig returns the ChainConfig's extra payload.
func (ExtraPayloadGetter[C, R]) FromChainConfig(c *ChainConfig) *C {
return pseudo.MustNewValue[*C](c.extraPayload()).Get()
func (ExtraPayloadGetter[C, R]) FromChainConfig(c *ChainConfig) C {
return pseudo.MustNewValue[C](c.extraPayload()).Get()
}

// PointerFromChainConfig returns a pointer to the ChainConfig's extra payload.
// This is guaranteed to be non-nil.
func (ExtraPayloadGetter[C, R]) PointerFromChainConfig(c *ChainConfig) *C {
return pseudo.MustPointerTo[C](c.extraPayload()).Value.Get()
}

// hooksFromChainConfig is equivalent to FromChainConfig(), but returns an
// interface instead of the concrete type implementing it; this allows it to be
// used in non-generic code. If the concrete-type value is nil (typically
// because no [Extras] were registered) a [noopHooks] is returned so it can be
// used without nil checks.
// used in non-generic code.
func (e ExtraPayloadGetter[C, R]) hooksFromChainConfig(c *ChainConfig) ChainConfigHooks {
if h := e.FromChainConfig(c); h != nil {
return *h
}
return NOOPHooks{}
return e.FromChainConfig(c)
}

// FromRules returns the Rules' extra payload.
func (ExtraPayloadGetter[C, R]) FromRules(r *Rules) *R {
return pseudo.MustNewValue[*R](r.extraPayload()).Get()
func (ExtraPayloadGetter[C, R]) FromRules(r *Rules) R {
return pseudo.MustNewValue[R](r.extraPayload()).Get()
}

// PointerFromRules returns a pointer to the Rules's extra payload. This is
// guaranteed to be non-nil.
func (ExtraPayloadGetter[C, R]) PointerFromRules(r *Rules) *R {
return pseudo.MustPointerTo[R](r.extraPayload()).Value.Get()
}

// hooksFromRules is the [RulesHooks] equivalent of hooksFromChainConfig().
func (e ExtraPayloadGetter[C, R]) hooksFromRules(r *Rules) RulesHooks {
if h := e.FromRules(r); h != nil {
return *h
}
return NOOPHooks{}
return e.FromRules(r)
}

// addRulesExtra is called at the end of [ChainConfig.Rules]; it exists to
Expand All @@ -189,7 +202,7 @@ func (c *ChainConfig) extraPayload() *pseudo.Type {
panic(fmt.Sprintf("%T.ExtraPayload() called before RegisterExtras()", c))
}
if c.extra == nil {
c.extra = registeredExtras.chainConfig.NilPointer()
c.extra = registeredExtras.newChainConfig()
}
return c.extra
}
Expand All @@ -201,7 +214,7 @@ func (r *Rules) extraPayload() *pseudo.Type {
panic(fmt.Sprintf("%T.ExtraPayload() called before RegisterExtras()", r))
}
if r.extra == nil {
r.extra = registeredExtras.rules.NilPointer()
r.extra = registeredExtras.newRules()
}
return r.extra
}
97 changes: 88 additions & 9 deletions params/config.libevm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,26 +50,36 @@ func TestRegisterExtras(t *testing.T) {
name: "Rules payload copied from ChainConfig payload",
register: func() {
RegisterExtras(Extras[ccExtraA, rulesExtraA]{
NewRules: func(cc *ChainConfig, r *Rules, ex *ccExtraA, _ *big.Int, _ bool, _ uint64) *rulesExtraA {
return &rulesExtraA{
NewRules: func(cc *ChainConfig, r *Rules, ex ccExtraA, _ *big.Int, _ bool, _ uint64) rulesExtraA {
return rulesExtraA{
A: ex.A,
}
},
})
},
ccExtra: pseudo.From(&ccExtraA{
ccExtra: pseudo.From(ccExtraA{
A: "hello",
}).Type,
wantRulesExtra: &rulesExtraA{
wantRulesExtra: rulesExtraA{
A: "hello",
},
},
{
name: "no NewForRules() function results in typed but nil pointer",
name: "no NewForRules() function results in zero value",
register: func() {
RegisterExtras(Extras[ccExtraB, rulesExtraB]{})
},
ccExtra: pseudo.From(&ccExtraB{
ccExtra: pseudo.From(ccExtraB{
B: "world",
}).Type,
wantRulesExtra: rulesExtraB{},
},
{
name: "no NewForRules() function results in nil pointer",
register: func() {
RegisterExtras(Extras[ccExtraB, *rulesExtraB]{})
},
ccExtra: pseudo.From(ccExtraB{
B: "world",
}).Type,
wantRulesExtra: (*rulesExtraB)(nil),
Expand All @@ -79,10 +89,10 @@ func TestRegisterExtras(t *testing.T) {
register: func() {
RegisterExtras(Extras[rawJSON, struct{ RulesHooks }]{})
},
ccExtra: pseudo.From(&rawJSON{
ccExtra: pseudo.From(rawJSON{
RawMessage: []byte(`"hello, world"`),
}).Type,
wantRulesExtra: (*struct{ RulesHooks })(nil),
wantRulesExtra: struct{ RulesHooks }{},
},
}

Expand Down Expand Up @@ -111,6 +121,75 @@ func TestRegisterExtras(t *testing.T) {
}
}

func TestModificationOfZeroExtras(t *testing.T) {
type (
ccExtra struct {
X int
NOOPHooks
}
rulesExtra struct {
X int
NOOPHooks
}
)

TestOnlyClearRegisteredExtras()
t.Cleanup(TestOnlyClearRegisteredExtras)
getter := RegisterExtras(Extras[ccExtra, rulesExtra]{})

config := new(ChainConfig)
rules := new(Rules)
// These assertion helpers are defined before any modifications so that the
// closure is demonstrably over the original zero values.
assertChainConfigExtra := func(t *testing.T, want ccExtra, msg string) {
t.Helper()
assert.Equalf(t, want, getter.FromChainConfig(config), "%T: "+msg, &config)
}
assertRulesExtra := func(t *testing.T, want rulesExtra, msg string) {
t.Helper()
assert.Equalf(t, want, getter.FromRules(rules), "%T: "+msg, &rules)
}

assertChainConfigExtra(t, ccExtra{}, "zero value")
assertRulesExtra(t, rulesExtra{}, "zero value")

const answer = 42
getter.PointerFromChainConfig(config).X = answer
assertChainConfigExtra(t, ccExtra{X: answer}, "after setting via pointer field")

const pi = 314159
getter.PointerFromRules(rules).X = pi
assertRulesExtra(t, rulesExtra{X: pi}, "after setting via pointer field")

ccReplace := ccExtra{X: 142857}
*getter.PointerFromChainConfig(config) = ccReplace
assertChainConfigExtra(t, ccReplace, "after replacement of entire extra via `*pointer = x`")

rulesReplace := rulesExtra{X: 18101986}
*getter.PointerFromRules(rules) = rulesReplace
assertRulesExtra(t, rulesReplace, "after replacement of entire extra via `*pointer = x`")

if t.Failed() {
// The test of shallow copying is now guaranteed to fail.
return
}
t.Run("shallow copy", func(t *testing.T) {
ccCopy := *config
rCopy := *rules

assert.Equal(t, getter.FromChainConfig(&ccCopy), ccReplace, "ChainConfig extras copied")
assert.Equal(t, getter.FromRules(&rCopy), rulesReplace, "Rules extras copied")

const seqUp = 123456789
getter.PointerFromChainConfig(&ccCopy).X = seqUp
assertChainConfigExtra(t, ccExtra{X: seqUp}, "original changed because copy only shallow")

const seqDown = 987654321
getter.PointerFromRules(&rCopy).X = seqDown
assertRulesExtra(t, rulesExtra{X: seqDown}, "original changed because copy only shallow")
})
}

func TestExtrasPanic(t *testing.T) {
TestOnlyClearRegisteredExtras()
defer TestOnlyClearRegisteredExtras()
Expand All @@ -131,7 +210,7 @@ func TestExtrasPanic(t *testing.T) {

assertPanics(
t, func() {
mustBeStruct[int]()
mustBeStructOrPointerToOne[int]()
},
notStructMessage[int](),
)
Expand Down
Loading
Loading