From d7b6a4c8a37c5efc44dec32c07e4af8ed2b3fb92 Mon Sep 17 00:00:00 2001 From: lxfeng1997 <33981743+lxfeng1997@users.noreply.github.com> Date: Sun, 15 Dec 2024 21:03:47 +0800 Subject: [PATCH] optimize: only inserted fields (#719) * only inserted fields * It is suspected that it is not used, so restore the previous code --------- Co-authored-by: JayLiu <38887641+luky116@users.noreply.github.com> --- pkg/datasource/sql/exec/at/base_executor.go | 43 +++++++++++++++++++ pkg/datasource/sql/exec/at/insert_executor.go | 14 ++++-- .../sql/exec/at/insert_executor_test.go | 4 +- 3 files changed, 56 insertions(+), 5 deletions(-) diff --git a/pkg/datasource/sql/exec/at/base_executor.go b/pkg/datasource/sql/exec/at/base_executor.go index 75f0cab56..05f440577 100644 --- a/pkg/datasource/sql/exec/at/base_executor.go +++ b/pkg/datasource/sql/exec/at/base_executor.go @@ -23,6 +23,7 @@ import ( "database/sql" "database/sql/driver" "fmt" + "seata.apache.org/seata-go/pkg/datasource/sql/undo" "strings" "github.com/arana-db/parser/ast" @@ -187,6 +188,48 @@ func (b *baseExecutor) buildRecordImages(rowsi driver.Rows, tableMetaData *types return &types.RecordImage{TableName: tableMetaData.TableName, Rows: rowImages, SQLType: sqlType}, nil } +func (b *baseExecutor) getNeedColumns(meta *types.TableMeta, columns []string, dbType types.DBType) []string { + var needUpdateColumns []string + if undo.UndoConfig.OnlyCareUpdateColumns && columns != nil && len(columns) > 0 { + needUpdateColumns = columns + if !b.containsPKByName(meta, columns) { + pkNames := meta.GetPrimaryKeyOnlyName() + if pkNames != nil && len(pkNames) > 0 { + for _, name := range pkNames { + needUpdateColumns = append(needUpdateColumns, name) + } + } + } + // todo If it contains onUpdate columns, add onUpdate columns + } else { + needUpdateColumns = meta.ColumnNames + } + + for i := range needUpdateColumns { + needUpdateColumns[i] = AddEscape(needUpdateColumns[i], dbType) + } + return needUpdateColumns +} + +func (b *baseExecutor) containsPKByName(meta *types.TableMeta, columns []string) bool { + pkColumnNameList := meta.GetPrimaryKeyOnlyName() + if len(pkColumnNameList) == 0 { + return false + } + + matchCounter := 0 + for _, column := range columns { + for _, pkName := range pkColumnNameList { + if strings.EqualFold(pkName, column) || + strings.EqualFold(pkName, strings.ToLower(column)) { + matchCounter++ + } + } + } + + return matchCounter == len(pkColumnNameList) +} + func getSqlNullValue(value interface{}) interface{} { if value == nil { return nil diff --git a/pkg/datasource/sql/exec/at/insert_executor.go b/pkg/datasource/sql/exec/at/insert_executor.go index a05da5062..ae7bac71d 100644 --- a/pkg/datasource/sql/exec/at/insert_executor.go +++ b/pkg/datasource/sql/exec/at/insert_executor.go @@ -184,9 +184,17 @@ func (i *insertExecutor) buildAfterImageSQL(ctx context.Context) (string, []driv } // build check sql sb := strings.Builder{} - sb.WriteString("SELECT * FROM " + tableName) - whereSQL := i.buildWhereConditionByPKs(pkColumnNameList, len(pkValuesMap[pkColumnNameList[0]]), "mysql", maxInSize) - sb.WriteString(" WHERE " + whereSQL + " ") + suffix := strings.Builder{} + var insertColumns []string + + for _, column := range i.parserCtx.InsertStmt.Columns { + insertColumns = append(insertColumns, column.Name.O) + } + sb.WriteString("SELECT " + strings.Join(i.getNeedColumns(meta, insertColumns, types.DBTypeMySQL), ", ")) + suffix.WriteString(" FROM " + tableName) + whereSQL := i.buildWhereConditionByPKs(pkColumnNameList, rowSize, "mysql", maxInSize) + suffix.WriteString(" WHERE " + whereSQL + " ") + sb.WriteString(suffix.String()) return sb.String(), i.buildPKParams(pkRowImages, pkColumnNameList), nil } diff --git a/pkg/datasource/sql/exec/at/insert_executor_test.go b/pkg/datasource/sql/exec/at/insert_executor_test.go index fa899df11..742249bc4 100644 --- a/pkg/datasource/sql/exec/at/insert_executor_test.go +++ b/pkg/datasource/sql/exec/at/insert_executor_test.go @@ -78,7 +78,7 @@ func TestBuildSelectSQLByInsert(t *testing.T) { }, }, - expectQuery: "SELECT * FROM user WHERE (`id`) IN ((?),(?)) ", + expectQuery: "SELECT id, name FROM user WHERE (`id`) IN ((?),(?)) ", expectQueryArgs: []driver.Value{int64(19), int64(21)}, }, { @@ -107,7 +107,7 @@ func TestBuildSelectSQLByInsert(t *testing.T) { }, }, }, - expectQuery: "SELECT * FROM user WHERE (`user_id`) IN ((?)) ", + expectQuery: "SELECT user_id, name FROM user WHERE (`user_id`) IN ((?)) ", expectQueryArgs: []driver.Value{int64(20)}, }, }