diff --git a/libevm/hookstest/stub.go b/libevm/hookstest/stub.go index e8fda4310781..2915d487453f 100644 --- a/libevm/hookstest/stub.go +++ b/libevm/hookstest/stub.go @@ -32,7 +32,7 @@ type Stub struct { // Register is a convenience wrapper for registering s as both the // [params.ChainConfigHooks] and [params.RulesHooks] via [Register]. func (s *Stub) Register(tb testing.TB) { - Register(tb, params.Extras[Stub, Stub]{ + Register(tb, params.Extras[*Stub, *Stub]{ NewRules: func(_ *params.ChainConfig, _ *params.Rules, _ *Stub, blockNum *big.Int, isMerge bool, timestamp uint64) *Stub { return s }, diff --git a/libevm/pseudo/type.go b/libevm/pseudo/type.go index 8c453f4cb0e7..8d1568638ed5 100644 --- a/libevm/pseudo/type.go +++ b/libevm/pseudo/type.go @@ -60,6 +60,27 @@ func Zero[T any]() *Pseudo[T] { return From[T](x) } +// PointerTo is equivalent to [From] called with a pointer to the payload +// carried by `t`. It first confirms that the payload is of type `T`. +func PointerTo[T any](t *Type) (*Pseudo[*T], error) { + c, ok := t.val.(*concrete[T]) + if !ok { + var want *T + return nil, fmt.Errorf("cannot create *Pseudo[%T] from *Type carrying %T", want, t.val.get()) + } + return From(&c.val), nil +} + +// MustPointerTo is equivalent to [PointerTo] except that it panics instead of +// returning an error. +func MustPointerTo[T any](t *Type) *Pseudo[*T] { + p, err := PointerTo[T](t) + if err != nil { + panic(err) + } + return p +} + // Interface returns the wrapped value as an `any`, equivalent to // [reflect.Value.Interface]. Prefer [Value.Get]. func (t *Type) Interface() any { return t.val.get() } diff --git a/libevm/pseudo/type_test.go b/libevm/pseudo/type_test.go index 27ecf7e497ea..0b25c945ce29 100644 --- a/libevm/pseudo/type_test.go +++ b/libevm/pseudo/type_test.go @@ -77,3 +77,26 @@ func ExamplePseudo_TypeAndValue() { _ = typ _ = val } + +func TestPointer(t *testing.T) { + type carrier struct { + payload int + } + + typ, val := From(carrier{42}).TypeAndValue() + + t.Run("invalid type", func(t *testing.T) { + _, err := PointerTo[int](typ) + require.Errorf(t, err, "PointerTo[int](%T)", carrier{}) + }) + + t.Run("valid type", func(t *testing.T) { + ptrVal := MustPointerTo[carrier](typ).Value + + assert.Equal(t, 42, val.Get().payload, "before setting via pointer") + var ptr *carrier = ptrVal.Get() + ptr.payload = 314159 + assert.Equal(t, 314159, val.Get().payload, "after setting via pointer") + }) + +} diff --git a/params/config.libevm.go b/params/config.libevm.go index cf23310a1a5b..2f4730b8aa3d 100644 --- a/params/config.libevm.go +++ b/params/config.libevm.go @@ -25,25 +25,27 @@ type Extras[C ChainConfigHooks, R RulesHooks] struct { // NewRules, if non-nil is called at the end of [ChainConfig.Rules] with the // newly created [Rules] and other context from the method call. Its // returned value will be the extra payload of the [Rules]. If NewRules is - // nil then so too will the [Rules] extra payload be a nil `*R`. + // nil then so too will the [Rules] extra payload be a zero-value `R`. // // NewRules MAY modify the [Rules] but MUST NOT modify the [ChainConfig]. - NewRules func(_ *ChainConfig, _ *Rules, _ *C, blockNum *big.Int, isMerge bool, timestamp uint64) *R + // TODO(arr4n): add the [Rules] to the return signature to make it clearer + // that the caller can modify the generated Rules. + NewRules func(_ *ChainConfig, _ *Rules, _ C, blockNum *big.Int, isMerge bool, timestamp uint64) R } // RegisterExtras registers the types `C` and `R` such that they are carried as // extra payloads in [ChainConfig] and [Rules] structs, respectively. It is // expected to be called in an `init()` function and MUST NOT be called more -// than once. Both `C` and `R` MUST be structs. +// than once. Both `C` and `R` MUST be structs or pointers to structs. // // After registration, JSON unmarshalling of a [ChainConfig] will create a new -// `*C` and unmarshal the JSON key "extra" into it. Conversely, JSON marshalling -// will populate the "extra" key with the contents of the `*C`. Both the +// `C` and unmarshal the JSON key "extra" into it. Conversely, JSON marshalling +// will populate the "extra" key with the contents of the `C`. Both the // [json.Marshaler] and [json.Unmarshaler] interfaces are honoured if // implemented by `C` and/or `R.` // // Calls to [ChainConfig.Rules] will call the `NewRules` function of the -// registered [Extras] to create a new `*R`. +// registered [Extras] to create a new `R`. // // The payloads can be accessed via the [ExtraPayloadGetter.FromChainConfig] and // [ExtraPayloadGetter.FromRules] methods of the getter returned by @@ -54,16 +56,16 @@ func RegisterExtras[C ChainConfigHooks, R RulesHooks](e Extras[C, R]) ExtraPaylo if registeredExtras != nil { panic("re-registration of Extras") } - mustBeStruct[C]() - mustBeStruct[R]() + mustBeStructOrPointerToOne[C]() + mustBeStructOrPointerToOne[R]() getter := e.getter() registeredExtras = &extraConstructors{ - chainConfig: pseudo.NewConstructor[C](), - rules: pseudo.NewConstructor[R](), - reuseJSONRoot: e.ReuseJSONRoot, - newForRules: e.newForRules, - getter: getter, + newChainConfig: pseudo.NewConstructor[C]().Zero, + newRules: pseudo.NewConstructor[R]().Zero, + reuseJSONRoot: e.ReuseJSONRoot, + newForRules: e.newForRules, + getter: getter, } return getter } @@ -95,9 +97,9 @@ func TestOnlyClearRegisteredExtras() { var registeredExtras *extraConstructors type extraConstructors struct { - chainConfig, rules pseudo.Constructor - reuseJSONRoot bool - newForRules func(_ *ChainConfig, _ *Rules, blockNum *big.Int, isMerge bool, timestamp uint64) *pseudo.Type + newChainConfig, newRules func() *pseudo.Type + reuseJSONRoot bool + newForRules func(_ *ChainConfig, _ *Rules, blockNum *big.Int, isMerge bool, timestamp uint64) *pseudo.Type // use top-level hooksFrom() functions instead of these as they handle // instances where no [Extras] were registered. getter interface { @@ -108,7 +110,7 @@ type extraConstructors struct { func (e *Extras[C, R]) newForRules(c *ChainConfig, r *Rules, blockNum *big.Int, isMerge bool, timestamp uint64) *pseudo.Type { if e.NewRules == nil { - return registeredExtras.rules.NilPointer() + return registeredExtras.newRules() } rExtra := e.NewRules(c, r, e.getter().FromChainConfig(c), blockNum, isMerge, timestamp) return pseudo.From(rExtra).Type @@ -116,19 +118,26 @@ func (e *Extras[C, R]) newForRules(c *ChainConfig, r *Rules, blockNum *big.Int, func (*Extras[C, R]) getter() (g ExtraPayloadGetter[C, R]) { return } -// mustBeStruct panics if `T` isn't a struct. -func mustBeStruct[T any]() { +// mustBeStructOrPointerToOne panics if `T` isn't a struct or a *struct. +func mustBeStructOrPointerToOne[T any]() { var x T - if k := reflect.TypeOf(x).Kind(); k != reflect.Struct { - panic(notStructMessage[T]()) + switch t := reflect.TypeOf(x); t.Kind() { + case reflect.Struct: + return + case reflect.Pointer: + if t.Elem().Kind() == reflect.Struct { + return + } } + panic(notStructMessage[T]()) } -// notStructMessage returns the message with which [mustBeStruct] might panic. -// It exists to avoid change-detector tests should the message contents change. +// notStructMessage returns the message with which [mustBeStructOrPointerToOne] +// might panic. It exists to avoid change-detector tests should the message +// contents change. func notStructMessage[T any]() string { var x T - return fmt.Sprintf("%T is not a struct", x) + return fmt.Sprintf("%T is not a struct nor a pointer to a struct", x) } // An ExtraPayloadGettter provides strongly typed access to the extra payloads @@ -139,33 +148,37 @@ type ExtraPayloadGetter[C ChainConfigHooks, R RulesHooks] struct { } // FromChainConfig returns the ChainConfig's extra payload. -func (ExtraPayloadGetter[C, R]) FromChainConfig(c *ChainConfig) *C { - return pseudo.MustNewValue[*C](c.extraPayload()).Get() +func (ExtraPayloadGetter[C, R]) FromChainConfig(c *ChainConfig) C { + return pseudo.MustNewValue[C](c.extraPayload()).Get() +} + +// PointerFromChainConfig returns a pointer to the ChainConfig's extra payload. +// This is guaranteed to be non-nil. +func (ExtraPayloadGetter[C, R]) PointerFromChainConfig(c *ChainConfig) *C { + return pseudo.MustPointerTo[C](c.extraPayload()).Value.Get() } // hooksFromChainConfig is equivalent to FromChainConfig(), but returns an // interface instead of the concrete type implementing it; this allows it to be -// used in non-generic code. If the concrete-type value is nil (typically -// because no [Extras] were registered) a [noopHooks] is returned so it can be -// used without nil checks. +// used in non-generic code. func (e ExtraPayloadGetter[C, R]) hooksFromChainConfig(c *ChainConfig) ChainConfigHooks { - if h := e.FromChainConfig(c); h != nil { - return *h - } - return NOOPHooks{} + return e.FromChainConfig(c) } // FromRules returns the Rules' extra payload. -func (ExtraPayloadGetter[C, R]) FromRules(r *Rules) *R { - return pseudo.MustNewValue[*R](r.extraPayload()).Get() +func (ExtraPayloadGetter[C, R]) FromRules(r *Rules) R { + return pseudo.MustNewValue[R](r.extraPayload()).Get() +} + +// PointerFromRules returns a pointer to the Rules's extra payload. This is +// guaranteed to be non-nil. +func (ExtraPayloadGetter[C, R]) PointerFromRules(r *Rules) *R { + return pseudo.MustPointerTo[R](r.extraPayload()).Value.Get() } // hooksFromRules is the [RulesHooks] equivalent of hooksFromChainConfig(). func (e ExtraPayloadGetter[C, R]) hooksFromRules(r *Rules) RulesHooks { - if h := e.FromRules(r); h != nil { - return *h - } - return NOOPHooks{} + return e.FromRules(r) } // addRulesExtra is called at the end of [ChainConfig.Rules]; it exists to @@ -189,7 +202,7 @@ func (c *ChainConfig) extraPayload() *pseudo.Type { panic(fmt.Sprintf("%T.ExtraPayload() called before RegisterExtras()", c)) } if c.extra == nil { - c.extra = registeredExtras.chainConfig.NilPointer() + c.extra = registeredExtras.newChainConfig() } return c.extra } @@ -201,7 +214,7 @@ func (r *Rules) extraPayload() *pseudo.Type { panic(fmt.Sprintf("%T.ExtraPayload() called before RegisterExtras()", r)) } if r.extra == nil { - r.extra = registeredExtras.rules.NilPointer() + r.extra = registeredExtras.newRules() } return r.extra } diff --git a/params/config.libevm_test.go b/params/config.libevm_test.go index 1a7fc680b36a..129cc76ec5bc 100644 --- a/params/config.libevm_test.go +++ b/params/config.libevm_test.go @@ -50,26 +50,36 @@ func TestRegisterExtras(t *testing.T) { name: "Rules payload copied from ChainConfig payload", register: func() { RegisterExtras(Extras[ccExtraA, rulesExtraA]{ - NewRules: func(cc *ChainConfig, r *Rules, ex *ccExtraA, _ *big.Int, _ bool, _ uint64) *rulesExtraA { - return &rulesExtraA{ + NewRules: func(cc *ChainConfig, r *Rules, ex ccExtraA, _ *big.Int, _ bool, _ uint64) rulesExtraA { + return rulesExtraA{ A: ex.A, } }, }) }, - ccExtra: pseudo.From(&ccExtraA{ + ccExtra: pseudo.From(ccExtraA{ A: "hello", }).Type, - wantRulesExtra: &rulesExtraA{ + wantRulesExtra: rulesExtraA{ A: "hello", }, }, { - name: "no NewForRules() function results in typed but nil pointer", + name: "no NewForRules() function results in zero value", register: func() { RegisterExtras(Extras[ccExtraB, rulesExtraB]{}) }, - ccExtra: pseudo.From(&ccExtraB{ + ccExtra: pseudo.From(ccExtraB{ + B: "world", + }).Type, + wantRulesExtra: rulesExtraB{}, + }, + { + name: "no NewForRules() function results in nil pointer", + register: func() { + RegisterExtras(Extras[ccExtraB, *rulesExtraB]{}) + }, + ccExtra: pseudo.From(ccExtraB{ B: "world", }).Type, wantRulesExtra: (*rulesExtraB)(nil), @@ -79,10 +89,10 @@ func TestRegisterExtras(t *testing.T) { register: func() { RegisterExtras(Extras[rawJSON, struct{ RulesHooks }]{}) }, - ccExtra: pseudo.From(&rawJSON{ + ccExtra: pseudo.From(rawJSON{ RawMessage: []byte(`"hello, world"`), }).Type, - wantRulesExtra: (*struct{ RulesHooks })(nil), + wantRulesExtra: struct{ RulesHooks }{}, }, } @@ -111,6 +121,75 @@ func TestRegisterExtras(t *testing.T) { } } +func TestModificationOfZeroExtras(t *testing.T) { + type ( + ccExtra struct { + X int + NOOPHooks + } + rulesExtra struct { + X int + NOOPHooks + } + ) + + TestOnlyClearRegisteredExtras() + t.Cleanup(TestOnlyClearRegisteredExtras) + getter := RegisterExtras(Extras[ccExtra, rulesExtra]{}) + + config := new(ChainConfig) + rules := new(Rules) + // These assertion helpers are defined before any modifications so that the + // closure is demonstrably over the original zero values. + assertChainConfigExtra := func(t *testing.T, want ccExtra, msg string) { + t.Helper() + assert.Equalf(t, want, getter.FromChainConfig(config), "%T: "+msg, &config) + } + assertRulesExtra := func(t *testing.T, want rulesExtra, msg string) { + t.Helper() + assert.Equalf(t, want, getter.FromRules(rules), "%T: "+msg, &rules) + } + + assertChainConfigExtra(t, ccExtra{}, "zero value") + assertRulesExtra(t, rulesExtra{}, "zero value") + + const answer = 42 + getter.PointerFromChainConfig(config).X = answer + assertChainConfigExtra(t, ccExtra{X: answer}, "after setting via pointer field") + + const pi = 314159 + getter.PointerFromRules(rules).X = pi + assertRulesExtra(t, rulesExtra{X: pi}, "after setting via pointer field") + + ccReplace := ccExtra{X: 142857} + *getter.PointerFromChainConfig(config) = ccReplace + assertChainConfigExtra(t, ccReplace, "after replacement of entire extra via `*pointer = x`") + + rulesReplace := rulesExtra{X: 18101986} + *getter.PointerFromRules(rules) = rulesReplace + assertRulesExtra(t, rulesReplace, "after replacement of entire extra via `*pointer = x`") + + if t.Failed() { + // The test of shallow copying is now guaranteed to fail. + return + } + t.Run("shallow copy", func(t *testing.T) { + ccCopy := *config + rCopy := *rules + + assert.Equal(t, getter.FromChainConfig(&ccCopy), ccReplace, "ChainConfig extras copied") + assert.Equal(t, getter.FromRules(&rCopy), rulesReplace, "Rules extras copied") + + const seqUp = 123456789 + getter.PointerFromChainConfig(&ccCopy).X = seqUp + assertChainConfigExtra(t, ccExtra{X: seqUp}, "original changed because copy only shallow") + + const seqDown = 987654321 + getter.PointerFromRules(&rCopy).X = seqDown + assertRulesExtra(t, rulesExtra{X: seqDown}, "original changed because copy only shallow") + }) +} + func TestExtrasPanic(t *testing.T) { TestOnlyClearRegisteredExtras() defer TestOnlyClearRegisteredExtras() @@ -131,7 +210,7 @@ func TestExtrasPanic(t *testing.T) { assertPanics( t, func() { - mustBeStruct[int]() + mustBeStructOrPointerToOne[int]() }, notStructMessage[int](), ) diff --git a/params/example.libevm_test.go b/params/example.libevm_test.go index 37e8b2e4b580..e2a7340d5610 100644 --- a/params/example.libevm_test.go +++ b/params/example.libevm_test.go @@ -40,8 +40,8 @@ var getter params.ExtraPayloadGetter[ChainConfigExtra, RulesExtra] // constructRulesExtra acts as an adjunct to the [params.ChainConfig.Rules] // method. Its primary purpose is to construct the extra payload for the // [params.Rules] but it MAY also modify the [params.Rules]. -func constructRulesExtra(c *params.ChainConfig, r *params.Rules, cEx *ChainConfigExtra, blockNum *big.Int, isMerge bool, timestamp uint64) *RulesExtra { - return &RulesExtra{ +func constructRulesExtra(c *params.ChainConfig, r *params.Rules, cEx ChainConfigExtra, blockNum *big.Int, isMerge bool, timestamp uint64) RulesExtra { + return RulesExtra{ IsMyFork: cEx.MyForkTime != nil && *cEx.MyForkTime <= timestamp, timestamp: timestamp, } @@ -66,12 +66,12 @@ type RulesExtra struct { } // FromChainConfig returns the extra payload carried by the ChainConfig. -func FromChainConfig(c *params.ChainConfig) *ChainConfigExtra { +func FromChainConfig(c *params.ChainConfig) ChainConfigExtra { return getter.FromChainConfig(c) } // FromRules returns the extra payload carried by the Rules. -func FromRules(r *params.Rules) *RulesExtra { +func FromRules(r *params.Rules) RulesExtra { return getter.FromRules(r) } @@ -137,16 +137,14 @@ func ExampleExtraPayloadGetter() { fmt.Println("Chain ID", config.ChainID) // original geth fields work as expected ccExtra := FromChainConfig(config) // extraparams.FromChainConfig() in practice - if ccExtra != nil && ccExtra.MyForkTime != nil { + if ccExtra.MyForkTime != nil { fmt.Println("Fork time", *ccExtra.MyForkTime) } for _, time := range []uint64{forkTime - 1, forkTime, forkTime + 1} { rules := config.Rules(nil, false, time) rExtra := FromRules(&rules) // extraparams.FromRules() in practice - if rExtra != nil { - fmt.Printf("IsMyFork at %v: %t\n", rExtra.timestamp, rExtra.IsMyFork) - } + fmt.Printf("IsMyFork at %v: %t\n", rExtra.timestamp, rExtra.IsMyFork) } // Output: diff --git a/params/json.libevm.go b/params/json.libevm.go index 1f57b39327c5..a60dd39062a7 100644 --- a/params/json.libevm.go +++ b/params/json.libevm.go @@ -30,7 +30,7 @@ func (c *ChainConfig) UnmarshalJSON(data []byte) error { return c.unmarshalJSONWithExtra(data) case reg != nil && reg.reuseJSONRoot: // although the latter is redundant, it's clearer - c.extra = reg.chainConfig.NilPointer() + c.extra = reg.newChainConfig() if err := json.Unmarshal(data, c.extra); err != nil { c.extra = nil return err @@ -47,7 +47,7 @@ func (c *ChainConfig) UnmarshalJSON(data []byte) error { func (c *ChainConfig) unmarshalJSONWithExtra(data []byte) error { cc := &chainConfigWithExportedExtra{ chainConfigWithoutMethods: (*chainConfigWithoutMethods)(c), - Extra: registeredExtras.chainConfig.NilPointer(), + Extra: registeredExtras.newChainConfig(), } if err := json.Unmarshal(data, cc); err != nil { return err diff --git a/params/json.libevm_test.go b/params/json.libevm_test.go index fa4e51aff6d0..780b8f220e6f 100644 --- a/params/json.libevm_test.go +++ b/params/json.libevm_test.go @@ -40,7 +40,7 @@ func TestChainConfigJSONRoundTrip(t *testing.T) { }, }, { - name: "reuse top-level JSON", + name: "reuse top-level JSON with non-pointer", register: func() { RegisterExtras(Extras[rootJSONChainConfigExtra, NOOPHooks]{ ReuseJSONRoot: true, @@ -50,13 +50,29 @@ func TestChainConfigJSONRoundTrip(t *testing.T) { "chainId": 5678, "foo": "hello" }`, + want: &ChainConfig{ + ChainID: big.NewInt(5678), + extra: pseudo.From(rootJSONChainConfigExtra{TopLevelFoo: "hello"}).Type, + }, + }, + { + name: "reuse top-level JSON with pointer", + register: func() { + RegisterExtras(Extras[*rootJSONChainConfigExtra, NOOPHooks]{ + ReuseJSONRoot: true, + }) + }, + jsonInput: `{ + "chainId": 5678, + "foo": "hello" + }`, want: &ChainConfig{ ChainID: big.NewInt(5678), extra: pseudo.From(&rootJSONChainConfigExtra{TopLevelFoo: "hello"}).Type, }, }, { - name: "nested JSON", + name: "nested JSON with non-pointer", register: func() { RegisterExtras(Extras[nestedChainConfigExtra, NOOPHooks]{ ReuseJSONRoot: false, // explicit zero value only for tests @@ -66,6 +82,22 @@ func TestChainConfigJSONRoundTrip(t *testing.T) { "chainId": 42, "extra": {"foo": "world"} }`, + want: &ChainConfig{ + ChainID: big.NewInt(42), + extra: pseudo.From(nestedChainConfigExtra{NestedFoo: "world"}).Type, + }, + }, + { + name: "nested JSON with pointer", + register: func() { + RegisterExtras(Extras[*nestedChainConfigExtra, NOOPHooks]{ + ReuseJSONRoot: false, // explicit zero value only for tests + }) + }, + jsonInput: `{ + "chainId": 42, + "extra": {"foo": "world"} + }`, want: &ChainConfig{ ChainID: big.NewInt(42), extra: pseudo.From(&nestedChainConfigExtra{NestedFoo: "world"}).Type,