diff --git a/src/main/java/com/actiontech/dble/plan/optimizer/JoinStrategyChooser.java b/src/main/java/com/actiontech/dble/plan/optimizer/JoinStrategyChooser.java index 938110368..93a8c8d96 100644 --- a/src/main/java/com/actiontech/dble/plan/optimizer/JoinStrategyChooser.java +++ b/src/main/java/com/actiontech/dble/plan/optimizer/JoinStrategyChooser.java @@ -15,6 +15,7 @@ import com.actiontech.dble.plan.node.TableNode; import com.actiontech.dble.plan.util.PlanUtil; import com.actiontech.dble.util.StringUtil; import com.google.common.base.Strings; +import org.jetbrains.annotations.Nullable; import java.util.*; @@ -149,25 +150,14 @@ public class JoinStrategyChooser { if (isSmallTable((TableNode) node) || canDoAsMerge(joinNode) || !innerJoin) { return; } - List joinFilter = joinNode.getJoinFilter(); - for (ItemFuncEqual itemFuncEqual : joinFilter) { - List arguments = itemFuncEqual.arguments(); - Item item = arguments.stream().filter(argument -> !StringUtil.equals(getTableName((TableNode) node), argument.getTableName())).findFirst().get(); - PlanNode dependNode = nodeMap.get(item.getTableName()); - if (isSmallTable((TableNode) dependNode) && innerJoin) { - joinNode.setStrategy(Strategy.ALWAYS_NEST_LOOP); - node.setNestLoopFilters(new ArrayList<>()); - node.setNestLoopDependNode(dependNode); - List nodeList = Optional.ofNullable(dependNode.getNestLoopDependOnNodeList()).orElse(new ArrayList<>()); - nodeList.add(nodeList.size(), node); - dependNode.setNestLoopDependOnNodeList(nodeList); - return; - } - } joinNode.setStrategy(JoinNode.Strategy.ALWAYS_NEST_LOOP); + PlanNode dependedNode = findDependedNode(node, innerJoin); + if (Objects.isNull(dependedNode)) { + return; + } + handlerDependedNode(dependedNode, node); node.setNestLoopFilters(new ArrayList<>()); - node.setNestLoopDependNode(findDependNode(node)); - + node.setNestLoopDependNode(dependedNode); } private String getTableName(TableNode node) { @@ -178,20 +168,47 @@ public class JoinStrategyChooser { return node.getTableName(); } - private PlanNode findDependNode(PlanNode node) { + private PlanNode findDependedNode(PlanNode node, boolean innerJoin) { JoinNode joinNode = (JoinNode) node.getParent(); - String firstTableName = null; + PlanNode firstNode = null; List joinFilter = joinNode.getJoinFilter(); for (ItemFuncEqual itemFuncEqual : joinFilter) { List arguments = itemFuncEqual.arguments(); - String tableName = arguments.get(0).getTableName(); - firstTableName = Optional.ofNullable(firstTableName).orElse(tableName); + Item item = arguments.stream().filter(argument -> !StringUtil.equals(getTableName((TableNode) node), argument.getTableName())).findFirst().get(); + PlanNode dependedNode = nodeMap.get(item.getTableName()); + if (Objects.isNull(firstNode)) { + firstNode = dependedNode; + } + if (isSmallTable((TableNode) dependedNode) && innerJoin) { + return dependedNode; + } } - PlanNode dependNode = nodeMap.get(firstTableName); - List nodeList = Optional.ofNullable(dependNode.getNestLoopDependOnNodeList()).orElse(new ArrayList<>()); + return firstNode; + } + + @Nullable + private void handlerDependedNode(PlanNode dependedNode, PlanNode node) { + setNestLoopDependOnNodeList(dependedNode, node); + PlanNode parent = dependedNode.getParent(); + setNestLoopDependOnNodeList(dependedNode, node); + boolean isCurrentNode = true; + while (Objects.nonNull(parent)) { + if (!(parent instanceof JoinNode) || !canDoAsMerge((JoinNode) parent)) { + break; + } + dependedNode = parent; + isCurrentNode = false; + parent = parent.getParent(); + } + if (!isCurrentNode) { + setNestLoopDependOnNodeList(dependedNode, node); + } + } + + private void setNestLoopDependOnNodeList(PlanNode dependedNode, PlanNode node) { + List nodeList = Optional.ofNullable(dependedNode.getNestLoopDependOnNodeList()).orElse(new ArrayList<>()); nodeList.add(nodeList.size(), node); - dependNode.setNestLoopDependOnNodeList(nodeList); - return nodeMap.get(firstTableName); + dependedNode.setNestLoopDependOnNodeList(nodeList); } private boolean buildNodeMap(JoinNode joinNode) {