diff --git a/cte.go b/cte.go index 8599aba..4ef19d7 100644 --- a/cte.go +++ b/cte.go @@ -13,6 +13,11 @@ func With(tables ...*CTETableBuilder) *CTEBuilder { return DefaultFlavor.NewCTEBuilder().With(tables...) } +// WithRecursive creates a new recursive CTE builder with default flavor. +func WithRecursive(tables ...*CTETableBuilder) *CTEBuilder { + return DefaultFlavor.NewCTEBuilder().WithRecursive(tables...) +} + func newCTEBuilder() *CTEBuilder { return &CTEBuilder{ args: &Args{}, @@ -22,6 +27,7 @@ func newCTEBuilder() *CTEBuilder { // CTEBuilder is a CTE (Common Table Expression) builder. type CTEBuilder struct { + recursive bool tableNames []string tableBuilderVars []string @@ -49,6 +55,12 @@ func (cteb *CTEBuilder) With(tables ...*CTETableBuilder) *CTEBuilder { return cteb } +// WithRecursive sets the CTE name and columns and turns on the RECURSIVE keyword. +func (cteb *CTEBuilder) WithRecursive(tables ...*CTETableBuilder) *CTEBuilder { + cteb.With(tables...).recursive = true + return cteb +} + // Select creates a new SelectBuilder to build a SELECT statement using this CTE. func (cteb *CTEBuilder) Select(col ...string) *SelectBuilder { sb := cteb.args.Flavor.NewSelectBuilder() @@ -73,6 +85,9 @@ func (cteb *CTEBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{} if len(cteb.tableBuilderVars) > 0 { buf.WriteLeadingString("WITH ") + if cteb.recursive { + buf.WriteString("RECURSIVE ") + } buf.WriteStrings(cteb.tableBuilderVars, ", ") } diff --git a/cte_test.go b/cte_test.go index 75a15ed..ae3bbc6 100644 --- a/cte_test.go +++ b/cte_test.go @@ -30,6 +30,28 @@ func ExampleWith() { // WITH users (id, name) AS (SELECT id, name FROM users WHERE name IS NOT NULL), devices AS (SELECT device_id FROM devices) SELECT users.id, orders.id, devices.device_id FROM users, devices JOIN orders ON users.id = orders.user_id AND devices.device_id = orders.device_id } +func ExampleWithRecursive() { + sb := WithRecursive( + CTETable("source_accounts", "id", "parent_id").As( + UnionAll( + Select("p.id", "p.parent_id"). + From("accounts AS p"). + Where("p.id = 2"), // Show orders for account 2 and all its child accounts + Select("c.id", "c.parent_id"). + From("accounts AS c"). + Join("source_accounts AS sa", "c.parent_id = sa.id"), + ), + ), + ).Select("o.id", "o.date", "o.amount"). + From("orders AS o"). + Join("source_accounts", "o.account_id = source_accounts.id") + + fmt.Println(sb) + + // Output: + // WITH RECURSIVE source_accounts (id, parent_id) AS ((SELECT p.id, p.parent_id FROM accounts AS p WHERE p.id = 2) UNION ALL (SELECT c.id, c.parent_id FROM accounts AS c JOIN source_accounts AS sa ON c.parent_id = sa.id)) SELECT o.id, o.date, o.amount FROM orders AS o JOIN source_accounts ON o.account_id = source_accounts.id +} + func ExampleCTEBuilder() { usersBuilder := Select("id", "name", "level").From("users") usersBuilder.Where( @@ -82,3 +104,27 @@ func TestCTEBuilder(t *testing.T) { sql = ctetb.String() a.Equal(sql, "/* table init */ t (a, b) /* after table */ AS (SELECT a, b FROM t) /* after table as */") } + +func TestRecursiveCTEBuilder(t *testing.T) { + a := assert.New(t) + cteb := newCTEBuilder() + cteb.recursive = true + ctetb := newCTETableBuilder() + cteb.SQL("/* init */") + cteb.With(ctetb) + cteb.SQL("/* after with */") + + ctetb.SQL("/* table init */") + ctetb.Table("t", "a", "b") + ctetb.SQL("/* after table */") + + ctetb.As(Select("a", "b").From("t")) + ctetb.SQL("/* after table as */") + + sql, args := cteb.Build() + a.Equal(sql, "/* init */ WITH RECURSIVE /* table init */ t (a, b) /* after table */ AS (SELECT a, b FROM t) /* after table as */ /* after with */") + a.Assert(args == nil) + + sql = ctetb.String() + a.Equal(sql, "/* table init */ t (a, b) /* after table */ AS (SELECT a, b FROM t) /* after table as */") +}