-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor(konfetty)!: make defaults work
- Loading branch information
Showing
6 changed files
with
304 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
package konfetty | ||
|
||
import "reflect" | ||
|
||
// DefaultProvider is an interface for types that can provide their own default values | ||
type DefaultProvider interface { | ||
Defaults() any | ||
} | ||
|
||
// Defaulter is an interface for types that can set their own default values | ||
type Defaulter interface { | ||
SetDefaults() | ||
} | ||
|
||
// fillDefaults recursively fills in default values for structs that implement DefaultProvider | ||
func fillDefaults(v any) error { | ||
return fillDefaultsRecursive(reflect.ValueOf(v)) | ||
} | ||
|
||
func fillDefaultsRecursive(v reflect.Value) error { | ||
// Handle pointer types | ||
if v.Kind() == reflect.Ptr { | ||
if v.IsNil() { | ||
// If the pointer is nil, create a new instance of the pointed-to type | ||
v.Set(reflect.New(v.Type().Elem())) | ||
} | ||
v = v.Elem() | ||
} | ||
|
||
if v.Kind() != reflect.Struct { | ||
return nil | ||
} | ||
|
||
t := v.Type() | ||
|
||
// Iterate through all fields | ||
for i := 0; i < v.NumField(); i++ { | ||
field := v.Field(i) | ||
fieldType := t.Field(i) | ||
|
||
// Handle embedded fields | ||
if fieldType.Anonymous { | ||
if field.Kind() == reflect.Ptr && field.IsNil() { | ||
// If the embedded field is a nil pointer, create a new instance | ||
field.Set(reflect.New(field.Type().Elem())) | ||
} | ||
if err := fillDefaultsRecursive(field); err != nil { | ||
return err | ||
} | ||
|
||
// Apply defaults for the embedded field if it implements DefaultProvider | ||
if defaulter, ok := field.Addr().Interface().(DefaultProvider); ok { | ||
defaults := reflect.ValueOf(defaulter.Defaults()) | ||
if defaults.Kind() == reflect.Ptr { | ||
defaults = defaults.Elem() | ||
} | ||
fillFromDefaults(field, defaults) | ||
} | ||
|
||
continue | ||
} | ||
|
||
switch field.Kind() { | ||
case reflect.Ptr: | ||
if field.IsNil() { | ||
// If the field is a nil pointer, create a new instance | ||
field.Set(reflect.New(field.Type().Elem())) | ||
} | ||
if err := fillDefaultsRecursive(field.Elem()); err != nil { | ||
return err | ||
} | ||
case reflect.Struct: | ||
if err := fillDefaultsRecursive(field); err != nil { | ||
return err | ||
} | ||
case reflect.Slice: | ||
for j := 0; j < field.Len(); j++ { | ||
if err := fillDefaultsRecursive(field.Index(j)); err != nil { | ||
return err | ||
} | ||
} | ||
} | ||
} | ||
|
||
// Apply defaults to the current struct if it implements DefaultProvider | ||
if defaulter, ok := v.Addr().Interface().(DefaultProvider); ok { | ||
defaults := reflect.ValueOf(defaulter.Defaults()) | ||
if defaults.Kind() == reflect.Ptr { | ||
defaults = defaults.Elem() | ||
} | ||
fillFromDefaults(v, defaults) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func fillFromDefaults(dst, src reflect.Value) { | ||
for i := 0; i < src.NumField(); i++ { | ||
srcField := src.Field(i) | ||
srcFieldName := src.Type().Field(i).Name | ||
|
||
// Check if the destination has this field | ||
dstField := dst.FieldByName(srcFieldName) | ||
if !dstField.IsValid() { | ||
continue // Skip fields that don't exist in the destination | ||
} | ||
|
||
if dstField.CanSet() && isZeroValue(dstField) { | ||
// Only set the value if it's settable and currently zero | ||
dstField.Set(srcField) | ||
} | ||
} | ||
} | ||
|
||
func isZeroValue(v reflect.Value) bool { | ||
zero := reflect.Zero(v.Type()).Interface() | ||
return reflect.DeepEqual(v.Interface(), zero) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
package konfetty | ||
|
||
import ( | ||
"testing" | ||
"time" | ||
|
||
"github.com/google/go-cmp/cmp" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
// Structs provided by the user | ||
type Profile struct { | ||
Checks Checks | ||
} | ||
|
||
type Checks struct { | ||
Ping []PingCheck | ||
} | ||
|
||
type PingCheck struct { | ||
*BaseCheck | ||
Host string | ||
} | ||
|
||
type BaseCheck struct { | ||
Name string | ||
Interval time.Duration | ||
Timeout time.Duration | ||
} | ||
|
||
// Implement DefaultProvider for BaseCheck | ||
func (b BaseCheck) Defaults() interface{} { | ||
return BaseCheck{ | ||
Interval: 30 * time.Second, | ||
Timeout: 5 * time.Second, | ||
} | ||
} | ||
|
||
func TestFillDefaults(t *testing.T) { | ||
tests := []struct { | ||
name string | ||
input *Profile | ||
expected *Profile | ||
}{ | ||
{ | ||
name: "Profile with partially filled checks", | ||
input: &Profile{ | ||
Checks: Checks{ | ||
Ping: []PingCheck{ | ||
{ | ||
BaseCheck: &BaseCheck{Name: "Custom Ping"}, | ||
Host: "example.com", | ||
}, | ||
}, | ||
}, | ||
}, | ||
expected: &Profile{ | ||
Checks: Checks{ | ||
Ping: []PingCheck{ | ||
{ | ||
BaseCheck: &BaseCheck{ | ||
Name: "Custom Ping", | ||
Interval: 30 * time.Second, | ||
Timeout: 5 * time.Second, | ||
}, | ||
Host: "example.com", | ||
}, | ||
}, | ||
}, | ||
}, | ||
}, | ||
{ | ||
name: "Profile with nil checks", | ||
input: &Profile{ | ||
Checks: Checks{ | ||
Ping: []PingCheck{ | ||
{ | ||
Host: "example.com", | ||
}, | ||
}, | ||
}, | ||
}, | ||
expected: &Profile{ | ||
Checks: Checks{ | ||
Ping: []PingCheck{ | ||
{ | ||
BaseCheck: &BaseCheck{ | ||
Interval: 30 * time.Second, | ||
Timeout: 5 * time.Second, | ||
}, | ||
Host: "example.com", | ||
}, | ||
}, | ||
}, | ||
}, | ||
}, | ||
} | ||
|
||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
err := fillDefaults(tt.input) | ||
require.NoError(t, err) | ||
|
||
diff := cmp.Diff(tt.expected, tt.input) | ||
if diff != "" { | ||
t.Errorf("FillDefaults() mismatch (-want +got):\n%s", diff) | ||
} | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.