diff --git a/presto-main/src/main/java/io/prestosql/SystemSessionProperties.java b/presto-main/src/main/java/io/prestosql/SystemSessionProperties.java index 41d990d2f80dbb95ff7098eb7c9aee3fec60f98f..81bd75ec5812a47c05e94a7b2b305a63b257b378 100644 --- a/presto-main/src/main/java/io/prestosql/SystemSessionProperties.java +++ b/presto-main/src/main/java/io/prestosql/SystemSessionProperties.java @@ -121,6 +121,8 @@ public final class SystemSessionProperties public static final String ENABLE_INTERMEDIATE_AGGREGATIONS = "enable_intermediate_aggregations"; public static final String PUSH_AGGREGATION_THROUGH_JOIN = "push_aggregation_through_join"; public static final String PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN = "push_partial_aggregation_through_join"; + public static final String PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN_SELECTIVITY = "push_partial_aggregation_through_join_selectivity"; + public static final String PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN_OUT_TO_IN_RATIO = "push_partial_aggregation_through_join_out_to_in_ratio"; public static final String PARSE_DECIMAL_LITERALS_AS_DOUBLE = "parse_decimal_literals_as_double"; public static final String FORCE_SINGLE_NODE_OUTPUT = "force_single_node_output"; public static final String FILTER_AND_PROJECT_MIN_OUTPUT_PAGE_SIZE = "filter_and_project_min_output_page_size"; @@ -627,6 +629,16 @@ public final class SystemSessionProperties "Push partial aggregations below joins", false, false), + doubleProperty( + PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN_SELECTIVITY, + "Push partial aggregations below joins when join selectivity is more than this config value", + 0.8D, + false), + doubleProperty( + PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN_OUT_TO_IN_RATIO, + "Push partial aggregations below joins when out to in records ratio for pushed aggregated node is less than equal to this config value", + 0.7D, + false), booleanProperty( PARSE_DECIMAL_LITERALS_AS_DOUBLE, "Parse decimal literals as DOUBLE instead of DECIMAL", @@ -1452,6 +1464,16 @@ public final class SystemSessionProperties return session.getSystemProperty(PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN, Boolean.class); } + public static double getPushAggregationThroughJoinSelectivity(Session session) + { + return session.getSystemProperty(PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN_SELECTIVITY, Double.class); + } + + public static double getPushAggregationThroughJoinOutToInRatio(Session session) + { + return session.getSystemProperty(PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN_OUT_TO_IN_RATIO, Double.class); + } + public static boolean isParseDecimalLiteralsAsDouble(Session session) { return session.getSystemProperty(PARSE_DECIMAL_LITERALS_AS_DOUBLE, Boolean.class); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java index 9df53f30c0b5d13ced1b2a7a56334bcb84df57cc..13ece1f9bd440bdd6b66c70dca2d103c116b9e20 100755 --- a/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java @@ -121,6 +121,7 @@ import io.prestosql.sql.planner.iterative.rule.PushLimitThroughProject; import io.prestosql.sql.planner.iterative.rule.PushLimitThroughSemiJoin; import io.prestosql.sql.planner.iterative.rule.PushLimitThroughUnion; import io.prestosql.sql.planner.iterative.rule.PushOffsetThroughProject; +import io.prestosql.sql.planner.iterative.rule.PushPartialAggregationProjectionsThroughJoin; import io.prestosql.sql.planner.iterative.rule.PushPartialAggregationThroughExchange; import io.prestosql.sql.planner.iterative.rule.PushPartialAggregationThroughJoin; import io.prestosql.sql.planner.iterative.rule.PushPredicateIntoTableScan; @@ -761,6 +762,7 @@ public class PlanOptimizers costCalculator, ImmutableSet.of( new PushPartialAggregationThroughJoin(), + new PushPartialAggregationProjectionsThroughJoin(), new PushPartialAggregationThroughExchange(metadata), new PruneJoinColumns()))); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushPartialAggregationProjectionsThroughJoin.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushPartialAggregationProjectionsThroughJoin.java new file mode 100644 index 0000000000000000000000000000000000000000..d1091a2be51398480d064b97668db476a3d5f99a --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushPartialAggregationProjectionsThroughJoin.java @@ -0,0 +1,183 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import io.prestosql.matching.Capture; +import io.prestosql.matching.Captures; +import io.prestosql.matching.Pattern; +import io.prestosql.spi.plan.AggregationNode; +import io.prestosql.spi.plan.Assignments; +import io.prestosql.spi.plan.JoinNode; +import io.prestosql.spi.plan.PlanNode; +import io.prestosql.spi.plan.ProjectNode; +import io.prestosql.spi.plan.Symbol; +import io.prestosql.spi.relation.RowExpression; +import io.prestosql.spi.relation.VariableReferenceExpression; +import io.prestosql.sql.planner.SymbolsExtractor; +import io.prestosql.sql.planner.TypeProvider; +import io.prestosql.sql.planner.plan.AssignmentUtils; + +import java.util.HashSet; +import java.util.List; +import java.util.Map; + +import static io.prestosql.sql.planner.plan.Patterns.aggregation; +import static io.prestosql.sql.planner.plan.Patterns.join; +import static io.prestosql.sql.planner.plan.Patterns.project; +import static io.prestosql.sql.planner.plan.Patterns.source; + +public class PushPartialAggregationProjectionsThroughJoin + extends PushPartialAggregationThroughJoin +{ + private static final Capture JOIN_NODE = Capture.newCapture(); + private static final Capture PROJECT_NODE = Capture.newCapture(); + + private static final Pattern PATTERN = aggregation() + .matching(PushPartialAggregationProjectionsThroughJoin::isSupportedAggregationNode) + .with(source().matching(project().capturedAs(PROJECT_NODE) + .with(source().matching(join().capturedAs(JOIN_NODE))))); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(AggregationNode aggregationNode, Captures captures, Context context) + { + ProjectNode projectNode = captures.get(PROJECT_NODE); + JoinNode joinNode = captures.get(JOIN_NODE); + + if (joinNode.getType() != JoinNode.Type.INNER) { + return Result.empty(); + } + + // Check if Project can be pushed down through join + // Check if aggregations can be pushed down through join + // First push Project through Join, then apply the rule + Assignments assignments = projectNode.getAssignments(); + Assignments.Builder leftAssignments = Assignments.builder(); + Assignments.Builder rightAssignments = Assignments.builder(); + HashSet leftSymbolSet = new HashSet<>(joinNode.getLeft().getOutputSymbols()); + HashSet rightSymbolSet = new HashSet<>(joinNode.getRight().getOutputSymbols()); + for (Map.Entry assignment : assignments.entrySet()) { + List symbols = SymbolsExtractor.extractAll(assignment.getValue()); + if (symbols.size() == 0) { + return Result.empty(); + } + if (leftSymbolSet.containsAll(symbols)) { + leftAssignments.put(assignment.getKey(), assignment.getValue()); + } + else if (rightSymbolSet.containsAll(symbols)) { + rightAssignments.put(assignment.getKey(), assignment.getValue()); + } + else { + return Result.empty(); + } + } + TypeProvider typeProvider = context.getSymbolAllocator().getTypes(); + for (Map.Entry df : joinNode.getDynamicFilters().entrySet()) { + if (leftSymbolSet.contains(df.getValue())) { + leftAssignments.put(df.getValue(), new VariableReferenceExpression(df.getValue().getName(), typeProvider.get(df.getValue()))); + } + else if (rightSymbolSet.contains(df.getValue())) { + rightAssignments.put(df.getValue(), new VariableReferenceExpression(df.getValue().getName(), typeProvider.get(df.getValue()))); + } + } + if (joinNode.getFilter().isPresent()) { + List symbolsList = SymbolsExtractor.extractAll(joinNode.getFilter().get()); + for (Symbol symbol : symbolsList) { + if (leftSymbolSet.contains(symbol)) { + leftAssignments.putAll(AssignmentUtils.identityAssignments(typeProvider, symbol)); + } + else if (rightSymbolSet.contains(symbol)) { + rightAssignments.putAll(AssignmentUtils.identityAssignments(typeProvider, symbol)); + } + } + } + for (JoinNode.EquiJoinClause clause : joinNode.getCriteria()) { + if (leftSymbolSet.contains(clause.getLeft())) { + Assignments assignments1 = AssignmentUtils.identityAssignments(typeProvider, clause.getLeft()); + leftAssignments.putAll(assignments1); + + Assignments assignments2 = AssignmentUtils.identityAssignments(typeProvider, clause.getRight()); + rightAssignments.putAll(assignments2); + } + else if (rightSymbolSet.contains(clause.getRight())) { + Assignments assignments1 = AssignmentUtils.identityAssignments(typeProvider, clause.getRight()); + leftAssignments.putAll(assignments1); + + Assignments assignments2 = AssignmentUtils.identityAssignments(typeProvider, clause.getLeft()); + rightAssignments.putAll(assignments2); + } + } + + if (joinNode.getLeftHashSymbol().isPresent()) { + if (leftSymbolSet.contains(joinNode.getLeftHashSymbol().get())) { + Assignments assignments1 = AssignmentUtils.identityAssignments(typeProvider, joinNode.getLeftHashSymbol().get()); + leftAssignments.putAll(assignments1); + } + } + if (joinNode.getRightHashSymbol().isPresent()) { + if (rightSymbolSet.contains(joinNode.getRightHashSymbol().get())) { + Assignments assignments1 = AssignmentUtils.identityAssignments(typeProvider, joinNode.getRightHashSymbol().get()); + rightAssignments.putAll(assignments1); + } + } + + PlanNode leftNode = joinNode.getLeft(); + Assignments build = leftAssignments.build(); + if (build.size() > 0) { + leftNode = new ProjectNode(context.getIdAllocator().getNextId(), joinNode.getLeft(), build); + } + + PlanNode rightNode = joinNode.getRight(); + build = rightAssignments.build(); + if (build.size() > 0) { + rightNode = new ProjectNode(context.getIdAllocator().getNextId(), joinNode.getRight(), build); + } + JoinNode newJoinNode = new JoinNode(joinNode.getId(), + joinNode.getType(), + leftNode, + rightNode, + joinNode.getCriteria(), + projectNode.getOutputSymbols(), + joinNode.getFilter(), + joinNode.getLeftHashSymbol(), + joinNode.getRightHashSymbol(), + joinNode.getDistributionType(), + joinNode.isSpillable(), + joinNode.getDynamicFilters()); + AggregationNode newAggrNode = new AggregationNode(aggregationNode.getId(), + newJoinNode, + aggregationNode.getAggregations(), + aggregationNode.getGroupingSets(), + aggregationNode.getPreGroupedSymbols(), + aggregationNode.getStep(), + aggregationNode.getHashSymbol(), + aggregationNode.getGroupIdSymbol(), + aggregationNode.getAggregationType(), + aggregationNode.getFinalizeSymbol()); + + if (allAggregationsOn(newAggrNode.getAggregations(), newJoinNode.getLeft().getOutputSymbols())) { + return pushPartialToLeftChild(newAggrNode, newJoinNode, context); + } + if (allAggregationsOn(newAggrNode.getAggregations(), newJoinNode.getRight().getOutputSymbols())) { + return pushPartialToRightChild(newAggrNode, newJoinNode, context); + } + + return Result.empty(); + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushPartialAggregationThroughJoin.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushPartialAggregationThroughJoin.java index 74d630c15ed2013fbf1853ef2e4ff98ff3678731..114a32f10fb3c389800c4c2561437172e763484c 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushPartialAggregationThroughJoin.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushPartialAggregationThroughJoin.java @@ -17,6 +17,9 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Streams; import io.prestosql.Session; +import io.prestosql.SystemSessionProperties; +import io.prestosql.cost.AggregationStatsRule; +import io.prestosql.cost.PlanNodeStatsEstimate; import io.prestosql.matching.Capture; import io.prestosql.matching.Captures; import io.prestosql.matching.Pattern; @@ -53,7 +56,7 @@ public class PushPartialAggregationThroughJoin .matching(PushPartialAggregationThroughJoin::isSupportedAggregationNode) .with(source().matching(join().capturedAs(JOIN_NODE))); - private static boolean isSupportedAggregationNode(AggregationNode aggregationNode) + protected static boolean isSupportedAggregationNode(AggregationNode aggregationNode) { // Don't split streaming aggregations if (aggregationNode.isStreamable()) { @@ -90,16 +93,16 @@ public class PushPartialAggregationThroughJoin // TODO: leave partial aggregation above Join? if (allAggregationsOn(aggregationNode.getAggregations(), joinNode.getLeft().getOutputSymbols())) { - return Result.ofPlanNode(pushPartialToLeftChild(aggregationNode, joinNode, context)); + return pushPartialToLeftChild(aggregationNode, joinNode, context); } if (allAggregationsOn(aggregationNode.getAggregations(), joinNode.getRight().getOutputSymbols())) { - return Result.ofPlanNode(pushPartialToRightChild(aggregationNode, joinNode, context)); + return pushPartialToRightChild(aggregationNode, joinNode, context); } return Result.empty(); } - private static boolean allAggregationsOn(Map aggregations, List symbols) + protected static boolean allAggregationsOn(Map aggregations, List symbols) { Set inputs = aggregations.values().stream() .map(SymbolsExtractor::extractAll) @@ -108,20 +111,75 @@ public class PushPartialAggregationThroughJoin return symbols.containsAll(inputs); } - private PlanNode pushPartialToLeftChild(AggregationNode node, JoinNode child, Context context) + private double getOutToInApplicabilityRatio(Context context) + { + return SystemSessionProperties.getPushAggregationThroughJoinOutToInRatio(context.getSession()); + } + + private double getJoinSelectivityRatio(Context context) + { + return SystemSessionProperties.getPushAggregationThroughJoinSelectivity(context.getSession()); + } + + private boolean isAggrNodeNotReduceOutputRows(Context context, AggregationNode pushedAggregation) + { + PlanNodeStatsEstimate sourceStats = context.getStatsProvider().getStats(pushedAggregation.getSource()); + PlanNodeStatsEstimate aggrNodeStats = AggregationStatsRule.groupBy(sourceStats, pushedAggregation.getGroupingKeys(), pushedAggregation.getAggregations()); + if (sourceStats.isOutputRowCountUnknown() || aggrNodeStats.isOutputRowCountUnknown()) { + return true; + } + return aggrNodeStats.getOutputRowCount() / sourceStats.getOutputRowCount() > getOutToInApplicabilityRatio(context); + } + + private boolean hasHighSelectivityForJoin(JoinNode joinNode, Context context, boolean isAggrPushedToLeft) + { + PlanNodeStatsEstimate childStats; + if (isAggrPushedToLeft) { + childStats = context.getStatsProvider().getStats(joinNode.getLeft()); + } + else { + childStats = context.getStatsProvider().getStats(joinNode.getRight()); + } + PlanNodeStatsEstimate joinStats = context.getStatsProvider().getStats(joinNode); + if (joinStats.isOutputRowCountUnknown() || childStats.isOutputRowCountUnknown()) { + return false; + } + + return joinStats.getOutputRowCount() / childStats.getOutputRowCount() >= getJoinSelectivityRatio(context); + } + + protected Result pushPartialToLeftChild(AggregationNode node, JoinNode child, Context context) { Set joinLeftChildSymbols = ImmutableSet.copyOf(child.getLeft().getOutputSymbols()); List groupingSet = getPushedDownGroupingSet(node, joinLeftChildSymbols, intersection(getJoinRequiredSymbols(child), joinLeftChildSymbols)); AggregationNode pushedAggregation = replaceAggregationSource(node, child.getLeft(), groupingSet); - return pushPartialToJoin(node, child, pushedAggregation, child.getRight(), context); + + // Apply only if can reduce the record count + if (isAggrNodeNotReduceOutputRows(context, pushedAggregation)) { + return Result.empty(); + } + + if (hasHighSelectivityForJoin(child, context, true)) { + return Result.ofPlanNode(pushPartialToJoin(node, child, pushedAggregation, child.getRight(), context)); + } + return Result.empty(); } - private PlanNode pushPartialToRightChild(AggregationNode node, JoinNode child, Context context) + protected Result pushPartialToRightChild(AggregationNode node, JoinNode child, Context context) { Set joinRightChildSymbols = ImmutableSet.copyOf(child.getRight().getOutputSymbols()); List groupingSet = getPushedDownGroupingSet(node, joinRightChildSymbols, intersection(getJoinRequiredSymbols(child), joinRightChildSymbols)); AggregationNode pushedAggregation = replaceAggregationSource(node, child.getRight(), groupingSet); - return pushPartialToJoin(node, child, child.getLeft(), pushedAggregation, context); + + // Apply only if can reduce the record count + if (isAggrNodeNotReduceOutputRows(context, pushedAggregation)) { + return Result.empty(); + } + + if (hasHighSelectivityForJoin(child, context, false)) { + return Result.ofPlanNode(pushPartialToJoin(node, child, child.getLeft(), pushedAggregation, context)); + } + return Result.empty(); } private Set getJoinRequiredSymbols(JoinNode node) diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushPartialAggregationProjectionsThroughJoin.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushPartialAggregationProjectionsThroughJoin.java new file mode 100644 index 0000000000000000000000000000000000000000..81d360942689f3340e59c6aa3eabd388b2935975 --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushPartialAggregationProjectionsThroughJoin.java @@ -0,0 +1,148 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.prestosql.spi.plan.Assignments; +import io.prestosql.spi.plan.JoinNode.EquiJoinClause; +import io.prestosql.sql.planner.assertions.PlanMatchPattern; +import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static io.prestosql.SystemSessionProperties.PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN; +import static io.prestosql.spi.plan.AggregationNode.Step.PARTIAL; +import static io.prestosql.spi.plan.JoinNode.Type.INNER; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.aggregation; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.equiJoinClause; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.functionCall; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.join; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.project; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.singleGroupingSet; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.values; +import static io.prestosql.sql.planner.iterative.rule.test.PlanBuilder.expression; + +public class TestPushPartialAggregationProjectionsThroughJoin + extends BaseRuleTest +{ + @Test + public void testPushesPartialAggregationProjectionsThroughJoin() + { + tester().assertThat(new PushPartialAggregationProjectionsThroughJoin()) + .setSystemProperty(PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN, "true") + .on(p -> { + p.symbol("LEFT_AGGR", BIGINT); + return p.aggregation(ab -> ab + .source(p.project(Assignments.builder() + .put(p.symbol("EXPR"), p.rowExpression("LEFT_AGGR * 10")) + .put(p.symbol("LEFT_GROUP_BY"), p.rowExpression("LEFT_GROUP_BY")) + .put(p.symbol("RIGHT_GROUP_BY"), p.rowExpression("RIGHT_GROUP_BY")) + .build(), + p.join( + INNER, + p.valuesLong(10, p.symbol("LEFT_EQUI"), p.symbol("LEFT_NON_EQUI"), p.symbol("LEFT_GROUP_BY"), p.symbol("LEFT_AGGR"), p.symbol("LEFT_HASH")), + p.valuesLong(100, p.symbol("RIGHT_EQUI"), p.symbol("RIGHT_NON_EQUI"), p.symbol("RIGHT_GROUP_BY"), p.symbol("RIGHT_HASH")), + ImmutableList.of(new EquiJoinClause(p.symbol("LEFT_EQUI"), p.symbol("RIGHT_EQUI"))), + ImmutableList.of(p.symbol("LEFT_GROUP_BY"), p.symbol("LEFT_AGGR"), p.symbol("RIGHT_GROUP_BY")), + Optional.of(p.rowExpression("LEFT_NON_EQUI = RIGHT_NON_EQUI")), + Optional.of(p.symbol("LEFT_HASH")), + Optional.of(p.symbol("RIGHT_HASH"))))) + .addAggregation(p.symbol("AVG", DOUBLE), expression("AVG(EXPR)"), ImmutableList.of(DOUBLE)) + .singleGroupingSet(p.symbol("LEFT_GROUP_BY"), p.symbol("RIGHT_GROUP_BY")) + .step(PARTIAL)); + }) + .matches(project(ImmutableMap.of( + "LEFT_GROUP_BY", PlanMatchPattern.expression("LEFT_GROUP_BY"), + "RIGHT_GROUP_BY", PlanMatchPattern.expression("RIGHT_GROUP_BY"), + "AVG", PlanMatchPattern.expression("AVG")), + join(INNER, ImmutableList.of(equiJoinClause("LEFT_EQUI", "RIGHT_EQUI")), + Optional.of("LEFT_NON_EQUI = RIGHT_NON_EQUI"), + aggregation( + singleGroupingSet("LEFT_EQUI", "LEFT_NON_EQUI", "LEFT_GROUP_BY", "LEFT_HASH"), + ImmutableMap.of(Optional.of("AVG"), functionCall("avg", ImmutableList.of("EXPR"))), + ImmutableMap.of(), + Optional.empty(), + PARTIAL, + project(ImmutableMap.of( + "LEFT_EQUI", PlanMatchPattern.expression("LEFT_EQUI"), + "LEFT_NON_EQUI", PlanMatchPattern.expression("LEFT_NON_EQUI"), + "LEFT_GROUP_BY", PlanMatchPattern.expression("LEFT_GROUP_BY"), + "EXPR", PlanMatchPattern.expression("LEFT_AGGR * 10"), + "LEFT_HASH", PlanMatchPattern.expression("LEFT_HASH")), + values("LEFT_EQUI", "LEFT_NON_EQUI", "LEFT_GROUP_BY", "LEFT_AGGR", "LEFT_HASH"))), + project(ImmutableMap.of( + "RIGHT_GROUP_BY", PlanMatchPattern.expression("RIGHT_GROUP_BY"), + "RIGHT_NON_EQUI", PlanMatchPattern.expression("RIGHT_NON_EQUI"), + "RIGHT_EQUI", PlanMatchPattern.expression("RIGHT_EQUI"), + "RIGHT_HASH", PlanMatchPattern.expression("RIGHT_HASH")), + values("RIGHT_EQUI", "RIGHT_NON_EQUI", "RIGHT_GROUP_BY", "RIGHT_HASH"))))); + } + + @Test + public void testPushesPartialAggregationProjectionsThroughJoinRight() + { + tester().assertThat(new PushPartialAggregationProjectionsThroughJoin()) + .setSystemProperty(PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN, "true") + .on(p -> { + p.symbol("RIGHT_AGGR", BIGINT); + return p.aggregation(ab -> ab + .source(p.project(Assignments.builder() + .put(p.symbol("EXPR"), p.rowExpression("RIGHT_AGGR * 10")) + .put(p.symbol("LEFT_GROUP_BY"), p.rowExpression("LEFT_GROUP_BY")) + .put(p.symbol("RIGHT_GROUP_BY"), p.rowExpression("RIGHT_GROUP_BY")) + .build(), + p.join( + INNER, + p.valuesLong(10, p.symbol("LEFT_EQUI"), p.symbol("LEFT_NON_EQUI"), p.symbol("LEFT_GROUP_BY"), p.symbol("LEFT_HASH")), + p.valuesLong(100, p.symbol("RIGHT_EQUI"), p.symbol("RIGHT_NON_EQUI"), p.symbol("RIGHT_GROUP_BY"), p.symbol("RIGHT_AGGR"), p.symbol("RIGHT_HASH")), + ImmutableList.of(new EquiJoinClause(p.symbol("LEFT_EQUI"), p.symbol("RIGHT_EQUI"))), + ImmutableList.of(p.symbol("LEFT_GROUP_BY"), p.symbol("RIGHT_AGGR"), p.symbol("RIGHT_GROUP_BY")), + Optional.of(p.rowExpression("LEFT_NON_EQUI = RIGHT_NON_EQUI")), + Optional.of(p.symbol("LEFT_HASH")), + Optional.of(p.symbol("RIGHT_HASH"))))) + .addAggregation(p.symbol("AVG", DOUBLE), expression("AVG(EXPR)"), ImmutableList.of(DOUBLE)) + .singleGroupingSet(p.symbol("LEFT_GROUP_BY"), p.symbol("RIGHT_GROUP_BY")) + .step(PARTIAL)); + }) + .matches(project(ImmutableMap.of( + "LEFT_GROUP_BY", PlanMatchPattern.expression("LEFT_GROUP_BY"), + "RIGHT_GROUP_BY", PlanMatchPattern.expression("RIGHT_GROUP_BY"), + "AVG", PlanMatchPattern.expression("AVG")), + join(INNER, ImmutableList.of(equiJoinClause("LEFT_EQUI", "RIGHT_EQUI")), + Optional.of("LEFT_NON_EQUI = RIGHT_NON_EQUI"), + project(ImmutableMap.of( + "LEFT_GROUP_BY", PlanMatchPattern.expression("LEFT_GROUP_BY"), + "LEFT_NON_EQUI", PlanMatchPattern.expression("LEFT_NON_EQUI"), + "LEFT_EQUI", PlanMatchPattern.expression("LEFT_EQUI"), + "LEFT_HASH", PlanMatchPattern.expression("LEFT_HASH")), + values("LEFT_EQUI", "LEFT_NON_EQUI", "LEFT_GROUP_BY", "LEFT_HASH")), + aggregation( + singleGroupingSet("RIGHT_EQUI", "RIGHT_NON_EQUI", "RIGHT_GROUP_BY", "RIGHT_HASH"), + ImmutableMap.of(Optional.of("AVG"), functionCall("avg", ImmutableList.of("EXPR"))), + ImmutableMap.of(), + Optional.empty(), + PARTIAL, + project(ImmutableMap.of( + "RIGHT_EQUI", PlanMatchPattern.expression("RIGHT_EQUI"), + "RIGHT_NON_EQUI", PlanMatchPattern.expression("RIGHT_NON_EQUI"), + "RIGHT_GROUP_BY", PlanMatchPattern.expression("RIGHT_GROUP_BY"), + "EXPR", PlanMatchPattern.expression("RIGHT_AGGR * 10"), + "RIGHT_HASH", PlanMatchPattern.expression("RIGHT_HASH")), + values("RIGHT_EQUI", "RIGHT_NON_EQUI", "RIGHT_GROUP_BY", "RIGHT_AGGR", "RIGHT_HASH")))))); + } +} diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushPartialAggregationThroughJoin.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushPartialAggregationThroughJoin.java index 0f66476309c18da9155f7c94ddf125cfac3ba19a..39d31b59b523045b770f140b2ac18f26b501cb26 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushPartialAggregationThroughJoin.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushPartialAggregationThroughJoin.java @@ -48,8 +48,8 @@ public class TestPushPartialAggregationThroughJoin .source( p.join( INNER, - p.values(p.symbol("LEFT_EQUI"), p.symbol("LEFT_NON_EQUI"), p.symbol("LEFT_GROUP_BY"), p.symbol("LEFT_AGGR"), p.symbol("LEFT_HASH")), - p.values(p.symbol("RIGHT_EQUI"), p.symbol("RIGHT_NON_EQUI"), p.symbol("RIGHT_GROUP_BY"), p.symbol("RIGHT_HASH")), + p.valuesLong(10, p.symbol("LEFT_EQUI"), p.symbol("LEFT_NON_EQUI"), p.symbol("LEFT_GROUP_BY"), p.symbol("LEFT_AGGR"), p.symbol("LEFT_HASH")), + p.valuesLong(100, p.symbol("RIGHT_EQUI"), p.symbol("RIGHT_NON_EQUI"), p.symbol("RIGHT_GROUP_BY"), p.symbol("RIGHT_HASH")), ImmutableList.of(new EquiJoinClause(p.symbol("LEFT_EQUI"), p.symbol("RIGHT_EQUI"))), ImmutableList.of(p.symbol("LEFT_GROUP_BY"), p.symbol("LEFT_AGGR"), p.symbol("RIGHT_GROUP_BY")), Optional.of(castToRowExpression("LEFT_NON_EQUI <= RIGHT_NON_EQUI")), diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/PlanBuilder.java index 145eb196992cac46c80673280053de4a4635e06e..ef3df74526798393e7894df2df6748d429a8885f 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/PlanBuilder.java @@ -96,6 +96,7 @@ import io.prestosql.sql.relational.OriginalExpressionUtils; import io.prestosql.sql.relational.SqlToRowExpressionTranslator; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.FunctionCall; +import io.prestosql.sql.tree.LongLiteral; import io.prestosql.sql.tree.NodeRef; import io.prestosql.sql.tree.NullLiteral; import io.prestosql.testing.TestingHandle; @@ -235,6 +236,14 @@ public class PlanBuilder nElements(rows, row -> nElements(columns.length, cell -> OriginalExpressionUtils.castToRowExpression(new NullLiteral())))); } + public ValuesNode valuesLong(int rows, Symbol... columns) + { + return values( + idAllocator.getNextId(), + ImmutableList.copyOf(columns), + nElements(rows, row -> nElements(columns.length, cell -> OriginalExpressionUtils.castToRowExpression(new LongLiteral(String.valueOf(cell)))))); + } + public ValuesNode values(List columns, List> rows) { return values(idAllocator.getNextId(), columns, rows);