Skip to content

Commit

Permalink
Merge pull request #179 from huandu/feature/cte-table-for-update-delete
Browse files Browse the repository at this point in the history
Automatically ref names of CTETables in DELETE and UPDATE statements
  • Loading branch information
huandu authored Nov 6, 2024
2 parents 134c901 + 96c9b25 commit bb320aa
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 29 deletions.
8 changes: 4 additions & 4 deletions cte.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,12 @@ func (cteb *CTEBuilder) TableNames() []string {
return tableNames
}

// tableNamesForSelect returns a list of table names which should be automatically added to FROM clause.
// It's not public, as this feature is designed only for SelectBuilder right now.
func (cteb *CTEBuilder) tableNamesForSelect() []string {
// tableNamesForFrom returns a list of table names which should be automatically added to FROM clause.
// It's not public, as this feature is designed only for SelectBuilder/UpdateBuilder/DeleteBuilder right now.
func (cteb *CTEBuilder) tableNamesForFrom() []string {
cnt := 0

// It's rare that the ShouldAddToTableList() returns true.
// ShouldAddToTableList() unlikely returns true.
// Count it before allocating any memory for better performance.
for _, query := range cteb.queries {
if query.ShouldAddToTableList() {
Expand Down
37 changes: 37 additions & 0 deletions cte_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,43 @@ func ExampleCTEBuilder() {
// [users valid_users]
}

func ExampleCTEBuilder_update() {
builder := With(
CTETable("users", "user_id").As(
Select("user_id").From("vip_users"),
),
).Update("orders").Set(
"orders.transport_fee = 0",
).Where(
"users.user_id = orders.user_id",
)

sqlForMySQL, _ := builder.BuildWithFlavor(MySQL)
sqlForPostgreSQL, _ := builder.BuildWithFlavor(PostgreSQL)

fmt.Println(sqlForMySQL)
fmt.Println(sqlForPostgreSQL)

// Output:
// WITH users (user_id) AS (SELECT user_id FROM vip_users) UPDATE orders, users SET orders.transport_fee = 0 WHERE users.user_id = orders.user_id
// WITH users (user_id) AS (SELECT user_id FROM vip_users) UPDATE orders FROM users SET orders.transport_fee = 0 WHERE users.user_id = orders.user_id
}

func ExampleCTEBuilder_delete() {
sql := With(
CTETable("users", "user_id").As(
Select("user_id").From("cheaters"),
),
).DeleteFrom("awards").Where(
"users.user_id = awards.user_id",
).String()

fmt.Println(sql)

// Output:
// WITH users (user_id) AS (SELECT user_id FROM cheaters) DELETE FROM awards, users WHERE users.user_id = awards.user_id
}

func TestCTEBuilder(t *testing.T) {
a := assert.New(t)
cteb := newCTEBuilder()
Expand Down
51 changes: 40 additions & 11 deletions delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ type DeleteBuilder struct {
whereClauseProxy *whereClauseProxy
whereClauseExpr string

cteBuilder string
table string
cteBuilderVar string
cteBuilder *CTEBuilder

tables []string
orderByCols []string
order string
limit int
Expand All @@ -60,24 +62,48 @@ type DeleteBuilder struct {
var _ Builder = new(DeleteBuilder)

// DeleteFrom sets table name in DELETE.
func DeleteFrom(table string) *DeleteBuilder {
return DefaultFlavor.NewDeleteBuilder().DeleteFrom(table)
func DeleteFrom(table ...string) *DeleteBuilder {
return DefaultFlavor.NewDeleteBuilder().DeleteFrom(table...)
}

// With sets WITH clause (the Common Table Expression) before DELETE.
func (db *DeleteBuilder) With(builder *CTEBuilder) *DeleteBuilder {
db.marker = deleteMarkerAfterWith
db.cteBuilder = db.Var(builder)
db.cteBuilderVar = db.Var(builder)
db.cteBuilder = builder
return db
}

// DeleteFrom sets table name in DELETE.
func (db *DeleteBuilder) DeleteFrom(table string) *DeleteBuilder {
db.table = Escape(table)
func (db *DeleteBuilder) DeleteFrom(table ...string) *DeleteBuilder {
db.tables = table
db.marker = deleteMarkerAfterDeleteFrom
return db
}

// TableNames returns all table names in this DELETE statement.
func (db *DeleteBuilder) TableNames() []string {
var additionalTableNames []string

if db.cteBuilder != nil {
additionalTableNames = db.cteBuilder.tableNamesForFrom()
}

var tableNames []string

if len(db.tables) > 0 && len(additionalTableNames) > 0 {
tableNames = make([]string, len(db.tables)+len(additionalTableNames))
copy(tableNames, db.tables)
copy(tableNames[len(db.tables):], additionalTableNames)
} else if len(db.tables) > 0 {
tableNames = db.tables
} else if len(additionalTableNames) > 0 {
tableNames = additionalTableNames
}

return tableNames
}

// Where sets expressions of WHERE in DELETE.
func (db *DeleteBuilder) Where(andExpr ...string) *DeleteBuilder {
if len(andExpr) == 0 || estimateStringsBytes(andExpr) == 0 {
Expand Down Expand Up @@ -146,17 +172,20 @@ func (db *DeleteBuilder) Build() (sql string, args []interface{}) {
// BuildWithFlavor returns compiled DELETE string and args with flavor and initial args.
// They can be used in `DB#Query` of package `database/sql` directly.
func (db *DeleteBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) {

buf := newStringBuilder()
db.injection.WriteTo(buf, deleteMarkerInit)

if db.cteBuilder != "" {
buf.WriteLeadingString(db.cteBuilder)
if db.cteBuilder != nil {
buf.WriteLeadingString(db.cteBuilderVar)
db.injection.WriteTo(buf, deleteMarkerAfterWith)
}

if len(db.table) > 0 {
tableNames := db.TableNames()

if len(tableNames) > 0 {
buf.WriteLeadingString("DELETE FROM ")
buf.WriteString(db.table)
buf.WriteStrings(tableNames, ", ")
}

db.injection.WriteTo(buf, deleteMarkerAfterDeleteFrom)
Expand Down
4 changes: 2 additions & 2 deletions select.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,12 @@ func Select(col ...string) *SelectBuilder {
return DefaultFlavor.NewSelectBuilder().Select(col...)
}

// TableNames returns all table names in a SELECT.
// TableNames returns all table names in this SELECT statement.
func (sb *SelectBuilder) TableNames() []string {
var additionalTableNames []string

if sb.cteBuilder != nil {
additionalTableNames = sb.cteBuilder.tableNamesForSelect()
additionalTableNames = sb.cteBuilder.tableNamesForFrom()
}

var tableNames []string
Expand Down
70 changes: 58 additions & 12 deletions update.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ type UpdateBuilder struct {
whereClauseProxy *whereClauseProxy
whereClauseExpr string

cteBuilder string
table string
cteBuilderVar string
cteBuilder *CTEBuilder

tables []string
assignments []string
orderByCols []string
order string
Expand All @@ -63,24 +65,46 @@ type UpdateBuilder struct {
var _ Builder = new(UpdateBuilder)

// Update sets table name in UPDATE.
func Update(table string) *UpdateBuilder {
return DefaultFlavor.NewUpdateBuilder().Update(table)
func Update(table ...string) *UpdateBuilder {
return DefaultFlavor.NewUpdateBuilder().Update(table...)
}

// With sets WITH clause (the Common Table Expression) before UPDATE.
func (ub *UpdateBuilder) With(builder *CTEBuilder) *UpdateBuilder {
ub.marker = updateMarkerAfterWith
ub.cteBuilder = ub.Var(builder)
ub.cteBuilderVar = ub.Var(builder)
ub.cteBuilder = builder
return ub
}

// Update sets table name in UPDATE.
func (ub *UpdateBuilder) Update(table string) *UpdateBuilder {
ub.table = Escape(table)
func (ub *UpdateBuilder) Update(table ...string) *UpdateBuilder {
ub.tables = table
ub.marker = updateMarkerAfterUpdate
return ub
}

// TableNames returns all table names in this UPDATE statement.
func (ub *UpdateBuilder) TableNames() (tableNames []string) {
var additionalTableNames []string

if ub.cteBuilder != nil {
additionalTableNames = ub.cteBuilder.tableNamesForFrom()
}

if len(ub.tables) > 0 && len(additionalTableNames) > 0 {
tableNames = make([]string, len(ub.tables)+len(additionalTableNames))
copy(tableNames, ub.tables)
copy(tableNames[len(ub.tables):], additionalTableNames)
} else if len(ub.tables) > 0 {
tableNames = ub.tables
} else if len(additionalTableNames) > 0 {
tableNames = additionalTableNames
}

return tableNames
}

// Set sets the assignments in SET.
func (ub *UpdateBuilder) Set(assignment ...string) *UpdateBuilder {
ub.assignments = assignment
Expand Down Expand Up @@ -212,14 +236,36 @@ func (ub *UpdateBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{
buf := newStringBuilder()
ub.injection.WriteTo(buf, updateMarkerInit)

if ub.cteBuilder != "" {
buf.WriteLeadingString(ub.cteBuilder)
if ub.cteBuilder != nil {
buf.WriteLeadingString(ub.cteBuilderVar)
ub.injection.WriteTo(buf, updateMarkerAfterWith)
}

if len(ub.table) > 0 {
buf.WriteLeadingString("UPDATE ")
buf.WriteString(ub.table)
switch flavor {
case MySQL:
// CTE table names should be written after UPDATE keyword in MySQL.
tableNames := ub.TableNames()

if len(tableNames) > 0 {
buf.WriteLeadingString("UPDATE ")
buf.WriteStrings(tableNames, ", ")
}

default:
if len(ub.tables) > 0 {
buf.WriteLeadingString("UPDATE ")
buf.WriteStrings(ub.tables, ", ")

// For ISO SQL, CTE table names should be written after FROM keyword.
if ub.cteBuilder != nil {
cteTableNames := ub.cteBuilder.tableNamesForFrom()

if len(cteTableNames) > 0 {
buf.WriteLeadingString("FROM ")
buf.WriteStrings(cteTableNames, ", ")
}
}
}
}

ub.injection.WriteTo(buf, updateMarkerAfterUpdate)
Expand Down

0 comments on commit bb320aa

Please sign in to comment.