Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

apply TypeConverter to args, not just fields #11

Merged
merged 9 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ jobs:
strategy:
matrix:
containerGoVer:
- "1.20"
- "1.21-rc"
- "1.21"
- "1.22"
runs-on: ubuntu-latest
container: golang:${{ matrix.containerGoVer }}
services:
Expand Down Expand Up @@ -57,8 +57,8 @@ jobs:
strategy:
matrix:
containerGoVer:
- "1.20"
- "1.21-rc"
- "1.21"
- "1.22"
runs-on: ubuntu-latest
container: golang:${{ matrix.containerGoVer }}
steps:
Expand All @@ -76,8 +76,7 @@ jobs:
strategy:
matrix:
containerGoVer:
- "1.20"
- "1.21-rc"
- "latest"
runs-on: ubuntu-latest
container: golang:${{ matrix.containerGoVer }}
steps:
Expand Down
71 changes: 71 additions & 0 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,11 @@ func (m *DbMap) Select(ctx context.Context, i interface{}, query string, args ..
expandSliceArgs(&query, args...)
}

args, err := m.convertArgs(args...)
if err != nil {
return nil, err
}

return hookedselect(ctx, m, m, i, query, args...)
}

Expand All @@ -607,6 +612,11 @@ func (m *DbMap) ExecContext(ctx context.Context, query string, args ...interface
expandSliceArgs(&query, args...)
}

args, err := m.convertArgs(args...)
if err != nil {
return nil, err
}

if m.logger != nil {
now := time.Now()
defer m.trace(now, query, args...)
Expand All @@ -620,6 +630,11 @@ func (m *DbMap) SelectInt(ctx context.Context, query string, args ...interface{}
expandSliceArgs(&query, args...)
}

args, err := m.convertArgs(args...)
if err != nil {
return 0, err
}

return SelectInt(ctx, m, query, args...)
}

Expand All @@ -629,6 +644,11 @@ func (m *DbMap) SelectNullInt(ctx context.Context, query string, args ...interfa
expandSliceArgs(&query, args...)
}

args, err := m.convertArgs(args...)
if err != nil {
return sql.NullInt64{}, err
}

return SelectNullInt(ctx, m, query, args...)
}

Expand All @@ -638,6 +658,11 @@ func (m *DbMap) SelectFloat(ctx context.Context, query string, args ...interface
expandSliceArgs(&query, args...)
}

args, err := m.convertArgs(args...)
if err != nil {
return 0, err
}

return SelectFloat(ctx, m, query, args...)
}

Expand All @@ -647,6 +672,11 @@ func (m *DbMap) SelectNullFloat(ctx context.Context, query string, args ...inter
expandSliceArgs(&query, args...)
}

args, err := m.convertArgs(args...)
if err != nil {
return sql.NullFloat64{}, err
}

return SelectNullFloat(ctx, m, query, args...)
}

Expand All @@ -656,6 +686,11 @@ func (m *DbMap) SelectStr(ctx context.Context, query string, args ...interface{}
expandSliceArgs(&query, args...)
}

args, err := m.convertArgs(args...)
if err != nil {
return "", err
}

return SelectStr(ctx, m, query, args...)
}

Expand All @@ -665,6 +700,11 @@ func (m *DbMap) SelectNullStr(ctx context.Context, query string, args ...interfa
expandSliceArgs(&query, args...)
}

args, err := m.convertArgs(args...)
if err != nil {
return sql.NullString{}, err
}

return SelectNullStr(ctx, m, query, args...)
}

Expand All @@ -674,6 +714,11 @@ func (m *DbMap) SelectOne(ctx context.Context, holder interface{}, query string,
expandSliceArgs(&query, args...)
}

args, err := m.convertArgs(args...)
if err != nil {
return err
}

return SelectOne(ctx, m, m, holder, query, args...)
}

Expand Down Expand Up @@ -798,6 +843,11 @@ func (m *DbMap) QueryRowContext(ctx context.Context, query string, args ...inter
expandSliceArgs(&query, args...)
}

args, err := m.convertArgs(args...)
if err != nil {
return nil
}

if m.logger != nil {
now := time.Now()
defer m.trace(now, query, args...)
Expand All @@ -811,6 +861,11 @@ func (m *DbMap) QueryContext(ctx context.Context, q string, args ...interface{})
expandSliceArgs(&q, args...)
}

args, err := m.convertArgs(args...)
if err != nil {
return nil, err
}

if m.logger != nil {
now := time.Now()
defer m.trace(now, q, args...)
Expand All @@ -829,6 +884,22 @@ func (m *DbMap) trace(started time.Time, query string, args ...interface{}) {
}
}

// convertArgs passes each argument through the TypeConverter, if any,
// and returns the result (which may be identical to the input).
func (m *DbMap) convertArgs(args ...interface{}) ([]interface{}, error) {
if m.TypeConverter == nil {
return args, nil
}
for i, arg := range args {
converted, err := m.TypeConverter.ToDb(arg)
if err != nil {
return nil, err
}
args[i] = converted
}
return args, nil
}

type stringer interface {
ToStringSlice() []string
}
Expand Down
101 changes: 99 additions & 2 deletions gorp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1734,7 +1734,8 @@ func TestTypeConversionExample(t *testing.T) {
t.Errorf("tc2 %v != %v", expected, tc2)
}

tc2.Name = CustomStringType("hi2")
hi2 := CustomStringType("hi2")
tc2.Name = hi2
tc2.PersonJSON = Person{FName: "Jane", LName: "Doe"}
_update(dbmap, tc2)

Expand All @@ -1744,10 +1745,106 @@ func TestTypeConversionExample(t *testing.T) {
t.Errorf("tc3 %v != %v", expected, tc3)
}

d := dbmap.Dialect
pj := d.QuoteField("PersonJSON")
id := d.QuoteField("Id")
name := d.QuoteField("Name")
bv0 := d.BindVar(0)
bv1 := d.BindVar(1)

// Test that the Person argument to Select goes through the
// type converter
var holder TypeConversionExample
personJSON := Person{FName: "Jane", LName: "Doe"}
_, err := dbmap.Select(context.Background(),
holder,
`select * from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1,
personJSON, hi2)
if err != nil {
t.Errorf(`Select failed: %s`, err)
}

err = dbmap.SelectOne(context.Background(),
&holder,
`select * from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1,
personJSON, hi2)
if err != nil {
t.Errorf(`Select failed: %s`, err)
}

_, err = dbmap.SelectInt(context.Background(),
`select `+id+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1,
personJSON, hi2)
if err != nil {
t.Errorf(`Select failed: %s`, err)
}

_, err = dbmap.SelectInt(context.Background(),
`select `+id+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1,
personJSON, hi2)
if err != nil {
t.Errorf(`Select failed: %s`, err)
}

_, err = dbmap.SelectNullInt(context.Background(),
`select `+id+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1,
personJSON, hi2)
if err != nil {
t.Errorf(`Select failed: %s`, err)
}

_, err = dbmap.SelectFloat(context.Background(),
`select `+id+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1,
personJSON, hi2)
if err != nil {
t.Errorf(`Select failed: %s`, err)
}

_, err = dbmap.SelectNullFloat(context.Background(),
`select `+id+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1,
personJSON, hi2)
if err != nil {
t.Errorf(`Select failed: %s`, err)
}

_, err = dbmap.SelectStr(context.Background(),
`select `+name+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1,
personJSON, hi2)
if err != nil {
t.Errorf(`Select failed: %s`, err)
}

_, err = dbmap.SelectNullStr(context.Background(),
`select `+name+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1,
personJSON, hi2)
if err != nil {
t.Errorf(`Select failed: %s`, err)
}

_, err = dbmap.QueryContext(context.Background(),
`select `+name+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1,
personJSON, hi2)
if err != nil {
t.Errorf(`Select failed: %s`, err)
}

row := dbmap.QueryRowContext(context.Background(),
`select `+name+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1,
personJSON, hi2)
if row == nil || row.Err() != nil {
t.Errorf(`QueryRowContext failed: %s`, row.Err())
}

_, err = dbmap.ExecContext(context.Background(),
`select `+name+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1,
personJSON, hi2)
if err != nil {
t.Errorf(`Select failed: %s`, err)
}

if _del(dbmap, tc) != 1 {
t.Errorf("Did not delete row with Id: %d", tc.Id)
}

}

func TestWithEmbeddedStruct(t *testing.T) {
Expand Down
Loading