diff --git a/Taskfile.yaml b/Taskfile.yaml index 32f611e..9f3fefb 100644 --- a/Taskfile.yaml +++ b/Taskfile.yaml @@ -10,7 +10,7 @@ tasks: cmds: - go mod tidy - build:lib: + build: desc: Build the konfetty library cmds: - go build ./... diff --git a/defaults.go b/defaults.go new file mode 100644 index 0000000..bb32de6 --- /dev/null +++ b/defaults.go @@ -0,0 +1,118 @@ +package konfetty + +import "reflect" + +// DefaultProvider is an interface for types that can provide their own default values +type DefaultProvider interface { + Defaults() any +} + +// Defaulter is an interface for types that can set their own default values +type Defaulter interface { + SetDefaults() +} + +// fillDefaults recursively fills in default values for structs that implement DefaultProvider +func fillDefaults(v any) error { + return fillDefaultsRecursive(reflect.ValueOf(v)) +} + +func fillDefaultsRecursive(v reflect.Value) error { + // Handle pointer types + if v.Kind() == reflect.Ptr { + if v.IsNil() { + // If the pointer is nil, create a new instance of the pointed-to type + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + + if v.Kind() != reflect.Struct { + return nil + } + + t := v.Type() + + // Iterate through all fields + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + fieldType := t.Field(i) + + // Handle embedded fields + if fieldType.Anonymous { + if field.Kind() == reflect.Ptr && field.IsNil() { + // If the embedded field is a nil pointer, create a new instance + field.Set(reflect.New(field.Type().Elem())) + } + if err := fillDefaultsRecursive(field); err != nil { + return err + } + + // Apply defaults for the embedded field if it implements DefaultProvider + if defaulter, ok := field.Addr().Interface().(DefaultProvider); ok { + defaults := reflect.ValueOf(defaulter.Defaults()) + if defaults.Kind() == reflect.Ptr { + defaults = defaults.Elem() + } + fillFromDefaults(field, defaults) + } + + continue + } + + switch field.Kind() { + case reflect.Ptr: + if field.IsNil() { + // If the field is a nil pointer, create a new instance + field.Set(reflect.New(field.Type().Elem())) + } + if err := fillDefaultsRecursive(field.Elem()); err != nil { + return err + } + case reflect.Struct: + if err := fillDefaultsRecursive(field); err != nil { + return err + } + case reflect.Slice: + for j := 0; j < field.Len(); j++ { + if err := fillDefaultsRecursive(field.Index(j)); err != nil { + return err + } + } + } + } + + // Apply defaults to the current struct if it implements DefaultProvider + if defaulter, ok := v.Addr().Interface().(DefaultProvider); ok { + defaults := reflect.ValueOf(defaulter.Defaults()) + if defaults.Kind() == reflect.Ptr { + defaults = defaults.Elem() + } + fillFromDefaults(v, defaults) + } + + return nil +} + +func fillFromDefaults(dst, src reflect.Value) { + for i := 0; i < src.NumField(); i++ { + srcField := src.Field(i) + srcFieldName := src.Type().Field(i).Name + + // Check if the destination has this field + dstField := dst.FieldByName(srcFieldName) + if !dstField.IsValid() { + continue // Skip fields that don't exist in the destination + } + + if dstField.CanSet() && isZeroValue(dstField) { + // Only set the value if it's settable and currently zero + dstField.Set(srcField) + } + } +} + +func isZeroValue(v reflect.Value) bool { + zero := reflect.Zero(v.Type()).Interface() + return reflect.DeepEqual(v.Interface(), zero) +} diff --git a/defaults_test.go b/defaults_test.go new file mode 100644 index 0000000..44ea640 --- /dev/null +++ b/defaults_test.go @@ -0,0 +1,110 @@ +package konfetty + +import ( + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/require" +) + +// Structs provided by the user +type Profile struct { + Checks Checks +} + +type Checks struct { + Ping []PingCheck +} + +type PingCheck struct { + *BaseCheck + Host string +} + +type BaseCheck struct { + Name string + Interval time.Duration + Timeout time.Duration +} + +// Implement DefaultProvider for BaseCheck +func (b BaseCheck) Defaults() interface{} { + return BaseCheck{ + Interval: 30 * time.Second, + Timeout: 5 * time.Second, + } +} + +func TestFillDefaults(t *testing.T) { + tests := []struct { + name string + input *Profile + expected *Profile + }{ + { + name: "Profile with partially filled checks", + input: &Profile{ + Checks: Checks{ + Ping: []PingCheck{ + { + BaseCheck: &BaseCheck{Name: "Custom Ping"}, + Host: "example.com", + }, + }, + }, + }, + expected: &Profile{ + Checks: Checks{ + Ping: []PingCheck{ + { + BaseCheck: &BaseCheck{ + Name: "Custom Ping", + Interval: 30 * time.Second, + Timeout: 5 * time.Second, + }, + Host: "example.com", + }, + }, + }, + }, + }, + { + name: "Profile with nil checks", + input: &Profile{ + Checks: Checks{ + Ping: []PingCheck{ + { + Host: "example.com", + }, + }, + }, + }, + expected: &Profile{ + Checks: Checks{ + Ping: []PingCheck{ + { + BaseCheck: &BaseCheck{ + Interval: 30 * time.Second, + Timeout: 5 * time.Second, + }, + Host: "example.com", + }, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := fillDefaults(tt.input) + require.NoError(t, err) + + diff := cmp.Diff(tt.expected, tt.input) + if diff != "" { + t.Errorf("FillDefaults() mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/go.mod b/go.mod index c6dd499..1152218 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,18 @@ module github.com/nikoksr/konfetty go 1.22.5 require ( + github.com/google/go-cmp v0.6.0 github.com/knadh/koanf/parsers/yaml v0.1.0 github.com/knadh/koanf/providers/env v0.1.0 github.com/knadh/koanf/providers/file v1.0.0 github.com/knadh/koanf/v2 v2.1.1 + github.com/stretchr/testify v1.9.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect ) require ( @@ -22,6 +30,7 @@ require ( ) require ( - dario.cat/mergo v1.0.0 + github.com/knadh/koanf/parsers/json v0.1.0 + github.com/knadh/koanf/parsers/toml/v2 v2.1.0 gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 2da2443..d86c829 100644 --- a/go.sum +++ b/go.sum @@ -1,16 +1,19 @@ -dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= -dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= -github.com/go-viper/mapstructure/v2 v2.0.0-alpha.1 h1:TQcrn6Wq+sKGkpyPvppOz99zsMBaUOKXq6HSv655U1c= -github.com/go-viper/mapstructure/v2 v2.0.0-alpha.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/go-viper/mapstructure/v2 v2.0.0 h1:dhn8MZ1gZ0mzeodTG3jt5Vj/o87xZKuNAprG2mQfMfc= github.com/go-viper/mapstructure/v2 v2.0.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/knadh/koanf/maps v0.1.1 h1:G5TjmUh2D7G2YWf5SQQqSiHRJEjaicvU0KpypqB3NIs= github.com/knadh/koanf/maps v0.1.1/go.mod h1:npD/QZY3V6ghQDdcQzl1W4ICNVTkohC8E73eI2xW4yI= +github.com/knadh/koanf/parsers/json v0.1.0 h1:dzSZl5pf5bBcW0Acnu20Djleto19T0CfHcvZ14NJ6fU= +github.com/knadh/koanf/parsers/json v0.1.0/go.mod h1:ll2/MlXcZ2BfXD6YJcjVFzhG9P0TdJ207aIBKQhV2hY= +github.com/knadh/koanf/parsers/toml/v2 v2.1.0 h1:EUdIKIeezfDj6e1ABDhIjhbURUpyrP1HToqW6tz8R0I= +github.com/knadh/koanf/parsers/toml/v2 v2.1.0/go.mod h1:0KtwfsWJt4igUTQnsn0ZjFWVrP80Jv7edTBRbQFd2ho= github.com/knadh/koanf/parsers/yaml v0.1.0 h1:ZZ8/iGfRLvKSaMEECEBPM1HQslrZADk8fP1XFUxVI5w= github.com/knadh/koanf/parsers/yaml v0.1.0/go.mod h1:cvbUDC7AL23pImuQP0oRw/hPuccrNBS2bps8asS0CwY= github.com/knadh/koanf/providers/env v0.1.0 h1:LqKteXqfOWyx5Ab9VfGHmjY9BvRXi+clwyZozgVRiKg= @@ -27,18 +30,28 @@ github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa1 github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s= github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ= github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= +github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= +github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= -github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/konfetty.go b/konfetty.go index 535d039..14d22d0 100644 --- a/konfetty.go +++ b/konfetty.go @@ -1,17 +1,29 @@ package konfetty import ( + "errors" "fmt" "strings" - "dario.cat/mergo" "github.com/go-viper/mapstructure/v2" + "github.com/knadh/koanf/parsers/json" + "github.com/knadh/koanf/parsers/toml/v2" "github.com/knadh/koanf/parsers/yaml" "github.com/knadh/koanf/providers/env" "github.com/knadh/koanf/providers/file" "github.com/knadh/koanf/v2" ) +type FileFormat uint8 + +const ( + FileFormatYAML FileFormat = iota + FileFormatJSON + FileFormatTOML +) + +const defaultStructTag = "konfetty" + // Loader is the main interface for loading and validating configurations. type Loader[T any] interface { Load(paths ...string) (*T, error) @@ -25,16 +37,20 @@ type Option[T any] func(*loader[T]) type loader[T any] struct { k *koanf.Koanf envPrefix string - fileFormat string + fileFormat FileFormat + structTag string validateFn func(*T) error - defaultsFn func() T } // NewLoader creates a new configuration loader. func NewLoader[T any](options ...Option[T]) Loader[T] { l := &loader[T]{ - k: koanf.New("."), - fileFormat: "yaml", // default to YAML + k: koanf.NewWithConf(koanf.Conf{ + Delim: ".", + StrictMerge: true, + }), + fileFormat: FileFormatYAML, // default to YAML + structTag: defaultStructTag, } for _, option := range options { @@ -52,23 +68,23 @@ func WithEnvPrefix[T any](prefix string) Option[T] { } // WithFileFormat sets the format for configuration files. -func WithFileFormat[T any](format string) Option[T] { +func WithFileFormat[T any](format FileFormat) Option[T] { return func(l *loader[T]) { l.fileFormat = format } } -// WithValidator sets a custom validation function. -func WithValidator[T any](fn func(*T) error) Option[T] { +// WithStructTag sets the struct tag for configuration fields. +func WithStructTag[T any](tag string) Option[T] { return func(l *loader[T]) { - l.validateFn = fn + l.structTag = tag } } -// WithDefaults sets a function to provide default values. -func WithDefaults[T any](fn func() T) Option[T] { +// WithValidator sets a custom validation function. +func WithValidator[T any](fn func(*T) error) Option[T] { return func(l *loader[T]) { - l.defaultsFn = fn + l.validateFn = fn } } @@ -89,25 +105,27 @@ func (l *loader[T]) Load(paths ...string) (*T, error) { } // Unmarshal into the config struct + decodeHook := mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), // Convert strings to time.Duration + mapstructure.StringToSliceHookFunc(","), // Convert comma-separated strings to slices + ) + if err := l.k.UnmarshalWithConf("", &cfg, koanf.UnmarshalConf{ + Tag: l.structTag, DecoderConfig: &mapstructure.DecoderConfig{ Result: &cfg, WeaklyTypedInput: true, - DecodeHook: mapstructure.ComposeDecodeHookFunc( - mapstructure.StringToTimeDurationHookFunc(), - mapstructure.StringToSliceHookFunc(","), - ), + Squash: true, + TagName: l.structTag, + DecodeHook: decodeHook, }, }); err != nil { return nil, fmt.Errorf("unmarshal config: %w", err) } - // Merge with defaults - if l.defaultsFn != nil { - defaults := l.defaultsFn() - if err := mergo.Merge(&cfg, defaults); err != nil { - return nil, fmt.Errorf("merge defaults: %w", err) - } + // Apply defaults + if err := fillDefaults(&cfg); err != nil { + return nil, fmt.Errorf("fill defaults: %w", err) } // Validate @@ -123,11 +141,14 @@ func (l *loader[T]) loadFile(path string) error { var parser koanf.Parser switch l.fileFormat { - case "yaml": + case FileFormatYAML: parser = yaml.Parser() - // Add more formats here as needed + case FileFormatJSON: + parser = json.Parser() + case FileFormatTOML: + parser = toml.Parser() default: - return fmt.Errorf("unsupported file format: %s", l.fileFormat) + return errors.New("unsupported file format") } if err := l.k.Load(file.Provider(path), parser); err != nil {