Skip to content

Commit

Permalink
Fix conversion logic
Browse files Browse the repository at this point in the history
  • Loading branch information
kenshaw committed Nov 20, 2024
1 parent 6f3ff60 commit 32bc8c3
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 29 deletions.
8 changes: 4 additions & 4 deletions _examples/bind/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ func main() {
arg string
u *url.URL
urlSet bool
strings []string
ints []int
m map[string]string
strings []string
m map[int]int
)
k.Run(context.Background(),
func(ctx context.Context, args []string) error {
Expand All @@ -35,7 +35,7 @@ func main() {
k.Flags().
String("arg", "an arg", k.Bind(&arg)).
URL("url", "a url", k.Short("u"), k.BindSet(&u, &urlSet)).
Slice("str", "a string", k.Short("s"), k.Bind(&strings), k.Bind(&ints)).
Map("map", "a map", k.Short("m"), k.Bind(&m)),
Slice("str", "a string", k.Short("s"), k.Bind(&strings), k.Bind(&ints), k.Uint64T).
Map("map", "a map", k.Short("m"), k.Bind(&m), k.IntT, k.MapKey(k.IntT)),
)
}
1 change: 1 addition & 0 deletions cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ func (fs *FlagSet) Hook(name, desc string, f func(context.Context) error, opts .
type Flag struct {
Type Type
Sub Type
Key Type
Descs []Desc
Def any
NoArg bool
Expand Down
47 changes: 30 additions & 17 deletions conv.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,24 +149,29 @@ func convValue(v reflect.Value, val any) (ok bool) {
switch v = v.Elem(); v.Kind() {
case reflect.String:
v.SetString(toString(val))
return true
case reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8, reflect.Int:
if i := toInt[int64](val); !v.OverflowInt(i) {
if i, ok := asInt64(val); ok && !v.OverflowInt(i) {
v.SetInt(i)
return true
}
case reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8, reflect.Uint:
if u := toUint[uint64](val); !v.OverflowUint(u) {
if u, ok := asUint64(val); ok && !v.OverflowUint(u) {
v.SetUint(u)
return true
}
case reflect.Float64, reflect.Float32:
if f := toFloat[float64](val); !v.OverflowFloat(f) {
if f, ok := asFloat64(val); ok && !v.OverflowFloat(f) {
v.SetFloat(f)
return true
}
case reflect.Complex128, reflect.Complex64:
if c := toComplex[complex128](val); !overflowComplex(v, c) {
if c, ok := asComplex128(val); ok && !overflowComplex(v, c) {
v.SetComplex(c)
return true
}
}
return true
return false
}

// unmarshalTextValue creates a new value and unmarshals the value to it.
Expand Down Expand Up @@ -305,6 +310,11 @@ func asInt64(val any) (int64, bool) {
case interface{ Int() int }:
return int64(v.Int()), true
}
if s, ok := asString(val); ok {
if i, err := strconv.ParseInt(s, 10, 64); err == nil {
return i, true
}
}
return 0, false
}

Expand Down Expand Up @@ -332,6 +342,11 @@ func asUint64(val any) (uint64, bool) {
case interface{ Uint() uint }:
return uint64(v.Uint()), true
}
if s, ok := asString(val); ok {
if u, err := strconv.ParseUint(s, 10, 64); err == nil {
return u, true
}
}
return 0, false
}

Expand All @@ -347,6 +362,11 @@ func asFloat64(val any) (float64, bool) {
case interface{ Float32() float32 }:
return float64(v.Float32()), true
}
if s, ok := asString(val); ok {
if f, err := strconv.ParseFloat(s, 64); err == nil {
return f, true
}
}
return 0, false
}

Expand All @@ -362,6 +382,11 @@ func asComplex128(val any) (complex128, bool) {
case interface{ Complex64() complex64 }:
return complex128(v.Complex64()), true
}
if s, ok := asString(val); ok {
if c, err := strconv.ParseComplex(s, 128); err == nil {
return c, true
}
}
return 0, false
}

Expand Down Expand Up @@ -451,9 +476,6 @@ func toInt[T inti](val any) T {
if v, ok := asUint64(val); ok {
return T(v)
}
if v, err := strconv.ParseInt(toString(val), 10, 64); err == nil {
return T(v)
}
var v T
return v
}
Expand All @@ -466,9 +488,6 @@ func toUint[T uinti](val any) T {
if v, ok := asInt64(val); ok {
return T(v)
}
if v, err := strconv.ParseUint(toString(val), 10, 64); err == nil {
return T(v)
}
var v T
return v
}
Expand All @@ -478,9 +497,6 @@ func toFloat[T floati](val any) T {
if v, ok := asFloat64(val); ok {
return T(v)
}
if v, err := strconv.ParseFloat(toString(val), 64); err == nil {
return T(v)
}
var v T
return v
}
Expand All @@ -490,9 +506,6 @@ func toComplex[T complexi](val any) T {
if v, ok := asComplex128(val); ok {
return T(v)
}
if v, err := strconv.ParseComplex(toString(val), 128); err == nil {
return T(v)
}
var v T
return v
}
Expand Down
2 changes: 2 additions & 0 deletions kobra.go
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,8 @@ const (
ErrMissingArgument Error = "missing argument"
// ErrInvalidValue is the invalid value error.
ErrInvalidValue Error = "invalid value"
// ErrInvalidKeyConversion is the invalid key conversion error.
ErrInvalidKeyConversion Error = "invalid key conversion"
// ErrInvalidConversion is the invalid conversion error.
ErrInvalidConversion Error = "invalid conversion"
// ErrTypeMismatch is the type mismatch error.
Expand Down
15 changes: 13 additions & 2 deletions opts.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,22 @@ func Sub(f func(context.Context, []string) error, opts ...Option) Option {
}
}

// MapKey is a flag option to set the map key type.
func MapKey(opts ...Option) Option {
return option{
flag: func(g *Flag) error {
if g.Type == MapT {
}
return nil
},
}
}

// BindSet is a flag option to set a binding variable and a set flag.
func BindSet[T *E, E any](v T, b *bool) Option {
return option{
flag: func(g *Flag) error {
val, err := newBind[T](v, b)
val, err := newBind(v, b)
if err != nil {
return err
}
Expand All @@ -195,7 +206,7 @@ func BindSet[T *E, E any](v T, b *bool) Option {

// Bind is a flag option to set a binding variable.
func Bind[T *E, E any](v T) Option {
return BindSet[T](v, nil)
return BindSet(v, nil)
}

// Default is a flag option to set the flag's default value.
Expand Down
2 changes: 1 addition & 1 deletion type.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ func registerMarshaler(f func() (any, error), m map[reflect.Type]marshalDesc, op
New: f,
}
for _, o := range opts {
o.apply(&d)
_ = o.apply(&d)
}
m[typ] = d
}
34 changes: 29 additions & 5 deletions var.go
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,7 @@ func (val *bindVal[T, E]) Val() any {
return val.v
}

func (val *bindVal[T, E]) Set(ctx context.Context, s string) error {
func (val *bindVal[T, E]) Set(_ context.Context, s string) error {
switch reflect.TypeOf(*val.v).Kind() {
case reflect.Slice:
return val.sliceSet(s)
Expand Down Expand Up @@ -769,6 +769,28 @@ func (val *bindVal[T, E]) sliceSet(s string) error {
}

func (val *bindVal[T, E]) mapSet(s string) error {
k, value, ok := strings.Cut(s, "=")
if !ok {
return ErrInvalidMapValue
}
m := reflect.ValueOf(val.v).Elem()
// create map if nil
if m.IsNil() {
m.Set(reflect.MakeMap(m.Type()))
var ok bool
if *val.v, ok = m.Interface().(E); !ok {
return ErrInvalidConversion
}
}
key := reflect.New(m.Type().Key())
if !convValue(key, k) {
return ErrInvalidKeyConversion
}
v := reflect.New(m.Type().Elem())
if !convValue(v, value) {
return ErrInvalidValue
}
m.SetMapIndex(reflect.Indirect(key), reflect.Indirect(v))
return nil
}

Expand Down Expand Up @@ -839,6 +861,7 @@ func (val *sliceVal) Len() int {
// mapVal is a map value.
type mapVal struct {
typ Type
sub Type
v map[string]*VarSet
}

Expand All @@ -847,6 +870,7 @@ func NewMap(opts ...Option) func() (Value, error) {
return func() (Value, error) {
val := &mapVal{
typ: StringT,
sub: StringT,
}
for _, o := range opts {
if err := o.apply(val); err != nil {
Expand All @@ -858,15 +882,15 @@ func NewMap(opts ...Option) func() (Value, error) {
}

func (val *mapVal) Type() Type {
return "map[string]" + val.typ
return "map[" + val.sub + "]" + val.typ
}

func (val *mapVal) Val() any {
return val.v
}

func (val *mapVal) Get() (string, error) {
return string(val.Type()), nil
return "", nil
}

func (val *mapVal) Set(ctx context.Context, s string) error {
Expand Down Expand Up @@ -924,9 +948,9 @@ func (vars Vars) Set(ctx context.Context, g *Flag, value string, wasSet bool) er
if err != nil {
return err
}
for _, val := range g.Binds {
for i, val := range g.Binds {
if err := val.Set(ctx, value); err != nil {
return fmt.Errorf("cannot bind %s to %T: %w", g.Name(), val.Val(), err)
return fmt.Errorf("flag %s: bind %d (%T): cannot set %q: %w", g.Name(), i, val.Val(), value, err)
}
}
vars[name] = vs
Expand Down

0 comments on commit 32bc8c3

Please sign in to comment.