diff --git a/pkg/env/env.go b/pkg/env/env.go index 2f4da4b7e3..431b44e174 100644 --- a/pkg/env/env.go +++ b/pkg/env/env.go @@ -3,7 +3,6 @@ package env import ( "bytes" "fmt" - "runtime" "runtime/debug" "strings" @@ -25,7 +24,9 @@ import ( "github.com/open-component-model/ocm/pkg/contexts/datacontext/attrs/vfsattr" "github.com/open-component-model/ocm/pkg/contexts/oci" ocm "github.com/open-component-model/ocm/pkg/contexts/ocm/cpi" + "github.com/open-component-model/ocm/pkg/testutils" "github.com/open-component-model/ocm/pkg/utils" + "github.com/open-component-model/ocm/pkg/utils/pkgutils" ) //////////////////////////////////////////////////////////////////////////////// @@ -232,23 +233,13 @@ func ModifiableTestData(paths ...string) tdOpt { } func projectTestData(modifiable bool, source string, dest ...string) Option { - path := "." - for count := 0; count < 20; count++ { - if ok, err := vfs.FileExists(osfs.OsFs, filepath.Join(path, "go.mod")); err != nil || ok { - if err != nil { - panic(err) - } - path = filepath.Join(path, source) - break - } - if count == 19 { - panic("could not find go.mod (within 20 steps)") - } - - path = filepath.Join(path, "..") + pathToRoot, err := testutils.GetRelativePathToProjectRoot() + if err != nil { + panic(err) } + pathToTestdata := filepath.Join(pathToRoot, source) - return testData(modifiable, path, general.OptionalDefaulted("/testdata", dest...)) + return testData(modifiable, pathToTestdata, general.OptionalDefaulted("/testdata", dest...)) } func ProjectTestData(source string, dest ...string) Option { @@ -260,29 +251,16 @@ func ModifiableProjectTestData(source string, dest ...string) Option { } func projectTestDataForCaller(modifiable bool, dest ...string) Option { - pc, _, _, ok := runtime.Caller(2) - if !ok { - panic("unable to find caller") - } - - // Get the function details from the program counter - caller := runtime.FuncForPC(pc) - if caller == nil { - panic("unable to find caller") + packagePath, err := pkgutils.GetPackageName(2) + if err != nil { + panic(err) } - fullFuncName := caller.Name() - - // Split the name to extract the package path - // Assuming the format: "package/path.functionName" - lastSlashIndex := strings.LastIndex(fullFuncName, "/") - if lastSlashIndex == -1 { - panic("unable to find package name") + moduleName, err := testutils.GetModuleName() + if err != nil { + panic(err) } - - funcIndex := strings.Index(fullFuncName[lastSlashIndex:], ".") - packagePath := fullFuncName[:lastSlashIndex+funcIndex] - path, ok := strings.CutPrefix(packagePath, "github.com/open-component-model/ocm/") + path, ok := strings.CutPrefix(packagePath, moduleName+"/") if !ok { panic("unable to find package name") } diff --git a/pkg/testutils/package.go b/pkg/testutils/package.go new file mode 100644 index 0000000000..2e2925195d --- /dev/null +++ b/pkg/testutils/package.go @@ -0,0 +1,78 @@ +package testutils + +import ( + "fmt" + "strings" + + "github.com/mandelsoft/filepath/pkg/filepath" + "github.com/mandelsoft/goutils/general" + "github.com/mandelsoft/vfs/pkg/osfs" + "github.com/mandelsoft/vfs/pkg/vfs" + "golang.org/x/mod/modfile" + + "github.com/open-component-model/ocm/pkg/utils/pkgutils" +) + +const GO_MOD = "go.mod" + +func GetPackagePathFromProjectRoot(i ...interface{}) (string, error) { + pkg, err := pkgutils.GetPackageName(i...) + if err != nil { + return "", err + } + mod, err := GetModuleName() + if err != nil { + return "", err + } + path, ok := strings.CutPrefix(pkg, mod+"/") + if !ok { + return "", fmt.Errorf("prefix %q not found in %q", mod, pkg) + } + return path, nil +} + +// GetModuleName returns a go modules module name by finding and parsing the go.mod file. +func GetModuleName() (string, error) { + pathToRoot, err := GetRelativePathToProjectRoot() + if err != nil { + return "", err + } + pathToGoMod := filepath.Join(pathToRoot, GO_MOD) + // Read the content of the go.mod file + data, err := vfs.ReadFile(osfs.OsFs, pathToGoMod) + if err != nil { + return "", err + } + + // Parse the go.mod file + modFile, err := modfile.Parse(GO_MOD, data, nil) + if err != nil { + return "", fmt.Errorf("error parsing %s file: %w", GO_MOD, err) + } + + // Print the module path + return modFile.Module.Mod.Path, nil +} + +// GetRelativePathToProjectRoot calculates the relative path to a go projects root directory. +// It therefore assumes that the project root is the directory containing the go.mod file. +// The optional parameter i determines how many directories the function will step up through, attempting to find a +// go.mod file. If it cannot find a directory with a go.mod file within i iterations, the function throws an error. +func GetRelativePathToProjectRoot(i ...int) (string, error) { + iterations := general.OptionalDefaulted(20, i...) + + path := "." + for count := 0; count < iterations; count++ { + if ok, err := vfs.FileExists(osfs.OsFs, filepath.Join(path, GO_MOD)); err != nil || ok { + if err != nil { + return "", fmt.Errorf("failed to check if %s exists: %w", GO_MOD, err) + } + return path, nil + } + if count == iterations { + return "", fmt.Errorf("could not find %s (within %d steps)", GO_MOD, iterations) + } + path = filepath.Join(path, "..") + } + return "", nil +} diff --git a/pkg/testutils/package_test.go b/pkg/testutils/package_test.go new file mode 100644 index 0000000000..8e5805c70a --- /dev/null +++ b/pkg/testutils/package_test.go @@ -0,0 +1,15 @@ +package testutils_test + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + me "github.com/open-component-model/ocm/pkg/testutils" +) + +var _ = Describe("package tests", func() { + It("go module name", func() { + mod := me.Must(me.GetModuleName()) + Expect(mod).To(Equal("github.com/open-component-model/ocm")) + }) +}) diff --git a/pkg/utils/package.go b/pkg/utils/package.go deleted file mode 100644 index 932a765982..0000000000 --- a/pkg/utils/package.go +++ /dev/null @@ -1,44 +0,0 @@ -package utils - -import ( - "fmt" - "reflect" - "runtime" - "strings" -) - -const MODULE_PATH = "github.com/open-component-model/ocm" - -func GetPackageNameForFunc(i interface{}) (string, error) { - // Get the function's pointer - ptr := reflect.ValueOf(i).Pointer() - // Retrieve the function's runtime information - funcForPC := runtime.FuncForPC(ptr) - if funcForPC == nil { - return "", fmt.Errorf("could not determine package name") - } - // Get the full name of the function, including the package path - fullFuncName := funcForPC.Name() - - // Split the name to extract the package path - // Assuming the format: "package/path.functionName" - lastSlashIndex := strings.LastIndex(fullFuncName, "/") - if lastSlashIndex == -1 { - return "", fmt.Errorf("could not determine package name") - } - - packagePath := fullFuncName[:lastSlashIndex] - return packagePath, nil -} - -func GetPackagePathFromProjectRootForFunc(i interface{}) (string, error) { - pkg, err := GetPackageNameForFunc(i) - if err != nil { - return "", err - } - path, ok := strings.CutPrefix(pkg, "github.com/open-component-model/ocm/") - if !ok { - return "", fmt.Errorf("prefix %q not found in %q", MODULE_PATH, pkg) - } - return path, nil -} diff --git a/pkg/utils/pkgutils/package.go b/pkg/utils/pkgutils/package.go new file mode 100644 index 0000000000..857ca195f5 --- /dev/null +++ b/pkg/utils/pkgutils/package.go @@ -0,0 +1,102 @@ +package pkgutils + +import ( + "fmt" + "reflect" + "runtime" + "strings" +) + +// GetPackageName gets the package name for an object, a type, a function or a caller offset. +// +// Examples: +// +// GetPackageName(1) +// GetPackageName(&MyStruct{}) +// GetPackageName(GetPackageName) +// GetPackageName(generics.TypeOf[MyStruct]()) +func GetPackageName(i ...interface{}) (string, error) { + if len(i) == 0 { + i = []interface{}{0} + } + if t, ok := i[0].(reflect.Type); ok { + pkgpath := t.PkgPath() + if pkgpath == "" { + return "", fmt.Errorf("unable to determine package name") + } + return pkgpath, nil + } + v := reflect.ValueOf(i[0]) + for v.Kind() == reflect.Ptr { + v = v.Elem() + } + switch v.Kind() { + case reflect.Func: + return getPackageNameForFuncPC(v.Pointer()) + case reflect.Struct, reflect.Chan, reflect.Map, reflect.Slice, reflect.Array: + pkgpath := v.Type().PkgPath() + if pkgpath == "" { + return "", fmt.Errorf("unable to determine package name") + } + return pkgpath, nil + default: + offset, err := CastInt(v.Interface()) + if err != nil { + return "", err + } + pc, _, _, ok := runtime.Caller(offset + 1) + if !ok { + return "", fmt.Errorf("unable to find caller") + } + return getPackageNameForFuncPC(pc) + } +} + +func getPackageNameForFuncPC(pc uintptr) (string, error) { + // Retrieve the function's runtime information + funcForPC := runtime.FuncForPC(pc) + if funcForPC == nil { + return "", fmt.Errorf("could not determine package name") + } + // Get the full name of the function, including the package path + fullFuncName := funcForPC.Name() + + // Split the name to extract the package path + // Assuming the format: "package/path.functionName" + lastSlashIndex := strings.LastIndex(fullFuncName, "/") + if lastSlashIndex == -1 { + panic("unable to find package name") + } + + funcIndex := strings.Index(fullFuncName[lastSlashIndex:], ".") + packagePath := fullFuncName[:lastSlashIndex+funcIndex] + + return packagePath, nil +} + +func CastInt(i interface{}) (int, error) { + switch v := i.(type) { + case int: + return v, nil + case int8: + return int(v), nil + case int16: + return int(v), nil + case int32: + return int(v), nil + case int64: + return int(v), nil + case uint: + return int(v), nil + case uint8: + return int(v), nil + case uint16: + return int(v), nil + case uint32: + return int(v), nil + case uint64: + return int(v), nil + default: + return 0, fmt.Errorf("unable to cast %T into int", i) + } +} diff --git a/pkg/utils/pkgutils/package_test.go b/pkg/utils/pkgutils/package_test.go new file mode 100644 index 0000000000..b7ed7f1331 --- /dev/null +++ b/pkg/utils/pkgutils/package_test.go @@ -0,0 +1,32 @@ +package pkgutils_test + +import ( + "github.com/mandelsoft/goutils/generics" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + . "github.com/open-component-model/ocm/pkg/testutils" + me "github.com/open-component-model/ocm/pkg/utils/pkgutils" + "github.com/open-component-model/ocm/pkg/utils/pkgutils/testpackage" + "reflect" +) + +type typ struct{} + +var _ = Describe("package tests", func() { + DescribeTable("determine package type for ", func(typ interface{}) { + Expect(Must(me.GetPackageName(typ))).To(Equal(reflect.TypeOf(testpackage.MyStruct{}).PkgPath())) + }, + Entry("struct", &testpackage.MyStruct{}), + Entry("array", &testpackage.MyArray{}), + Entry("list", &testpackage.MyList{}), + Entry("map", &testpackage.MyMap{}), + Entry("chan", make(testpackage.MyChan)), + Entry("func", testpackage.MyFunc), + Entry("func type", generics.TypeOf[testpackage.MyFuncType]()), + Entry("struct type", generics.TypeOf[testpackage.MyStruct]()), + ) + It("determine package for caller func", func() { + Expect(Must(testpackage.MyFunc())).To(Equal(reflect.TypeOf(testpackage.MyStruct{}).PkgPath())) + Expect(Must(testpackage.MyFunc(1))).To(Equal(reflect.TypeOf(typ{}).PkgPath())) + }) +}) diff --git a/pkg/utils/pkgutils/suite_test.go b/pkg/utils/pkgutils/suite_test.go new file mode 100644 index 0000000000..2946fa4d3b --- /dev/null +++ b/pkg/utils/pkgutils/suite_test.go @@ -0,0 +1,13 @@ +package pkgutils_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestConfig(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Package Utils Test Suite") +} diff --git a/pkg/utils/pkgutils/testpackage/testtypes.go b/pkg/utils/pkgutils/testpackage/testtypes.go new file mode 100644 index 0000000000..057497e57a --- /dev/null +++ b/pkg/utils/pkgutils/testpackage/testtypes.go @@ -0,0 +1,21 @@ +package testpackage + +import ( + "github.com/mandelsoft/goutils/sliceutils" + + "github.com/open-component-model/ocm/pkg/utils/pkgutils" +) + +type ( + MyStruct struct{} + + MyList []int + MyArray [3]int + MyMap map[int]int + MyChan chan int + MyFuncType func() +) + +func MyFunc(i ...int) (string, error) { + return pkgutils.GetPackageName(sliceutils.Convert[interface{}](i)...) +}