Skip to content

Commit

Permalink
refactor(konfetty): add tests and fix some of issues they provoked
Browse files Browse the repository at this point in the history
  • Loading branch information
nikoksr committed Jul 25, 2024
1 parent 62c4f07 commit 6a3a61a
Show file tree
Hide file tree
Showing 7 changed files with 875 additions and 99 deletions.
23 changes: 17 additions & 6 deletions Taskfile.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,6 @@ tasks:
cmds:
- go mod tidy

build:
desc: Build the konfetty library
cmds:
- go build ./...
silent: true

fmt:
desc: Format the code
cmds:
Expand All @@ -28,6 +22,23 @@ tasks:
cmds:
- golangci-lint run ./...

test:
desc: Run tests
cmds:
- go build ./...
- go test -failfast -race ./...

gen-coverage:
desc: Generate coverage report
cmds:
- go test -race -covermode=atomic -coverprofile=coverage.out ./... > /dev/null

coverage-html:
desc: Generate coverage report and open it in the browser
cmds:
- task gen-coverage
- go tool cover -html=coverage.out -o cover.html

help:
desc: Show help
cmds:
Expand Down
70 changes: 56 additions & 14 deletions defaults.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,34 @@
package konfetty

import "reflect"
import (
"errors"
"reflect"
)

const maxDepth = 100 // Adjust this value as needed

// DefaultProvider is an interface for types that can provide their own default values
type DefaultProvider interface {
Defaults() any
}

// fillDefaults recursively fills in default values for structs that implement DefaultProvider
func fillDefaults(v any) error {
return fillDefaultsRecursive(reflect.ValueOf(v))
func fillDefaults(v any, maxDepth int) error {
return fillDefaultsRecursive(reflect.ValueOf(v), 0, maxDepth)
}

func fillDefaultsRecursive(v reflect.Value) error {
func fillDefaultsRecursive(v reflect.Value, depth, maxDepth int) error {
if depth > maxDepth {
return errors.New("maximum recursion depth exceeded, possible circular dependency")
}

// Handle pointer types
if v.Kind() == reflect.Ptr {
if v.IsNil() {
// Check if we can set the value before creating a new instance
if !v.CanSet() {
return nil // Skip if we can't set the value
}
// If the pointer is nil, create a new instance of the pointed-to type
v.Set(reflect.New(v.Type().Elem()))
}
Expand All @@ -33,18 +46,28 @@ func fillDefaultsRecursive(v reflect.Value) error {
field := v.Field(i)
fieldType := t.Field(i)

// Skip unexported fields
if !field.CanSet() {
continue
}

// Handle embedded fields
if fieldType.Anonymous {
if field.Kind() == reflect.Ptr && field.IsNil() {
// Check if we can set the field before creating a new instance
if !field.CanSet() {
continue
}
// If the embedded field is a nil pointer, create a new instance
field.Set(reflect.New(field.Type().Elem()))
}
if err := fillDefaultsRecursive(field); err != nil {
if err := fillDefaultsRecursive(field, depth+1, maxDepth); err != nil {
return err
}

// Apply defaults for the embedded field if it implements DefaultProvider
if defaulter, ok := field.Addr().Interface().(DefaultProvider); ok {
if field.CanAddr() && field.Addr().Type().Implements(reflect.TypeOf((*DefaultProvider)(nil)).Elem()) {
defaulter := field.Addr().Interface().(DefaultProvider)
defaults := reflect.ValueOf(defaulter.Defaults())
if defaults.Kind() == reflect.Ptr {
defaults = defaults.Elem()
Expand All @@ -58,27 +81,42 @@ func fillDefaultsRecursive(v reflect.Value) error {
switch field.Kind() {
case reflect.Ptr:
if field.IsNil() {
// Check if we can set the field before creating a new instance
if !field.CanSet() {
continue
}
// If the field is a nil pointer, create a new instance
field.Set(reflect.New(field.Type().Elem()))
}
if err := fillDefaultsRecursive(field.Elem()); err != nil {
if err := fillDefaultsRecursive(field.Elem(), depth+1, maxDepth); err != nil {
return err
}
case reflect.Struct:
if err := fillDefaultsRecursive(field); err != nil {
if err := fillDefaultsRecursive(field, depth+1, maxDepth); err != nil {
return err
}
case reflect.Slice:
for j := 0; j < field.Len(); j++ {
if err := fillDefaultsRecursive(field.Index(j)); err != nil {
if err := fillDefaultsRecursive(field.Index(j), depth+1, maxDepth); err != nil {
return err
}
}
case reflect.Map:
// Handle maps (new addition)
for _, key := range field.MapKeys() {
value := field.MapIndex(key)
if value.CanAddr() {
if err := fillDefaultsRecursive(value, depth+1, maxDepth); err != nil {
return err
}
}
}
}
}

// Apply defaults to the current struct if it implements DefaultProvider
if defaulter, ok := v.Addr().Interface().(DefaultProvider); ok {
if v.CanAddr() && v.Addr().Type().Implements(reflect.TypeOf((*DefaultProvider)(nil)).Elem()) {
defaulter := v.Addr().Interface().(DefaultProvider)
defaults := reflect.ValueOf(defaulter.Defaults())
if defaults.Kind() == reflect.Ptr {
defaults = defaults.Elem()
Expand All @@ -96,12 +134,16 @@ func fillFromDefaults(dst, src reflect.Value) {

// Check if the destination has this field
dstField := dst.FieldByName(srcFieldName)
if !dstField.IsValid() {
continue // Skip fields that don't exist in the destination
if !dstField.IsValid() || !dstField.CanSet() {
continue // Skip fields that don't exist in the destination or can't be set
}

// Check if the types are compatible
if !srcField.Type().AssignableTo(dstField.Type()) {
continue // Skip if types are not compatible
}

if dstField.CanSet() && isZeroValue(dstField) {
// Only set the value if it's settable and currently zero
if isZeroValue(dstField) {
dstField.Set(srcField)
}
}
Expand Down
Loading

0 comments on commit 6a3a61a

Please sign in to comment.