Skip to content

Commit

Permalink
[BugFix] Add exhausted check for MultiJoinBinder (backport #54917) (#…
Browse files Browse the repository at this point in the history
…54934)

Signed-off-by: shuming.li <[email protected]>
Co-authored-by: shuming.li <[email protected]>
  • Loading branch information
mergify[bot] and LiShuMing authored Jan 15, 2025
1 parent 1ac2fef commit 75e5068
Show file tree
Hide file tree
Showing 8 changed files with 269 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3954,6 +3954,10 @@ public String getAnalyzeForMV() {
return analyzeTypeForMV;
}

public void setAnalyzeForMv(String analyzeTypeForMV) {
this.analyzeTypeForMV = analyzeTypeForMV;
}

public boolean isEnableBigQueryLog() {
return enableBigQueryLog;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ public class GroupExpression {
private boolean isUnused = false;

private Optional<Boolean> isAppliedMVRules = Optional.empty();
// all mv rewrite rules
private static final List<Rule> ALL_MV_REWRITE_RULES = RuleSet.getRewriteRulesByType(RuleSetType.ALL_MV_REWRITE);

public GroupExpression(Operator op, List<Group> inputs) {
this.op = op;
Expand Down Expand Up @@ -347,9 +349,8 @@ public String debugString(String headlineIndent, String detailIndent) {
}

public boolean hasAppliedMVRules() {
if (!isAppliedMVRules.isPresent()) {
final List<Rule> mvRules = RuleSet.getRewriteRulesByType(RuleSetType.ALL_MV_REWRITE);
isAppliedMVRules = Optional.of(mvRules.stream().anyMatch(rule -> hasRuleApplied(rule)));
if (isAppliedMVRules.isEmpty()) {
isAppliedMVRules = Optional.of(ALL_MV_REWRITE_RULES.stream().anyMatch(rule -> hasRuleApplied(rule)));
}
return isAppliedMVRules.get();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,12 @@ public long optimizerElapsedMs() {
return optimizerTimer.elapsed(TimeUnit.MILLISECONDS);
}

public Stopwatch getStopwatch(RuleType ruleType) {
return ruleWatchMap.computeIfAbsent(ruleType, (k) -> Stopwatch.createStarted());
}

public boolean ruleExhausted(RuleType ruleType) {
Stopwatch watch = ruleWatchMap.computeIfAbsent(ruleType, (k) -> Stopwatch.createStarted());
Stopwatch watch = getStopwatch(ruleType);
long elapsed = watch.elapsed(TimeUnit.MILLISECONDS);
long timeLimit = Math.min(sessionVariable.getOptimizerMaterializedViewTimeLimitMillis(),
sessionVariable.getOptimizerExecuteTimeout());
Expand Down
116 changes: 80 additions & 36 deletions fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/Binder.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,23 @@

package com.starrocks.sql.optimizer.rule;

import com.google.common.base.Stopwatch;
import com.google.common.collect.Lists;
import com.starrocks.common.profile.Tracers;
import com.starrocks.qe.SessionVariable;
import com.starrocks.sql.optimizer.Group;
import com.starrocks.sql.optimizer.GroupExpression;
import com.starrocks.sql.optimizer.OptExpression;
import com.starrocks.sql.optimizer.OptimizerContext;
import com.starrocks.sql.optimizer.operator.OperatorType;
import com.starrocks.sql.optimizer.operator.pattern.Pattern;

import java.util.List;
import java.util.concurrent.TimeUnit;

// Used to extract matched expression from GroupExpression
public class Binder {

private final OptimizerContext optimizerContext;
private final Pattern pattern;
private final GroupExpression groupExpression;
// binder status
Expand All @@ -36,7 +41,7 @@ public class Binder {
private final List<Integer> groupExpressionIndex;

// `multiJoinBinder` is used for MULTI_JOIN pattern and is stateless so can be used in recursive.
private final MultiJoinBinder multiJoinBinder = new MultiJoinBinder();
private final MultiJoinBinder multiJoinBinder;
// `isPatternWithoutChildren` is to mark whether the input pattern can be used for zero child optimization.
private final boolean isPatternWithoutChildren;
// `nextIdx` marks the current idx which iterates calling `next()` method and it's used for MULTI_JOIN pattern
Expand All @@ -50,11 +55,14 @@ public class Binder {
* @param groupExpression Search this for binding. Because GroupExpression's inputs are groups,
* several Expressions matched the pattern should be bound from it
*/
public Binder(Pattern pattern, GroupExpression groupExpression) {
public Binder(OptimizerContext optimizerContext, Pattern pattern,
GroupExpression groupExpression, Stopwatch stopwatch) {
this.optimizerContext = optimizerContext;
this.pattern = pattern;
this.groupExpression = groupExpression;
this.groupExpressionIndex = Lists.newArrayList(0);

this.multiJoinBinder = new MultiJoinBinder(optimizerContext, stopwatch);
// MULTI_JOIN is a special pattern which can contain children groups if the input group expression
// is not a scan node.
this.isPatternWithoutChildren = pattern.isPatternMultiJoin()
Expand Down Expand Up @@ -89,8 +97,8 @@ public OptExpression next() {
this.groupTraceKey = 0;

// Match with the next groupExpression of the last group node
int lastNode = this.groupExpressionIndex.size() - 1;
int lastNodeIndex = this.groupExpressionIndex.get(lastNode);
final int lastNode = this.groupExpressionIndex.size() - 1;
final int lastNodeIndex = this.groupExpressionIndex.get(lastNode);
this.groupExpressionIndex.set(lastNode, lastNodeIndex + 1);

expression = match(pattern, groupExpression);
Expand Down Expand Up @@ -119,11 +127,14 @@ private OptExpression match(Pattern pattern, GroupExpression groupExpression) {
int patternIndex = 0;
int groupExpressionIndex = 0;

while (patternIndex < pattern.children().size() && groupExpressionIndex < groupExpression.getInputs().size()) {
final int patternSize = pattern.children().size();
final int geSize = groupExpression.getInputs().size();
while (patternIndex < patternSize && groupExpressionIndex < geSize) {
trace();
Group group = groupExpression.getInputs().get(groupExpressionIndex);
Pattern childPattern = pattern.childAt(patternIndex);
OptExpression opt = match(childPattern, extractGroupExpression(childPattern, group));

final Group group = groupExpression.getInputs().get(groupExpressionIndex);
final Pattern childPattern = pattern.childAt(patternIndex);
final OptExpression opt = match(childPattern, extractGroupExpression(childPattern, group));

if (opt == null) {
return null;
Expand All @@ -132,8 +143,7 @@ private OptExpression match(Pattern pattern, GroupExpression groupExpression) {
}

if (!(childPattern.isPatternMultiLeaf() &&
groupExpression.getInputs().size() - groupExpressionIndex >
pattern.children().size() - patternIndex)) {
geSize - groupExpressionIndex > patternSize - patternIndex)) {
patternIndex++;
}

Expand All @@ -154,7 +164,7 @@ private void trace() {
* extract GroupExpression by groupExpressionIndex
*/
private GroupExpression extractGroupExpression(Pattern pattern, Group group) {
int valueIndex = groupExpressionIndex.get(groupTraceKey);
final int valueIndex = groupExpressionIndex.get(groupTraceKey);
if (pattern.isPatternLeaf() || pattern.isPatternMultiLeaf()) {
if (valueIndex > 0) {
groupExpressionIndex.remove(groupTraceKey);
Expand All @@ -180,6 +190,22 @@ private GroupExpression extractGroupExpression(Pattern pattern, Group group) {
* binding state and check the expression at the same time. But MULTI_JOIN could enumerate the GE without any check
*/
private class MultiJoinBinder {
private final SessionVariable sessionVariable;
// Stopwatch to void infinite loop
private final Stopwatch watch;
// Time limit for the entire optimization
private final long timeLimit;
// to avoid stop watch costing too much time, only check exhausted every CHECK_EXHAUSTED_INTERVAL times
private static final int CHECK_EXHAUSTED_INTERVAL = 1000;
private long loopCount = 0;

public MultiJoinBinder(OptimizerContext optimizerContext, Stopwatch stopwatch) {
this.sessionVariable = optimizerContext.getSessionVariable();
this.watch = stopwatch;
this.timeLimit = Math.min(sessionVariable.getOptimizerMaterializedViewTimeLimitMillis(),
sessionVariable.getOptimizerExecuteTimeout());
}

public OptExpression match(GroupExpression ge) {
// 1. Check if the entire tree is MULTI_JOIN
// 2. Enumerate GE
Expand All @@ -190,26 +216,47 @@ public OptExpression match(GroupExpression ge) {
return enumerate(ge);
}

/**
* Check whether the binder is exhausted.
*/
private boolean exhausted() {
if (loopCount++ % CHECK_EXHAUSTED_INTERVAL == 0) {
final long elapsed = watch.elapsed(TimeUnit.MILLISECONDS);
final boolean exhausted = elapsed > timeLimit;
if (exhausted) {
Tracers.log(Tracers.Module.MV, args ->
String.format("[MV TRACE] MultiJoinBinder %s exhausted(loop:%s)\n", this, loopCount));
}
return exhausted;
}
return false;
}

private OptExpression enumerate(GroupExpression ge) {
List<OptExpression> resultInputs = Lists.newArrayList();
final List<OptExpression> resultInputs = Lists.newArrayList();
final int geSize = ge.getInputs().size();

int groupExpressionIndex = 0;
while (groupExpressionIndex < ge.getInputs().size()) {
while (groupExpressionIndex < geSize) {
// to avoid infinite loop
if (exhausted()) {
return null;
}
trace();

Group group = ge.getInputs().get(groupExpressionIndex);
GroupExpression nextGroupExpression = extractGroupExpression(group);
final Group group = ge.getInputs().get(groupExpressionIndex);
final GroupExpression nextGroupExpression = extractGroupExpression(group);
// avoid recursive
if (nextGroupExpression == null || !isMultiJoinOp(nextGroupExpression)) {
return null;
}

OptExpression opt = enumerate(nextGroupExpression);
final OptExpression opt = enumerate(nextGroupExpression);
if (opt == null) {
return null;
} else {
resultInputs.add(opt);
}

groupExpressionIndex++;
}

Expand All @@ -222,14 +269,14 @@ private GroupExpression extractGroupExpression(Group group) {
groupExpressionIndex.remove(groupTraceKey);
return null;
}
List<GroupExpression> groupExpressions = group.getLogicalExpressions();
final List<GroupExpression> groupExpressions = group.getLogicalExpressions();
GroupExpression next = groupExpressions.get(valueIndex);
if (nextIdx == 0) {
return next;
}

// shortcut for no child group expression
if (Pattern.ALL_SCAN_TYPES.contains(next.getOp().getOpType()) && valueIndex > 0) {
if (valueIndex > 0 && Pattern.ALL_SCAN_TYPES.contains(next.getOp().getOpType())) {
groupExpressionIndex.remove(groupTraceKey);
return null;
}
Expand All @@ -241,8 +288,13 @@ private GroupExpression extractGroupExpression(Group group) {
// NOTE: To avoid iterating all children group expressions which may take a lot of time, only iterate
// group expressions which are already rewritten by mv except the first iteration, so can be used for
// nested mv rewritten.
int geSize = groupExpressions.size();
final int geSize = groupExpressions.size();
while (++valueIndex < geSize) {
if (exhausted()) {
groupExpressionIndex.remove(groupTraceKey);
return null;
}

next = group.getLogicalExpressions().get(valueIndex);
if (next.hasAppliedMVRules()) {
groupExpressionIndex.set(groupTraceKey, valueIndex);
Expand Down Expand Up @@ -271,9 +323,13 @@ private boolean isMultiJoinRecursive(GroupExpression ge) {
if (!isMultiJoinOp(ge)) {
return false;
}

for (int i = 0; i < ge.getInputs().size(); i++) {
Group child = ge.inputAt(i);
final int geSize = ge.getInputs().size();
for (int i = 0; i < geSize; i++) {
// to avoid infinite loop
if (exhausted()) {
return false;
}
final Group child = ge.inputAt(i);
if (isMultiJoinRecursive(child.getFirstLogicalExpression())) {
continue;
}
Expand All @@ -295,16 +351,4 @@ private boolean hasRewrittenMvScan(Group g) {
.anyMatch(ge -> ge.getOp().getOpType() == OperatorType.LOGICAL_OLAP_SCAN && ge.hasAppliedMVRules());
}
}

/**
* Extract a expression from GroupExpression which match the given pattern once
*
* @param pattern Bound expression should be matched
* @param groupExpression Search this for binding. Because GroupExpression's inputs are groups,
* several Expressions matched the pattern should be bound from it
*/
public static OptExpression bind(Pattern pattern, GroupExpression groupExpression) {
Binder binder = new Binder(pattern, groupExpression);
return binder.next();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package com.starrocks.sql.optimizer.task;

import com.google.common.base.Stopwatch;
import com.google.common.collect.Lists;
import com.starrocks.common.Pair;
import com.starrocks.common.profile.Timer;
Expand All @@ -22,6 +23,7 @@
import com.starrocks.sql.common.StarRocksPlannerException;
import com.starrocks.sql.optimizer.GroupExpression;
import com.starrocks.sql.optimizer.OptExpression;
import com.starrocks.sql.optimizer.OptimizerContext;
import com.starrocks.sql.optimizer.OptimizerTraceUtil;
import com.starrocks.sql.optimizer.operator.pattern.Pattern;
import com.starrocks.sql.optimizer.rule.Binder;
Expand Down Expand Up @@ -66,11 +68,13 @@ public void execute() {
return;
}
// Apply rule and get all new OptExpressions
Pattern pattern = rule.getPattern();
Binder binder = new Binder(pattern, groupExpression);
final Pattern pattern = rule.getPattern();
final OptimizerContext optimizerContext = context.getOptimizerContext();
final Stopwatch ruleStopWatch = optimizerContext.getStopwatch(rule.type());
final Binder binder = new Binder(optimizerContext, pattern, groupExpression, ruleStopWatch);
final List<OptExpression> newExpressions = Lists.newArrayList();
final List<OptExpression> extractExpressions = Lists.newArrayList();
OptExpression extractExpr = binder.next();
List<OptExpression> newExpressions = Lists.newArrayList();
List<OptExpression> extractExpressions = Lists.newArrayList();
while (extractExpr != null) {
// Check if the rule has exhausted or not to avoid optimization time exceeding the limit.:
// 1. binder.next() may be infinite loop if something is wrong.
Expand Down
Loading

0 comments on commit 75e5068

Please sign in to comment.