From b6d61944d7b9987331251bbe0be9fd098efdb34f Mon Sep 17 00:00:00 2001 From: zepatrik Date: Tue, 25 Jun 2024 16:37:22 +0200 Subject: [PATCH 1/2] chore: small improvements and deprecations --- assertx/assertx.go | 11 +++++++---- cmdx/helper.go | 30 ++++++++++++++++++++++++++++-- cmdx/printing.go | 17 ++++++++--------- cmdx/printing_test.go | 11 +++++------ fetcher/fetcher.go | 2 +- jsonschemax/keys.go | 5 ++--- logrusx/logrus.go | 6 +++--- stringslice/filter.go | 19 ++++++------------- stringslice/has.go | 22 +++++++++------------- stringslice/merge.go | 10 ++++------ stringslice/reverse.go | 14 +++++++------- stringslice/unique.go | 7 ++++--- stringsx/default.go | 6 ++---- stringsx/ptr.go | 1 + stringsx/switch_case.go | 15 +++++++++------ 15 files changed, 96 insertions(+), 80 deletions(-) diff --git a/assertx/assertx.go b/assertx/assertx.go index 8b5b91ce..110b92f7 100644 --- a/assertx/assertx.go +++ b/assertx/assertx.go @@ -17,13 +17,15 @@ import ( "github.com/stretchr/testify/require" ) -func PrettifyJSONPayload(t *testing.T, payload interface{}) string { +func PrettifyJSONPayload(t testing.TB, payload interface{}) string { + t.Helper() o, err := json.MarshalIndent(payload, "", " ") require.NoError(t, err) return string(o) } -func EqualAsJSON(t *testing.T, expected, actual interface{}, args ...interface{}) { +func EqualAsJSON(t testing.TB, expected, actual interface{}, args ...interface{}) { + t.Helper() var eb, ab bytes.Buffer if len(args) == 0 { args = []interface{}{PrettifyJSONPayload(t, actual)} @@ -34,7 +36,8 @@ func EqualAsJSON(t *testing.T, expected, actual interface{}, args ...interface{} assert.JSONEq(t, strings.TrimSpace(eb.String()), strings.TrimSpace(ab.String()), args...) } -func EqualAsJSONExcept(t *testing.T, expected, actual interface{}, except []string, args ...interface{}) { +func EqualAsJSONExcept(t testing.TB, expected, actual interface{}, except []string, args ...interface{}) { + t.Helper() var eb, ab bytes.Buffer if len(args) == 0 { args = []interface{}{PrettifyJSONPayload(t, actual)} @@ -56,7 +59,7 @@ func EqualAsJSONExcept(t *testing.T, expected, actual interface{}, except []stri assert.JSONEq(t, strings.TrimSpace(ebs), strings.TrimSpace(abs), args...) } -func TimeDifferenceLess(t *testing.T, t1, t2 time.Time, seconds int) { +func TimeDifferenceLess(t testing.TB, t1, t2 time.Time, seconds int) { t.Helper() delta := math.Abs(float64(t1.Unix()) - float64(t2.Unix())) assert.Less(t, delta, float64(seconds)) diff --git a/cmdx/helper.go b/cmdx/helper.go index 594bdbb0..45eb294b 100644 --- a/cmdx/helper.go +++ b/cmdx/helper.go @@ -10,12 +10,14 @@ import ( "fmt" "io" "net/http" + "net/url" "os" "testing" "golang.org/x/sync/errgroup" "github.com/spf13/cobra" + "github.com/spf13/pflag" "github.com/stretchr/testify/require" "github.com/pkg/errors" @@ -25,9 +27,9 @@ import ( var ( // ErrNilDependency is returned if a dependency is missing. - ErrNilDependency = errors.New("a dependency was expected to be defined but is nil. Please open an issue with the stack trace") + ErrNilDependency = fmt.Errorf("a dependency was expected to be defined but is nil. Please open an issue with the stack trace") // ErrNoPrintButFail is returned to detect a failure state that was already reported to the user in some way - ErrNoPrintButFail = errors.New("this error should never be printed") + ErrNoPrintButFail = fmt.Errorf("this error should never be printed") debugStdout, debugStderr = io.Discard, io.Discard ) @@ -48,6 +50,7 @@ func FailSilently(cmd *cobra.Command) error { } // Must fatals with the optional message if err is not nil. +// Deprecated: do not use this function in commands, as it makes it impossible to test them. Instead, return the error. func Must(err error, message string, args ...interface{}) { if err == nil { return @@ -58,6 +61,7 @@ func Must(err error, message string, args ...interface{}) { } // CheckResponse fatals if err is nil or the response.StatusCode does not match the expectedStatusCode +// Deprecated: do not use this function in commands, as it makes it impossible to test them. Instead, return the error. func CheckResponse(err error, expectedStatusCode int, response *http.Response) { Must(err, "Command failed because error occurred: %s", err) @@ -85,6 +89,7 @@ Response payload: } // FormatResponse takes an object and prints a json.MarshalIdent version of it or fatals. +// Deprecated: do not use this function in commands, as it makes it impossible to test them. Instead, return the error. func FormatResponse(o interface{}) string { out, err := json.MarshalIndent(o, "", "\t") Must(err, `Command failed because an error occurred while prettifying output: %s`, err) @@ -92,6 +97,7 @@ func FormatResponse(o interface{}) string { } // Fatalf prints to os.Stderr and exists with code 1. +// Deprecated: do not use this function in commands, as it makes it impossible to test them. Instead, return the error. func Fatalf(message string, args ...interface{}) { if len(args) > 0 { _, _ = fmt.Fprintf(os.Stderr, message+"\n", args...) @@ -102,6 +108,7 @@ func Fatalf(message string, args ...interface{}) { } // ExpectDependency expects every dependency to be not nil or it fatals. +// Deprecated: do not use this function in commands, as it makes it impossible to test them. Instead, return the error. func ExpectDependency(logger *logrusx.Logger, dependencies ...interface{}) { if logger == nil { panic("missing logger for dependency check") @@ -225,3 +232,22 @@ func (c *CommandExecuter) ExecNoErr(t require.TestingT, args ...string) string { func (c *CommandExecuter) ExecExpectedErr(t require.TestingT, args ...string) string { return ExecExpectedErrCtx(c.Ctx, t, c.New(), append(c.PersistentArgs, args...)...) } + +type URL struct { + url.URL +} + +var _ pflag.Value = (*URL)(nil) + +func (u *URL) Set(s string) error { + uu, err := url.Parse(s) + if err != nil { + return err + } + u.URL = *uu + return nil +} + +func (*URL) Type() string { + return "url" +} diff --git a/cmdx/printing.go b/cmdx/printing.go index 98cd6688..bea36d03 100644 --- a/cmdx/printing.go +++ b/cmdx/printing.go @@ -291,26 +291,25 @@ func RegisterFormatFlags(flags *pflag.FlagSet) { flags.String(FlagFormat, string(FormatDefault), fmt.Sprintf("Set the output format. One of %s, %s, %s, %s, %s and %s.", FormatTable, FormatJSON, FormatYAML, FormatJSONPretty, FormatJSONPath, FormatJSONPointer)) } -type bodyer interface { - Body() []byte -} - func PrintOpenAPIError(cmd *cobra.Command, err error) error { if err == nil { return nil } - var be bodyer + var be interface { + Body() []byte + } if !errors.As(err, &be) { return err } - var didPrettyPrint bool - if message := gjson.GetBytes(be.Body(), "error.message"); message.Exists() { + body := be.Body() + didPrettyPrint := false + if message := gjson.GetBytes(body, "error.message"); message.Exists() { _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "%s\n", message.String()) didPrettyPrint = true } - if reason := gjson.GetBytes(be.Body(), "error.reason"); reason.Exists() { + if reason := gjson.GetBytes(body, "error.reason"); reason.Exists() { _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "%s\n", reason.String()) didPrettyPrint = true } @@ -319,7 +318,7 @@ func PrintOpenAPIError(cmd *cobra.Command, err error) error { return FailSilently(cmd) } - if body, err := json.MarshalIndent(json.RawMessage(be.Body()), "", " "); err == nil { + if body, err := json.MarshalIndent(json.RawMessage(body), "", " "); err == nil { _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "%s\nFailed to execute API request, see error above.\n", body) return FailSilently(cmd) } diff --git a/cmdx/printing_test.go b/cmdx/printing_test.go index 61ec190f..a28d6904 100644 --- a/cmdx/printing_test.go +++ b/cmdx/printing_test.go @@ -6,13 +6,12 @@ package cmdx import ( "bytes" "fmt" + "slices" "strconv" "testing" "github.com/spf13/cobra" - "github.com/ory/x/stringslice" - "github.com/spf13/pflag" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -170,8 +169,8 @@ func TestPrinting(t *testing.T) { for _, s := range tc.contained { assert.Contains(t, out.String(), s, "%s", out.String()) } - notContained := stringslice.Filter(allFields, func(s string) bool { - return stringslice.Has(tc.contained, s) + notContained := slices.DeleteFunc(allFields, func(s string) bool { + return slices.Contains(tc.contained, s) }) for _, s := range notContained { assert.NotContains(t, out.String(), s, "%s", out.String()) @@ -258,8 +257,8 @@ func TestPrinting(t *testing.T) { for _, s := range tc.contained { assert.Contains(t, out.String(), s, "%s", out.String()) } - notContained := stringslice.Filter(allFields, func(s string) bool { - return stringslice.Has(tc.contained, s) + notContained := slices.DeleteFunc(allFields, func(s string) bool { + return slices.Contains(tc.contained, s) }) for _, s := range notContained { assert.NotContains(t, out.String(), s, "%s", out.String()) diff --git a/fetcher/fetcher.go b/fetcher/fetcher.go index aeda9e1a..0ae8f20e 100644 --- a/fetcher/fetcher.go +++ b/fetcher/fetcher.go @@ -101,7 +101,7 @@ func (f *Fetcher) FetchContext(ctx context.Context, source string) (*bytes.Buffe // context that is used for HTTP requests. func (f *Fetcher) FetchBytes(ctx context.Context, source string) ([]byte, error) { switch s := stringsx.SwitchPrefix(source); { - case s.HasPrefix("http://"), s.HasPrefix("https://"): + case s.HasPrefix("http://", "https://"): return f.fetchRemote(ctx, source) case s.HasPrefix("file://"): return f.fetchFile(strings.TrimPrefix(source, "file://")) diff --git a/jsonschemax/keys.go b/jsonschemax/keys.go index ea5e9e66..ab9638c1 100644 --- a/jsonschemax/keys.go +++ b/jsonschemax/keys.go @@ -11,14 +11,13 @@ import ( "fmt" "math/big" "regexp" + "slices" "sort" "strings" "github.com/pkg/errors" "github.com/ory/jsonschema/v3" - - "github.com/ory/x/stringslice" ) type ( @@ -268,7 +267,7 @@ func listPaths(schema *jsonschema.Schema, parent *jsonschema.Schema, parents []s types = append(types, is.Ref.Types...) } } - types = stringslice.Unique(types) + types = slices.Compact(types) if len(types) == 1 { switch types[0] { case "boolean": diff --git a/logrusx/logrus.go b/logrusx/logrus.go index 0cbae82b..a56ed511 100644 --- a/logrusx/logrus.go +++ b/logrusx/logrus.go @@ -104,7 +104,7 @@ func setFormatter(l *logrus.Logger, o *options) { default: unknownFormat = true fallthrough - case format.AddCase("text"), format.AddCase(""): + case format.AddCase("text", ""): l.Formatter = &logrus.TextFormatter{ DisableQuote: true, DisableTimestamp: false, @@ -203,7 +203,7 @@ func New(name string, version string, opts ...Option) *Logger { name: name, version: version, leakSensitive: o.leakSensitive || o.c.Bool("log.leak_sensitive_values"), - redactionText: stringsx.DefaultIfEmpty(o.redactionText, `Value is sensitive and has been redacted. To see the value set config key "log.leak_sensitive_values = true" or environment variable "LOG_LEAK_SENSITIVE_VALUES=true".`), + redactionText: stringsx.Coalesce(o.redactionText, `Value is sensitive and has been redacted. To see the value set config key "log.leak_sensitive_values = true" or environment variable "LOG_LEAK_SENSITIVE_VALUES=true".`), Entry: newLogger(o.l, o).WithFields(logrus.Fields{ "audience": "application", "service_name": name, "service_version": version}), } @@ -215,7 +215,7 @@ func NewAudit(name string, version string, opts ...Option) *Logger { func (l *Logger) UseConfig(c configurator) { l.leakSensitive = l.leakSensitive || c.Bool("log.leak_sensitive_values") - l.redactionText = stringsx.DefaultIfEmpty(c.String("log.redaction_text"), l.redactionText) + l.redactionText = stringsx.Coalesce(c.String("log.redaction_text"), l.redactionText) o := newOptions(append(l.opts, WithConfigurator(c))) setLevel(l.Entry.Logger, o) setFormatter(l.Entry.Logger, o) diff --git a/stringslice/filter.go b/stringslice/filter.go index cdefd02e..2ebbee64 100644 --- a/stringslice/filter.go +++ b/stringslice/filter.go @@ -4,27 +4,19 @@ package stringslice import ( + "slices" "strings" "unicode" ) // Filter applies the provided filter function and removes all items from the slice for which the filter function returns true. -// This function uses append and might cause -func Filter(values []string, filter func(string) bool) (ret []string) { - for _, value := range values { - if !filter(value) { - ret = append(ret, value) - } - } - - if ret == nil { - return []string{} - } - - return +// Deprecated: use slices.DeleteFunc instead (changes semantics: the original slice is modified) +func Filter(values []string, filter func(string) bool) []string { + return slices.DeleteFunc(slices.Clone(values), filter) } // TrimEmptyFilter applies the strings.TrimFunc function and removes all empty strings +// Deprecated: use slices.DeleteFunc instead (changes semantics: the original slice is modified) func TrimEmptyFilter(values []string, trim func(rune) bool) (ret []string) { return Filter(values, func(value string) bool { return strings.TrimFunc(value, trim) == "" @@ -32,6 +24,7 @@ func TrimEmptyFilter(values []string, trim func(rune) bool) (ret []string) { } // TrimSpaceEmptyFilter applies the strings.TrimSpace function and removes all empty strings +// Deprecated: use slices.DeleteFunc with strings.TrimSpace instead (changes semantics: the original slice is modified) func TrimSpaceEmptyFilter(values []string) []string { return TrimEmptyFilter(values, unicode.IsSpace) } diff --git a/stringslice/has.go b/stringslice/has.go index b75a6e6b..e863fa84 100644 --- a/stringslice/has.go +++ b/stringslice/has.go @@ -3,24 +3,20 @@ package stringslice -import "strings" +import ( + "slices" + "strings" +) // Has returns true if the needle is in the haystack (case-sensitive) +// Deprecated: use slices.Contains instead func Has(haystack []string, needle string) bool { - for _, current := range haystack { - if current == needle { - return true - } - } - return false + return slices.Contains(haystack, needle) } // HasI returns true if the needle is in the haystack (case-insensitive) func HasI(haystack []string, needle string) bool { - for _, current := range haystack { - if strings.EqualFold(current, needle) { - return true - } - } - return false + return slices.ContainsFunc(haystack, func(value string) bool { + return strings.EqualFold(value, needle) + }) } diff --git a/stringslice/merge.go b/stringslice/merge.go index 92438732..fe0c887b 100644 --- a/stringslice/merge.go +++ b/stringslice/merge.go @@ -3,12 +3,10 @@ package stringslice +import "slices" + // Merge merges several string slices into one. +// Deprecated: use slices.Concat instead func Merge(parts ...[]string) []string { - var result []string - for _, part := range parts { - result = append(result, part...) - } - - return result + return slices.Concat(parts...) } diff --git a/stringslice/reverse.go b/stringslice/reverse.go index 054fc21f..ca205500 100644 --- a/stringslice/reverse.go +++ b/stringslice/reverse.go @@ -3,12 +3,12 @@ package stringslice -func Reverse(s []string) []string { - r := make([]string, len(s)) - - for i, j := 0, len(r)-1; i <= j; i, j = i+1, j-1 { - r[i], r[j] = s[j], s[i] - } +import "slices" - return r +// Reverse reverses the order of a string slice +// Deprecated: use slices.Reverse instead (changes semantics) +func Reverse(s []string) []string { + c := slices.Clone(s) + slices.Reverse(c) + return c } diff --git a/stringslice/unique.go b/stringslice/unique.go index 6228c5e4..7a649d45 100644 --- a/stringslice/unique.go +++ b/stringslice/unique.go @@ -3,14 +3,15 @@ package stringslice -// Unique returns the given string slice with unique values. +// Unique returns the given string slice with unique values, preserving order. +// Consider using slices.Compact with slices.Sort instead when you don't care about order. func Unique(i []string) []string { u := make([]string, 0, len(i)) - m := make(map[string]bool) + m := make(map[string]struct{}, len(i)) for _, val := range i { if _, ok := m[val]; !ok { - m[val] = true + m[val] = struct{}{} u = append(u, val) } } diff --git a/stringsx/default.go b/stringsx/default.go index 0ae2769b..e494b22d 100644 --- a/stringsx/default.go +++ b/stringsx/default.go @@ -3,9 +3,7 @@ package stringsx +// Deprecated: use Coalesce instead func DefaultIfEmpty(s string, defaultValue string) string { - if len(s) == 0 { - return defaultValue - } - return s + return Coalesce(s, defaultValue) } diff --git a/stringsx/ptr.go b/stringsx/ptr.go index 0a94acbf..990aa3f8 100644 --- a/stringsx/ptr.go +++ b/stringsx/ptr.go @@ -3,6 +3,7 @@ package stringsx +// Deprecated: use pointerx.Ptr instead func GetPointer(s string) *string { return &s } diff --git a/stringsx/switch_case.go b/stringsx/switch_case.go index 9e8f1101..dc5cb7fe 100644 --- a/stringsx/switch_case.go +++ b/stringsx/switch_case.go @@ -5,6 +5,7 @@ package stringsx import ( "fmt" + "slices" "strings" ) @@ -42,14 +43,16 @@ func SwitchPrefix(actual string) *RegisteredPrefixes { } } -func (r *RegisteredCases) AddCase(c string) bool { - r.cases = append(r.cases, c) - return r.actual == c +func (r *RegisteredCases) AddCase(cases ...string) bool { + r.cases = append(r.cases, cases...) + return slices.Contains(cases, r.actual) } -func (r *RegisteredPrefixes) HasPrefix(prefix string) bool { - r.prefixes = append(r.prefixes, prefix) - return strings.HasPrefix(r.actual, prefix) +func (r *RegisteredPrefixes) HasPrefix(prefixes ...string) bool { + r.prefixes = append(r.prefixes, prefixes...) + return slices.ContainsFunc(prefixes, func(s string) bool { + return strings.HasPrefix(r.actual, s) + }) } func (r *RegisteredCases) String() string { From 62ffc1d7fb2bed5ae850548939119694492d98b9 Mon Sep 17 00:00:00 2001 From: zepatrik Date: Tue, 25 Jun 2024 17:29:51 +0200 Subject: [PATCH 2/2] fix: side-effect in test --- cmdx/printing_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmdx/printing_test.go b/cmdx/printing_test.go index a28d6904..cf33cf2c 100644 --- a/cmdx/printing_test.go +++ b/cmdx/printing_test.go @@ -169,7 +169,7 @@ func TestPrinting(t *testing.T) { for _, s := range tc.contained { assert.Contains(t, out.String(), s, "%s", out.String()) } - notContained := slices.DeleteFunc(allFields, func(s string) bool { + notContained := slices.DeleteFunc(slices.Clone(allFields), func(s string) bool { return slices.Contains(tc.contained, s) }) for _, s := range notContained { @@ -257,7 +257,7 @@ func TestPrinting(t *testing.T) { for _, s := range tc.contained { assert.Contains(t, out.String(), s, "%s", out.String()) } - notContained := slices.DeleteFunc(allFields, func(s string) bool { + notContained := slices.DeleteFunc(slices.Clone(allFields), func(s string) bool { return slices.Contains(tc.contained, s) }) for _, s := range notContained {