From 5abd231a6ee16ee15efb8a4af815a5744ad34835 Mon Sep 17 00:00:00 2001 From: L-maple Date: Thu, 23 Jan 2025 14:28:40 +0800 Subject: [PATCH] core: enhance partition prune when comparing partition key with constant of different types --- pkg/planner/core/integration_test.go | 113 +++++++++++++++++++ pkg/planner/core/rule_partition_processor.go | 63 +++++++++-- 2 files changed, 168 insertions(+), 8 deletions(-) diff --git a/pkg/planner/core/integration_test.go b/pkg/planner/core/integration_test.go index 4528bff107fc9..a1070af01ccb5 100644 --- a/pkg/planner/core/integration_test.go +++ b/pkg/planner/core/integration_test.go @@ -175,6 +175,119 @@ func TestPartitionPruningForEQ(t *testing.T) { require.Equal(t, 0, res[0]) } +func TestCast4PartitionPruning(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + + tk.MustExec("use test") + tk.MustExec(`drop table if exists t`) + tk.MustExec(`drop table if exists t_hash`) + tk.MustExec(`drop table if exists t_sub`) + tk.MustExec(`create table t(a int, b int, c int) partition by range(a) ( + partition p1 values less than (100), + partition p2 values less than (200), + partition pm values less than (MAXVALUE));`) + + // test between castIntAsReal(int) and real + tk.MustQuery(`explain select * from t where a between "123" and "199";`).Check( + testkit.Rows("TableReader_7 8000.00 root partition:p2 data:Selection_6", + "└─Selection_6 8000.00 cop[tikv] ge(cast(test.t.a, double BINARY), 123), le(cast(test.t.a, double BINARY), 199)", + " └─TableFullScan_5 10000.00 cop[tikv] table:t keep order:false, stats:pseudo")) + // MySQL explain: + //+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+ + //| id | select_type | table | partitions | type | possible_keys | key | key_len | ref | rows | filtered | Extra | + //+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+ + //| 1 | SIMPLE | t | p2 | ALL | NULL | NULL | NULL | NULL | 1 | 100.00 | Using where | + //+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+ + + // test between castIntAsDecimal(int) and decimal + tk.MustQuery(`explain select * from t where a between 123.12 and 199.99;`).Check( + testkit.Rows("TableReader_7 8000.00 root partition:p2 data:Selection_6", + "└─Selection_6 8000.00 cop[tikv] ge(cast(test.t.a, decimal(10,0) BINARY), 123.12), le(cast(test.t.a, decimal(10,0) BINARY), 199.99)", + " └─TableFullScan_5 10000.00 cop[tikv] table:t keep order:false, stats:pseudo")) + // MySQL explain: + //+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+ + //| id | select_type | table | partitions | type | possible_keys | key | key_len | ref | rows | filtered | Extra | + //+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+ + //| 1 | SIMPLE | t | p2 | ALL | NULL | NULL | NULL | NULL | 1 | 100.00 | Using where | + //+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+ + + // test between castIntAsReal(int) and real + tk.MustQuery(`explain select * from t where a between "123.12" and "199.99";`).Check( + testkit.Rows("TableReader_7 8000.00 root partition:p2 data:Selection_6", + "└─Selection_6 8000.00 cop[tikv] ge(cast(test.t.a, double BINARY), 123.12), le(cast(test.t.a, double BINARY), 199.99)", + " └─TableFullScan_5 10000.00 cop[tikv] table:t keep order:false, stats:pseudo")) + // MySQL explain: + //+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+ + //| id | select_type | table | partitions | type | possible_keys | key | key_len | ref | rows | filtered | Extra | + //+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+ + //| 1 | SIMPLE | t | p2 | ALL | NULL | NULL | NULL | NULL | 1 | 100.00 | Using where | + //+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+ + + // test between castIntAsReal(int) and real + tk.MustQuery(`explain select * from t where a between "ddd" and "99";`).Check( + testkit.Rows("TableReader_7 8000.00 root partition:p1 data:Selection_6", + "└─Selection_6 8000.00 cop[tikv] ge(cast(test.t.a, double BINARY), 0), le(cast(test.t.a, double BINARY), 99)", + " └─TableFullScan_5 10000.00 cop[tikv] table:t keep order:false, stats:pseudo")) + // MySQL explain: + //+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+ + //| id | select_type | table | partitions | type | possible_keys | key | key_len | ref | rows | filtered | Extra | + //+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+ + //| 1 | SIMPLE | t | p1 | ALL | NULL | NULL | NULL | NULL | 1 | 100.00 | Using where | + //+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+ + + // test between castIntAsReal(int) and real / between castIntAsReal(int) and real + tk.MustQuery(`explain select * from t where a between "123.12" and cast("199.99" as decimal);`).Check( + testkit.Rows("TableReader_7 8000.00 root partition:p2,pm data:Selection_6", + "└─Selection_6 8000.00 cop[tikv] ge(cast(test.t.a, double BINARY), 123.12), le(cast(test.t.a, double BINARY), 200)", + " └─TableFullScan_5 10000.00 cop[tikv] table:t keep order:false, stats:pseudo")) + // MySQL explain: + //+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+ + //| id | select_type | table | partitions | type | possible_keys | key | key_len | ref | rows | filtered | Extra | + //+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+ + //| 1 | SIMPLE | t | p2,pm | ALL | NULL | NULL | NULL | NULL | 1 | 100.00 | Using where | + //+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+ + + tk.MustExec(`CREATE TABLE t_hash(a int, b int) PARTITION BY HASH(a) PARTITIONS 6`) + tk.MustExec(`insert into t_hash values(1, 1), (10, 10), (26, 26)`) + tk.MustQuery(`select * from t_hash where a = '1'`).Check(testkit.Rows("1 1")) + tk.MustQuery(`explain select * from t_hash where a = '1'`).Check(testkit.Rows( + "TableReader_7 10.00 root partition:p1 data:Selection_6", + "└─Selection_6 10.00 cop[tikv] eq(test.t_hash.a, 1)", + " └─TableFullScan_5 10000.00 cop[tikv] table:t_hash keep order:false, stats:pseudo")) + // MySQL explain: + //+----+-------------+--------+------------+------+---------------+------+---------+------+------+----------+-------------+ + //| id | select_type | table | partitions | type | possible_keys | key | key_len | ref | rows | filtered | Extra | + //+----+-------------+--------+------------+------+---------------+------+---------+------+------+----------+-------------+ + //| 1 | SIMPLE | t_hash | p1 | ALL | NULL | NULL | NULL | NULL | 1 | 100.00 | Using where | + //+----+-------------+--------+------------+------+---------------+------+---------+------+------+----------+-------------+ + + tk.MustExec(`create table t_ts (report_updated timestamp) partition by range(unix_timestamp(report_updated)) ( + partition p1 values less than (1732982400), -- 2024-12-01 00:00:00 + partition p2 values less than (1733068800), -- 2024-12-02 00:00:00 + partition pm values less than (MAXVALUE));`) + tk.MustExec("insert into t_ts values('2024-11-30 00:00:00'), ('2024-12-01 00:00:00'), ('2024-12-02 00:00:00')") + tk.MustQuery("select * from t_ts where report_updated = '2024-12-01 00:00:00'").Check(testkit.Rows("2024-12-01 00:00:00")) + tk.MustQuery("explain select * from t_ts where report_updated = 20241201").Check(testkit.Rows( + "TableReader_7 10.00 root partition:p2 data:Selection_6", + "└─Selection_6 10.00 cop[tikv] eq(test.t_ts.report_updated, 2024-12-01 00:00:00)", + " └─TableFullScan_5 10000.00 cop[tikv] table:t_ts keep order:false, stats:pseudo")) + tk.MustQuery("explain select * from t_ts where report_updated = '2024-12-01 00:00:00'").Check(testkit.Rows( + "TableReader_7 10.00 root partition:p2 data:Selection_6", + "└─Selection_6 10.00 cop[tikv] eq(test.t_ts.report_updated, 2024-12-01 00:00:00.000000)", + " └─TableFullScan_5 10000.00 cop[tikv] table:t_ts keep order:false, stats:pseudo")) + tk.MustQuery("explain select * from t_ts where report_updated > unix_timestamp('2008-05-01 00:00:00')").Check(testkit.Rows( + "TableReader_7 8000.00 root partition:all data:Selection_6", + "└─Selection_6 8000.00 cop[tikv] gt(cast(test.t_ts.report_updated, double BINARY), 1.2095712e+09)", + " └─TableFullScan_5 10000.00 cop[tikv] table:t_ts keep order:false, stats:pseudo")) + //MysQL explain: + //+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+ + //| id | select_type | table | partitions | type | possible_keys | key | key_len | ref | rows | filtered | Extra | + //+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+ + //| 1 | SIMPLE | t_ts | p1,p2,pm | ALL | NULL | NULL | NULL | NULL | 3 | 33.33 | Using where | + //+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+ +} + func TestNotReadOnlySQLOnTiFlash(t *testing.T) { store := testkit.CreateMockStore(t) tk := testkit.NewTestKit(t, store) diff --git a/pkg/planner/core/rule_partition_processor.go b/pkg/planner/core/rule_partition_processor.go index c5e8d8aa9923b..a0fefe149d7f1 100644 --- a/pkg/planner/core/rule_partition_processor.go +++ b/pkg/planner/core/rule_partition_processor.go @@ -1593,10 +1593,23 @@ func (p *rangePruner) extractDataForPrune(sctx base.PlanContext, expr expression if arg1, ok := op.GetArgs()[1].(*expression.Constant); ok { col, con = arg0, arg1 } - } else if arg0, ok := op.GetArgs()[1].(*expression.Column); ok && arg0.ID == p.col.ID { - if arg1, ok := op.GetArgs()[0].(*expression.Constant); ok { + } else if arg1, ok := op.GetArgs()[1].(*expression.Column); ok && arg1.ID == p.col.ID { + if arg0, ok := op.GetArgs()[0].(*expression.Constant); ok { ret.op = opposite(ret.op) - col, con = arg0, arg1 + col, con = arg1, arg0 + } + } else if sarg0, ok := op.GetArgs()[0].(*expression.ScalarFunction); ok && sarg0.FuncName.L == ast.Cast { + if arg0, ok := sarg0.GetArgs()[0].(*expression.Column); ok && arg0.ID == p.col.ID { + if arg1, ok := op.GetArgs()[1].(*expression.Constant); ok { + col, con = arg0, arg1 + } + } + } else if sarg1, ok := op.GetArgs()[1].(*expression.ScalarFunction); ok && sarg1.FuncName.L == ast.Cast { + if arg1, ok := sarg1.GetArgs()[0].(*expression.Column); ok && arg1.ID == p.col.ID { + if arg0, ok := op.GetArgs()[0].(*expression.Constant); ok { + ret.op = opposite(ret.op) + col, con = arg1, arg0 + } } } if col == nil || con == nil { @@ -1606,6 +1619,14 @@ func (p *rangePruner) extractDataForPrune(sctx base.PlanContext, expr expression // Current expression is 'col op const' var constExpr expression.Expression if p.partFn != nil { + // If arg0 or arg1 is ScalarFunction, just skip it. + // Maybe more complicated cases would be considered in the future. + _, ok1 := op.GetArgs()[0].(*expression.ScalarFunction) + _, ok2 := op.GetArgs()[1].(*expression.ScalarFunction) + if ok1 || ok2 { + return ret, false + } + // If the partition function is not monotone, only EQ condition can be pruning. if p.monotonous == monotoneModeInvalid && ret.op != ast.EQ { return ret, false @@ -1626,11 +1647,37 @@ func (p *rangePruner) extractDataForPrune(sctx base.PlanContext, expr expression // If the partition expression is col, use constExpr. constExpr = con } - c, isNull, err := constExpr.EvalInt(sctx.GetExprCtx().GetEvalCtx(), chunk.Row{}) - if err == nil && !isNull { - ret.c = c - ret.unsigned = mysql.HasUnsignedFlag(constExpr.GetType(sctx.GetExprCtx().GetEvalCtx()).GetFlag()) - return ret, true + evalType := constExpr.GetType(sctx.GetExprCtx().GetEvalCtx()).EvalType() + if evalType == types.ETInt { + c, isNull, err := constExpr.EvalInt(sctx.GetExprCtx().GetEvalCtx(), chunk.Row{}) + if err == nil && !isNull { + ret.c = c + ret.unsigned = mysql.HasUnsignedFlag(constExpr.GetType(sctx.GetExprCtx().GetEvalCtx()).GetFlag()) + return ret, true + } + } else if evalType == types.ETReal { + f, isNull, err := constExpr.EvalReal(sctx.GetExprCtx().GetEvalCtx(), chunk.Row{}) + c := int64(f) + if err == nil && !isNull { + ret.c = c + ret.unsigned = mysql.HasUnsignedFlag(constExpr.GetType(sctx.GetExprCtx().GetEvalCtx()).GetFlag()) + return ret, true + } + } else if evalType == types.ETDecimal { + d, isNull, err := constExpr.EvalDecimal(sctx.GetExprCtx().GetEvalCtx(), chunk.Row{}) + if err != nil { + return ret, false + } + f, err := d.ToFloat64() + if err != nil { + return ret, false + } + if err == nil && !isNull { + ret.c = int64(f) + ret.unsigned = mysql.HasUnsignedFlag(constExpr.GetType(sctx.GetExprCtx().GetEvalCtx()).GetFlag()) + return ret, true + } + } else { } return ret, false }