Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

apply TypeConverter to args, not just fields #11

Merged
merged 9 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -597,16 +597,40 @@ func (m *DbMap) Select(ctx context.Context, i interface{}, query string, args ..
expandSliceArgs(&query, args...)
}

args, err := m.convertArgs(args...)
if err != nil {
return nil, err
}

return hookedselect(ctx, m, m, i, query, args...)
}

func (m *DbMap) convertArgs(args ...interface{}) ([]interface{}, error) {
jsha marked this conversation as resolved.
Show resolved Hide resolved
if m.TypeConverter == nil {
return args, nil
}
for i, arg := range args {
converted, err := m.TypeConverter.ToDb(arg)
if err != nil {
return nil, err
}
args[i] = converted
}
return args, nil
}

// Exec runs an arbitrary SQL statement. args represent the bind parameters.
// This is equivalent to running: ExecContext() using database/sql
func (m *DbMap) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
if m.ExpandSliceArgs {
expandSliceArgs(&query, args...)
}

args, err := m.convertArgs(args...)
if err != nil {
return nil, err
}

if m.logger != nil {
now := time.Now()
defer m.trace(now, query, args...)
Expand All @@ -620,6 +644,11 @@ func (m *DbMap) SelectInt(ctx context.Context, query string, args ...interface{}
expandSliceArgs(&query, args...)
}

args, err := m.convertArgs(args...)
if err != nil {
return 0, err
}

return SelectInt(ctx, m, query, args...)
}

Expand All @@ -629,6 +658,11 @@ func (m *DbMap) SelectNullInt(ctx context.Context, query string, args ...interfa
expandSliceArgs(&query, args...)
}

args, err := m.convertArgs(args...)
if err != nil {
return sql.NullInt64{}, err
}

return SelectNullInt(ctx, m, query, args...)
}

Expand All @@ -638,6 +672,11 @@ func (m *DbMap) SelectFloat(ctx context.Context, query string, args ...interface
expandSliceArgs(&query, args...)
}

args, err := m.convertArgs(args...)
if err != nil {
return 0, err
}

return SelectFloat(ctx, m, query, args...)
}

Expand All @@ -647,6 +686,11 @@ func (m *DbMap) SelectNullFloat(ctx context.Context, query string, args ...inter
expandSliceArgs(&query, args...)
}

args, err := m.convertArgs(args...)
if err != nil {
return sql.NullFloat64{}, err
}

return SelectNullFloat(ctx, m, query, args...)
}

Expand All @@ -656,6 +700,11 @@ func (m *DbMap) SelectStr(ctx context.Context, query string, args ...interface{}
expandSliceArgs(&query, args...)
}

args, err := m.convertArgs(args...)
if err != nil {
return "", err
}

return SelectStr(ctx, m, query, args...)
}

Expand All @@ -665,6 +714,11 @@ func (m *DbMap) SelectNullStr(ctx context.Context, query string, args ...interfa
expandSliceArgs(&query, args...)
}

args, err := m.convertArgs(args...)
if err != nil {
return sql.NullString{}, err
}

return SelectNullStr(ctx, m, query, args...)
}

Expand All @@ -674,6 +728,11 @@ func (m *DbMap) SelectOne(ctx context.Context, holder interface{}, query string,
expandSliceArgs(&query, args...)
}

args, err := m.convertArgs(args...)
if err != nil {
return err
}

return SelectOne(ctx, m, m, holder, query, args...)
}

Expand Down Expand Up @@ -798,6 +857,11 @@ func (m *DbMap) QueryRowContext(ctx context.Context, query string, args ...inter
expandSliceArgs(&query, args...)
}

args, err := m.convertArgs(args...)
if err != nil {
return nil
}

if m.logger != nil {
now := time.Now()
defer m.trace(now, query, args...)
Expand All @@ -811,6 +875,11 @@ func (m *DbMap) QueryContext(ctx context.Context, q string, args ...interface{})
expandSliceArgs(&q, args...)
}

args, err := m.convertArgs(args...)
if err != nil {
return nil, err
}

if m.logger != nil {
now := time.Now()
defer m.trace(now, q, args...)
Expand Down
94 changes: 92 additions & 2 deletions gorp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1734,7 +1734,8 @@ func TestTypeConversionExample(t *testing.T) {
t.Errorf("tc2 %v != %v", expected, tc2)
}

tc2.Name = CustomStringType("hi2")
hi2 := CustomStringType("hi2")
tc2.Name = hi2
tc2.PersonJSON = Person{FName: "Jane", LName: "Doe"}
_update(dbmap, tc2)

Expand All @@ -1744,10 +1745,99 @@ func TestTypeConversionExample(t *testing.T) {
t.Errorf("tc3 %v != %v", expected, tc3)
}

// Test that the Person argument to Select goes through the
// type converter
var holder TypeConversionExample
personJSON := Person{FName: "Jane", LName: "Doe"}
_, err := dbmap.Select(context.Background(),
holder,
"select * from type_conv_test where personjson = ? and name = ?",
personJSON, hi2)
if err != nil {
t.Errorf("Select failed: %s", err)
}

err = dbmap.SelectOne(context.Background(),
&holder,
"select * from type_conv_test where personjson = ? and name = ?",
personJSON, hi2)
if err != nil {
t.Errorf("Select failed: %s", err)
}

_, err = dbmap.SelectInt(context.Background(),
"select id from type_conv_test where personjson = ? and name = ?",
personJSON, hi2)
if err != nil {
t.Errorf("Select failed: %s", err)
}

_, err = dbmap.SelectInt(context.Background(),
"select id from type_conv_test where personjson = ? and name = ?",
personJSON, hi2)
if err != nil {
t.Errorf("Select failed: %s", err)
}

_, err = dbmap.SelectNullInt(context.Background(),
"select id from type_conv_test where personjson = ? and name = ?",
personJSON, hi2)
if err != nil {
t.Errorf("Select failed: %s", err)
}

_, err = dbmap.SelectFloat(context.Background(),
"select id * 1.2 from type_conv_test where personjson = ? and name = ?",
personJSON, hi2)
if err != nil {
t.Errorf("Select failed: %s", err)
}

_, err = dbmap.SelectNullFloat(context.Background(),
"select id * 1.2 from type_conv_test where personjson = ? and name = ?",
personJSON, hi2)
if err != nil {
t.Errorf("Select failed: %s", err)
}

_, err = dbmap.SelectStr(context.Background(),
"select name from type_conv_test where personjson = ? and name = ?",
personJSON, hi2)
if err != nil {
t.Errorf("Select failed: %s", err)
}

_, err = dbmap.SelectNullStr(context.Background(),
"select name from type_conv_test where personjson = ? and name = ?",
personJSON, hi2)
if err != nil {
t.Errorf("Select failed: %s", err)
}

_, err = dbmap.QueryContext(context.Background(),
"select name from type_conv_test where personjson = ? and name = ?",
personJSON, hi2)
if err != nil {
t.Errorf("Select failed: %s", err)
}

row := dbmap.QueryRowContext(context.Background(),
"select name from type_conv_test where personjson = ?",
personJSON)
if row == nil || row.Err() != nil {
t.Errorf("QueryRowContext failed: %s", row.Err())
}

_, err = dbmap.ExecContext(context.Background(),
"select name from type_conv_test where personjson = ?",
personJSON)
if err != nil {
t.Errorf("Select failed: %s", err)
}

if _del(dbmap, tc) != 1 {
t.Errorf("Did not delete row with Id: %d", tc.Id)
}

}

func TestWithEmbeddedStruct(t *testing.T) {
Expand Down
Loading