Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: small improvements and deprecations #792

Merged
merged 2 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions assertx/assertx.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ import (
"github.com/stretchr/testify/require"
)

func PrettifyJSONPayload(t *testing.T, payload interface{}) string {
func PrettifyJSONPayload(t testing.TB, payload interface{}) string {
t.Helper()
o, err := json.MarshalIndent(payload, "", " ")
require.NoError(t, err)
return string(o)
}

func EqualAsJSON(t *testing.T, expected, actual interface{}, args ...interface{}) {
func EqualAsJSON(t testing.TB, expected, actual interface{}, args ...interface{}) {
t.Helper()
var eb, ab bytes.Buffer
if len(args) == 0 {
args = []interface{}{PrettifyJSONPayload(t, actual)}
Expand All @@ -34,7 +36,8 @@ func EqualAsJSON(t *testing.T, expected, actual interface{}, args ...interface{}
assert.JSONEq(t, strings.TrimSpace(eb.String()), strings.TrimSpace(ab.String()), args...)
}

func EqualAsJSONExcept(t *testing.T, expected, actual interface{}, except []string, args ...interface{}) {
func EqualAsJSONExcept(t testing.TB, expected, actual interface{}, except []string, args ...interface{}) {
t.Helper()
var eb, ab bytes.Buffer
if len(args) == 0 {
args = []interface{}{PrettifyJSONPayload(t, actual)}
Expand All @@ -56,7 +59,7 @@ func EqualAsJSONExcept(t *testing.T, expected, actual interface{}, except []stri
assert.JSONEq(t, strings.TrimSpace(ebs), strings.TrimSpace(abs), args...)
}

func TimeDifferenceLess(t *testing.T, t1, t2 time.Time, seconds int) {
func TimeDifferenceLess(t testing.TB, t1, t2 time.Time, seconds int) {
t.Helper()
delta := math.Abs(float64(t1.Unix()) - float64(t2.Unix()))
assert.Less(t, delta, float64(seconds))
Expand Down
30 changes: 28 additions & 2 deletions cmdx/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"os"
"testing"

"golang.org/x/sync/errgroup"

"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/stretchr/testify/require"

"github.com/pkg/errors"
Expand All @@ -25,9 +27,9 @@ import (

var (
// ErrNilDependency is returned if a dependency is missing.
ErrNilDependency = errors.New("a dependency was expected to be defined but is nil. Please open an issue with the stack trace")
ErrNilDependency = fmt.Errorf("a dependency was expected to be defined but is nil. Please open an issue with the stack trace")
// ErrNoPrintButFail is returned to detect a failure state that was already reported to the user in some way
ErrNoPrintButFail = errors.New("this error should never be printed")
ErrNoPrintButFail = fmt.Errorf("this error should never be printed")

debugStdout, debugStderr = io.Discard, io.Discard
)
Expand All @@ -48,6 +50,7 @@ func FailSilently(cmd *cobra.Command) error {
}

// Must fatals with the optional message if err is not nil.
// Deprecated: do not use this function in commands, as it makes it impossible to test them. Instead, return the error.
func Must(err error, message string, args ...interface{}) {
if err == nil {
return
Expand All @@ -58,6 +61,7 @@ func Must(err error, message string, args ...interface{}) {
}

// CheckResponse fatals if err is nil or the response.StatusCode does not match the expectedStatusCode
// Deprecated: do not use this function in commands, as it makes it impossible to test them. Instead, return the error.
func CheckResponse(err error, expectedStatusCode int, response *http.Response) {
Must(err, "Command failed because error occurred: %s", err)

Expand Down Expand Up @@ -85,13 +89,15 @@ Response payload:
}

// FormatResponse takes an object and prints a json.MarshalIdent version of it or fatals.
// Deprecated: do not use this function in commands, as it makes it impossible to test them. Instead, return the error.
func FormatResponse(o interface{}) string {
out, err := json.MarshalIndent(o, "", "\t")
Must(err, `Command failed because an error occurred while prettifying output: %s`, err)
return string(out)
}

// Fatalf prints to os.Stderr and exists with code 1.
// Deprecated: do not use this function in commands, as it makes it impossible to test them. Instead, return the error.
func Fatalf(message string, args ...interface{}) {
if len(args) > 0 {
_, _ = fmt.Fprintf(os.Stderr, message+"\n", args...)
Expand All @@ -102,6 +108,7 @@ func Fatalf(message string, args ...interface{}) {
}

// ExpectDependency expects every dependency to be not nil or it fatals.
// Deprecated: do not use this function in commands, as it makes it impossible to test them. Instead, return the error.
func ExpectDependency(logger *logrusx.Logger, dependencies ...interface{}) {
if logger == nil {
panic("missing logger for dependency check")
Expand Down Expand Up @@ -225,3 +232,22 @@ func (c *CommandExecuter) ExecNoErr(t require.TestingT, args ...string) string {
func (c *CommandExecuter) ExecExpectedErr(t require.TestingT, args ...string) string {
return ExecExpectedErrCtx(c.Ctx, t, c.New(), append(c.PersistentArgs, args...)...)
}

type URL struct {
url.URL
}

var _ pflag.Value = (*URL)(nil)

func (u *URL) Set(s string) error {
uu, err := url.Parse(s)
if err != nil {
return err
}
u.URL = *uu
return nil
}

func (*URL) Type() string {
return "url"
}
17 changes: 8 additions & 9 deletions cmdx/printing.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,26 +291,25 @@ func RegisterFormatFlags(flags *pflag.FlagSet) {
flags.String(FlagFormat, string(FormatDefault), fmt.Sprintf("Set the output format. One of %s, %s, %s, %s, %s and %s.", FormatTable, FormatJSON, FormatYAML, FormatJSONPretty, FormatJSONPath, FormatJSONPointer))
}

type bodyer interface {
Body() []byte
}

func PrintOpenAPIError(cmd *cobra.Command, err error) error {
if err == nil {
return nil
}

var be bodyer
var be interface {
Body() []byte
}
if !errors.As(err, &be) {
return err
}

var didPrettyPrint bool
if message := gjson.GetBytes(be.Body(), "error.message"); message.Exists() {
body := be.Body()
didPrettyPrint := false
if message := gjson.GetBytes(body, "error.message"); message.Exists() {
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "%s\n", message.String())
didPrettyPrint = true
}
if reason := gjson.GetBytes(be.Body(), "error.reason"); reason.Exists() {
if reason := gjson.GetBytes(body, "error.reason"); reason.Exists() {
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "%s\n", reason.String())
didPrettyPrint = true
}
Expand All @@ -319,7 +318,7 @@ func PrintOpenAPIError(cmd *cobra.Command, err error) error {
return FailSilently(cmd)
}

if body, err := json.MarshalIndent(json.RawMessage(be.Body()), "", " "); err == nil {
if body, err := json.MarshalIndent(json.RawMessage(body), "", " "); err == nil {
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "%s\nFailed to execute API request, see error above.\n", body)
return FailSilently(cmd)
}
Expand Down
11 changes: 5 additions & 6 deletions cmdx/printing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@ package cmdx
import (
"bytes"
"fmt"
"slices"
"strconv"
"testing"

"github.com/spf13/cobra"

"github.com/ory/x/stringslice"

"github.com/spf13/pflag"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -170,8 +169,8 @@ func TestPrinting(t *testing.T) {
for _, s := range tc.contained {
assert.Contains(t, out.String(), s, "%s", out.String())
}
notContained := stringslice.Filter(allFields, func(s string) bool {
return stringslice.Has(tc.contained, s)
notContained := slices.DeleteFunc(slices.Clone(allFields), func(s string) bool {
return slices.Contains(tc.contained, s)
})
for _, s := range notContained {
assert.NotContains(t, out.String(), s, "%s", out.String())
Expand Down Expand Up @@ -258,8 +257,8 @@ func TestPrinting(t *testing.T) {
for _, s := range tc.contained {
assert.Contains(t, out.String(), s, "%s", out.String())
}
notContained := stringslice.Filter(allFields, func(s string) bool {
return stringslice.Has(tc.contained, s)
notContained := slices.DeleteFunc(slices.Clone(allFields), func(s string) bool {
return slices.Contains(tc.contained, s)
})
for _, s := range notContained {
assert.NotContains(t, out.String(), s, "%s", out.String())
Expand Down
2 changes: 1 addition & 1 deletion fetcher/fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func (f *Fetcher) FetchContext(ctx context.Context, source string) (*bytes.Buffe
// context that is used for HTTP requests.
func (f *Fetcher) FetchBytes(ctx context.Context, source string) ([]byte, error) {
switch s := stringsx.SwitchPrefix(source); {
case s.HasPrefix("http://"), s.HasPrefix("https://"):
case s.HasPrefix("http://", "https://"):
return f.fetchRemote(ctx, source)
case s.HasPrefix("file://"):
return f.fetchFile(strings.TrimPrefix(source, "file://"))
Expand Down
5 changes: 2 additions & 3 deletions jsonschemax/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@ import (
"fmt"
"math/big"
"regexp"
"slices"
"sort"
"strings"

"github.com/pkg/errors"

"github.com/ory/jsonschema/v3"

"github.com/ory/x/stringslice"
)

type (
Expand Down Expand Up @@ -268,7 +267,7 @@ func listPaths(schema *jsonschema.Schema, parent *jsonschema.Schema, parents []s
types = append(types, is.Ref.Types...)
}
}
types = stringslice.Unique(types)
types = slices.Compact(types)
if len(types) == 1 {
switch types[0] {
case "boolean":
Expand Down
6 changes: 3 additions & 3 deletions logrusx/logrus.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func setFormatter(l *logrus.Logger, o *options) {
default:
unknownFormat = true
fallthrough
case format.AddCase("text"), format.AddCase(""):
case format.AddCase("text", ""):
l.Formatter = &logrus.TextFormatter{
DisableQuote: true,
DisableTimestamp: false,
Expand Down Expand Up @@ -203,7 +203,7 @@ func New(name string, version string, opts ...Option) *Logger {
name: name,
version: version,
leakSensitive: o.leakSensitive || o.c.Bool("log.leak_sensitive_values"),
redactionText: stringsx.DefaultIfEmpty(o.redactionText, `Value is sensitive and has been redacted. To see the value set config key "log.leak_sensitive_values = true" or environment variable "LOG_LEAK_SENSITIVE_VALUES=true".`),
redactionText: stringsx.Coalesce(o.redactionText, `Value is sensitive and has been redacted. To see the value set config key "log.leak_sensitive_values = true" or environment variable "LOG_LEAK_SENSITIVE_VALUES=true".`),
Entry: newLogger(o.l, o).WithFields(logrus.Fields{
"audience": "application", "service_name": name, "service_version": version}),
}
Expand All @@ -215,7 +215,7 @@ func NewAudit(name string, version string, opts ...Option) *Logger {

func (l *Logger) UseConfig(c configurator) {
l.leakSensitive = l.leakSensitive || c.Bool("log.leak_sensitive_values")
l.redactionText = stringsx.DefaultIfEmpty(c.String("log.redaction_text"), l.redactionText)
l.redactionText = stringsx.Coalesce(c.String("log.redaction_text"), l.redactionText)
o := newOptions(append(l.opts, WithConfigurator(c)))
setLevel(l.Entry.Logger, o)
setFormatter(l.Entry.Logger, o)
Expand Down
19 changes: 6 additions & 13 deletions stringslice/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,27 @@
package stringslice

import (
"slices"
"strings"
"unicode"
)

// Filter applies the provided filter function and removes all items from the slice for which the filter function returns true.
// This function uses append and might cause
func Filter(values []string, filter func(string) bool) (ret []string) {
for _, value := range values {
if !filter(value) {
ret = append(ret, value)
}
}

if ret == nil {
return []string{}
}

return
// Deprecated: use slices.DeleteFunc instead (changes semantics: the original slice is modified)
func Filter(values []string, filter func(string) bool) []string {
return slices.DeleteFunc(slices.Clone(values), filter)
}

// TrimEmptyFilter applies the strings.TrimFunc function and removes all empty strings
// Deprecated: use slices.DeleteFunc instead (changes semantics: the original slice is modified)
func TrimEmptyFilter(values []string, trim func(rune) bool) (ret []string) {
return Filter(values, func(value string) bool {
return strings.TrimFunc(value, trim) == ""
})
}

// TrimSpaceEmptyFilter applies the strings.TrimSpace function and removes all empty strings
// Deprecated: use slices.DeleteFunc with strings.TrimSpace instead (changes semantics: the original slice is modified)
func TrimSpaceEmptyFilter(values []string) []string {
return TrimEmptyFilter(values, unicode.IsSpace)
}
22 changes: 9 additions & 13 deletions stringslice/has.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,20 @@

package stringslice

import "strings"
import (
"slices"
"strings"
)

// Has returns true if the needle is in the haystack (case-sensitive)
// Deprecated: use slices.Contains instead
func Has(haystack []string, needle string) bool {
for _, current := range haystack {
if current == needle {
return true
}
}
return false
return slices.Contains(haystack, needle)
}

// HasI returns true if the needle is in the haystack (case-insensitive)
func HasI(haystack []string, needle string) bool {
for _, current := range haystack {
if strings.EqualFold(current, needle) {
return true
}
}
return false
return slices.ContainsFunc(haystack, func(value string) bool {
return strings.EqualFold(value, needle)
})
}
10 changes: 4 additions & 6 deletions stringslice/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@

package stringslice

import "slices"

// Merge merges several string slices into one.
// Deprecated: use slices.Concat instead
func Merge(parts ...[]string) []string {
var result []string
for _, part := range parts {
result = append(result, part...)
}

return result
return slices.Concat(parts...)
}
14 changes: 7 additions & 7 deletions stringslice/reverse.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

package stringslice

func Reverse(s []string) []string {
r := make([]string, len(s))

for i, j := 0, len(r)-1; i <= j; i, j = i+1, j-1 {
r[i], r[j] = s[j], s[i]
}
import "slices"

return r
// Reverse reverses the order of a string slice
// Deprecated: use slices.Reverse instead (changes semantics)
func Reverse(s []string) []string {
c := slices.Clone(s)
slices.Reverse(c)
return c
}
7 changes: 4 additions & 3 deletions stringslice/unique.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@

package stringslice

// Unique returns the given string slice with unique values.
// Unique returns the given string slice with unique values, preserving order.
// Consider using slices.Compact with slices.Sort instead when you don't care about order.
func Unique(i []string) []string {
u := make([]string, 0, len(i))
m := make(map[string]bool)
m := make(map[string]struct{}, len(i))

for _, val := range i {
if _, ok := m[val]; !ok {
m[val] = true
m[val] = struct{}{}
u = append(u, val)
}
}
Expand Down
Loading
Loading