diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 470a1bb..8ffd290 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -16,8 +16,8 @@ jobs: strategy: matrix: containerGoVer: - - "1.20" - - "1.21-rc" + - "1.21" + - "1.22" runs-on: ubuntu-latest container: golang:${{ matrix.containerGoVer }} services: @@ -57,8 +57,8 @@ jobs: strategy: matrix: containerGoVer: - - "1.20" - - "1.21-rc" + - "1.21" + - "1.22" runs-on: ubuntu-latest container: golang:${{ matrix.containerGoVer }} steps: @@ -76,8 +76,7 @@ jobs: strategy: matrix: containerGoVer: - - "1.20" - - "1.21-rc" + - "latest" runs-on: ubuntu-latest container: golang:${{ matrix.containerGoVer }} steps: diff --git a/db.go b/db.go index d6e50a8..c29bfca 100644 --- a/db.go +++ b/db.go @@ -597,6 +597,11 @@ func (m *DbMap) Select(ctx context.Context, i interface{}, query string, args .. expandSliceArgs(&query, args...) } + args, err := m.convertArgs(args...) + if err != nil { + return nil, err + } + return hookedselect(ctx, m, m, i, query, args...) } @@ -607,6 +612,11 @@ func (m *DbMap) ExecContext(ctx context.Context, query string, args ...interface expandSliceArgs(&query, args...) } + args, err := m.convertArgs(args...) + if err != nil { + return nil, err + } + if m.logger != nil { now := time.Now() defer m.trace(now, query, args...) @@ -620,6 +630,11 @@ func (m *DbMap) SelectInt(ctx context.Context, query string, args ...interface{} expandSliceArgs(&query, args...) } + args, err := m.convertArgs(args...) + if err != nil { + return 0, err + } + return SelectInt(ctx, m, query, args...) } @@ -629,6 +644,11 @@ func (m *DbMap) SelectNullInt(ctx context.Context, query string, args ...interfa expandSliceArgs(&query, args...) } + args, err := m.convertArgs(args...) + if err != nil { + return sql.NullInt64{}, err + } + return SelectNullInt(ctx, m, query, args...) } @@ -638,6 +658,11 @@ func (m *DbMap) SelectFloat(ctx context.Context, query string, args ...interface expandSliceArgs(&query, args...) } + args, err := m.convertArgs(args...) + if err != nil { + return 0, err + } + return SelectFloat(ctx, m, query, args...) } @@ -647,6 +672,11 @@ func (m *DbMap) SelectNullFloat(ctx context.Context, query string, args ...inter expandSliceArgs(&query, args...) } + args, err := m.convertArgs(args...) + if err != nil { + return sql.NullFloat64{}, err + } + return SelectNullFloat(ctx, m, query, args...) } @@ -656,6 +686,11 @@ func (m *DbMap) SelectStr(ctx context.Context, query string, args ...interface{} expandSliceArgs(&query, args...) } + args, err := m.convertArgs(args...) + if err != nil { + return "", err + } + return SelectStr(ctx, m, query, args...) } @@ -665,6 +700,11 @@ func (m *DbMap) SelectNullStr(ctx context.Context, query string, args ...interfa expandSliceArgs(&query, args...) } + args, err := m.convertArgs(args...) + if err != nil { + return sql.NullString{}, err + } + return SelectNullStr(ctx, m, query, args...) } @@ -674,6 +714,11 @@ func (m *DbMap) SelectOne(ctx context.Context, holder interface{}, query string, expandSliceArgs(&query, args...) } + args, err := m.convertArgs(args...) + if err != nil { + return err + } + return SelectOne(ctx, m, m, holder, query, args...) } @@ -798,6 +843,11 @@ func (m *DbMap) QueryRowContext(ctx context.Context, query string, args ...inter expandSliceArgs(&query, args...) } + args, err := m.convertArgs(args...) + if err != nil { + return nil + } + if m.logger != nil { now := time.Now() defer m.trace(now, query, args...) @@ -811,6 +861,11 @@ func (m *DbMap) QueryContext(ctx context.Context, q string, args ...interface{}) expandSliceArgs(&q, args...) } + args, err := m.convertArgs(args...) + if err != nil { + return nil, err + } + if m.logger != nil { now := time.Now() defer m.trace(now, q, args...) @@ -829,6 +884,22 @@ func (m *DbMap) trace(started time.Time, query string, args ...interface{}) { } } +// convertArgs passes each argument through the TypeConverter, if any, +// and returns the result (which may be identical to the input). +func (m *DbMap) convertArgs(args ...interface{}) ([]interface{}, error) { + if m.TypeConverter == nil { + return args, nil + } + for i, arg := range args { + converted, err := m.TypeConverter.ToDb(arg) + if err != nil { + return nil, err + } + args[i] = converted + } + return args, nil +} + type stringer interface { ToStringSlice() []string } diff --git a/gorp_test.go b/gorp_test.go index 668ab51..b135fff 100644 --- a/gorp_test.go +++ b/gorp_test.go @@ -1734,7 +1734,8 @@ func TestTypeConversionExample(t *testing.T) { t.Errorf("tc2 %v != %v", expected, tc2) } - tc2.Name = CustomStringType("hi2") + hi2 := CustomStringType("hi2") + tc2.Name = hi2 tc2.PersonJSON = Person{FName: "Jane", LName: "Doe"} _update(dbmap, tc2) @@ -1744,10 +1745,106 @@ func TestTypeConversionExample(t *testing.T) { t.Errorf("tc3 %v != %v", expected, tc3) } + d := dbmap.Dialect + pj := d.QuoteField("PersonJSON") + id := d.QuoteField("Id") + name := d.QuoteField("Name") + bv0 := d.BindVar(0) + bv1 := d.BindVar(1) + + // Test that the Person argument to Select goes through the + // type converter + var holder TypeConversionExample + personJSON := Person{FName: "Jane", LName: "Doe"} + _, err := dbmap.Select(context.Background(), + holder, + `select * from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1, + personJSON, hi2) + if err != nil { + t.Errorf(`Select failed: %s`, err) + } + + err = dbmap.SelectOne(context.Background(), + &holder, + `select * from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1, + personJSON, hi2) + if err != nil { + t.Errorf(`Select failed: %s`, err) + } + + _, err = dbmap.SelectInt(context.Background(), + `select `+id+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1, + personJSON, hi2) + if err != nil { + t.Errorf(`Select failed: %s`, err) + } + + _, err = dbmap.SelectInt(context.Background(), + `select `+id+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1, + personJSON, hi2) + if err != nil { + t.Errorf(`Select failed: %s`, err) + } + + _, err = dbmap.SelectNullInt(context.Background(), + `select `+id+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1, + personJSON, hi2) + if err != nil { + t.Errorf(`Select failed: %s`, err) + } + + _, err = dbmap.SelectFloat(context.Background(), + `select `+id+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1, + personJSON, hi2) + if err != nil { + t.Errorf(`Select failed: %s`, err) + } + + _, err = dbmap.SelectNullFloat(context.Background(), + `select `+id+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1, + personJSON, hi2) + if err != nil { + t.Errorf(`Select failed: %s`, err) + } + + _, err = dbmap.SelectStr(context.Background(), + `select `+name+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1, + personJSON, hi2) + if err != nil { + t.Errorf(`Select failed: %s`, err) + } + + _, err = dbmap.SelectNullStr(context.Background(), + `select `+name+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1, + personJSON, hi2) + if err != nil { + t.Errorf(`Select failed: %s`, err) + } + + _, err = dbmap.QueryContext(context.Background(), + `select `+name+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1, + personJSON, hi2) + if err != nil { + t.Errorf(`Select failed: %s`, err) + } + + row := dbmap.QueryRowContext(context.Background(), + `select `+name+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1, + personJSON, hi2) + if row == nil || row.Err() != nil { + t.Errorf(`QueryRowContext failed: %s`, row.Err()) + } + + _, err = dbmap.ExecContext(context.Background(), + `select `+name+` from type_conv_test where `+pj+`=`+bv0+` and `+name+`=`+bv1, + personJSON, hi2) + if err != nil { + t.Errorf(`Select failed: %s`, err) + } + if _del(dbmap, tc) != 1 { t.Errorf("Did not delete row with Id: %d", tc.Id) } - } func TestWithEmbeddedStruct(t *testing.T) {