Skip to content

Commit

Permalink
Support sort merge join
Browse files Browse the repository at this point in the history
  • Loading branch information
csynineyang committed Jul 3, 2024
1 parent 1e525f4 commit a9a5c67
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 91 deletions.
68 changes: 29 additions & 39 deletions pkg/dataset/sort_merge_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package dataset

import (
"bytes"
"io"
"sync"
)
Expand All @@ -32,6 +33,7 @@ import (

import (
"github.com/arana-db/arana/pkg/mysql/rows"
"github.com/arana-db/arana/pkg/mysql"
"github.com/arana-db/arana/pkg/proto"
"github.com/arana-db/arana/pkg/runtime/ast"
"github.com/arana-db/arana/pkg/util/log"
Expand Down Expand Up @@ -77,8 +79,8 @@ func NewSortMergeJoin(joinType ast.JoinType, joinColumn *JoinColumn, outer proto
}

fields := make([]proto.Field, 0, len(outerFields)+len(innerFields))
fields = append(fields, outerFields...)
fields = append(fields, innerFields...)
fields = append(fields, outerFields...)

if joinType == ast.RightJoin {
outer, inner = inner, outer
Expand Down Expand Up @@ -249,6 +251,10 @@ func (j *JoinColumn) Column() string {
return ""
}

func (j *JoinColumn) SetColumn(column string) {
j.column = column
}

func (s *SortMergeJoin) Close() error {
return nil
}
Expand All @@ -263,16 +269,12 @@ func (s *SortMergeJoin) Next() (proto.Row, error) {
outerRow, innerRow proto.Row
)

if s.LastRow() != nil {
outerRow = s.LastRow()
} else {
outerRow, err = s.getOuterRow()
if err != nil {
return nil, err
}
outerRow, err = s.getOuterRow()
if err != nil {
return nil, err
}

innerRow, err = s.getInnerRow(outerRow)
innerRow, err = s.getInnerRow()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -319,6 +321,7 @@ func (s *SortMergeJoin) innerJoin(outerRow proto.Row, innerRow proto.Row) (proto
if res, err := s.equalCompare(outerRow, innerRow, outerValue); err != nil {
return nil, err
} else {
s.SetLastInnerRow(innerRow)
return res, nil
}
}
Expand Down Expand Up @@ -469,12 +472,6 @@ func (s *SortMergeJoin) rightJoin(outerRow proto.Row, innerRow proto.Row) (proto
}

func (s *SortMergeJoin) getOuterRow() (proto.Row, error) {
nextOuterRow := s.NextOuterRow()
if nextOuterRow != nil {
s.ResetNextOuterRow()
return nextOuterRow, nil
}

leftRow, err := s.outer.Next()
if err != nil && errors.Is(err, io.EOF) {
return nil, nil
Expand All @@ -486,21 +483,7 @@ func (s *SortMergeJoin) getOuterRow() (proto.Row, error) {
return leftRow, nil
}

func (s *SortMergeJoin) getInnerRow(outerRow proto.Row) (proto.Row, error) {
if outerRow != nil {
outerValue, err := outerRow.(proto.KeyedRow).Get(s.joinColumn.Column())
if err != nil {
return nil, err
}

if s.DescartesFlag() {
innerRow := s.EqualValue(outerValue.String())
if innerRow != nil {
return innerRow, nil
}
}
}

func (s *SortMergeJoin) getInnerRow() (proto.Row, error) {
lastInnerRow := s.LastInnerRow()
if lastInnerRow != nil {
s.ResetLastInnerRow()
Expand All @@ -518,19 +501,22 @@ func (s *SortMergeJoin) getInnerRow(outerRow proto.Row) (proto.Row, error) {
return rightRow, nil
}

func (s *SortMergeJoin) resGenerate(leftRow proto.Row, rightRow proto.Row) proto.Row {
func (s *SortMergeJoin) resGenerate(rightRow proto.Row, leftRow proto.Row) proto.Row {
var (
leftValue []proto.Value
rightValue []proto.Value
res []proto.Value
realFields []proto.Field
)

if leftRow == nil && rightRow == nil {
return nil
}

leftFields, _ := s.outer.Fields()
rightFields, _ := s.inner.Fields()
leftFields, _ := s.inner.Fields()
realFields = append(realFields, leftFields[:(len(leftFields) - 1)]...)
rightFields, _ := s.outer.Fields()
realFields = append(realFields, rightFields[:(len(rightFields) - 1)]...)

leftValue = make([]proto.Value, len(leftFields))
rightValue = make([]proto.Value, len(rightFields))
Expand Down Expand Up @@ -560,12 +546,16 @@ func (s *SortMergeJoin) resGenerate(leftRow proto.Row, rightRow proto.Row) proto
}
}

res = append(res, leftValue...)
res = append(res, rightValue...)

fields, _ := s.Fields()
res = append(res, leftValue[:(len(leftValue)-1)]...)
res = append(res, rightValue[:(len(rightValue)-1)]...)

return rows.NewBinaryVirtualRow(fields, res)
var b bytes.Buffer
row := rows.NewTextVirtualRow(realFields, res)
_, err := row.WriteTo(&b)
if err != nil {
return nil
}
return mysql.NewTextRow(realFields, b.Bytes())
}

func (s *SortMergeJoin) equalCompare(outerRow proto.Row, innerRow proto.Row, outerValue proto.Value) (proto.Row, error) {
Expand Down Expand Up @@ -614,7 +604,7 @@ func (s *SortMergeJoin) greaterCompare(outerRow proto.Row) (proto.Row, proto.Row
innerRow = s.EqualValue(outerValue.String())
} else {
s.setDescartesFlag(NotDescartes)
innerRow, err = s.getInnerRow(outerRow)
innerRow, err = s.getInnerRow()
if err != nil {
return nil, nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/proto/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ func NewTableMetadata(name string, columnMetadataList []*ColumnMetadata, indexMe
}
}
for _, indexMetadata := range indexMetadataList {
indexName := strings.ToLower(indexMetadata.Name)
tma.Indexes[indexName] = indexMetadata
tma.Indexes[indexMetadata.ColumnName] = indexMetadata
}

return tma
Expand All @@ -66,7 +65,8 @@ type ColumnMetadata struct {
}

type IndexMetadata struct {
Name string
ColumnName string
Name string
}

var _defaultSchemaLoader SchemaLoader
Expand Down
69 changes: 51 additions & 18 deletions pkg/runtime/optimize/dml/select.go
Original file line number Diff line number Diff line change
Expand Up @@ -464,12 +464,28 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt
if err != nil {
return nil, err
}
var tbLeft0 = tableLeft.Suffix()
if shardsLeft != nil {
_, tbLeft0 = shardsLeft.Smallest()
}
leftTblMeta, err := loadMetadataByTable(ctx, tbLeft0)
if err != nil {
return nil, err
}

join := from.Joins[0]
dbRight, aliasRight, tableRight, shardsRight, err := compute(join.Target)
if err != nil {
return nil, err
}
var tbRight0 = tableRight.Suffix()
if shardsRight != nil {
_, tbRight0 = shardsRight.Smallest()
}
rightTblMeta, err := loadMetadataByTable(ctx, tbRight0)
if err != nil {
return nil, err
}

// one db
if dbLeft == dbRight && shardsLeft == nil && shardsRight == nil {
Expand Down Expand Up @@ -517,6 +533,7 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt
return nil, errors.Errorf("not found buildKey or probeKey")
}

shouldSortMerge := shouldSortMergeJoin(leftTblMeta, rightTblMeta, leftKey, rightKey)
rewriteToSingle := func(tableSource ast.TableSourceItem, shards map[string][]string, onKey string) (proto.Plan, error) {
selectStmt := &ast.SelectStatement{
Select: stmt.Select,
Expand Down Expand Up @@ -594,30 +611,42 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt
return nil, err
}

setPlan := func(plan *dml.HashJoinPlan, buildPlan, probePlan proto.Plan, buildKey, probeKey string) {
plan.BuildKey = buildKey
plan.ProbeKey = probeKey
plan.BuildPlan = buildPlan
plan.ProbePlan = probePlan
}
var tmpPlan proto.Plan

if join.Typ == ast.InnerJoin {
setPlan(hashJoinPlan, leftPlan, rightPlan, leftKey, rightKey)
hashJoinPlan.IsFilterProbeRow = true
if shouldSortMerge {
tmpPlan = &dml.SortMergeJoin{
Stmt: stmt,
LeftQuery: leftPlan,
RightQuery: rightPlan,
JoinType: join.Typ,
LeftKey: leftKey,
RightKey: rightKey,
}
} else {
hashJoinPlan.IsFilterProbeRow = false
if join.Typ == ast.LeftJoin {
hashJoinPlan.IsReversedColumn = true
setPlan(hashJoinPlan, rightPlan, leftPlan, rightKey, leftKey)
} else if join.Typ == ast.RightJoin {
setPlan := func(plan *dml.HashJoinPlan, buildPlan, probePlan proto.Plan, buildKey, probeKey string) {
plan.BuildKey = buildKey
plan.ProbeKey = probeKey
plan.BuildPlan = buildPlan
plan.ProbePlan = probePlan
}

if join.Typ == ast.InnerJoin {
setPlan(hashJoinPlan, leftPlan, rightPlan, leftKey, rightKey)
hashJoinPlan.IsFilterProbeRow = true
} else {
return nil, errors.New("not support Join Type")
hashJoinPlan.IsFilterProbeRow = false
if join.Typ == ast.LeftJoin {
hashJoinPlan.IsReversedColumn = true
setPlan(hashJoinPlan, rightPlan, leftPlan, rightKey, leftKey)
} else if join.Typ == ast.RightJoin {
setPlan(hashJoinPlan, leftPlan, rightPlan, leftKey, rightKey)
} else {
return nil, errors.New("not support Join Type")
}
}
}

var tmpPlan proto.Plan
tmpPlan = hashJoinPlan
tmpPlan = hashJoinPlan
}

var (
analysis selectResult
Expand Down Expand Up @@ -700,6 +729,10 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt
return tmpPlan, nil
}

func shouldSortMergeJoin(leftTblMeta, rightTblMeta *proto.TableMetadata, leftKey, rightKey string) bool {
return leftTblMeta.Indexes[leftKey] != nil && rightTblMeta.Indexes[rightKey] != nil
}

func getSelectFlag(ru *rule.Rule, stmt *ast.SelectStatement) (flag uint32) {
switch len(stmt.From) {
case 1:
Expand Down
63 changes: 35 additions & 28 deletions pkg/runtime/plan/dml/rename.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package dml

import (
"context"
"fmt"
)

import (
Expand Down Expand Up @@ -52,38 +51,46 @@ func (rp RenamePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result
}

convFields := func(fields []proto.Field) []proto.Field {
if len(rp.RenameList) != len(fields) {
panic(fmt.Sprintf("the length of field doesn't match: expect=%d, actual=%d!", len(rp.RenameList), len(fields)))
}
//if len(rp.RenameList) != len(fields) {
// panic(fmt.Sprintf("the length of field doesn't match: expect=%d, actual=%d!", len(rp.RenameList), len(fields)))
//}

var renames map[int]struct{}
newFields := make([]proto.Field, 0, len(fields))
for i := 0; i < len(rp.RenameList); i++ {
rename := rp.RenameList[i]
name := fields[i].Name()
if rename == name {
continue
}
if renames == nil {
renames = make(map[int]struct{})
}
renames[i] = struct{}{}
f := *(fields[i].(*mysql.Field))
f.SetName(rp.RenameList[i])
f.SetOrgName(rp.RenameList[i])
newFields = append(newFields, &f)
}

if len(renames) < 1 {
return fields
}
// var renames map[int]struct{}
// for i := 0; i < len(rp.RenameList); i++ {
// rename := rp.RenameList[i]
// name := fields[i].Name()
// if rename == name {
// continue
// }
// if renames == nil {
// renames = make(map[int]struct{})
// }
// renames[i] = struct{}{}
// }

newFields := make([]proto.Field, 0, len(fields))
for i := 0; i < len(fields); i++ {
if _, ok := renames[i]; ok {
f := *(fields[i].(*mysql.Field))
f.SetName(rp.RenameList[i])
f.SetOrgName(rp.RenameList[i])
newFields = append(newFields, &f)
} else {
newFields = append(newFields, fields[i])
}
}
// if len(renames) < 1 {
// return fields
// }

// newFields := make([]proto.Field, 0, len(fields))
// for i := 0; i < len(fields); i++ {
// if _, ok := renames[i]; ok {
// f := *(fields[i].(*mysql.Field))
// f.SetName(rp.RenameList[i])
// f.SetOrgName(rp.RenameList[i])
// newFields = append(newFields, &f)
// } else {
// newFields = append(newFields, fields[i])
// }
// }
return newFields
}

Expand Down
Loading

0 comments on commit a9a5c67

Please sign in to comment.