Skip to content

Commit

Permalink
Merge pull request #180 from huandu/bug/cond-misuse
Browse files Browse the repository at this point in the history
Avoid stack overflow when Cond is misused
  • Loading branch information
huandu authored Nov 6, 2024
2 parents bb320aa + eb375d5 commit 5067ee7
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 6 deletions.
12 changes: 8 additions & 4 deletions args.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type Args struct {
// The default flavor used by `Args#Compile`
Flavor Flavor

indexBase int
argValues []interface{}
namedArgs map[string]int
sqlNamedArgs map[string]int
Expand Down Expand Up @@ -47,7 +48,7 @@ func (args *Args) Add(arg interface{}) string {
}

func (args *Args) add(arg interface{}) int {
idx := len(args.argValues)
idx := len(args.argValues) + args.indexBase

switch a := arg.(type) {
case sql.NamedArg:
Expand Down Expand Up @@ -164,7 +165,7 @@ func (args *Args) compileNamed(ctx *argsCompileContext, format string) string {
format = format[i+1:]

if p, ok := args.namedArgs[name]; ok {
format, _ = args.compileSuccessive(ctx, format, p)
format, _ = args.compileSuccessive(ctx, format, p-args.indexBase)
}

return format
Expand All @@ -181,14 +182,17 @@ func (args *Args) compileDigits(ctx *argsCompileContext, format string, offset i
format = format[i:]

if pointer, err := strconv.Atoi(digits); err == nil {
return args.compileSuccessive(ctx, format, pointer)
return args.compileSuccessive(ctx, format, pointer-args.indexBase)
}

return format, offset
}

func (args *Args) compileSuccessive(ctx *argsCompileContext, format string, offset int) (string, int) {
if offset >= len(args.argValues) {
if offset < 0 || offset >= len(args.argValues) {
ctx.WriteString("/* INVALID ARG $")
ctx.WriteString(strconv.Itoa(offset))
ctx.WriteString(" */")
return format, offset
}

Expand Down
2 changes: 1 addition & 1 deletion args_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func TestArgs(t *testing.T) {
cases := map[string][]interface{}{
"abc ? def\n[123]": {"abc $? def", 123},
"abc ? def\n[456]": {"abc $0 def", 456},
"abc def\n[]": {"abc $1 def", 123},
"abc /* INVALID ARG $1 */ def\n[]": {"abc $1 def", 123},
"abc def \n[]": {"abc ${unknown} def ", 123},
"abc $ def\n[]": {"abc $$ def", 123},
"abcdef$\n[]": {"abcdef$", 123},
Expand Down
14 changes: 13 additions & 1 deletion cond.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ const (
opNOT = "NOT "
)

const minIndexBase = 256

// Cond provides several helper methods to build conditions.
type Cond struct {
Args *Args
Expand All @@ -19,7 +21,17 @@ type Cond struct {
// NewCond returns a new Cond.
func NewCond() *Cond {
return &Cond{
Args: &Args{},
Args: &Args{
// Based on the discussion in #174, users may call this method to create
// `Cond` for building various conditions, which is a misuse, but we
// cannot completely prevent this error. To facilitate users in
// identifying the issue when they make mistakes and to avoid
// unexpected stackoverflows, the base index for `Args` is
// deliberately set to a larger non-zero value here. This can
// significantly reduce the likelihood of issues and allows for
// timely error notification to users.
indexBase: minIndexBase,
},
}
}

Expand Down
13 changes: 13 additions & 0 deletions cond_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,16 @@ func TestCondExpr(t *testing.T) {
a.Equal(actual, expected)
}
}

func TestCondMisuse(t *testing.T) {
a := assert.New(t)

cond := NewCond()
sb := Select("*").
From("t1").
Where(cond.Equal("a", 123))
sql, args := sb.Build()

a.Equal(sql, "SELECT * FROM t1 WHERE /* INVALID ARG $256 */")
a.Equal(args, nil)
}

0 comments on commit 5067ee7

Please sign in to comment.