Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner: enhance partition prune when comparing partition key with constant of different types #59155

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions pkg/planner/core/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,119 @@ func TestPartitionPruningForEQ(t *testing.T) {
require.Equal(t, 0, res[0])
}

func TestCast4PartitionPruning(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)

tk.MustExec("use test")
tk.MustExec(`drop table if exists t`)
tk.MustExec(`drop table if exists t_hash`)
tk.MustExec(`drop table if exists t_sub`)
tk.MustExec(`create table t(a int, b int, c int) partition by range(a) (
partition p1 values less than (100),
partition p2 values less than (200),
partition pm values less than (MAXVALUE));`)

// test between castIntAsReal(int) and real
tk.MustQuery(`explain select * from t where a between "123" and "199";`).Check(
testkit.Rows("TableReader_7 8000.00 root partition:p2 data:Selection_6",
"└─Selection_6 8000.00 cop[tikv] ge(cast(test.t.a, double BINARY), 123), le(cast(test.t.a, double BINARY), 199)",
" └─TableFullScan_5 10000.00 cop[tikv] table:t keep order:false, stats:pseudo"))
// MySQL explain:
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+
//| id | select_type | table | partitions | type | possible_keys | key | key_len | ref | rows | filtered | Extra |
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+
//| 1 | SIMPLE | t | p2 | ALL | NULL | NULL | NULL | NULL | 1 | 100.00 | Using where |
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+

// test between castIntAsDecimal(int) and decimal
tk.MustQuery(`explain select * from t where a between 123.12 and 199.99;`).Check(
testkit.Rows("TableReader_7 8000.00 root partition:p2 data:Selection_6",
"└─Selection_6 8000.00 cop[tikv] ge(cast(test.t.a, decimal(10,0) BINARY), 123.12), le(cast(test.t.a, decimal(10,0) BINARY), 199.99)",
" └─TableFullScan_5 10000.00 cop[tikv] table:t keep order:false, stats:pseudo"))
// MySQL explain:
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+
//| id | select_type | table | partitions | type | possible_keys | key | key_len | ref | rows | filtered | Extra |
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+
//| 1 | SIMPLE | t | p2 | ALL | NULL | NULL | NULL | NULL | 1 | 100.00 | Using where |
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+

// test between castIntAsReal(int) and real
tk.MustQuery(`explain select * from t where a between "123.12" and "199.99";`).Check(
testkit.Rows("TableReader_7 8000.00 root partition:p2 data:Selection_6",
"└─Selection_6 8000.00 cop[tikv] ge(cast(test.t.a, double BINARY), 123.12), le(cast(test.t.a, double BINARY), 199.99)",
" └─TableFullScan_5 10000.00 cop[tikv] table:t keep order:false, stats:pseudo"))
// MySQL explain:
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+
//| id | select_type | table | partitions | type | possible_keys | key | key_len | ref | rows | filtered | Extra |
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+
//| 1 | SIMPLE | t | p2 | ALL | NULL | NULL | NULL | NULL | 1 | 100.00 | Using where |
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+

// test between castIntAsReal(int) and real
tk.MustQuery(`explain select * from t where a between "ddd" and "99";`).Check(
testkit.Rows("TableReader_7 8000.00 root partition:p1 data:Selection_6",
"└─Selection_6 8000.00 cop[tikv] ge(cast(test.t.a, double BINARY), 0), le(cast(test.t.a, double BINARY), 99)",
" └─TableFullScan_5 10000.00 cop[tikv] table:t keep order:false, stats:pseudo"))
// MySQL explain:
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+
//| id | select_type | table | partitions | type | possible_keys | key | key_len | ref | rows | filtered | Extra |
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+
//| 1 | SIMPLE | t | p1 | ALL | NULL | NULL | NULL | NULL | 1 | 100.00 | Using where |
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+

// test between castIntAsReal(int) and real / between castIntAsReal(int) and real
tk.MustQuery(`explain select * from t where a between "123.12" and cast("199.99" as decimal);`).Check(
testkit.Rows("TableReader_7 8000.00 root partition:p2,pm data:Selection_6",
"└─Selection_6 8000.00 cop[tikv] ge(cast(test.t.a, double BINARY), 123.12), le(cast(test.t.a, double BINARY), 200)",
" └─TableFullScan_5 10000.00 cop[tikv] table:t keep order:false, stats:pseudo"))
// MySQL explain:
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+
//| id | select_type | table | partitions | type | possible_keys | key | key_len | ref | rows | filtered | Extra |
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+
//| 1 | SIMPLE | t | p2,pm | ALL | NULL | NULL | NULL | NULL | 1 | 100.00 | Using where |
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+

tk.MustExec(`CREATE TABLE t_hash(a int, b int) PARTITION BY HASH(a) PARTITIONS 6`)
tk.MustExec(`insert into t_hash values(1, 1), (10, 10), (26, 26)`)
tk.MustQuery(`select * from t_hash where a = '1'`).Check(testkit.Rows("1 1"))
tk.MustQuery(`explain select * from t_hash where a = '1'`).Check(testkit.Rows(
"TableReader_7 10.00 root partition:p1 data:Selection_6",
"└─Selection_6 10.00 cop[tikv] eq(test.t_hash.a, 1)",
" └─TableFullScan_5 10000.00 cop[tikv] table:t_hash keep order:false, stats:pseudo"))
// MySQL explain:
//+----+-------------+--------+------------+------+---------------+------+---------+------+------+----------+-------------+
//| id | select_type | table | partitions | type | possible_keys | key | key_len | ref | rows | filtered | Extra |
//+----+-------------+--------+------------+------+---------------+------+---------+------+------+----------+-------------+
//| 1 | SIMPLE | t_hash | p1 | ALL | NULL | NULL | NULL | NULL | 1 | 100.00 | Using where |
//+----+-------------+--------+------------+------+---------------+------+---------+------+------+----------+-------------+

tk.MustExec(`create table t_ts (report_updated timestamp) partition by range(unix_timestamp(report_updated)) (
partition p1 values less than (1732982400), -- 2024-12-01 00:00:00
partition p2 values less than (1733068800), -- 2024-12-02 00:00:00
partition pm values less than (MAXVALUE));`)
tk.MustExec("insert into t_ts values('2024-11-30 00:00:00'), ('2024-12-01 00:00:00'), ('2024-12-02 00:00:00')")
tk.MustQuery("select * from t_ts where report_updated = '2024-12-01 00:00:00'").Check(testkit.Rows("2024-12-01 00:00:00"))
tk.MustQuery("explain select * from t_ts where report_updated = 20241201").Check(testkit.Rows(
"TableReader_7 10.00 root partition:p2 data:Selection_6",
"└─Selection_6 10.00 cop[tikv] eq(test.t_ts.report_updated, 2024-12-01 00:00:00)",
" └─TableFullScan_5 10000.00 cop[tikv] table:t_ts keep order:false, stats:pseudo"))
tk.MustQuery("explain select * from t_ts where report_updated = '2024-12-01 00:00:00'").Check(testkit.Rows(
"TableReader_7 10.00 root partition:p2 data:Selection_6",
"└─Selection_6 10.00 cop[tikv] eq(test.t_ts.report_updated, 2024-12-01 00:00:00.000000)",
" └─TableFullScan_5 10000.00 cop[tikv] table:t_ts keep order:false, stats:pseudo"))
tk.MustQuery("explain select * from t_ts where report_updated > unix_timestamp('2008-05-01 00:00:00')").Check(testkit.Rows(
"TableReader_7 8000.00 root partition:all data:Selection_6",
"└─Selection_6 8000.00 cop[tikv] gt(cast(test.t_ts.report_updated, double BINARY), 1.2095712e+09)",
" └─TableFullScan_5 10000.00 cop[tikv] table:t_ts keep order:false, stats:pseudo"))
//MysQL explain:
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+
//| id | select_type | table | partitions | type | possible_keys | key | key_len | ref | rows | filtered | Extra |
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+
//| 1 | SIMPLE | t_ts | p1,p2,pm | ALL | NULL | NULL | NULL | NULL | 3 | 33.33 | Using where |
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+
}

func TestNotReadOnlySQLOnTiFlash(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
Expand Down
63 changes: 55 additions & 8 deletions pkg/planner/core/rule_partition_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1593,10 +1593,23 @@ func (p *rangePruner) extractDataForPrune(sctx base.PlanContext, expr expression
if arg1, ok := op.GetArgs()[1].(*expression.Constant); ok {
col, con = arg0, arg1
}
} else if arg0, ok := op.GetArgs()[1].(*expression.Column); ok && arg0.ID == p.col.ID {
if arg1, ok := op.GetArgs()[0].(*expression.Constant); ok {
} else if arg1, ok := op.GetArgs()[1].(*expression.Column); ok && arg1.ID == p.col.ID {
if arg0, ok := op.GetArgs()[0].(*expression.Constant); ok {
ret.op = opposite(ret.op)
col, con = arg0, arg1
col, con = arg1, arg0
}
} else if sarg0, ok := op.GetArgs()[0].(*expression.ScalarFunction); ok && sarg0.FuncName.L == ast.Cast {
if arg0, ok := sarg0.GetArgs()[0].(*expression.Column); ok && arg0.ID == p.col.ID {
if arg1, ok := op.GetArgs()[1].(*expression.Constant); ok {
col, con = arg0, arg1
}
}
} else if sarg1, ok := op.GetArgs()[1].(*expression.ScalarFunction); ok && sarg1.FuncName.L == ast.Cast {
if arg1, ok := sarg1.GetArgs()[0].(*expression.Column); ok && arg1.ID == p.col.ID {
if arg0, ok := op.GetArgs()[0].(*expression.Constant); ok {
ret.op = opposite(ret.op)
col, con = arg1, arg0
}
}
}
if col == nil || con == nil {
Expand All @@ -1606,6 +1619,14 @@ func (p *rangePruner) extractDataForPrune(sctx base.PlanContext, expr expression
// Current expression is 'col op const'
var constExpr expression.Expression
if p.partFn != nil {
// If arg0 or arg1 is ScalarFunction, just skip it.
// Maybe more complicated cases would be considered in the future.
_, ok1 := op.GetArgs()[0].(*expression.ScalarFunction)
_, ok2 := op.GetArgs()[1].(*expression.ScalarFunction)
if ok1 || ok2 {
return ret, false
}

// If the partition function is not monotone, only EQ condition can be pruning.
if p.monotonous == monotoneModeInvalid && ret.op != ast.EQ {
return ret, false
Expand All @@ -1626,11 +1647,37 @@ func (p *rangePruner) extractDataForPrune(sctx base.PlanContext, expr expression
// If the partition expression is col, use constExpr.
constExpr = con
}
c, isNull, err := constExpr.EvalInt(sctx.GetExprCtx().GetEvalCtx(), chunk.Row{})
if err == nil && !isNull {
ret.c = c
ret.unsigned = mysql.HasUnsignedFlag(constExpr.GetType(sctx.GetExprCtx().GetEvalCtx()).GetFlag())
return ret, true
evalType := constExpr.GetType(sctx.GetExprCtx().GetEvalCtx()).EvalType()
if evalType == types.ETInt {
c, isNull, err := constExpr.EvalInt(sctx.GetExprCtx().GetEvalCtx(), chunk.Row{})
if err == nil && !isNull {
ret.c = c
ret.unsigned = mysql.HasUnsignedFlag(constExpr.GetType(sctx.GetExprCtx().GetEvalCtx()).GetFlag())
return ret, true
}
} else if evalType == types.ETReal {
f, isNull, err := constExpr.EvalReal(sctx.GetExprCtx().GetEvalCtx(), chunk.Row{})
c := int64(f)
if err == nil && !isNull {
ret.c = c
ret.unsigned = mysql.HasUnsignedFlag(constExpr.GetType(sctx.GetExprCtx().GetEvalCtx()).GetFlag())
return ret, true
}
} else if evalType == types.ETDecimal {
d, isNull, err := constExpr.EvalDecimal(sctx.GetExprCtx().GetEvalCtx(), chunk.Row{})
if err != nil {
return ret, false
}
f, err := d.ToFloat64()
if err != nil {
return ret, false
}
if err == nil && !isNull {
ret.c = int64(f)
ret.unsigned = mysql.HasUnsignedFlag(constExpr.GetType(sctx.GetExprCtx().GetEvalCtx()).GetFlag())
return ret, true
}
} else {
}
return ret, false
}
Expand Down