From 01b0ec543ffa4a66000362156750ab63cf105cf2 Mon Sep 17 00:00:00 2001 From: Arran Schlosberg Date: Thu, 27 Jun 2024 13:18:23 +0100 Subject: [PATCH 1/7] feat: patching `vm.EVM` methods `*Call*()` to support stateful precompiles --- x/gethclone/astpatch/astpatch.go | 67 ++++++++++++---- x/gethclone/astpatch/astpatch_test.go | 8 +- x/gethclone/astpatch/funcs.go | 104 +++++++++++++++++++++++++ x/gethclone/gethclone.go | 49 +++++++++--- x/gethclone/go.mod | 12 ++- x/gethclone/go.sum | 107 +++++++++++++++++++++++++- x/gethclone/main.go | 3 + x/gethclone/stateful_precompiles.go | 96 +++++++++++++++++++++++ 8 files changed, 413 insertions(+), 33 deletions(-) create mode 100644 x/gethclone/astpatch/funcs.go create mode 100644 x/gethclone/stateful_precompiles.go diff --git a/x/gethclone/astpatch/astpatch.go b/x/gethclone/astpatch/astpatch.go index 3bc6fd4e3a..2bcf5b5c52 100644 --- a/x/gethclone/astpatch/astpatch.go +++ b/x/gethclone/astpatch/astpatch.go @@ -23,15 +23,35 @@ type ( PatchRegistry map[string]map[reflect.Type][]Patch ) -// Add is a convenience wrapper for registering a new `Patch` in the registry. -// The `zeroNode` can be any type (including nil pointers) that implements -// `ast.Node`. +// Apply is equivalent to `astutil.ApplyFunc()` except that it accepts +// `Patch`es. See `Patch` comment for error-handling semantics. +func Apply(root ast.Node, pre, post Patch) (ast.Node, error) { + var err error + x := func(p Patch) astutil.ApplyFunc { + return func(c *astutil.Cursor) bool { + if err != nil { + return false + } + if p == nil { + return true + } + err = p(c) + return err == nil + } + } + n := astutil.Apply(root, x(pre), x(post)) + return n, err +} + +// AddForType is a convenience wrapper for registering a new `Patch` in the +// registry. The `zeroNode` can be any type (including nil pointers) that +// implements `ast.Node`. // // The special `pkgPath` value "*" will match all package paths. While there is // no specific requirement for `pkgPath` other than it matching the equivalent // argument passed to `Apply()`, it is typically sourced from // `golang.org/x/tools/go/packages.Package.PkgPath`. -func (r PatchRegistry) Add(pkgPath string, zeroNode ast.Node, fn Patch) { +func (r PatchRegistry) AddForType(pkgPath string, zeroNode ast.Node, fn Patch) { pkg, ok := r[pkgPath] if !ok { pkg = make(map[reflect.Type][]Patch) @@ -42,6 +62,29 @@ func (r PatchRegistry) Add(pkgPath string, zeroNode ast.Node, fn Patch) { pkg[t] = append(pkg[t], fn) } +// A TypePatcher couples a `Patch` with the specific `ast.Node` type to which it +// applies. It is useful when `PatchRegistry.AddForType()` MUST receive a +// specific `Node` type for a particular `Patch`. +type TypePatcher interface { + Type() ast.Node + Patch(*astutil.Cursor) error +} + +// Add is a synonym of `AddForType()`, instead accepting an argument that +// provides the `Node` type and the `Patch`. +func (r PatchRegistry) Add(pkgPath string, tp TypePatcher) { + r.AddForType(pkgPath, tp.Type(), tp.Patch) +} + +// typePatcher implements the `TypePatcher` interface. +type typePatcher struct { + typ ast.Node + patch Patch +} + +func (p typePatcher) Type() ast.Node { return p.typ } +func (p typePatcher) Patch(c *astutil.Cursor) error { return p.patch(c) } + // Apply calls `astutil.Apply()` on `node`, calling the appropriate `Patch` // functions as the syntax tree is traversed. Patches are applied as the `pre` // argument to `astutil.Apply()`. @@ -51,19 +94,13 @@ func (r PatchRegistry) Add(pkgPath string, zeroNode ast.Node, fn Patch) { // // If any `Patch` returns an error then no further patches will be called, and // the error will be returned by `Apply()`. -func (r PatchRegistry) Apply(pkgPath string, node ast.Node) error { - var err error - astutil.Apply(node, func(c *astutil.Cursor) bool { - if err != nil { - return false - } - if err = r.applyToCursor("*", c); err != nil { - return false +func (r PatchRegistry) Apply(pkgPath string, node ast.Node) (ast.Node, error) { + return Apply(node, func(c *astutil.Cursor) error { + if err := r.applyToCursor("*", c); err != nil { + return err } - err = r.applyToCursor(pkgPath, c) - return err == nil + return r.applyToCursor(pkgPath, c) }, nil) - return err } func (r PatchRegistry) applyToCursor(pkgPath string, c *astutil.Cursor) error { diff --git a/x/gethclone/astpatch/astpatch_test.go b/x/gethclone/astpatch/astpatch_test.go index 22ad63af40..1a30db8303 100644 --- a/x/gethclone/astpatch/astpatch_test.go +++ b/x/gethclone/astpatch/astpatch_test.go @@ -79,11 +79,11 @@ func ` + errorIfFuncName + `() {} var spy patchSpy reg := make(PatchRegistry) - reg.Add("*", &ast.FuncDecl{}, spy.funcRecorder) + reg.AddForType("*", &ast.FuncDecl{}, spy.funcRecorder) const pkgPath = `github.com/the/repo/thepackage` - reg.Add(pkgPath, &ast.StructType{}, spy.structRecorder) + reg.AddForType(pkgPath, &ast.StructType{}, spy.structRecorder) - reg.Add("unknown/package/path", &ast.FuncDecl{}, func(c *astutil.Cursor) error { + reg.AddForType("unknown/package/path", &ast.FuncDecl{}, func(c *astutil.Cursor) error { t.Errorf("unexpected call to %T with different package path", (Patch)(nil)) return nil }) @@ -94,7 +94,7 @@ func ` + errorIfFuncName + `() {} // None of the `require.Equal*()` variants provide a check for exact // match (i.e. equivalent to ==) of the identical error being // propagated. - if gotErr := reg.Apply(pkgPath, file); gotErr != tt.wantErr { + if _, gotErr := reg.Apply(pkgPath, file); gotErr != tt.wantErr { t.Fatalf("%T.Apply(...) got err %v; want %v", reg, gotErr, tt.wantErr) } assert.Empty(t, cmp.Diff(tt.wantFuncs, spy.gotFuncs), "encountered function declarations (-want +got)") diff --git a/x/gethclone/astpatch/funcs.go b/x/gethclone/astpatch/funcs.go new file mode 100644 index 0000000000..fcdd872988 --- /dev/null +++ b/x/gethclone/astpatch/funcs.go @@ -0,0 +1,104 @@ +package astpatch + +import ( + "fmt" + "go/ast" + + "golang.org/x/tools/go/ast/astutil" +) + +// Method returns a `TypePatcher` that only applies to the specific method on +// the specific type. +// +// The `patch` argument functions like a regular `Patch` except that its +// parameters are extended to also accept the methods's AST declaration as its +// concrete type (i.e. `astutil.Cursor.Node().(*ast.FuncDecl)`). +// +// // Original declaration +// func (x *Thing) Do() { ... } +// +// // Patched with +// astpatch.Method("Thing", "Do", ...) +func Method(receiverType, methodName string, patch func(*astutil.Cursor, *ast.FuncDecl) error) TypePatcher { + return method(nil, receiverType, methodName, patch) +} + +// PointerMethod is identical to `Method()` except that it only matches methods +// with pointer receivers. +func PointerMethod(receiverType, methodName string, patch func(*astutil.Cursor, *ast.FuncDecl) error) TypePatcher { + ptr := true + return method(&ptr, receiverType, methodName, patch) +} + +// ValueMethod is identical to `Method()` except that it only matches methods +// with value receivers. +func ValueMethod(receiverType, methodName string, patch func(*astutil.Cursor, *ast.FuncDecl) error) TypePatcher { + ptr := false + return method(&ptr, receiverType, methodName, patch) +} + +func method(pointerReceiver *bool, receiverType, methodName string, patch func(*astutil.Cursor, *ast.FuncDecl) error) TypePatcher { + return typePatcher{ + typ: (*ast.FuncDecl)(nil), + patch: func(c *astutil.Cursor) error { + fn, ok := c.Node().(*ast.FuncDecl) + if !ok || fn.Recv == nil /*not a method*/ || fn.Name.Name != methodName { + return nil + } + if n := len(fn.Recv.List); n != 1 { + return fmt.Errorf("func receiver list length = %d (%v)", n, fn.Name) + } + + var rcvTypeName *ast.Ident + + switch rcvType := fn.Recv.List[0].Type.(type) { + case *ast.Ident: + if pointerReceiver != nil && *pointerReceiver { + return nil + } + rcvTypeName = rcvType + + case *ast.StarExpr: + if pointerReceiver != nil && !*pointerReceiver { + return nil + } + id, ok := rcvType.X.(*ast.Ident) + if !ok { + return fmt.Errorf("func receiver %T.X is not %T", rcvType, rcvTypeName) + } + rcvTypeName = id + + default: + return fmt.Errorf("unsupported %T.Recv.List.Type type %T", fn, rcvType) + } + + if rcvTypeName.Name != receiverType { + return nil + } + return patch(c, fn) + }, + } +} + +// UnqualifiedCall returns a patch that only applies to a call to the specific, +// unqualified function. A qualified function is one that has additional +// qualifiers before the selector (e.g. `foo.Bar()` or `pkg.Bar()`); an +// unqualified function lacks any such qualifiers and applies to builtin and +// package-internal functions. +// +// The `patch` argument functions like a regular `Patch` except that its +// parameters are extended to also accept the call's AST declaration as its +// concrete type (i.e. `astutil.Cursor.Node().(*ast.CallExpr)`). +func UnqualifiedCall(name string, patch func(*astutil.Cursor, *ast.CallExpr) error) Patch { + return func(c *astutil.Cursor) error { + call, ok := c.Node().(*ast.CallExpr) + if !ok { + return nil + } + fn, ok := call.Fun.(*ast.Ident) + if !ok || fn.Name != name { + return nil + } + return patch(c, call) + } +} diff --git a/x/gethclone/gethclone.go b/x/gethclone/gethclone.go index 46d822a01b..452d3997c8 100644 --- a/x/gethclone/gethclone.go +++ b/x/gethclone/gethclone.go @@ -20,7 +20,6 @@ import ( _ "embed" - // TODO(arr4n): change to using a git sub-module _ "github.com/ethereum/go-ethereum/common" ) @@ -34,11 +33,31 @@ type config struct { log *zap.SugaredLogger outputModule *modfile.Module astPatches astpatch.PatchRegistry + patchSets []patchSet processed set.Set[string] } -const geth = "github.com/ethereum/go-ethereum" +// A patchSet registers one or more patches on a `patch.PatchRegistry` and later +// validates that they were correctly applied. Validation is necessary because +// an error-free application of the registry doesn't guarantee that all expected +// nodes were actually visited. +type patchSet interface { + name() string + register(astpatch.PatchRegistry) + validate() error +} + +const gethMod = "github.com/ethereum/go-ethereum" + +// geth returns `gethMod`+`pkg` unless `pkg` already has `gethMod` as a prefix, +// in which case `pkg` is returned unchanged. +func geth(pkg string) string { + if strings.HasPrefix(pkg, gethMod) { + return pkg + } + return path.Join(gethMod, strings.TrimLeft(pkg, `/`)) +} func (c *config) run(ctx context.Context, logOpts ...zap.Option) (retErr error) { l, err := zap.NewDevelopment(logOpts...) @@ -50,9 +69,10 @@ func (c *config) run(ctx context.Context, logOpts ...zap.Option) (retErr error) defer c.log.Sync() for i, p := range c.packages { - if !strings.HasPrefix(p, geth) { - c.packages[i] = path.Join(geth, p) - } + c.packages[i] = geth(p) + } + for _, ps := range c.patchSets { + ps.register(c.astPatches) } mod, err := parseGoMod(c.outputGoMod) @@ -62,7 +82,16 @@ func (c *config) run(ctx context.Context, logOpts ...zap.Option) (retErr error) c.outputModule = mod.Module c.processed = make(set.Set[string]) - return c.loadAndParse(ctx, token.NewFileSet(), c.packages...) + if err := c.loadAndParse(ctx, token.NewFileSet(), c.packages...); err != nil { + return err + } + + for _, ps := range c.patchSets { + if err := ps.validate(); err != nil { + return fmt.Errorf("patch-set %q validation: %v", ps.name(), err) + } + } + return nil } func parseGoMod(filePath string) (*modfile.File, error) { @@ -112,7 +141,7 @@ func (c *config) parse(ctx context.Context, pkg *PackagePublic, fset *token.File } c.processed.Add(pkg.ImportPath) - shortPkgPath := strings.TrimPrefix(pkg.ImportPath, geth) + shortPkgPath := strings.TrimPrefix(pkg.ImportPath, gethMod) outDir := filepath.Join(filepath.Dir(c.outputGoMod), shortPkgPath) if err := os.MkdirAll(outDir, 0755); err != nil { @@ -138,7 +167,7 @@ func (c *config) parse(ctx context.Context, pkg *PackagePublic, fset *token.File List: []*ast.Comment{{Text: copyrightHeader}}, }}, file.Comments...) - if err := c.astPatches.Apply(pkg.ImportPath, file); err != nil { + if _, err := c.astPatches.Apply(pkg.ImportPath, file); err != nil { return fmt.Errorf("apply AST patches to %q: %v", pkg.ImportPath, err) } @@ -171,12 +200,12 @@ func (c *config) transformGethImports(fset *token.FileSet, file *ast.File) (set. imports := set.NewSet[string](len(file.Imports)) for _, im := range file.Imports { p := strings.Trim(im.Path.Value, `"`) - if !strings.HasPrefix(p, geth) { + if !strings.HasPrefix(p, gethMod) { continue } imports.Add(p) - if !astutil.RewriteImport(fset, file, p, strings.Replace(p, geth, c.outputModule.Mod.String(), 1)) { + if !astutil.RewriteImport(fset, file, p, strings.Replace(p, gethMod, c.outputModule.Mod.String(), 1)) { return nil, fmt.Errorf("failed to rewrite import %q", p) } } diff --git a/x/gethclone/go.mod b/x/gethclone/go.mod index d6e04f8ad6..cdbec45d6a 100644 --- a/x/gethclone/go.mod +++ b/x/gethclone/go.mod @@ -14,18 +14,28 @@ require ( ) require ( + github.com/bits-and-blooms/bitset v1.10.0 // indirect + github.com/btcsuite/btcd/btcec/v2 v2.3.2 // indirect + github.com/consensys/bavard v0.1.13 // indirect + github.com/consensys/gnark-crypto v0.12.1 // indirect + github.com/crate-crypto/go-kzg-4844 v0.7.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/decred/dcrd/dcrec/secp256k1/v4 v4.1.0 // indirect + github.com/ethereum/c-kzg-4844 v0.4.0 // indirect github.com/google/renameio/v2 v2.0.0 // indirect github.com/gorilla/rpc v1.2.0 // indirect github.com/holiman/uint256 v1.2.4 // indirect + github.com/mmcloughlin/addchain v0.4.0 // indirect github.com/mr-tron/base58 v1.2.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/supranational/blst v0.3.11 // indirect go.uber.org/mock v0.4.0 // indirect go.uber.org/multierr v1.10.0 // indirect golang.org/x/crypto v0.21.0 // indirect golang.org/x/exp v0.0.0-20231127185646-65229373498e // indirect + golang.org/x/sync v0.7.0 // indirect golang.org/x/sys v0.21.0 // indirect gonum.org/v1/gonum v0.11.0 // indirect - gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + rsc.io/tmplfunc v0.0.3 // indirect ) diff --git a/x/gethclone/go.sum b/x/gethclone/go.sum index f79ca393a4..531fc99969 100644 --- a/x/gethclone/go.sum +++ b/x/gethclone/go.sum @@ -1,38 +1,129 @@ +github.com/DataDog/zstd v1.5.2 h1:vUG4lAyuPCXO0TLbXvPv7EB7cNK1QV/luu55UHLrrn8= +github.com/DataDog/zstd v1.5.2/go.mod h1:g4AWEaM3yOg3HYfnJ3YIawPnVdXJh9QME85blwSAmyw= +github.com/VictoriaMetrics/fastcache v1.12.1 h1:i0mICQuojGDL3KblA7wUNlY5lOK6a4bwt3uRKnkZU40= +github.com/VictoriaMetrics/fastcache v1.12.1/go.mod h1:tX04vaqcNoQeGLD+ra5pU5sWkuxnzWhEzLwhP9w653o= github.com/ava-labs/avalanchego v1.11.8 h1:Q/der5bC/q3BQbIqxT7nNC0F30c+6X1G/eQzzMQ2CLk= github.com/ava-labs/avalanchego v1.11.8/go.mod h1:aPYTETkM0KjtC7vFwPO6S8z2L0QTKaXjVUi98pTdNO4= +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bits-and-blooms/bitset v1.10.0 h1:ePXTeiPEazB5+opbv5fr8umg2R/1NlzgDsyepwsSr88= +github.com/bits-and-blooms/bitset v1.10.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= +github.com/btcsuite/btcd/btcec/v2 v2.3.2 h1:5n0X6hX0Zk+6omWcihdYvdAlGf2DfasC0GMf7DClJ3U= +github.com/btcsuite/btcd/btcec/v2 v2.3.2/go.mod h1:zYzJ8etWJQIv1Ogk7OzpWjowwOdXY1W/17j2MW85J04= +github.com/btcsuite/btcd/chaincfg/chainhash v1.0.1 h1:q0rUy8C/TYNBQS1+CGKw68tLOFYSNEs0TFnxxnS9+4U= +github.com/btcsuite/btcd/chaincfg/chainhash v1.0.1/go.mod h1:7SFka0XMvUgj3hfZtydOrQY2mwhPclbT2snogU7SQQc= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cockroachdb/errors v1.9.1 h1:yFVvsI0VxmRShfawbt/laCIDy/mtTqqnvoNgiy5bEV8= +github.com/cockroachdb/errors v1.9.1/go.mod h1:2sxOtL2WIc096WSZqZ5h8fa17rdDq9HZOZLBCor4mBk= +github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b h1:r6VH0faHjZeQy818SGhaone5OnYfxFR/+AzdY3sf5aE= +github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b/go.mod h1:Vz9DsVWQQhf3vs21MhPMZpMGSht7O/2vFW2xusFUVOs= +github.com/cockroachdb/pebble v0.0.0-20230928194634-aa077af62593 h1:aPEJyR4rPBvDmeyi+l/FS/VtA00IWvjeFvjen1m1l1A= +github.com/cockroachdb/pebble v0.0.0-20230928194634-aa077af62593/go.mod h1:6hk1eMY/u5t+Cf18q5lFMUA1Rc+Sm5I6Ra1QuPyxXCo= +github.com/cockroachdb/redact v1.1.3 h1:AKZds10rFSIj7qADf0g46UixK8NNLwWTNdCIGS5wfSQ= +github.com/cockroachdb/redact v1.1.3/go.mod h1:BVNblN9mBWFyMyqK1k3AAiSxhvhfK2oOZZ2lK+dpvRg= +github.com/cockroachdb/tokenbucket v0.0.0-20230807174530-cc333fc44b06 h1:zuQyyAKVxetITBuuhv3BI9cMrmStnpT18zmgmTxunpo= +github.com/cockroachdb/tokenbucket v0.0.0-20230807174530-cc333fc44b06/go.mod h1:7nc4anLGjupUW/PeY5qiNYsdNXj7zopG+eqsS7To5IQ= +github.com/consensys/bavard v0.1.13 h1:oLhMLOFGTLdlda/kma4VOJazblc7IM5y5QPd2A/YjhQ= +github.com/consensys/bavard v0.1.13/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI= +github.com/consensys/gnark-crypto v0.12.1 h1:lHH39WuuFgVHONRl3J0LRBtuYdQTumFSDtJF7HpyG8M= +github.com/consensys/gnark-crypto v0.12.1/go.mod h1:v2Gy7L/4ZRosZ7Ivs+9SfUDr0f5UlG+EM5t7MPHiLuY= +github.com/crate-crypto/go-ipa v0.0.0-20231025140028-3c0104f4b233 h1:d28BXYi+wUpz1KBmiF9bWrjEMacUEREV6MBi2ODnrfQ= +github.com/crate-crypto/go-ipa v0.0.0-20231025140028-3c0104f4b233/go.mod h1:geZJZH3SzKCqnz5VT0q/DyIG/tvu/dZk+VIfXicupJs= +github.com/crate-crypto/go-kzg-4844 v0.7.0 h1:C0vgZRk4q4EZ/JgPfzuSoxdCq3C3mOZMBShovmncxvA= +github.com/crate-crypto/go-kzg-4844 v0.7.0/go.mod h1:1kMhvPgI0Ky3yIa+9lFySEBUBXkYxeOi8ZF1sYioxhc= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/decred/dcrd/crypto/blake256 v1.0.0 h1:/8DMNYp9SGi5f0w7uCm6d6M4OU2rGFK09Y2A4Xv7EE0= +github.com/decred/dcrd/crypto/blake256 v1.0.0/go.mod h1:sQl2p6Y26YV+ZOcSTP6thNdn47hh8kt6rqSlvmrXFAc= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.1.0 h1:HbphB4TFFXpv7MNrT52FGrrgVXF1owhMVTHFZIlnvd4= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.1.0/go.mod h1:DZGJHZMqrU4JJqFAWUS2UO1+lbSKsdiOoYi9Zzey7Fc= +github.com/ethereum/c-kzg-4844 v0.4.0 h1:3MS1s4JtA868KpJxroZoepdV0ZKBp3u/O5HcZ7R3nlY= +github.com/ethereum/c-kzg-4844 v0.4.0/go.mod h1:VewdlzQmpT5QSrVhbBuGoCdFJkpaJlO1aQputP83wc0= github.com/ethereum/go-ethereum v1.13.8 h1:1od+thJel3tM52ZUNQwvpYOeRHlbkVFZ5S8fhi0Lgsg= github.com/ethereum/go-ethereum v1.13.8/go.mod h1:sc48XYQxCzH3fG9BcrXCOOgQk2JfZzNAmIKnceogzsA= +github.com/gballet/go-verkle v0.1.1-0.20231031103413-a67434b50f46 h1:BAIP2GihuqhwdILrV+7GJel5lyPV3u1+PgzrWLc0TkE= +github.com/gballet/go-verkle v0.1.1-0.20231031103413-a67434b50f46/go.mod h1:QNpY22eby74jVhqH4WhDLDwxc/vqsern6pW+u2kbkpc= +github.com/getsentry/sentry-go v0.18.0 h1:MtBW5H9QgdcJabtZcuJG80BMOwaBpkRDZkxRkNC1sN0= +github.com/getsentry/sentry-go v0.18.0/go.mod h1:Kgon4Mby+FJ7ZWHFUAZgVaIa8sxHtnRJRLTXZr51aKQ= +github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= +github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= +github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw= +github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= +github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= +github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= +github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb h1:PBC98N2aIaM3XXiurYmW7fx4GZkL8feAMVq7nEjURHk= +github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/renameio/v2 v2.0.0 h1:UifI23ZTGY8Tt29JbYFiuyIU3eX+RNFtUwefq9qAhxg= github.com/google/renameio/v2 v2.0.0/go.mod h1:BtmJXm5YlszgC+TD4HOEEUFgkJP3nLxehU6hfe7jRt4= +github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/gorilla/rpc v1.2.0 h1:WvvdC2lNeT1SP32zrIce5l0ECBfbAlmrmSBsuc57wfk= github.com/gorilla/rpc v1.2.0/go.mod h1:V4h9r+4sF5HnzqbwIez0fKSpANP0zlYd3qR7p36jkTQ= +github.com/holiman/bloomfilter/v2 v2.0.3 h1:73e0e/V0tCydx14a0SCYS/EWCxgwLZ18CZcZKVu0fao= +github.com/holiman/bloomfilter/v2 v2.0.3/go.mod h1:zpoh+gs7qcpqrHr3dB55AMiJwo0iURXE7ZOP9L9hSkA= github.com/holiman/uint256 v1.2.4 h1:jUc4Nk8fm9jZabQuqr2JzednajVmBpC+oiTiXZJEApU= github.com/holiman/uint256 v1.2.4/go.mod h1:EOMSn4q6Nyt9P6efbI3bueV4e1b3dGlUCXeiRV4ng7E= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/klauspost/compress v1.15.15 h1:EF27CXIuDsYJ6mmvtBRlEuB2UVOqHG1tAXgZ7yIO+lw= +github.com/klauspost/compress v1.15.15/go.mod h1:ZcK2JAFqKOpnBlxcLsJzYfrS9X1akm9fHZNnD9+Vo/4= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/leanovate/gopter v0.2.9 h1:fQjYxZaynp97ozCzfOyOuAGOU4aU/z37zf/tOujFk7c= +github.com/leanovate/gopter v0.2.9/go.mod h1:U2L/78B+KVFIx2VmW6onHJQzXtFb+p5y3y2Sh+Jxxv8= +github.com/mattn/go-runewidth v0.0.13 h1:lTGmDsbAYt5DmK6OnoV7EuIF1wEIFAcxld6ypU4OSgU= +github.com/mattn/go-runewidth v0.0.13/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= +github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= +github.com/mmcloughlin/addchain v0.4.0 h1:SobOdjm2xLj1KkXN5/n0xTIWyZA2+s99UCY1iPfkHRY= +github.com/mmcloughlin/addchain v0.4.0/go.mod h1:A86O+tHqZLMNO4w6ZZ4FlVQEadcoqkyU72HC5wJ4RlU= +github.com/mmcloughlin/profile v0.1.1/go.mod h1:IhHD7q1ooxgwTgjxQYkACGA77oFTDdFVejUS1/tS/qU= github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o= github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc= +github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= +github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v1.14.0 h1:nJdhIvne2eSX/XRAFV9PcvFFRbrjbcTUj0VP62TMhnw= +github.com/prometheus/client_golang v1.14.0/go.mod h1:8vpkKitgIVNcqrRBWh1C4TIUQgYNtG/XQE4E/Zae36Y= +github.com/prometheus/client_model v0.3.0 h1:UBgGFHqYdG/TPFD1B1ogZywDqEkwp3fBMvqdiQ7Xew4= +github.com/prometheus/client_model v0.3.0/go.mod h1:LDGWKZIo7rky3hgvBe+caln+Dr3dPggB5dvjtD7w9+w= +github.com/prometheus/common v0.42.0 h1:EKsfXEYo4JpWMHH5cg+KOUWeuJSov1Id8zGR8eeI1YM= +github.com/prometheus/common v0.42.0/go.mod h1:xBwqVerjNdUDjgODMpudtOMwlOwf2SaTr1yjz4b7Zbc= +github.com/prometheus/procfs v0.10.1 h1:kYK1Va/YMlutzCGazswoHKo//tZVlFpKYh+PymziUAg= +github.com/prometheus/procfs v0.10.1/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPHWJq+XBB/FM= +github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/sanity-io/litter v1.5.1 h1:dwnrSypP6q56o3lFxTU+t2fwQ9A+U5qrXVO4Qg9KwVU= github.com/sanity-io/litter v1.5.1/go.mod h1:5Z71SvaYy5kcGtyglXOC9rrUi3c1E8CamFWjQsazTh0= +github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI= +github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/supranational/blst v0.3.11 h1:LyU6FolezeWAhvQk0k6O/d49jqgO52MSDDfYgbeoEm4= +github.com/supranational/blst v0.3.11/go.mod h1:jZJtfjgudtNl4en1tzwPIV3KjUnQUvG3/j+w+fVonLw= +github.com/syndtr/goleveldb v1.0.1-0.20220614013038-64ee5596c38a h1:1ur3QoCqvE5fl+nylMaIr9PVV1w343YRDtsy+Rwu7XI= +github.com/syndtr/goleveldb v1.0.1-0.20220614013038-64ee5596c38a/go.mod h1:RRCYJbIwD5jmqPI9XoAFR0OcDxqUctll6zUj/+B4S48= github.com/thepudds/fzgen v0.4.2 h1:HlEHl5hk2/cqEomf2uK5SA/FeJc12s/vIHmOG+FbACw= github.com/thepudds/fzgen v0.4.2/go.mod h1:kHCWdsv5tdnt32NIHYDdgq083m6bMtaY0M+ipiO9xWE= +github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= +github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= +github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= +github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY= +github.com/yusufpapurcu/wmi v1.2.2 h1:KBNDSne4vP5mbSWnJbO+51IMOXJB67QiYCSBrubbPRg= +github.com/yusufpapurcu/wmi v1.2.2/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= @@ -47,14 +138,24 @@ golang.org/x/exp v0.0.0-20231127185646-65229373498e h1:Gvh4YaCaXNs6dKTlfgismwWZK golang.org/x/exp v0.0.0-20231127185646-65229373498e/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI= golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0= golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA= golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c= gonum.org/v1/gonum v0.11.0 h1:f1IJhK4Km5tBJmaiJXtk/PkL4cdVX6J+tGiM187uT5E= gonum.org/v1/gonum v0.11.0/go.mod h1:fSG4YDCxxUZQJ7rKsQrj0gMOg00Il0Z96/qMA4bVQhA= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +rsc.io/tmplfunc v0.0.3 h1:53XFQh69AfOa8Tw0Jm7t+GV7KZhOi6jzsCzTtKbMvzU= +rsc.io/tmplfunc v0.0.3/go.mod h1:AG3sTPzElb1Io3Yg4voV9AGZJuleGAwaVRxL9M49PhA= diff --git a/x/gethclone/main.go b/x/gethclone/main.go index 68397821f1..351312e87f 100644 --- a/x/gethclone/main.go +++ b/x/gethclone/main.go @@ -14,6 +14,9 @@ import ( func main() { c := config{ astPatches: make(astpatch.PatchRegistry), + patchSets: []patchSet{ + &statefulPrecompiles{}, + }, } pflag.StringSliceVar(&c.packages, "packages", []string{"core/vm"}, `Geth packages to clone, with or without "github.com/ethereum/go-ethereum" prefix.`) diff --git a/x/gethclone/stateful_precompiles.go b/x/gethclone/stateful_precompiles.go new file mode 100644 index 0000000000..6fe7531451 --- /dev/null +++ b/x/gethclone/stateful_precompiles.go @@ -0,0 +1,96 @@ +package main + +import ( + "fmt" + "go/ast" + "math/big" + + "github.com/ava-labs/subnet-evm/x/gethclone/astpatch" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/vm" + "golang.org/x/tools/go/ast/astutil" +) + +// statefulPrecompiles is a `patchSet` that modifies the way the EVM call +// methods dispatch to precompiled contracts, allowing for integration with +// Avalanche stateful precompiles. +type statefulPrecompiles struct { + patchedMethods map[string]bool +} + +func (*statefulPrecompiles) name() string { + return "stateful-precompiles" +} + +func (*statefulPrecompiles) evmCallMethods() []string { + return []string{"Call", "CallCode", "DelegateCall", "StaticCall"} +} + +func (p *statefulPrecompiles) register(reg astpatch.PatchRegistry) { + if p.patchedMethods == nil { + p.patchedMethods = make(map[string]bool) + } + + for _, method := range p.evmCallMethods() { + reg.Add( + geth("core/vm"), + astpatch.Method("EVM", method, p.patchRunPrecompiledCalls), + ) + } +} + +// validate returns nil iff all `evmCallMethods()` were patched. +func (p *statefulPrecompiles) validate() error { + for _, m := range p.evmCallMethods() { + if !p.patchedMethods[m] { + return fmt.Errorf("%T.%s() not patched", (*vm.EVM)(nil), m) + } + } + return nil +} + +func (p *statefulPrecompiles) patchRunPrecompiledCalls(_ *astutil.Cursor, fn *ast.FuncDecl) error { + { + // This block only locks in the assumptions we're making in the patch + // that follows. By doing so, we (a) communicate intent should a merge + // conflict arise in the future, and (b) ensure that assumptions can be + // programatically verified (by the compiler). + + type ( + // We need to propagate the caller and do so by assuming that it's + // the first parameter. + callWithValue func(caller vm.ContractRef, _ common.Address, input []byte, gas uint64, value *big.Int) ([]byte, uint64, error) + callWithoutValue func(vm.ContractRef, common.Address, []byte, uint64) ([]byte, uint64, error) + ) + + var ( + evm = (*vm.EVM)(nil) + + _, _ callWithValue = evm.Call, evm.CallCode + _, _ callWithoutValue = evm.DelegateCall, evm.StaticCall + + // We simply extend the parameter list of the regular calls. + _ func(_ vm.PrecompiledContract, input []byte, gas uint64) ([]byte, uint64, error) = vm.RunPrecompiledContract + ) + } + + _, err := astpatch.Apply(fn, + astpatch.UnqualifiedCall("RunPrecompiledContract", func(_ *astutil.Cursor, call *ast.CallExpr) error { + call.Fun = ast.NewIdent("RunStatefulPrecompiledContract") + + call.Args = append( + call.Args, + ast.NewIdent(fn.Type.Params.List[0].Names[0].Name), + &ast.SelectorExpr{ + X: ast.NewIdent(fn.Recv.List[0].Names[0].Name), + Sel: ast.NewIdent("interpreter.readOnly"), + }, + ) + + p.patchedMethods[fn.Name.Name] = true + return nil + }), + nil, + ) + return err +} From 44a72855ebcbbb6a10eda8525eca8509b6678729 Mon Sep 17 00:00:00 2001 From: Arran Schlosberg Date: Thu, 27 Jun 2024 13:30:20 +0100 Subject: [PATCH 2/7] doc: explain outcome of patched `RunPrecompiledContract()` calls --- x/gethclone/stateful_precompiles.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/x/gethclone/stateful_precompiles.go b/x/gethclone/stateful_precompiles.go index 6fe7531451..1317dc9384 100644 --- a/x/gethclone/stateful_precompiles.go +++ b/x/gethclone/stateful_precompiles.go @@ -49,6 +49,17 @@ func (p *statefulPrecompiles) validate() error { return nil } +// patchRunPrecompiledCalls finds all `RunPrecompiledContract()` calls inside +// `fn` and changes them to (a) call a different function; and (b) also +// propagate fn's first argument (the caller) and `evm.interpreter.readOnly`. +// +// RunPrecompiledContract(p, input gas) +// // becomes +// RunStatefulPrecompiledContract(p, input, gas, caller, evm.interpreter.readOnly) +// +// The definition of `RunStatefulPrecompiledContract()` SHOULD be implemented as +// regular Go code. The determination of whether `p` is stateful or not can be +// achieved with a type switch. func (p *statefulPrecompiles) patchRunPrecompiledCalls(_ *astutil.Cursor, fn *ast.FuncDecl) error { { // This block only locks in the assumptions we're making in the patch From 6e719062f55058c31a7b631db740c3aefdd0d5da Mon Sep 17 00:00:00 2001 From: Arran Schlosberg Date: Thu, 27 Jun 2024 13:44:07 +0100 Subject: [PATCH 3/7] feat: rename `vm.RunPrecompiledContract` to avoid accidental calls --- x/gethclone/astpatch/funcs.go | 27 +++++++++++++++++++++++++++ x/gethclone/stateful_precompiles.go | 9 ++++++++- 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/x/gethclone/astpatch/funcs.go b/x/gethclone/astpatch/funcs.go index fcdd872988..c00b405acb 100644 --- a/x/gethclone/astpatch/funcs.go +++ b/x/gethclone/astpatch/funcs.go @@ -102,3 +102,30 @@ func UnqualifiedCall(name string, patch func(*astutil.Cursor, *ast.CallExpr) err return patch(c, call) } } + +// Function returns a `TypePatcher` that only applies to the specific function +// declaration. +// +// The `patch` argument functions like a regular `Patch` except that its +// parameters are extended to also accept the methods's AST declaration as its +// concrete type (i.e. `astutil.Cursor.Node().(*ast.FuncDecl)`). +func Function(name string, patch func(*astutil.Cursor, *ast.FuncDecl) error) TypePatcher { + return typePatcher{ + typ: (*ast.FuncDecl)(nil), + patch: func(c *astutil.Cursor) error { + fn, ok := c.Node().(*ast.FuncDecl) + if !ok || fn.Name.Name != name { + return nil + } + return patch(c, fn) + }, + } +} + +// RenameFunction does what it says on the tin. +func RenameFunction(from, to string) TypePatcher { + return Function(from, func(c *astutil.Cursor, fn *ast.FuncDecl) error { + fn.Name.Name = to + return nil + }) +} diff --git a/x/gethclone/stateful_precompiles.go b/x/gethclone/stateful_precompiles.go index 1317dc9384..3b2e3b0ec6 100644 --- a/x/gethclone/stateful_precompiles.go +++ b/x/gethclone/stateful_precompiles.go @@ -37,6 +37,13 @@ func (p *statefulPrecompiles) register(reg astpatch.PatchRegistry) { astpatch.Method("EVM", method, p.patchRunPrecompiledCalls), ) } + + // Rename `RunPrecompiledContract()` to avoid it being called in production. + // An alias should be created in *test* file(s) that call the original. + reg.Add( + geth("core/vm"), + astpatch.RenameFunction("RunPrecompiledContract", "geth_RunPrecompiledContract"), + ) } // validate returns nil iff all `evmCallMethods()` were patched. @@ -85,7 +92,7 @@ func (p *statefulPrecompiles) patchRunPrecompiledCalls(_ *astutil.Cursor, fn *as ) } - _, err := astpatch.Apply(fn, + _, err := astpatch.Apply(fn.Body, astpatch.UnqualifiedCall("RunPrecompiledContract", func(_ *astutil.Cursor, call *ast.CallExpr) error { call.Fun = ast.NewIdent("RunStatefulPrecompiledContract") From 91e1a9916ae030fe22b30adfa7c67c83d21d2a68 Mon Sep 17 00:00:00 2001 From: Arran Schlosberg Date: Thu, 27 Jun 2024 14:04:05 +0100 Subject: [PATCH 4/7] chore: fix comment and shorten loop variable --- x/gethclone/stateful_precompiles.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/x/gethclone/stateful_precompiles.go b/x/gethclone/stateful_precompiles.go index 3b2e3b0ec6..ae14d3496b 100644 --- a/x/gethclone/stateful_precompiles.go +++ b/x/gethclone/stateful_precompiles.go @@ -31,10 +31,10 @@ func (p *statefulPrecompiles) register(reg astpatch.PatchRegistry) { p.patchedMethods = make(map[string]bool) } - for _, method := range p.evmCallMethods() { + for _, m := range p.evmCallMethods() { reg.Add( geth("core/vm"), - astpatch.Method("EVM", method, p.patchRunPrecompiledCalls), + astpatch.Method("EVM", m, p.patchRunPrecompiledCalls), ) } @@ -57,8 +57,9 @@ func (p *statefulPrecompiles) validate() error { } // patchRunPrecompiledCalls finds all `RunPrecompiledContract()` calls inside -// `fn` and changes them to (a) call a different function; and (b) also -// propagate fn's first argument (the caller) and `evm.interpreter.readOnly`. +// `fn.Body` and changes them to (a) call a different function; and (b) also +// propagate fn's first argument (the caller) and `evm.interpreter.readOnly`, +// where `evm` is the receiver name of the function being patched. // // RunPrecompiledContract(p, input gas) // // becomes From 98fca625638d0711a45c2b3977d4e69f866609a7 Mon Sep 17 00:00:00 2001 From: Arran Schlosberg Date: Thu, 27 Jun 2024 16:25:35 +0100 Subject: [PATCH 5/7] fix: `astpatch` documentation --- x/gethclone/astpatch/astpatch.go | 7 ++++--- x/gethclone/astpatch/funcs.go | 3 +-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/x/gethclone/astpatch/astpatch.go b/x/gethclone/astpatch/astpatch.go index 2bcf5b5c52..8b27ce5a35 100644 --- a/x/gethclone/astpatch/astpatch.go +++ b/x/gethclone/astpatch/astpatch.go @@ -49,8 +49,8 @@ func Apply(root ast.Node, pre, post Patch) (ast.Node, error) { // // The special `pkgPath` value "*" will match all package paths. While there is // no specific requirement for `pkgPath` other than it matching the equivalent -// argument passed to `Apply()`, it is typically sourced from -// `golang.org/x/tools/go/packages.Package.PkgPath`. +// argument passed to `Apply()`, it is typically the import path of the package +// being patched. func (r PatchRegistry) AddForType(pkgPath string, zeroNode ast.Node, fn Patch) { pkg, ok := r[pkgPath] if !ok { @@ -64,7 +64,8 @@ func (r PatchRegistry) AddForType(pkgPath string, zeroNode ast.Node, fn Patch) { // A TypePatcher couples a `Patch` with the specific `ast.Node` type to which it // applies. It is useful when `PatchRegistry.AddForType()` MUST receive a -// specific `Node` type for a particular `Patch`. +// specific `Node` type for a particular `Patch`, in which case +// `PatchRegistry.Add()` SHOULD be used instead. type TypePatcher interface { Type() ast.Node Patch(*astutil.Cursor) error diff --git a/x/gethclone/astpatch/funcs.go b/x/gethclone/astpatch/funcs.go index c00b405acb..6d8b2e6b68 100644 --- a/x/gethclone/astpatch/funcs.go +++ b/x/gethclone/astpatch/funcs.go @@ -14,9 +14,8 @@ import ( // parameters are extended to also accept the methods's AST declaration as its // concrete type (i.e. `astutil.Cursor.Node().(*ast.FuncDecl)`). // -// // Original declaration +// // Method declaration // func (x *Thing) Do() { ... } -// // // Patched with // astpatch.Method("Thing", "Do", ...) func Method(receiverType, methodName string, patch func(*astutil.Cursor, *ast.FuncDecl) error) TypePatcher { From e5d8e32dbbb60d7d8bf53000eda1a4eb0cce24b9 Mon Sep 17 00:00:00 2001 From: Arran Schlosberg Date: Fri, 28 Jun 2024 13:21:04 +0100 Subject: [PATCH 6/7] test: `astpatch.*Method()` + `UnqualifiedCall()` + `Function()` + `RenameFunction()` --- x/gethclone/astpatch/astpatch_test.go | 314 ++++++++++++++++++++++++-- 1 file changed, 300 insertions(+), 14 deletions(-) diff --git a/x/gethclone/astpatch/astpatch_test.go b/x/gethclone/astpatch/astpatch_test.go index 1a30db8303..5d95639a8f 100644 --- a/x/gethclone/astpatch/astpatch_test.go +++ b/x/gethclone/astpatch/astpatch_test.go @@ -4,18 +4,32 @@ import ( "bytes" "fmt" "go/ast" + "go/format" "go/parser" "go/token" + "reflect" "testing" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/tools/go/ast/astutil" ) type patchSpy struct { - gotFuncs, gotStructs []string + visited +} + +type visited struct { + Funcs, Structs, Calls []string + FuncParams [][]string +} + +// assertEqual asserts that `v == want`, reporting a diff otherwise. +func (v visited) assertEqual(tb testing.TB, want visited) { + tb.Helper() + assert.Empty(tb, cmp.Diff(want, v, cmpopts.EquateEmpty()), "visited nodes diff (-want +got)") } const errorIfFuncName = "ErrorFuncName" @@ -23,28 +37,68 @@ const errorIfFuncName = "ErrorFuncName" var errFuncName = fmt.Errorf("encountered sentinel function %q", errorIfFuncName) func (s *patchSpy) funcRecorder(c *astutil.Cursor) error { - name := c.Node().(*ast.FuncDecl).Name.String() + fn, ok := c.Node().(*ast.FuncDecl) + if !ok { + return fmt.Errorf("%T.funcRecorder() called with %T not %T", s, c.Node(), fn) + } + + name := fn.Name.String() if name == errorIfFuncName { return errFuncName } - s.gotFuncs = append(s.gotFuncs, name) + s.Funcs = append(s.Funcs, name) + + var params []string + for _, p := range fn.Type.Params.List { + // Params of the same type but different name are grouped together in + // AST nodes + for _, n := range p.Names { + params = append(params, n.Name) + } + } + s.FuncParams = append(s.FuncParams, params) + return nil } func (s *patchSpy) structRecorder(c *astutil.Cursor) error { switch p := c.Parent().(type) { case *ast.TypeSpec: // it's a `type x struct` not, for example, a `map[T]struct{}` - s.gotStructs = append(s.gotStructs, p.Name.String()) + s.Structs = append(s.Structs, p.Name.String()) } return nil } +func (s *patchSpy) funcDeclRecorder(c *astutil.Cursor, fn *ast.FuncDecl) error { + if !reflect.DeepEqual(c.Node(), fn) { + return fmt.Errorf("reflect.DeepEqual(%T.Node(), %T) = false; want true", c, fn) + } + return s.funcRecorder(c) +} + +func (s *patchSpy) callRecorder(c *astutil.Cursor, call *ast.CallExpr) error { + if !reflect.DeepEqual(c.Node(), call) { + return fmt.Errorf("reflect.DeepEqual(%T.Node(), %T) = false; want true", c, call) + } + + var name string + switch fn := call.Fun.(type) { + case *ast.Ident: + name = fn.Name + default: + return fmt.Errorf("incomplete test double: %T.callRecorder() called with %T.Fun of unsupported type %T", s, call, call.Fun) + } + s.Calls = append(s.Calls, name) + + return nil +} + func TestPatchRegistry(t *testing.T) { tests := []struct { - name string - src string - wantErr error - wantFuncs, wantStructs []string + name string + src string + wantErr error + want visited }{ { name: "happy path", @@ -58,8 +112,11 @@ type StructA struct{} type StructB struct{} `, - wantFuncs: []string{"FnA", "FnB"}, - wantStructs: []string{"StructA", "StructB"}, + want: visited{ + Funcs: []string{"FnA", "FnB"}, + FuncParams: [][]string{{}, {}}, + Structs: []string{"StructA", "StructB"}, + }, }, { name: "error propagation", @@ -69,8 +126,11 @@ func HappyFn() {} func ` + errorIfFuncName + `() {} `, - wantErr: errFuncName, - wantFuncs: []string{"HappyFn"}, + wantErr: errFuncName, + want: visited{ + Funcs: []string{"HappyFn"}, + FuncParams: [][]string{{}}, + }, }, } @@ -97,8 +157,7 @@ func ` + errorIfFuncName + `() {} if _, gotErr := reg.Apply(pkgPath, file); gotErr != tt.wantErr { t.Fatalf("%T.Apply(...) got err %v; want %v", reg, gotErr, tt.wantErr) } - assert.Empty(t, cmp.Diff(tt.wantFuncs, spy.gotFuncs), "encountered function declarations (-want +got)") - assert.Empty(t, cmp.Diff(tt.wantStructs, spy.gotStructs), "encountered struct-type declarations (-want +got)") + spy.visited.assertEqual(t, tt.want) }) } } @@ -119,3 +178,230 @@ func bestEffortLogAST(tb testing.TB, x any) { } tb.Logf("AST of parsed source:\n\n%s", buf.String()) } + +func TestTypePatchers(t *testing.T) { + const src = `package box + +type ( + TypeA struct {} + TypeB int +) + +func (TypeA) ValueMethod(a int) {} +func (TypeA) valueMethod(a int) {} +func (TypeB) ValueMethod(b int) {} +func (TypeB) valueMethod(b int) {} + +func (*TypeA) PointerMethod(a int) {} +func (*TypeA) pointerMethod(a int) {} +func (*TypeB) PointerMethod(b int) {} +func (*TypeB) pointerMethod(b int) {} + +func Fn() {} +func fn() {} + +func calledFn() {} +func notCalledFn() {} +func init() { + calledFn() +} +` + + tests := []struct { + name string + patcher func(*patchSpy) TypePatcher + want visited + }{ + { + name: "method agnostic to pointer/value receiver", + patcher: func(s *patchSpy) TypePatcher { + return Method("TypeA", "valueMethod", s.funcDeclRecorder) + }, + want: visited{ + Funcs: []string{"valueMethod"}, + FuncParams: [][]string{{"a"}}, + }, + }, + { + name: "exported method, otherwise same as earlier test", + patcher: func(s *patchSpy) TypePatcher { + return Method("TypeA", "ValueMethod", s.funcDeclRecorder) + }, + want: visited{ + Funcs: []string{"ValueMethod"}, + FuncParams: [][]string{{"a"}}, + }, + }, + { + name: "method on different type, otherwise same as earlier test", + patcher: func(s *patchSpy) TypePatcher { + return Method("TypeB", "valueMethod", s.funcDeclRecorder) + }, + want: visited{ + Funcs: []string{"valueMethod"}, + FuncParams: [][]string{{"b"}}, + }, + }, + { + name: "PointerMethod() with pointer receiver matches", + patcher: func(s *patchSpy) TypePatcher { + return PointerMethod("TypeA", "pointerMethod", s.funcDeclRecorder) + }, + want: visited{ + Funcs: []string{"pointerMethod"}, + FuncParams: [][]string{{"a"}}, + }, + }, + { + name: "PointerMethod() with value receiver ignores", + patcher: func(s *patchSpy) TypePatcher { + return PointerMethod("TypeA", "valueMethod", s.funcDeclRecorder) + }, + want: visited{}, + }, + { + name: "ValueMethod() with value receiver matches", + patcher: func(s *patchSpy) TypePatcher { + return ValueMethod("TypeA", "valueMethod", s.funcDeclRecorder) + }, + want: visited{ + Funcs: []string{"valueMethod"}, + FuncParams: [][]string{{"a"}}, + }, + }, + { + name: "ValueMethod() with pointer receiver ignores", + patcher: func(s *patchSpy) TypePatcher { + return ValueMethod("TypeA", "pointerMethod", s.funcDeclRecorder) + }, + want: visited{}, + }, + { + name: "function (not method) declaration", + patcher: func(s *patchSpy) TypePatcher { + return Function("fn", s.funcDeclRecorder) + }, + want: visited{ + Funcs: []string{"fn"}, + FuncParams: [][]string{{}}, + }, + }, + { + name: "function of different name to earlier test", + patcher: func(s *patchSpy) TypePatcher { + return Function("Fn", s.funcDeclRecorder) + }, + want: visited{ + Funcs: []string{"Fn"}, + FuncParams: [][]string{{}}, + }, + }, + { + name: "unqualified function call", + patcher: func(s *patchSpy) TypePatcher { + return typePatcher{ + typ: new(ast.CallExpr), + patch: UnqualifiedCall("calledFn", s.callRecorder), + } + }, + want: visited{Calls: []string{"calledFn"}}, + }, + { + name: "UnqualifiedCall() for function not actually called", + patcher: func(s *patchSpy) TypePatcher { + return typePatcher{ + typ: new(ast.CallExpr), + patch: UnqualifiedCall("notCalledFn", s.callRecorder), + } + }, + want: visited{Calls: nil}, + }, + } + + file := parseGoFile(t, token.NewFileSet(), src) + bestEffortLogAST(t, file) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var spy patchSpy + reg := make(PatchRegistry) + reg.Add("*", tt.patcher(&spy)) + + _, err := reg.Apply("", file) + require.NoErrorf(t, err, `%T.Apply(...)`, reg) + spy.visited.assertEqual(t, tt.want) + }) + } +} + +func TestRefactoring(t *testing.T) { + tests := []struct { + name, src string + patcher TypePatcher + want string + }{ + { + name: `RenameFunction("foo", "phew")`, + src: ` +package tape + +func foo() {} +func bar() {} +`, + patcher: RenameFunction("foo", "phew"), + want: ` +package tape + +func phew() {} +func bar() {} +`, + }, + { + name: `RenameFunction("bar", "pub")`, + src: ` +package tape + +func foo() {} +func bar() {} +`, + patcher: RenameFunction("bar", "pub"), + want: ` +package tape + +func foo() {} +func pub() {} +`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reg := make(PatchRegistry) + reg.Add("*", tt.patcher) + + fset := token.NewFileSet() + + gotNode, err := reg.Apply("", parseGoFile(t, fset, tt.src)) + require.NoErrorf(t, err, "%T.Apply(...)", reg) + + var got bytes.Buffer + require.NoErrorf(t, format.Node(&got, fset, gotNode), "format.Node(..., [output of %T.Apply(...)])", reg) + + assert.Equal(t, got.String(), formatGo(t, tt.want), "output of format.Node() after patching") + }) + } +} + +// formatGo parses the file represented by `src` into AST and returns the output +// of `format.Node()`. This allows expected test values to be formatted +// incorrectly without affecting equality checks. +func formatGo(tb testing.TB, src string) string { + tb.Helper() + + fset := token.NewFileSet() + file := parseGoFile(tb, fset, src) + + var buf bytes.Buffer + require.NoError(tb, format.Node(&buf, fset, file), "format.Node(parser.ParseFile(...)) round trip") + return buf.String() +} From a17694faf797d8c526af9ce4e0ae0036a20b2991 Mon Sep 17 00:00:00 2001 From: Arran Schlosberg Date: Sun, 30 Jun 2024 10:40:22 +0100 Subject: [PATCH 7/7] refactor: introduce generic `TypedPatch[N ast.Node]()` function type --- x/gethclone/astpatch/astpatch.go | 20 +++++--- x/gethclone/astpatch/astpatch_test.go | 66 ++++++++++++++++++--------- x/gethclone/astpatch/funcs.go | 36 +++++---------- 3 files changed, 70 insertions(+), 52 deletions(-) diff --git a/x/gethclone/astpatch/astpatch.go b/x/gethclone/astpatch/astpatch.go index 8b27ce5a35..a4509580ec 100644 --- a/x/gethclone/astpatch/astpatch.go +++ b/x/gethclone/astpatch/astpatch.go @@ -15,6 +15,12 @@ type ( // A non-nil error is equivalent to returning false and will also abort all // further calls to other patches. Patch func(*astutil.Cursor) error + // A TypedPatch functions identically to a `Patch` except that it also + // receives the `ast.Node` as its concrete type. + // + // Invariant: `c.Node()` and `n` MUST be of the same concrete type and point + // to the same memory; i.e. `c.Node().(N) == n` doesn't panic and is true. + TypedPatch[N ast.Node] func(c *astutil.Cursor, n N) error // A PatchRegistry maps [Go package path] -> [ast.Node concrete types] -> // [all `Patch` functions that must be applied to said node types in said // package]. @@ -62,29 +68,29 @@ func (r PatchRegistry) AddForType(pkgPath string, zeroNode ast.Node, fn Patch) { pkg[t] = append(pkg[t], fn) } -// A TypePatcher couples a `Patch` with the specific `ast.Node` type to which it +// A Patcher couples a `Patch` with the specific `ast.Node` type to which it // applies. It is useful when `PatchRegistry.AddForType()` MUST receive a // specific `Node` type for a particular `Patch`, in which case // `PatchRegistry.Add()` SHOULD be used instead. -type TypePatcher interface { +type Patcher interface { Type() ast.Node Patch(*astutil.Cursor) error } // Add is a synonym of `AddForType()`, instead accepting an argument that // provides the `Node` type and the `Patch`. -func (r PatchRegistry) Add(pkgPath string, tp TypePatcher) { +func (r PatchRegistry) Add(pkgPath string, tp Patcher) { r.AddForType(pkgPath, tp.Type(), tp.Patch) } -// typePatcher implements the `TypePatcher` interface. -type typePatcher struct { +// patcher implements the `Patcher` interface. +type patcher struct { typ ast.Node patch Patch } -func (p typePatcher) Type() ast.Node { return p.typ } -func (p typePatcher) Patch(c *astutil.Cursor) error { return p.patch(c) } +func (p patcher) Type() ast.Node { return p.typ } +func (p patcher) Patch(c *astutil.Cursor) error { return p.patch(c) } // Apply calls `astutil.Apply()` on `node`, calling the appropriate `Patch` // functions as the syntax tree is traversed. Patches are applied as the `pre` diff --git a/x/gethclone/astpatch/astpatch_test.go b/x/gethclone/astpatch/astpatch_test.go index 5d95639a8f..a1a27a8395 100644 --- a/x/gethclone/astpatch/astpatch_test.go +++ b/x/gethclone/astpatch/astpatch_test.go @@ -7,7 +7,6 @@ import ( "go/format" "go/parser" "go/token" - "reflect" "testing" "github.com/google/go-cmp/cmp" @@ -70,15 +69,15 @@ func (s *patchSpy) structRecorder(c *astutil.Cursor) error { } func (s *patchSpy) funcDeclRecorder(c *astutil.Cursor, fn *ast.FuncDecl) error { - if !reflect.DeepEqual(c.Node(), fn) { - return fmt.Errorf("reflect.DeepEqual(%T.Node(), %T) = false; want true", c, fn) + if err := checkTypedPatchInvariant(c, fn); err != nil { + return err } return s.funcRecorder(c) } func (s *patchSpy) callRecorder(c *astutil.Cursor, call *ast.CallExpr) error { - if !reflect.DeepEqual(c.Node(), call) { - return fmt.Errorf("reflect.DeepEqual(%T.Node(), %T) = false; want true", c, call) + if err := checkTypedPatchInvariant(c, call); err != nil { + return err } var name string @@ -179,7 +178,7 @@ func bestEffortLogAST(tb testing.TB, x any) { tb.Logf("AST of parsed source:\n\n%s", buf.String()) } -func TestTypePatchers(t *testing.T) { +func TestPatchers(t *testing.T) { const src = `package box type ( @@ -209,12 +208,12 @@ func init() { tests := []struct { name string - patcher func(*patchSpy) TypePatcher + patcher func(*patchSpy) Patcher want visited }{ { name: "method agnostic to pointer/value receiver", - patcher: func(s *patchSpy) TypePatcher { + patcher: func(s *patchSpy) Patcher { return Method("TypeA", "valueMethod", s.funcDeclRecorder) }, want: visited{ @@ -224,7 +223,7 @@ func init() { }, { name: "exported method, otherwise same as earlier test", - patcher: func(s *patchSpy) TypePatcher { + patcher: func(s *patchSpy) Patcher { return Method("TypeA", "ValueMethod", s.funcDeclRecorder) }, want: visited{ @@ -234,7 +233,7 @@ func init() { }, { name: "method on different type, otherwise same as earlier test", - patcher: func(s *patchSpy) TypePatcher { + patcher: func(s *patchSpy) Patcher { return Method("TypeB", "valueMethod", s.funcDeclRecorder) }, want: visited{ @@ -244,7 +243,7 @@ func init() { }, { name: "PointerMethod() with pointer receiver matches", - patcher: func(s *patchSpy) TypePatcher { + patcher: func(s *patchSpy) Patcher { return PointerMethod("TypeA", "pointerMethod", s.funcDeclRecorder) }, want: visited{ @@ -254,14 +253,14 @@ func init() { }, { name: "PointerMethod() with value receiver ignores", - patcher: func(s *patchSpy) TypePatcher { + patcher: func(s *patchSpy) Patcher { return PointerMethod("TypeA", "valueMethod", s.funcDeclRecorder) }, want: visited{}, }, { name: "ValueMethod() with value receiver matches", - patcher: func(s *patchSpy) TypePatcher { + patcher: func(s *patchSpy) Patcher { return ValueMethod("TypeA", "valueMethod", s.funcDeclRecorder) }, want: visited{ @@ -271,14 +270,14 @@ func init() { }, { name: "ValueMethod() with pointer receiver ignores", - patcher: func(s *patchSpy) TypePatcher { + patcher: func(s *patchSpy) Patcher { return ValueMethod("TypeA", "pointerMethod", s.funcDeclRecorder) }, want: visited{}, }, { name: "function (not method) declaration", - patcher: func(s *patchSpy) TypePatcher { + patcher: func(s *patchSpy) Patcher { return Function("fn", s.funcDeclRecorder) }, want: visited{ @@ -288,7 +287,7 @@ func init() { }, { name: "function of different name to earlier test", - patcher: func(s *patchSpy) TypePatcher { + patcher: func(s *patchSpy) Patcher { return Function("Fn", s.funcDeclRecorder) }, want: visited{ @@ -298,8 +297,8 @@ func init() { }, { name: "unqualified function call", - patcher: func(s *patchSpy) TypePatcher { - return typePatcher{ + patcher: func(s *patchSpy) Patcher { + return patcher{ typ: new(ast.CallExpr), patch: UnqualifiedCall("calledFn", s.callRecorder), } @@ -308,8 +307,8 @@ func init() { }, { name: "UnqualifiedCall() for function not actually called", - patcher: func(s *patchSpy) TypePatcher { - return typePatcher{ + patcher: func(s *patchSpy) Patcher { + return patcher{ typ: new(ast.CallExpr), patch: UnqualifiedCall("notCalledFn", s.callRecorder), } @@ -337,7 +336,7 @@ func init() { func TestRefactoring(t *testing.T) { tests := []struct { name, src string - patcher TypePatcher + patcher Patcher want string }{ { @@ -405,3 +404,28 @@ func formatGo(tb testing.TB, src string) string { require.NoError(tb, format.Node(&buf, fset, file), "format.Node(parser.ParseFile(...)) round trip") return buf.String() } + +// checkTypedPatchInvariant accepts the `astutil.Cursor` (`c`) and `ast.Node` +// (`node`) passed to a `TypedPatch` function, confirming that `c.Node()` is (a) +// of the same concrete type as `node`; and (b) points to the same memory. +func checkTypedPatchInvariant[ + N any, PtrN interface { + ast.Node + *N + }, +](c *astutil.Cursor, node PtrN) error { + var tp TypedPatch[PtrN] + + switch cNode := c.Node().(type) { + case PtrN: + if cNode != node { + return fmt.Errorf("%T argument invariant broken; %T.Node() and %T point to different memory", + tp, c, node) + } + default: + return fmt.Errorf("%T argument invariant broken; %T.Node() of concrete type %T; want %T", + tp, c, cNode, node) + } + + return nil +} diff --git a/x/gethclone/astpatch/funcs.go b/x/gethclone/astpatch/funcs.go index 6d8b2e6b68..fc31e97e9b 100644 --- a/x/gethclone/astpatch/funcs.go +++ b/x/gethclone/astpatch/funcs.go @@ -7,37 +7,33 @@ import ( "golang.org/x/tools/go/ast/astutil" ) -// Method returns a `TypePatcher` that only applies to the specific method on -// the specific type. -// -// The `patch` argument functions like a regular `Patch` except that its -// parameters are extended to also accept the methods's AST declaration as its -// concrete type (i.e. `astutil.Cursor.Node().(*ast.FuncDecl)`). +// Method returns a `Patcher` that only applies to the specific method on the +// specific type. // // // Method declaration // func (x *Thing) Do() { ... } // // Patched with // astpatch.Method("Thing", "Do", ...) -func Method(receiverType, methodName string, patch func(*astutil.Cursor, *ast.FuncDecl) error) TypePatcher { +func Method(receiverType, methodName string, patch TypedPatch[*ast.FuncDecl]) Patcher { return method(nil, receiverType, methodName, patch) } // PointerMethod is identical to `Method()` except that it only matches methods // with pointer receivers. -func PointerMethod(receiverType, methodName string, patch func(*astutil.Cursor, *ast.FuncDecl) error) TypePatcher { +func PointerMethod(receiverType, methodName string, patch TypedPatch[*ast.FuncDecl]) Patcher { ptr := true return method(&ptr, receiverType, methodName, patch) } // ValueMethod is identical to `Method()` except that it only matches methods // with value receivers. -func ValueMethod(receiverType, methodName string, patch func(*astutil.Cursor, *ast.FuncDecl) error) TypePatcher { +func ValueMethod(receiverType, methodName string, patch TypedPatch[*ast.FuncDecl]) Patcher { ptr := false return method(&ptr, receiverType, methodName, patch) } -func method(pointerReceiver *bool, receiverType, methodName string, patch func(*astutil.Cursor, *ast.FuncDecl) error) TypePatcher { - return typePatcher{ +func method(pointerReceiver *bool, receiverType, methodName string, patch TypedPatch[*ast.FuncDecl]) Patcher { + return patcher{ typ: (*ast.FuncDecl)(nil), patch: func(c *astutil.Cursor) error { fn, ok := c.Node().(*ast.FuncDecl) @@ -84,11 +80,7 @@ func method(pointerReceiver *bool, receiverType, methodName string, patch func(* // qualifiers before the selector (e.g. `foo.Bar()` or `pkg.Bar()`); an // unqualified function lacks any such qualifiers and applies to builtin and // package-internal functions. -// -// The `patch` argument functions like a regular `Patch` except that its -// parameters are extended to also accept the call's AST declaration as its -// concrete type (i.e. `astutil.Cursor.Node().(*ast.CallExpr)`). -func UnqualifiedCall(name string, patch func(*astutil.Cursor, *ast.CallExpr) error) Patch { +func UnqualifiedCall(name string, patch TypedPatch[*ast.CallExpr]) Patch { return func(c *astutil.Cursor) error { call, ok := c.Node().(*ast.CallExpr) if !ok { @@ -102,14 +94,10 @@ func UnqualifiedCall(name string, patch func(*astutil.Cursor, *ast.CallExpr) err } } -// Function returns a `TypePatcher` that only applies to the specific function +// Function returns a `Patcher` that only applies to the specific function // declaration. -// -// The `patch` argument functions like a regular `Patch` except that its -// parameters are extended to also accept the methods's AST declaration as its -// concrete type (i.e. `astutil.Cursor.Node().(*ast.FuncDecl)`). -func Function(name string, patch func(*astutil.Cursor, *ast.FuncDecl) error) TypePatcher { - return typePatcher{ +func Function(name string, patch TypedPatch[*ast.FuncDecl]) Patcher { + return patcher{ typ: (*ast.FuncDecl)(nil), patch: func(c *astutil.Cursor) error { fn, ok := c.Node().(*ast.FuncDecl) @@ -122,7 +110,7 @@ func Function(name string, patch func(*astutil.Cursor, *ast.FuncDecl) error) Typ } // RenameFunction does what it says on the tin. -func RenameFunction(from, to string) TypePatcher { +func RenameFunction(from, to string) Patcher { return Function(from, func(c *astutil.Cursor, fn *ast.FuncDecl) error { fn.Name.Name = to return nil