diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala index 3263976927b1e3c1a155577cc2e7a39fd12c8da8..89771cf9740a0a93febbd1fc1438b75472d5e00f 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala @@ -25,24 +25,24 @@ import com.huawei.boostkit.spark.util.{ModifyUtilAdaptor, PhysicalPlanSelector} import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, SparkPlugin} import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} -import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, Literal, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, IntegerLiteral, LessThan, LessThanOrEqual, Literal, NamedExpression, Rank, RowNumber, SortOrder, WindowExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Partial, PartialMerge} -import org.apache.spark.sql.catalyst.optimizer.{DelayCartesianProduct, HeuristicJoinReorder, RewriteSelfJoinInInPredicate} +import org.apache.spark.sql.catalyst.optimizer.{DelayCartesianProduct, RewriteSelfJoinInInPredicate} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, AdaptiveSparkPlanExec, BroadcastQueryStageExec, OmniAQEShuffleReadExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.{DummyLogicalPlan, ExtendedAggUtils, HashAggregateExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.execution.window.{TopNPushDownForWindow, WindowExec} +import org.apache.spark.sql.execution.window.{WindowExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.catalyst.planning.PhysicalAggregation import org.apache.spark.sql.catalyst.plans.LeftSemi import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.execution.util.SparkMemoryUtils.addLeakSafeTaskCompletionListener import org.apache.spark.sql.execution.aggregate.PushOrderedLimitThroughAgg -import nova.hetu.omniruntime.memory.MemoryManager import org.apache.spark.sql.util.ShimUtil +import nova.hetu.omniruntime.memory.MemoryManager import scala.collection.mutable.ListBuffer @@ -59,12 +59,34 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) val enableRowShuffle: Boolean = columnarConf.enableRowShuffle val columnsThreshold: Int = columnarConf.columnsThreshold val enableColumnarDataWritingCommand: Boolean = columnarConf.enableColumnarDataWritingCommand + val enableColumnarTopNSort: Boolean = columnarConf.enableColumnarTopNSort + val topNSortThreshold: Int = columnarConf.topNSortThreshold - def checkBhjRightChild(x: Any): Boolean = { - x match { - case _: ColumnarFilterExec | _: ColumnarConditionProjectExec => true + private def checkBhjRightChild(plan: Any): Boolean = + plan match { + case _: ColumnarFilterExec => true + case _: ColumnarConditionProjectExec => true case _ => false } + + def isTopNExpression(expr: Expression): Boolean = expr match { + case Alias(child, _) => isTopNExpression(child) + case WindowExpression(_: Rank, _) => true + case _ => false + } + + def isStrictTopN(expr: Expression): Boolean = expr match { + case Alias(child, _) => isStrictTopN(child) + case WindowExpression(_: RowNumber, _) => true + case _ => false + } + + private def splitConjunctivePredicates(condition: Expression): Seq[Expression] = { + condition match { + case And(cond1, cond2) => + splitConjunctivePredicates(cond1) ++ splitConjunctivePredicates(cond2) + case other => other :: Nil + } } def replaceWithColumnarPlan(plan: SparkPlan): SparkPlan = { @@ -158,6 +180,65 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) ColumnarProjectExec(plan.projectList, child) } case plan: FilterExec => + if(enableColumnarTopNSort) { + plan.transform { + case f@FilterExec(condition, + w@WindowExec(Seq(windowExpression), _, orderSpec, sort: SortExec)) + if orderSpec.nonEmpty && isTopNExpression(windowExpression) => + var topn = Int.MaxValue + val nonTopNConditions = splitConjunctivePredicates(condition).filter { + case LessThan(e: NamedExpression, IntegerLiteral(n)) + if e.exprId == windowExpression.exprId => + topn = Math.min(topn, n - 1) + false + case GreaterThan(IntegerLiteral(n), e: NamedExpression) + if e.exprId == windowExpression.exprId => + topn = Math.min(topn, n - 1) + false + case LessThanOrEqual(e: NamedExpression, IntegerLiteral(n)) + if e.exprId == windowExpression.exprId => + topn = Math.min(topn, n) + false + case EqualTo(e: NamedExpression, IntegerLiteral(n)) + if n == 1 && e.exprId == windowExpression.exprId => + topn = 1 + false + case EqualTo(IntegerLiteral(n), e: NamedExpression) + if n == 1 && e.exprId == windowExpression.exprId => + topn = 1 + false + case GreaterThanOrEqual(IntegerLiteral(n), e: NamedExpression) + if e.exprId == windowExpression.exprId => + topn = Math.min(topn, n) + false + case _ => true + } + // topn <= SQLConf.get.topNPushDownForWindowThreshold 100. + val strictTopN = isStrictTopN(windowExpression) + val omniSupport: Boolean = try { + ColumnarTopNSortExec(topn, strictTopN, w.partitionSpec, w.orderSpec, sort.global, sort.child).buildCheck() + true + } catch { + case _: Throwable => false + } + if (topn > 0 && topn <= topNSortThreshold && omniSupport) { + val topNSortExec = ColumnarTopNSortExec( + topn, strictTopN, w.partitionSpec, w.orderSpec, sort.global, replaceWithColumnarPlan(sort.child)) + logInfo(s"Columnar Processing for ${topNSortExec.getClass} is currently supported.") + val newCondition = if (nonTopNConditions.isEmpty) { + Literal.TrueLiteral + } else { + nonTopNConditions.reduce(And) + } + val window = ColumnarWindowExec(w.windowExpression, w.partitionSpec, w.orderSpec, topNSortExec) + return ColumnarFilterExec(newCondition, window) + } else { + logInfo{s"topn: ${topn} is bigger than topNSortThreshold: ${topNSortThreshold}."} + val child = replaceWithColumnarPlan(f.child) + return ColumnarFilterExec(f.condition, child) + } + } + } val child = replaceWithColumnarPlan(plan.child) logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") ColumnarFilterExec(plan.condition, child) @@ -788,9 +869,7 @@ class ColumnarPlugin extends (SparkSessionExtensions => Unit) with Logging { extensions.injectPlannerStrategy(_ => ShuffleJoinStrategy) extensions.injectOptimizerRule(_ => RewriteSelfJoinInInPredicate) extensions.injectOptimizerRule(_ => DelayCartesianProduct) - extensions.injectOptimizerRule(_ => HeuristicJoinReorder) extensions.injectQueryStagePrepRule(session => DedupLeftSemiJoinAQE(session)) - extensions.injectQueryStagePrepRule(_ => TopNPushDownForWindow) extensions.injectQueryStagePrepRule(session => FallbackBroadcastExchange(session)) extensions.injectQueryStagePrepRule(session => PushOrderedLimitThroughAgg(session)) ModifyUtilAdaptor.injectRule(extensions) diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala index aae1d2c86acae75d37a37f27e73d6491eb16caa8..5c638d39eeaf680f9abb321909f95c5696a59687 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala @@ -65,8 +65,6 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { def enableShareBroadcastJoinNestedTable: Boolean = conf.getConf(ENABLE_SHARE_BROADCAST_JOIN_NESTED_TABLE) - def enableHeuristicJoinReorder: Boolean = conf.getConf(ENABLE_HEURISTIC_JOIN_REORDER) - def enableDelayCartesianProduct: Boolean = conf.getConf(ENABLE_DELAY_CARTESIAN_PRODUCT) def enableColumnarFileScan: Boolean = conf.getConf(ENABLE_COLUMNAR_FILE_SCAN) @@ -147,9 +145,7 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { def enableGlobalColumnarLimit : Boolean = conf.getConf(ENABLE_GLOBAL_COLUMNAR_LIMIT) - def topNPushDownForWindowThreshold: Int = conf.getConf(TOP_N_PUSH_DOWN_FOR_WINDOW_THRESHOLD) - - def topNPushDownForWindowEnable: Boolean = conf.getConf(TOP_N_PUSH_DOWN_FOR_WINDOW_ENABLE) + def topNSortThreshold: Int = conf.getConf(TOP_N_THRESHOLD) def pushOrderedLimitThroughAggEnable: Boolean = conf.getConf(PUSH_ORDERED_LIMIT_THROUGH_AGG_ENABLE) @@ -309,12 +305,6 @@ object ColumnarPluginConfig { .booleanConf .createWithDefault(true) - val ENABLE_HEURISTIC_JOIN_REORDER = buildConf("spark.sql.heuristicJoinReorder.enabled") - .internal() - .doc("enable or disable heuristic join reorder") - .booleanConf - .createWithDefault(true) - val ENABLE_DELAY_CARTESIAN_PRODUCT = buildConf("spark.sql.enableDelayCartesianProduct.enabled") .internal() .doc("enable or disable delay cartesian product") @@ -538,16 +528,11 @@ object ColumnarPluginConfig { .booleanConf .createWithDefault(true) - val TOP_N_PUSH_DOWN_FOR_WINDOW_THRESHOLD = buildConf("spark.sql.execution.topNPushDownForWindow.threshold") + val TOP_N_THRESHOLD = buildConf("spark.omni.sql.columnar.topN.threshold") .internal() .intConf .createWithDefault(100) - val TOP_N_PUSH_DOWN_FOR_WINDOW_ENABLE = buildConf("spark.sql.execution.topNPushDownForWindow.enabled") - .internal() - .booleanConf - .createWithDefault(true) - val PUSH_ORDERED_LIMIT_THROUGH_AGG_ENABLE = buildConf("spark.omni.sql.columnar.pushOrderedLimitThroughAggEnable.enabled") .internal() .booleanConf diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorder.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DelayCartesianProduct.scala similarity index 51% rename from omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorder.scala rename to omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DelayCartesianProduct.scala index f0dd04487fff7420a86e501c64a5345d0794738b..835faa5dd32d6b5fdb8f1cf4bc9abd8ef5be35ee 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorder.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DelayCartesianProduct.scala @@ -17,22 +17,15 @@ package org.apache.spark.sql.catalyst.optimizer -import scala.annotation.tailrec -import scala.collection.mutable - import com.huawei.boostkit.spark.ColumnarPluginConfig - import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, EqualNullSafe, EqualTo, Expression, IsNotNull, PredicateHelper} +import org.apache.spark.sql.catalyst.expressions.{And, Expression, PredicateHelper} import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.util.sideBySide - - - /** * Move all cartesian products to the root of the plan */ @@ -162,165 +155,7 @@ object DelayCartesianProduct extends Rule[LogicalPlan] with PredicateHelper { } } -/** - * Firstly, Heuristic reorder join need to execute small joins with filters - * , which can reduce intermediate results - */ -object HeuristicJoinReorder extends Rule[LogicalPlan] - with PredicateHelper with JoinSelectionHelper { - - /** - * Join a list of plans together and push down the conditions into them. - * The joined plan are picked from left to right, thus the final result is a left-deep tree. - * - * @param input a list of LogicalPlans to inner join and the type of inner join. - * @param conditions a list of condition for join. - */ - @tailrec - final def createReorderJoin(input: Seq[(LogicalPlan, InnerLike)], conditions: Seq[Expression]) - : LogicalPlan = { - assert(input.size >= 2) - if (input.size == 2) { - val (joinConditions, others) = conditions.partition(canEvaluateWithinJoin) - val ((leftPlan, leftJoinType), (rightPlan, rightJoinType)) = (input(0), input(1)) - val innerJoinType = (leftJoinType, rightJoinType) match { - case (Inner, Inner) => Inner - case (_, _) => Cross - } - // Set the join node ordered so that we don't need to transform them again. - val orderJoin = OrderedJoin(leftPlan, rightPlan, innerJoinType, joinConditions.reduceLeftOption(And)) - if (others.nonEmpty) { - Filter(others.reduceLeft(And), orderJoin) - } else { - orderJoin - } - } else { - val (left, _) :: rest = input.toList - val candidates = rest.filter { planJoinPair => - val plan = planJoinPair._1 - // 1. it has join conditions with the left node - // 2. it has a filter - // 3. it can be broadcast - val isEqualJoinCondition = conditions.flatMap { - case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => None - case EqualNullSafe(l, r) if l.references.isEmpty || r.references.isEmpty => None - case e@EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, plan) => Some(e) - case e@EqualTo(l, r) if canEvaluate(l, plan) && canEvaluate(r, left) => Some(e) - case e@EqualNullSafe(l, r) if canEvaluate(l, left) && canEvaluate(r, plan) => Some(e) - case e@EqualNullSafe(l, r) if canEvaluate(l, plan) && canEvaluate(r, left) => Some(e) - case _ => None - }.nonEmpty - - val hasFilter = plan match { - case f: Filter if hasValuableCondition(f.condition) => true - case Project(_, f: Filter) if hasValuableCondition(f.condition) => true - case _ => false - } - isEqualJoinCondition && hasFilter - } - val (right, innerJoinType) = if (candidates.nonEmpty) { - candidates.minBy(_._1.stats.sizeInBytes) - } else { - rest.head - } - - val joinedRefs = left.outputSet ++ right.outputSet - val selectedJoinConditions = mutable.HashSet.empty[Expression] - val (joinConditions, others) = conditions.partition { e => - // If there are semantically equal conditions, they should come from two different joins. - // So we should not put them into one join. - if (!selectedJoinConditions.contains(e.canonicalized) && e.references.subsetOf(joinedRefs) - && canEvaluateWithinJoin(e)) { - selectedJoinConditions.add(e.canonicalized) - true - } else { - false - } - } - // Set the join node ordered so that we don't need to transform them again. - val joined = OrderedJoin(left, right, innerJoinType, joinConditions.reduceLeftOption(And)) - - // should not have reference to same logical plan - createReorderJoin(Seq((joined, Inner)) ++ rest.filterNot(_._1 eq right), others) - } - } - - private def hasValuableCondition(condition: Expression): Boolean = { - val conditions = splitConjunctivePredicates(condition) - !conditions.forall(_.isInstanceOf[IsNotNull]) - } - - def apply(plan: LogicalPlan): LogicalPlan = { - if (ColumnarPluginConfig.getSessionConf.enableHeuristicJoinReorder) { - val newPlan = plan.transform { - case p@ExtractFiltersAndInnerJoinsByIgnoreProjects(input, conditions) - if input.size > 2 && conditions.nonEmpty => - val reordered = createReorderJoin(input, conditions) - if (p.sameOutput(reordered)) { - reordered - } else { - // Reordering the joins have changed the order of the columns. - // Inject a projection to make sure we restore to the expected ordering. - Project(p.output, reordered) - } - } - - // After reordering is finished, convert OrderedJoin back to Join - val result = newPlan.transformDown { - case OrderedJoin(left, right, jt, cond) => Join(left, right, jt, cond, JoinHint.NONE) - } - if (!result.resolved) { - // In some special cases related to subqueries, we find that after reordering, - val comparedPlans = sideBySide(plan.treeString, result.treeString).mkString("\n") - logWarning("The structural integrity of the plan is broken, falling back to the " + - s"original plan. == Comparing two plans ===\n$comparedPlans") - plan - } else { - result - } - } else { - plan - } - } -} - -/** - * This is different from [[ExtractFiltersAndInnerJoins]] in that it can collect filters and - * inner joins by ignoring projects on top of joins, which are produced by column pruning. - */ -private object ExtractFiltersAndInnerJoinsByIgnoreProjects extends PredicateHelper { - - /** - * Flatten all inner joins, which are next to each other. - * Return a list of logical plans to be joined with a boolean for each plan indicating if it - * was involved in an explicit cross join. Also returns the entire list of join conditions for - * the left-deep tree. - */ - def flattenJoin(plan: LogicalPlan, parentJoinType: InnerLike = Inner) - : (Seq[(LogicalPlan, InnerLike)], Seq[Expression]) = plan match { - case Join(left, right, joinType: InnerLike, cond, hint) if hint == JoinHint.NONE => - val (plans, conditions) = flattenJoin(left, joinType) - (plans ++ Seq((right, joinType)), conditions ++ - cond.toSeq.flatMap(splitConjunctivePredicates)) - case Filter(filterCondition, j@Join(_, _, _: InnerLike, _, hint)) if hint == JoinHint.NONE => - val (plans, conditions) = flattenJoin(j) - (plans, conditions ++ splitConjunctivePredicates(filterCondition)) - case Project(projectList, child) - if projectList.forall(_.isInstanceOf[Attribute]) => flattenJoin(child) - - case _ => (Seq((plan, parentJoinType)), Seq.empty) - } - - def unapply(plan: LogicalPlan): Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression])] - = plan match { - case f@Filter(_, Join(_, _, _: InnerLike, _, _)) => - Some(flattenJoin(f)) - case j@Join(_, _, _, _, hint) if hint == JoinHint.NONE => - Some(flattenJoin(j)) - case _ => None - } -} private object ExtractFiltersAndInnerJoinsForBushy extends PredicateHelper { diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/AbstractUnsafeRowSorter.java b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/AbstractUnsafeRowSorter.java deleted file mode 100644 index 9ddbd2bd135d97ef215c67802497dfb78788ab16..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/AbstractUnsafeRowSorter.java +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution; - -import java.io.IOException; - -import scala.collection.Iterator; -import scala.math.Ordering; - -import com.google.common.annotations.VisibleForTesting; - -import org.apache.spark.sql.types.StructType; -import org.apache.spark.util.collection.unsafe.sort.RecordComparator; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.UnsafeRow; - -public abstract class AbstractUnsafeRowSorter -{ - protected final StructType schema; - - /** - * If positive, forces records to be spilled to disk at the give frequency (measured in numbers of records). - * This is only intended to be used in tests. - * */ - protected int testSpillFrequency = 0; - - AbstractUnsafeRowSorter(final StructType schema) { - this.schema = schema; - } - - // This flag makes sure the cleanupResource() has been called. - // After the cleanup work, iterator.next should always return false. - // Downstream operator triggers the resource cleanup while they found there's no need to keep the iterator anymore. - // See more detail in SPARK-21492. - boolean isReleased = false; - - public abstract void insertRow(UnsafeRow row) throws IOException; - - public abstract Iterator sort() throws IOException; - - public abstract Iterator sort(Iterator inputIterator) throws IOException; - - /** - * @return the peak memory used so far, in bytes. - * */ - public abstract long getPeakMemoryUsage(); - - /** - * @return the total amount of time spent sorting data (in-memory only). - * */ - public abstract long getSortTimeNanos(); - - public abstract void cleanupResources(); - - /** - * Foreces spills to occur every 'frequency' records. Only for use in tests. - * */ - @VisibleForTesting - void setTestSpillFrequency(int frequency) { - assert frequency > 0 : "Frequency must be positive"; - testSpillFrequency = frequency; - } - - static final class RowComparator extends RecordComparator { - private final Ordering ordering; - private final UnsafeRow row1; - private final UnsafeRow row2; - - RowComparator(Ordering ordering, int numFields) { - this.row1 = new UnsafeRow(numFields); - this.row2 = new UnsafeRow(numFields); - this.ordering = ordering; - } - - @Override - public int compare( - Object baseObj1, - long baseOff1, - int baseLen1, - Object baseObj2, - long baseOff2, - int baseLen2) { - // Note that since ordering doesn't need the total length of the record, we just pass 0 int the row. - row1.pointTo(baseObj1, baseOff1, 0); - row2.pointTo(baseObj2, baseOff2, 0); - return ordering.compare(row1, row2); - } - } -} diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/ColumnarLimit.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/ColumnarLimit.scala index 410ce4127fe5038682e5195a7ec5e4f586439ab1..675c00af1eb531f54bd0cb08ebb84c3a8e40cb5e 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/ColumnarLimit.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/ColumnarLimit.scala @@ -155,11 +155,74 @@ case class ColumnarGlobalLimitExec(limit: Int, child: SparkPlan, offset: Int = 0 copy(child = newChild) def buildCheck(): Unit = { - if (offset > 0) { - throw new UnsupportedOperationException("ColumnarGlobalLimitExec doesn't support offset greater than 0.") - } child.output.foreach(attr => sparkTypeToOmniType(attr.dataType, attr.metadata)) } + + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val addInputTime = longMetric("addInputTime") + val omniCodegenTime = longMetric("omniCodegenTime") + val getOutputTime = longMetric("getOutputTime") + val numOutputRows = longMetric("numOutputRows") + val numOutputVecBatches = longMetric("numOutputVecBatches") + + child.executeColumnar().mapPartitions { iter => + + val startCodegen = System.nanoTime() + val limitOperatorFactory = new OmniLimitOperatorFactory(limit, offset) + val limitOperator = limitOperatorFactory.createOperator + omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) + + // close operator + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + limitOperator.close() + limitOperatorFactory.close() + }) + + val localSchema = this.schema + new Iterator[ColumnarBatch] { + private var results: java.util.Iterator[VecBatch] = _ + + override def hasNext: Boolean = { + while ((results == null || !results.hasNext) && iter.hasNext) { + val batch = iter.next() + val input = transColBatchToOmniVecs(batch) + val vecBatch = new VecBatch(input, batch.numRows()) + val startInput = System.nanoTime() + limitOperator.addInput(vecBatch) + addInputTime += NANOSECONDS.toMillis(System.nanoTime() - startInput) + + val startGetOp = System.nanoTime() + results = limitOperator.getOutput + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + } + if (results == null) { + false + } else { + val startGetOp: Long = System.nanoTime() + val hasNext = results.hasNext + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + hasNext + } + } + + override def next(): ColumnarBatch = { + val startGetOp = System.nanoTime() + val vecBatch = results.next() + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + val vectors: Seq[OmniColumnVector] = OmniColumnVector.allocateColumns( + vecBatch.getRowCount, localSchema, false) + vectors.zipWithIndex.foreach { case (vector, i) => + vector.reset() + vector.setVec(vecBatch.getVectors()(i)) + } + numOutputRows += vecBatch.getRowCount + numOutputVecBatches += 1 + vecBatch.close() + new ColumnarBatch(vectors.toArray, vecBatch.getRowCount) + } + } + } + } } case class ColumnarTakeOrderedAndProjectExec( @@ -168,7 +231,7 @@ case class ColumnarTakeOrderedAndProjectExec( projectList: Seq[NamedExpression], child: SparkPlan, offset: Int = 0) - extends UnaryExecNode { + extends UnaryExecNodeShim(sortOrder, projectList) { override def supportsColumnar: Boolean = true @@ -218,9 +281,6 @@ case class ColumnarTakeOrderedAndProjectExec( } def buildCheck(): Unit = { - if (offset > 0) { - throw new UnsupportedOperationException("ColumnarTakeOrderedAndProjectExec doesn't support offset greater than 0.") - } genSortParam(child.output, sortOrder) val projectEqualChildOutput = projectList == child.output var omniInputTypes: Array[DataType] = null @@ -245,9 +305,9 @@ case class ColumnarTakeOrderedAndProjectExec( } else { val (sourceTypes, ascending, nullFirsts, sortColsExp) = genSortParam(child.output, sortOrder) - def computeTopN(iter: Iterator[ColumnarBatch], schema: StructType): Iterator[ColumnarBatch] = { + def computeTopN(iter: Iterator[ColumnarBatch], schema: StructType, offset: Int): Iterator[ColumnarBatch] = { val startCodegen = System.nanoTime() - val topNOperatorFactory = new OmniTopNWithExprOperatorFactory(sourceTypes, limit, sortColsExp, ascending, nullFirsts, + val topNOperatorFactory = new OmniTopNWithExprOperatorFactory(sourceTypes, limit, offset, sortColsExp, ascending, nullFirsts, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val topNOperator = topNOperatorFactory.createOperator longMetric("omniCodegenTime") += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) @@ -265,7 +325,7 @@ case class ColumnarTakeOrderedAndProjectExec( } else { val localTopK: RDD[ColumnarBatch] = { child.executeColumnar().mapPartitionsWithIndexInternal { (_, iter) => - computeTopN(iter, this.child.schema) + computeTopN(iter, this.child.schema, 0) } } @@ -302,7 +362,7 @@ case class ColumnarTakeOrderedAndProjectExec( } singlePartitionRDD.mapPartitions { iter => // TopN = omni-top-n + omni-project - val topN: Iterator[ColumnarBatch] = computeTopN(iter, this.child.schema) + val topN: Iterator[ColumnarBatch] = computeTopN(iter, this.child.schema, offset) if (!projectEqualChildOutput) { dealPartitionData(null, null, addInputTime, omniCodegenTime, getOutputTime, omniInputTypes, omniExpressions, topN, this.schema) @@ -313,8 +373,6 @@ case class ColumnarTakeOrderedAndProjectExec( } } - override def outputOrdering: Seq[SortOrder] = sortOrder - override def outputPartitioning: Partitioning = SinglePartition override def simpleString(maxFields: Int): String = { diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala deleted file mode 100644 index 0ddf89b8c1c3d36b63e7bebd2f9b12e0b1a7f385..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ /dev/null @@ -1,307 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.util.concurrent.TimeUnit._ -import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.executor.TaskMetrics -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode} -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS -import org.apache.spark.sql.execution.UnsafeExternalRowSorter.PrefixComputer -import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.util.collection.unsafe.sort.PrefixComparator - - -/** - * Base class of [[SortExec]] and [[TopNSortExec]]. All subclasses of this class need to override - * their own sorter which inherits from [[org.apache.spark.sql.execution.AbstractUnsafeRowSorter]] - * to perform corresponding sorting. - * - * @param global when true performs a global sort of all partitions by shuffling the data first - * if necessary. - * @param testSpillFrequency Method for configuring periodic spilling in unit tests. - * If set, will spill every 'frequency' records. - * */ -abstract class SortExecBase( - sortOrder: Seq[SortOrder], - global: Boolean, - child: SparkPlan, - testSpillFrequency: Int = 0) - extends UnaryExecNode with BlockingOperatorWithCodegen { - - override def output: Seq[Attribute] = child.output - - override def outputOrdering: Seq[SortOrder] = sortOrder - - // sort performed is local within a given partition so will retain - // child operator's partitioning - override def outputPartitioning: Partitioning = child.outputPartitioning - - override def requiredChildDistribution: Seq[Distribution] = - if (global) OrderedDistribution(sortOrder):: Nil else UnspecifiedDistribution :: Nil - - private val enableRadixSort = conf.enableRadixSort - - override lazy val metrics = Map( - "sortTime" -> SQLMetrics.createTimingMetric(sparkContext, "sort time"), - "peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"), - "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size") - ) - - protected val sorterClassName: String - - protected def newSorterInstance( - ordering: Ordering[InternalRow], - prefixComparator: PrefixComparator, - prefixComputer: PrefixComputer, - pageSize: Long, - canSortFullyWIthPrefix: Boolean): AbstractUnsafeRowSorter - - private[sql] var rowSorter: AbstractUnsafeRowSorter = _ - - /** - * This method gets invoked only once for each SortExec instance to initialize - * an AbstractUnsafeRowSorter, both 'plan.execute' and code generation are using it. - * In the code generation code path, we need to call this function outside the class - * so we should make it public - * */ - def createSorter(): AbstractUnsafeRowSorter = { - val ordering = RowOrdering.create(sortOrder, output) - - // THe comparator for comparing prefix - val boundSortExpression = BindReferences.bindReference(sortOrder.head, output) - val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) - - val canSortFullyWIthPrefix = sortOrder.length == 1 && - SortPrefixUtils.canSortFullyWithPrefix(boundSortExpression) - - // The generator for prefix - val prefixExpr = SortPrefix(boundSortExpression) - val prefixProjection = UnsafeProjection.create(Seq(prefixExpr)) - val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { - private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix - override def computePrefix(row: InternalRow): - UnsafeExternalRowSorter.PrefixComputer.Prefix = { - val prefix = prefixProjection.apply(row) - result.isNull = prefix.isNullAt(0) - result.value = if (result.isNull) prefixExpr.nullValue else prefix.getLong(0) - result - } - } - - val pageSize = SparkEnv.get.memoryManager.pageSizeBytes - rowSorter = newSorterInstance(ordering, prefixComparator, prefixComputer, - pageSize, canSortFullyWIthPrefix) - - if (testSpillFrequency > 0) { - rowSorter.setTestSpillFrequency(testSpillFrequency) - } - rowSorter - } - - protected override def doExecute(): RDD[InternalRow] = { - val peakMemory = longMetric("peakMemory") - val spillSize = longMetric("spillSize") - val sortTime = longMetric("sortTime") - - child.execute().mapPartitionsInternal { iter => - val sorter = createSorter() - val metrics = TaskContext.get().taskMetrics() - - // Remember spill data size of this task before execute this operator, - // so that we can figure out how many bytes we spilled for this operator. - val spillSizeBefore = metrics.memoryBytesSpilled - val sortedIterator = sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) - sortTime += NANOSECONDS.toMillis(sorter.getSortTimeNanos) - peakMemory += sorter.getPeakMemoryUsage - spillSize += metrics.memoryBytesSpilled - spillSizeBefore - metrics.incPeakExecutionMemory(sorter.getPeakMemoryUsage) - - sortedIterator - } - } - - override def usedInputs: AttributeSet = AttributeSet(Seq.empty) - - override def inputRDDs(): Seq[RDD[InternalRow]] = { - child.asInstanceOf[CodegenSupport].inputRDDs - } - - // Name of sorter variable used in codegen - private var sorterVariable: String = _ - - override protected def doProduce(ctx: CodegenContext): String = { - val needToSort = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, - "needToSort", v => s"$v = true;") - - // Initalize the class member variables. This includes the instance of the Sorter - // and the iterator to return sorted rows. - val thisPlan = ctx.addReferenceObj("plan", this) - // Inline mutable state since not many Sort operations in a task - sorterVariable = ctx.addMutableState(sorterClassName, "sorter", - v => s"$v = $thisPlan.createSorter();", forceInline = true) - val metrics = ctx.addMutableState(classOf[TaskMetrics].getName, "metrics", - v => s"$v = org.apache.spark.TaskContext.get().taskMetrics();", forceInline = true) - val sortedIterator = ctx.addMutableState("scala.collection.Iterator", - "sortedIter", forceInline = true) - - val addToSorter = ctx.freshName("addToSorter") - val addToSorterFuncName = ctx.addNewFunction(addToSorter, - s""" - | private void $addToSorter() throws java.io.IOException { - | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} - | } - """.stripMargin.trim) - - val outputRow = ctx.freshName("outputRow") - val peakMemory = metricTerm(ctx, "peakMemory") - val spillSize = metricTerm(ctx, "spillSize") - val spillSizeBefore = ctx.freshName("spillSizeBefore") - val sortTime = metricTerm(ctx, "sortTime") - s""" - | if ($needToSort) { - | long $spillSizeBefore = $metrics.memoryBytesSpilled(); - | $addToSorterFuncName(); - | $sortedIterator = $sorterVariable.sort(); - | $sortTime.add($sorterVariable.getSortTimeNanos() / $NANOS_PER_MILLIS); - | $peakMemory.add($sorterVariable.getPeakMemoryUsage()); - | $spillSize.add($metrics.memoryBytesSpilled() - $spillSizeBefore); - | $metrics.incPeakExecutionMemory($sorterVariable.getPeakMemoryUsage()); - | $needToSort = false; - | } - | - | while ($limitNotReachedCond $sortedIterator.hasNext()) { - | UnsafeRow $outputRow = (UnsafeRow)$sortedIterator.next(); - | ${consume(ctx, null, outputRow)} - | if (shouldStop()) return; - | } - """.stripMargin.trim - } - - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - s""" - | ${row.code} - | $sorterVariable.insertRow((UnsafeRow)${row.value}); - """.stripMargin - } - - /** - * In BaseSortExec, we overwrites cleanupResources to close AbstractUnsafeRowSorter. - * */ - - override protected[sql] def cleanupResources(): Unit = { - if (rowSorter != null) { - // There's possible for rowSorter is null here, for example, in the scenario of empty - // iterator in the current task, the downstream physical node(like SortMergeJoinExec) will - // trigger cleanupResources before rowSorter initialized in createSorter - rowSorter.cleanupResources() - } - super.cleanupResources() - } -} - - -/** - * Performs (external) sorting - * */ -case class SortExec( - sortOrder: Seq[SortOrder], - global: Boolean, - child: SparkPlan, - testSpillFrequency: Int = 0) - extends SortExecBase(sortOrder, global, child, testSpillFrequency) { - private val enableRadixSort = conf.enableRadixSort - - - override val sorterClassName: String = classOf[UnsafeExternalRowSorter].getName - - override def newSorterInstance( - ordering: Ordering[InternalRow], - prefixComparator: PrefixComparator, - prefixComputer: PrefixComputer, - pageSize: Long, - canSortFullyWIthPrefix: Boolean): UnsafeExternalRowSorter = { - UnsafeExternalRowSorter.create( - schema, - ordering, - prefixComparator, - prefixComputer, - pageSize, - enableRadixSort && canSortFullyWIthPrefix) - } - - override def createSorter(): UnsafeExternalRowSorter = { - super.createSorter().asInstanceOf[UnsafeExternalRowSorter] - } - - override protected def withNewChildInternal(newChild: SparkPlan): SortExec = { - copy(child = newChild) - } -} - -/** - * Performs topN sort - * - * @param strictTopN when true it strictly returns n results. This param distinguishes - * [[RowNumber]] from [[Rank]]. [[RowNumber]] corresponds to true - * and [[Rank]] corresponds to false. - * @param partitionSpec partitionSpec of [[org.apache.spark.sql.execution.window.WindowExec]] - * @param sortOrder orderSpec of [[org.apache.spark.sql.execution.window.WindowExec]] - * */ -case class TopNSortExec( - n: Int, - strictTopN: Boolean, - partitionSpec: Seq[Expression], - sortOrder: Seq[SortOrder], - global: Boolean, - child: SparkPlan) - extends SortExecBase(sortOrder, global, child, 0) { - - override val sorterClassName: String = classOf[UnsafeTopNRowSorter].getName - - override def newSorterInstance( - ordering: Ordering[InternalRow], - prefixComparator: PrefixComparator, - prefixComputer: PrefixComputer, - pageSize: Long, - canSortFullyWIthPrefix: Boolean): UnsafeTopNRowSorter = { - val partitionSpecProjection = UnsafeProjection.create(partitionSpec, output) - UnsafeTopNRowSorter.create( - n, - strictTopN, - schema, - partitionSpecProjection, - ordering, - prefixComparator, - prefixComputer, - pageSize, - canSortFullyWIthPrefix) - } - - override def createSorter(): UnsafeTopNRowSorter = { - super.createSorter().asInstanceOf[UnsafeTopNRowSorter] - } - - override protected def withNewChildInternal(newChild: SparkPlan): TopNSortExec = { - copy(child = newChild) - } -} diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/TopNSortExec.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/TopNSortExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..ece8ac9d35c58d28760473d11805e7c41b8c3041 --- /dev/null +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/TopNSortExec.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ + +/** + * Performs topN sort + * + * @param strictTopN when true it strictly returns n results. This param distinguishes + * [[RowNumber]] from [[Rank]]. [[RowNumber]] corresponds to true + * and [[Rank]] corresponds to false. + * @param partitionSpec partitionSpec of [[org.apache.spark.sql.execution.window.WindowExec]] + * @param sortOrder orderSpec of [[org.apache.spark.sql.execution.window.WindowExec]] + * */ +case class TopNSortExec( + n: Int, + strictTopN: Boolean, + partitionSpec: Seq[Expression], + sortOrder: Seq[SortOrder], + global: Boolean, + child: SparkPlan) extends UnaryExecNode { + override protected def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException("unsupported topn sort exec") + } + + override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = { + copy(child = newChild) + } +} diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java deleted file mode 100644 index b36a424d22f54fa629e8dfd774c7d503ee75362c..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ /dev/null @@ -1,198 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution; - -import java.io.IOException; -import java.util.function.Supplier; - -import scala.collection.Iterator; -import scala.math.Ordering; - -import org.apache.spark.SparkEnv; -import org.apache.spark.TaskContext; -import org.apache.spark.internal.config.package$; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.UnsafeRow; -import org.apache.spark.sql.types.StructType; -import org.apache.spark.unsafe.Platform; -import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; -import org.apache.spark.util.collection.unsafe.sort.RecordComparator; -import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter; -import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator; - -public final class UnsafeExternalRowSorter extends AbstractUnsafeRowSorter { - private long numRowsInserted = 0; - private final UnsafeExternalRowSorter.PrefixComputer prefixComputer; - private final UnsafeExternalSorter sorter; - - public abstract static class PrefixComputer { - public static class Prefix { - // Key prefix value, or the null prefix value if isNull = true - public long value; - - // Whether the key is null - public boolean isNull; - } - - /** - * Computes prefix for the given row. For efficiency, the object may be reused in - * further calls to a given PrefixComputer. - * */ - public abstract Prefix computePrefix(InternalRow row); - } - - public static UnsafeExternalRowSorter createWithRecordComparator( - StructType schema, - Supplier recordComparatorSupplier, - PrefixComparator prefixComparator, - UnsafeExternalRowSorter.PrefixComputer prefixComputer, - long pageSizeBytes, - boolean canUseRadixSort) throws IOException { - return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, prefixComparator, - prefixComputer, pageSizeBytes, canUseRadixSort); - } - - public static UnsafeExternalRowSorter create( - StructType schema, - Ordering ordering, - PrefixComparator prefixComparator, - UnsafeExternalRowSorter.PrefixComputer prefixComputer, - long pageSizeBytes, - boolean canUseRadixSort) throws IOException { - Supplier recordComparatorSupplier = () -> new RowComparator(ordering, schema.length()); - return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, prefixComparator, - prefixComputer, pageSizeBytes, canUseRadixSort); - } - - private UnsafeExternalRowSorter( - StructType schema, - Supplier recordComparatorSupplier, - PrefixComparator prefixComparator, - UnsafeExternalRowSorter.PrefixComputer prefixComputer, - long pageSizeBytes, - boolean canUseRadixSort) { - super(schema); - this.prefixComputer = prefixComputer; - final SparkEnv sparkEnv = SparkEnv.get(); - final TaskContext taskContext = TaskContext.get(); - sorter = UnsafeExternalSorter.create( - taskContext.taskMemoryManager(), - sparkEnv.blockManager(), - sparkEnv.serializerManager(), - taskContext, - recordComparatorSupplier, - prefixComparator, - (int) (long) sparkEnv.conf().get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()), - pageSizeBytes, - (int) sparkEnv.conf().get( - package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()), - canUseRadixSort); - } - - @Override - public void insertRow(UnsafeRow row) throws IOException { - final PrefixComputer.Prefix prefix = prefixComputer.computePrefix(row); - sorter.insertRecord( - row.getBaseObject(), - row.getBaseOffset(), - row.getSizeInBytes(), - prefix.value, - prefix.isNull); - numRowsInserted++; - if (testSpillFrequency > 0 && (numRowsInserted % testSpillFrequency) == 0) { - sorter.spill(); - } - } - - @Override - public long getPeakMemoryUsage() { - return sorter.getPeakMemoryUsedBytes(); - } - - @Override - public long getSortTimeNanos() { - return sorter.getSortTimeNanos(); - } - - @Override - public void cleanupResources() { - isReleased = true; - sorter.cleanupResources(); - } - - @Override - public Iterator sort() throws IOException { - try { - final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator(); - if (!sortedIterator.hasNext()) { - // Since we won't ever call next() on an empty iterator, we need to clean up resources - // here in order to prevent memory leaks. - cleanupResources(); - } - return new RowIterator() { - private final int numFields = schema.length(); - private UnsafeRow row = new UnsafeRow(numFields); - - @Override - public boolean advanceNext() { - try { - if (!isReleased && sortedIterator.hasNext()) { - sortedIterator.loadNext(); - row.pointTo( - sortedIterator.getBaseObject(), - sortedIterator.getBaseOffset(), - sortedIterator.getRecordLength()); - // Here is the initial buf ifx in SPARK-9364: the bug fix of use-after-free bug - // when returning the last row from an iterator. For example, in - // [[GroupedIterator]], we still use the last row after traversing the iterator - // in 'fetchNextGroupIterator' - if (!sortedIterator.hasNext()) { - row = row.copy(); // so that we don't have dangling pointers to freed page - cleanupResources(); - } - return true; - } else { - row = null; // so that we don't keep reference to the base object - return false; - } - } catch (IOException e) { - cleanupResources(); - // Scala iterators don't declare any checked exceptions, so we need to use this hack - // to re-throw the exception. - Platform.throwException(e); - } - throw new RuntimeException("Exception should have been re-thrown in next()"); - } - - @Override - public UnsafeRow getRow() { return row; } - }.toScala(); - } catch (IOException e) { - cleanupResources(); - throw e; - } - } - - @Override - public Iterator sort(Iterator inputIterator) throws IOException { - while (inputIterator.hasNext()) { - insertRow(inputIterator.next()); - } - return sort(); - } -} diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/UnsafeTopNRowSorter.java b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/UnsafeTopNRowSorter.java deleted file mode 100644 index 6a27c8edfa16042201f37addc0d0e0783fa81d5c..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/UnsafeTopNRowSorter.java +++ /dev/null @@ -1,256 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution; - -import java.io.IOException; -import java.util.*; -import java.util.function.Supplier; - -import scala.collection.Iterator; -import scala.math.Ordering; - -import org.apache.spark.TaskContext; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; -import org.apache.spark.sql.catalyst.expressions.UnsafeRow; -import org.apache.spark.sql.execution.topnsort.UnsafeInMemoryTopNSorter; -import org.apache.spark.sql.execution.topnsort.UnsafePartitionedTopNSorter; -import org.apache.spark.sql.types.StructType; -import org.apache.spark.unsafe.Platform; -import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; -import org.apache.spark.util.collection.unsafe.sort.RecordComparator; -import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator; - -public final class UnsafeTopNRowSorter extends AbstractUnsafeRowSorter { - - private final UnsafePartitionedTopNSorter partitionedTopNSorter; - - // partition key - private final UnsafeProjection partitionSpecProjection; - - // order(rank) key - private final UnsafeExternalRowSorter.PrefixComputer prefixComputer; - - private long totalSortTimeNanos = 0L; - private final long timeNanosBeforeInsertRow; - - public static UnsafeTopNRowSorter create( - int n, - boolean strictTopN, - StructType schema, - UnsafeProjection partitionSpecProjection, - Ordering orderingOfRankKey, - PrefixComparator prefixComparator, - UnsafeExternalRowSorter.PrefixComputer prefixComputer, - long pageSizeBytes, - boolean canSortFullyWithPrefix) { - Supplier recordComparatorSupplier = - () -> new RowComparator(orderingOfRankKey, schema.length()); - return new UnsafeTopNRowSorter( - n, strictTopN, schema, partitionSpecProjection, recordComparatorSupplier, - prefixComparator, prefixComputer, pageSizeBytes, canSortFullyWithPrefix); - } - - private UnsafeTopNRowSorter( - int n, - boolean strictTopN, - StructType schema, - UnsafeProjection partitionSpecProjection, - Supplier recordComparatorSupplier, - PrefixComparator prefixComparator, - UnsafeExternalRowSorter.PrefixComputer prefixComputer, - long pageSizeBytes, - boolean canSortFullyWithPrefix) { - super(schema); - this.prefixComputer = prefixComputer; - final TaskContext taskContext = TaskContext.get(); - this.partitionSpecProjection = partitionSpecProjection; - this.partitionedTopNSorter = UnsafePartitionedTopNSorter.create( - n, - strictTopN, - taskContext.taskMemoryManager(), - taskContext, - recordComparatorSupplier, - prefixComparator, - pageSizeBytes, - canSortFullyWithPrefix); - timeNanosBeforeInsertRow = System.nanoTime(); - } - - @Override - public void insertRow(UnsafeRow row) throws IOException { - final UnsafeExternalRowSorter.PrefixComputer.Prefix prefix = prefixComputer.computePrefix(row); - UnsafeRow partKey = partitionSpecProjection.apply(row); - partitionedTopNSorter.insertRow(partKey, row, prefix.value); - } - - /** - * Return the peak memory used so far, in bytes. - * */ - @Override - public long getPeakMemoryUsage() { - return partitionedTopNSorter.getPeakMemoryUsedBytes(); - } - - /** - * @return the total amount of time spent sorting data (in-memory only). - * */ - @Override - public long getSortTimeNanos() { - return totalSortTimeNanos; - } - - @Override - public Iterator sort() throws IOException - { - try { - Map partKeyToSorter = - partitionedTopNSorter.getPartKeyToSorter(); - if (partKeyToSorter.isEmpty()) { - // Since we won't ever call next() on an empty iterator, we need to clean up resources - // here in order to prevent memory leaks. - cleanupResources(); - return emptySortedIterator(); - } - - Queue sortedIteratorsForPartitions = new LinkedList<>(); - for (Map.Entry entry : partKeyToSorter.entrySet()) { - final UnsafeInMemoryTopNSorter topNSorter = entry.getValue(); - final UnsafeSorterIterator unsafeSorterIterator = topNSorter.getSortedIterator(); - - sortedIteratorsForPartitions.add(new RowIterator() - { - private final int numFields = schema.length(); - private UnsafeRow row = new UnsafeRow(numFields); - - @Override - public boolean advanceNext() - { - try { - if (!isReleased && unsafeSorterIterator.hasNext()) { - unsafeSorterIterator.loadNext(); - row.pointTo( - unsafeSorterIterator.getBaseObject(), - unsafeSorterIterator.getBaseOffset(), - unsafeSorterIterator.getRecordLength()); - // Here is the initial buf ifx in SPARK-9364: the bug fix of use-after-free bug - // when returning the last row from an iterator. For example, in - // [[GroupedIterator]], we still use the last row after traversing the iterator - // in 'fetchNextGroupIterator' - if (!unsafeSorterIterator.hasNext()) { - row = row.copy(); // so that we don't have dangling pointers to freed page - topNSorter.freeMemory(); - } - return true; - } - else { - row = null; // so that we don't keep reference to the base object - return false; - } - } catch (IOException e) { - topNSorter.freeMemory(); - // Scala iterators don't declare any checked exceptions, so we need to use this hack - // to re-throw the exception. - Platform.throwException(e); - } - throw new RuntimeException("Exception should have been re-thrown in next()"); - } - - @Override - public UnsafeRow getRow() - { - return row; - } - }); - } - - // Update total sort time. - if (totalSortTimeNanos == 0L) { - totalSortTimeNanos = System.nanoTime() - timeNanosBeforeInsertRow; - } - final ChainedIterator chainedIterator = new ChainedIterator(sortedIteratorsForPartitions); - return chainedIterator.toScala(); - } catch (Exception e) { - cleanupResources(); - throw e; - } - } - - private Iterator emptySortedIterator() { - return new RowIterator() { - @Override - public boolean advanceNext() { - return false; - } - - @Override - public UnsafeRow getRow() { - return null; - } - }.toScala(); - } - - /** - * Chain multiple UnsafeSorterIterators from PartSorterMap as single one. - * */ - private static final class ChainedIterator extends RowIterator { - private final Queue iterators; - private RowIterator current; - private UnsafeRow row; - - ChainedIterator(Queue iterators) { - assert iterators.size() > 0; - this.iterators = iterators; - this.current = iterators.remove(); - } - - @Override - public boolean advanceNext() { - boolean result = this.current.advanceNext(); - while(!result && !this.iterators.isEmpty()) { - this.current = iterators.remove(); - result = this.current.advanceNext(); - } - if (!result) { - this.row = null; - } else { - this.row = (UnsafeRow) this.current.getRow(); - } - return result; - } - - @Override - public UnsafeRow getRow() { - return row; - } - } - - @Override - public Iterator sort(Iterator inputIterator) throws IOException { - while (inputIterator.hasNext()) { - insertRow(inputIterator.next()); - } - return sort(); - } - - @Override - public void cleanupResources() { - isReleased = true; - partitionedTopNSorter.cleanupResources(); - } -} diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/topnsort/UnsafeInMemoryTopNSorter.java b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/topnsort/UnsafeInMemoryTopNSorter.java deleted file mode 100644 index 7b14bb6694eec58c48cef5a96aa6626ff22ec431..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/topnsort/UnsafeInMemoryTopNSorter.java +++ /dev/null @@ -1,272 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.topnsort; - -import org.apache.spark.TaskContext; -import org.apache.spark.memory.MemoryConsumer; -import org.apache.spark.memory.TaskMemoryManager; -import org.apache.spark.sql.catalyst.expressions.UnsafeRow; -import org.apache.spark.unsafe.UnsafeAlignedOffset; -import org.apache.spark.unsafe.array.LongArray; -import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator; - -public final class UnsafeInMemoryTopNSorter { - - private final MemoryConsumer consumer; - private final TaskMemoryManager memoryManager; - private final UnsafePartitionedTopNSorter.TopNSortComparator sortComparator; - - /** - * Within this buffer, position {@code 2 * i} holds a pointer to the record at index {@code i}, - * while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. - * - * Only part of the array will be used to store the pointers, the rest part is preserved as temporary buffer for sorting. - */ - private LongArray array; - - /** - * The position in the sort buffer where new records can be inserted. - */ - private int nextEmptyPos = 0; - - // Top n. - private final int n; - private final boolean strictTopN; - - // The capacity of array. - private final int capacity; - private static final int MIN_ARRAY_CAPACITY = 64; - - public UnsafeInMemoryTopNSorter( - final int n, - final boolean strictTopN, - final MemoryConsumer consumer, - final TaskMemoryManager memoryManager, - final UnsafePartitionedTopNSorter.TopNSortComparator sortComparator) { - this.n = n; - this.strictTopN = strictTopN; - this.consumer = consumer; - this.memoryManager = memoryManager; - this.sortComparator = sortComparator; - this.capacity = Math.max(MIN_ARRAY_CAPACITY, Integer.highestOneBit(n) << 1); - // The size of Long array is equal to twice capacity because each item consists of a prefix and a pointer. - this.array = consumer.allocateArray(capacity << 1); - } - - /** - * Free the memory used by pointer array - */ - public void freeMemory() { - if (consumer != null) { - if (array != null) { - consumer.freeArray(array); - } - array = null; - } - nextEmptyPos = 0; - } - - public long getMemoryUsage() { - if (array == null) { - return 0L; - } - return array.size() * 8; - } - - public int insert(UnsafeRow row, long prefix) { - if (nextEmptyPos < n) { - return insertIntoArray(nextEmptyPos -1, row, prefix); - } else { - // reach n candidates - final int compareResult = nthRecordCompareTo(row, prefix); - if (compareResult < 0) { - // skip this record - return -1; - } - else if (compareResult == 0) { - if (strictTopN) { - // For rows that have duplicate values, skip it if this is strict TopN (e.g. RowNumber). - return -1; - } - // append record - checkForInsert(); - array.set((nextEmptyPos << 1) + 1, prefix); - return nextEmptyPos++; - } - else { - checkForInsert(); - // The record at position n -1 should be excluded, so we start comparing with record at position n - 2. - final int insertPosition = insertIntoArray(n - 2, row, prefix); - if (strictTopN || insertPosition == n - 1 || hasDistinctTopN()) { - nextEmptyPos = n; - } - // For other cases, 'nextEmptyPos' will move to the next empty position in 'insertIntoArray()'. - // e.g. given rank <= 4, and we already have 1, 2, 6, 6, so 'nextEmptyPos' is 4. - // If the new row is 3, then values in the array will be 1, 2, 3, 6, 6, and 'nextEmptyPos' will be 5. - return insertPosition; - } - } - } - - public void updateRecordPointer(int position, long pointer) { - array.set(position << 1, pointer); - } - - private int insertIntoArray(int position, UnsafeRow row, long prefix) { - // find insert position - while (position >= 0 && sortComparator.compare(array.get(position << 1), array.get((position << 1) + 1), row, prefix) > 0) { - --position; - } - final int insertPos = position + 1; - - // move records between 'insertPos' and 'nextEmptyPos' to next positions - for (int i = nextEmptyPos; i > insertPos; --i) { - int src = (i - 1) << 1; - int dst = i << 1; - array.set(dst, array.get(src)); - array.set(dst + 1, array.get(src + 1)); - } - - // Insert prefix of this row. Note that the address will be inserted by 'updateRecordPointer()' - // after we get its address from 'taskMemoryManager' - array.set((insertPos << 1) + 1, prefix); - ++nextEmptyPos; - return insertPos; - } - - private void checkForInsert() { - if (nextEmptyPos >= capacity) { - throw new IllegalStateException("No space for new record.\n" + - "For RANK expressions with TOP-N filter(e.g. rk <= 100), we maintain a fixed capacity " + - "array for TOP-N sorting for each partition, and if there are too many same rankings, " + - "the result that needs to be retained will exceed the capacity of the array.\n" + - "Please consider using ROW_NUMBER expression or disabling TOP-N sorting by setting " + - "saprk.sql.execution.topNPushDownFOrWindow.enabled to false."); - } - } - - private int nthRecordCompareTo(UnsafeRow row, long prefix) { - int nthPos = n - 1; - return sortComparator.compare(array.get(nthPos << 1), array.get((nthPos << 1) + 1), row, prefix); - } - - private boolean hasDistinctTopN() { - int nthPosition = (n - 1) << 1; - return sortComparator.compare(array.get(nthPosition), array.get(nthPosition + 1), // nth record - array.get(nthPosition + 2), array.get(nthPosition + 3)) // (n + 1)th record - != 0; // not eq - } - - /** - * This is copied from - * {@link org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.SortedIterator}. - * */ - public final class TopNSortedIterator extends UnsafeSorterIterator implements Cloneable { - private final int numRecords; - private int position; - private int offset; - private Object baseObject; - private long baseOffset; - private long keyPrefix; - private int recordLength; - private long currentPageNumber; - private final TaskContext taskContext = TaskContext.get(); - - private TopNSortedIterator(int numRecords, int offset) { - this.numRecords = numRecords; - this.position = 0; - this.offset = offset; - } - - public TopNSortedIterator clone() { - TopNSortedIterator iter = new TopNSortedIterator(numRecords, offset); - iter.position = position; - iter.baseObject = baseObject; - iter.baseOffset = baseOffset; - iter.keyPrefix = keyPrefix; - iter.recordLength = recordLength; - iter.currentPageNumber = currentPageNumber; - return iter; - } - - @Override - public int getNumRecords() { - return numRecords; - } - - @Override - public boolean hasNext() { - return position / 2 < numRecords; - } - - @Override - public void loadNext() { - // Kill the task in case it has been marked as killed. This logic is from - // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order - // to avoid performance overhead. This check is added here in 'loadNext()' instead of in - // 'hasNext()' because it's technically possible for the caller to be relying on - // 'getNumRecords()' instead of 'hasNext()' to know when to stop. - if (taskContext != null) { - taskContext.killTaskIfInterrupted(); - } - // This pointer points to a 4-byte record length, followed by the record's bytes - final long recordPointer = array.get(offset + position); - currentPageNumber = TaskMemoryManager.decodePageNumber(recordPointer); - int uaoSize = UnsafeAlignedOffset.getUaoSize(); - baseObject = memoryManager.getPage(recordPointer); - // Skip over record length - baseOffset = memoryManager.getOffsetInPage(recordPointer) + uaoSize; - recordLength = UnsafeAlignedOffset.getSize(baseObject, baseOffset - uaoSize); - keyPrefix = array.get(offset + position + 1); - position += 2; - } - - @Override - public Object getBaseObject() { - return baseObject; - } - - @Override - public long getBaseOffset() { - return baseOffset; - } - - @Override - public long getCurrentPageNumber() { - return currentPageNumber; - } - - @Override - public int getRecordLength() { - return recordLength; - } - - @Override - public long getKeyPrefix() { - return keyPrefix; - } - } - - /** - * Return an iterator over record pointers in sorted order. For efficiency, all calls to - * {@code next()} will return the same mutable object. - * */ - public UnsafeSorterIterator getSortedIterator() { - return new TopNSortedIterator(nextEmptyPos, 0); - } -} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/topnsort/UnsafePartitionedTopNSorter.java b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/topnsort/UnsafePartitionedTopNSorter.java deleted file mode 100644 index 57941aefb4fc8a3234c0ab22b6d45294ae09c639..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/topnsort/UnsafePartitionedTopNSorter.java +++ /dev/null @@ -1,263 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.topnsort; - -import java.util.*; -import java.util.function.Supplier; - -import com.google.common.annotations.VisibleForTesting; - -import org.apache.spark.TaskContext; -import org.apache.spark.memory.MemoryConsumer; -import org.apache.spark.memory.TaskMemoryManager; -import org.apache.spark.sql.catalyst.expressions.UnsafeRow; -import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.UnsafeAlignedOffset; -import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.util.collection.unsafe.sort.*; - -/** - * Partitioned top n sorter based on {@link org.apache.spark.sql.execution.topnsort.UnsafeInMemoryTopNSorter}. - * The implementation mostly refers to {@link UnsafeExternalSorter}. - * */ -public final class UnsafePartitionedTopNSorter extends MemoryConsumer { - private final TaskMemoryManager taskMemoryManager; - private TopNSortComparator sortComparator; - - /** - * Memory pages that hold the records being sorted. The pages in this list are freed when - * spilling, although in principle we could recycle these pages across spills (on the other hand, - * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager itself). - * */ - private final LinkedList allocatedPages = new LinkedList<>(); - private final Map partToSorters = new LinkedHashMap<>(); - - private final int n; - private final boolean strictTopN; - private MemoryBlock currentPage = null; - private long pageCursor = -1; - private long peakMemoryUsedBytes = 0; - - public static UnsafePartitionedTopNSorter create( - int n, - boolean strictTopN, - TaskMemoryManager taskMemoryManager, - TaskContext taskContext, - Supplier recordComparatorSupplier, - PrefixComparator prefixComparator, - long pageSizeBytes, - boolean canSortFullyWithPrefix) { - assert n > 0 : "Top n must be positive"; - assert recordComparatorSupplier != null; - return new UnsafePartitionedTopNSorter(n, strictTopN, taskMemoryManager, taskContext, - recordComparatorSupplier, prefixComparator, pageSizeBytes, canSortFullyWithPrefix); - } - - private UnsafePartitionedTopNSorter( - int n, - boolean strictTopN, - TaskMemoryManager taskMemoryManager, - TaskContext taskContext, - Supplier recordComparatorSupplier, - PrefixComparator prefixComparator, - long pageSizeBytes, - boolean canSortFullyWithPrefix) { - super(taskMemoryManager, pageSizeBytes, taskMemoryManager.getTungstenMemoryMode()); - this.n = n; - this.strictTopN = strictTopN; - this.taskMemoryManager = taskMemoryManager; - this.sortComparator = new TopNSortComparator(recordComparatorSupplier.get(), - prefixComparator, taskMemoryManager, canSortFullyWithPrefix); - - // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at - // the end of the task. This is necessary to avoid memory leaks in when the downstream operator - // does not fully consume the sorter's output (e.g. sort followed by limit). - taskContext.addTaskCompletionListener(context -> { - cleanupResources(); - }); - } - - @Override - public long spill(long size, MemoryConsumer trigger) { - throw new UnsupportedOperationException("Spill is unsupported operation in topN in-memory sorter"); - } - - /** - * Return the total memory usage of this sorter, including the data pages and the sorter's pointer array. - * */ - private long getMemoryUsage() { - long totalPageSize = 0; - for (MemoryBlock page : allocatedPages) { - totalPageSize += page.size(); - } - for (UnsafeInMemoryTopNSorter sorter : partToSorters.values()) { - totalPageSize += sorter.getMemoryUsage(); - } - return totalPageSize; - } - - private void updatePeakMemoryUsed() { - long mem = getMemoryUsage(); - if (mem > peakMemoryUsedBytes) { - peakMemoryUsedBytes = mem; - } - } - - /** - * Return the peak memory used so far, in bytes. - * */ - public long getPeakMemoryUsedBytes() { - updatePeakMemoryUsed(); - return peakMemoryUsedBytes; - } - - @VisibleForTesting - public int getNumberOfAllocatedPages() { - return allocatedPages.size(); - } - - /** - * Free this sorter's data pages. - * - * @return the number of bytes freed. - * */ - private long freeMemory() { - updatePeakMemoryUsed(); - long memoryFreed = 0; - for (MemoryBlock block : allocatedPages) { - memoryFreed += block.size(); - freePage(block); - } - allocatedPages.clear(); - currentPage = null; - pageCursor = 0; - for (UnsafeInMemoryTopNSorter sorter: partToSorters.values()) { - memoryFreed += sorter.getMemoryUsage(); - sorter.freeMemory(); - } - partToSorters.clear(); - sortComparator = null; - return memoryFreed; - } - - /** - * Frees this sorter's in-memory data structures and cleans up its spill files. - * */ - public void cleanupResources() { - synchronized (this) { - freeMemory(); - } - } - - /** - * Allocates an additional page in order to insert an additional record. This will request - * additional memory from the memory manager and spill if the requested memory can not be obtained. - * - * @param required the required space in the data page, in bytes, including space for storing the record size - * */ - private void acquireNewPageIfNecessary(int required) { - if (currentPage == null || - pageCursor + required > currentPage.getBaseOffset() + currentPage.size()) { - currentPage = allocatePage(required); - pageCursor = currentPage.getBaseOffset(); - allocatedPages.add(currentPage); - } - } - - public void insertRow(UnsafeRow partKey, UnsafeRow row, long prefix) { - UnsafeInMemoryTopNSorter sorter = - partToSorters.computeIfAbsent( - partKey, - k -> new UnsafeInMemoryTopNSorter(n, strictTopN, this, taskMemoryManager, sortComparator) - ); - final int position = sorter.insert(row, prefix); - if (position >= 0) { - final int uaoSize = UnsafeAlignedOffset.getUaoSize(); - // Need 4 or 8 bytes to store the record length. - final int length = row.getSizeInBytes(); - final int required = length + uaoSize; - acquireNewPageIfNecessary(required); - - final Object base = currentPage.getBaseObject(); - final long recordAddress = - taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor); - UnsafeAlignedOffset.putSize(base, pageCursor, length); - pageCursor += uaoSize; - Platform.copyMemory(row.getBaseObject(), row.getBaseOffset(), base, pageCursor, length); - pageCursor += length; - - sorter.updateRecordPointer(position, recordAddress); - } - } - - public Map getPartKeyToSorter() { - return partToSorters; - } - - static final class TopNSortComparator { - private final RecordComparator recordComparator; - private final PrefixComparator prefixComparator; - private final TaskMemoryManager memoryManager; - private final boolean needCompareFully; - - TopNSortComparator( - RecordComparator recordComparator, - PrefixComparator prefixComparator, - TaskMemoryManager memoryManager, - boolean canSortFullyWithPrefix) { - this.recordComparator = recordComparator; - this.prefixComparator = prefixComparator; - this.memoryManager = memoryManager; - this.needCompareFully = !canSortFullyWithPrefix; - } - - public int compare(long pointer1, long prefix1, long pointer2, long prefix2) { - final int prefixComparisonResult = prefixComparator.compare(prefix1, prefix2); - if (needCompareFully && prefixComparisonResult == 0) { - final int uaoSize = UnsafeAlignedOffset.getUaoSize(); - final Object baseObject1 = memoryManager.getPage(pointer1); - final long baseOffset1 = memoryManager.getOffsetInPage(pointer1) + uaoSize; - final int baseLength1 = UnsafeAlignedOffset.getSize(baseObject1, baseOffset1 - uaoSize); - final Object baseObject2 = memoryManager.getPage(pointer2); - final long baseOffset2 = memoryManager.getOffsetInPage(pointer2) + uaoSize; - final int baseLength2 = UnsafeAlignedOffset.getSize(baseObject2, baseOffset2 - uaoSize); - return recordComparator.compare(baseObject1, baseOffset1, baseLength1, baseObject2, - baseOffset2, baseLength2); - } else { - return prefixComparisonResult; - } - } - - public int compare(long pointer, long prefix1, UnsafeRow row, long prefix2) { - final int prefixComparisonResult = prefixComparator.compare(prefix1, prefix2); - if (needCompareFully && prefixComparisonResult == 0) { - final int uaoSize = UnsafeAlignedOffset.getUaoSize(); - final Object baseObject1 = memoryManager.getPage(pointer); - final long baseOffset1 = memoryManager.getOffsetInPage(pointer) + uaoSize; - final int baseLength1 = UnsafeAlignedOffset.getSize(baseObject1, baseOffset1 - uaoSize); - final Object baseObject2 = row.getBaseObject(); - final long baseOffset2 = row.getBaseOffset(); - final int baseLength2 = row.getSizeInBytes(); - return recordComparator.compare(baseObject1, baseOffset1, baseLength1, baseObject2, - baseOffset2, baseLength2); - } else { - return prefixComparisonResult; - } - } - } -} diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/window/TopNPushDownForWindow.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/window/TopNPushDownForWindow.scala deleted file mode 100644 index d53c6e0286c21e026c5073335e96a5a00010a71a..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/window/TopNPushDownForWindow.scala +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.window; - -import com.huawei.boostkit.spark.ColumnarPluginConfig -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{FilterExec, SortExec, SparkPlan, TopNSortExec} - -object TopNPushDownForWindow extends Rule[SparkPlan] with PredicateHelper { - override def apply(plan: SparkPlan): SparkPlan = { - if (!ColumnarPluginConfig.getConf.topNPushDownForWindowEnable) { - return plan - } - - plan.transform { - case f @ FilterExec(condition, - w @ WindowExec(Seq(windowExpression), _, orderSpec, sort: SortExec)) - if orderSpec.nonEmpty && isTopNExpression(windowExpression) => - var topn = Int.MaxValue - val nonTopNConditions = splitConjunctivePredicates(condition).filter { - case LessThan(e: NamedExpression, IntegerLiteral(n)) - if e.exprId == windowExpression.exprId => - topn = Math.min(topn, n - 1) - false - case LessThanOrEqual(e: NamedExpression, IntegerLiteral(n)) - if e.exprId == windowExpression.exprId => - topn = Math.min(topn, n) - false - case GreaterThan(IntegerLiteral(n), e: NamedExpression) - if e.exprId == windowExpression.exprId => - topn = Math.min(topn, n - 1) - false - case GreaterThanOrEqual(IntegerLiteral(n), e: NamedExpression) - if e.exprId == windowExpression.exprId => - topn = Math.min(topn, n) - false - case EqualTo(e: NamedExpression, IntegerLiteral(n)) - if n == 1 && e.exprId == windowExpression.exprId => - topn = 1 - false - case EqualTo(IntegerLiteral(n), e: NamedExpression) - if n == 1 && e.exprId == windowExpression.exprId => - topn = 1 - false - case _ => true - } - - // topn <= SQLConf.get.topNPushDownForWindowThreshold 100. - if (topn> 0 && topn <= ColumnarPluginConfig.getConf.topNPushDownForWindowThreshold) { - val strictTopN = isStrictTopN(windowExpression) - val topNSortExec = TopNSortExec( - topn, strictTopN, w.partitionSpec, w.orderSpec, sort.global, sort.child) - val newCondition = if (nonTopNConditions.isEmpty) { - Literal.TrueLiteral - } else { - nonTopNConditions.reduce(And) - } - FilterExec(newCondition, w.copy(child = topNSortExec)) - } else { - f - } - } - } - - private def isTopNExpression(e: Expression): Boolean = e match { - case Alias(child, _) => isTopNExpression(child) - case WindowExpression(windowFunction, _) - if windowFunction.isInstanceOf[Rank] => true - case _ => false - } - - private def isStrictTopN(e: Expression): Boolean = e match { - case Alias(child, _) => isStrictTopN(child) - case WindowExpression(windowFunction, _) => windowFunction.isInstanceOf[RowNumber] - } -} diff --git a/omnioperator/omniop-spark-extension/spark-extension-shims/spark32-shim/src/main/scala/org/apache/spark/sql/execution/UnaryExecNodeShim.scala b/omnioperator/omniop-spark-extension/spark-extension-shims/spark32-shim/src/main/scala/org/apache/spark/sql/execution/UnaryExecNodeShim.scala new file mode 100644 index 0000000000000000000000000000000000000000..1052e70bf1f692d6b4a8e1440f593a468309f40d --- /dev/null +++ b/omnioperator/omniop-spark-extension/spark-extension-shims/spark32-shim/src/main/scala/org/apache/spark/sql/execution/UnaryExecNodeShim.scala @@ -0,0 +1,27 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.expressions.{NamedExpression, SortOrder} + +abstract class UnaryExecNodeShim(sortOrder: Seq[SortOrder], projectList: Seq[NamedExpression]) + extends UnaryExecNode { + + override def outputOrdering: Seq[SortOrder] = sortOrder +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-modify/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatWriter.scala b/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-modify/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatWriter.scala index 465123d8730d39656d242bd4785e22680312ce05..14fd3117ca8dfd38c1014a63a11a88e856990bae 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-modify/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatWriter.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-modify/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatWriter.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.FileFormatWriter.ConcurrentOutputWriterSpec -import org.apache.spark.sql.execution.{ColumnarProjectExec, ColumnarSortExec, OmniColumnarToRowExec, ProjectExec, SQLExecution, SortExec, SparkPlan, UnsafeExternalRowSorter} +import org.apache.spark.sql.execution.{ColumnarProjectExec, ColumnarSortExec, OmniColumnarToRowExec, ProjectExec, SQLExecution, SortExec, SparkPlan} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StringType import org.apache.spark.sql.vectorized.ColumnarBatch diff --git a/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-shim/src/main/scala/org/apache/spark/sql/execution/UnaryExecNodeShim.scala b/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-shim/src/main/scala/org/apache/spark/sql/execution/UnaryExecNodeShim.scala new file mode 100644 index 0000000000000000000000000000000000000000..1052e70bf1f692d6b4a8e1440f593a468309f40d --- /dev/null +++ b/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-shim/src/main/scala/org/apache/spark/sql/execution/UnaryExecNodeShim.scala @@ -0,0 +1,27 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.expressions.{NamedExpression, SortOrder} + +abstract class UnaryExecNodeShim(sortOrder: Seq[SortOrder], projectList: Seq[NamedExpression]) + extends UnaryExecNode { + + override def outputOrdering: Seq[SortOrder] = sortOrder +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/spark-extension-shims/spark34-shim/src/main/scala/org/apache/spark/sql/execution/UnaryExecNodeShim.scala b/omnioperator/omniop-spark-extension/spark-extension-shims/spark34-shim/src/main/scala/org/apache/spark/sql/execution/UnaryExecNodeShim.scala new file mode 100644 index 0000000000000000000000000000000000000000..8e49f1d388f776a969f917f4d0a6493b1242a320 --- /dev/null +++ b/omnioperator/omniop-spark-extension/spark-extension-shims/spark34-shim/src/main/scala/org/apache/spark/sql/execution/UnaryExecNodeShim.scala @@ -0,0 +1,29 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.expressions.{NamedExpression, SortOrder} + +abstract class UnaryExecNodeShim(sortOrder: Seq[SortOrder], projectList: Seq[NamedExpression]) + extends OrderPreservingUnaryExecNode { + + override def outputExpressions: Seq[NamedExpression] = projectList + + override def orderingExpressions: Seq[SortOrder] = sortOrder +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/spark-extension-shims/spark35-shim/src/main/scala/org/apache/spark/sql/execution/UnaryExecNodeShim.scala b/omnioperator/omniop-spark-extension/spark-extension-shims/spark35-shim/src/main/scala/org/apache/spark/sql/execution/UnaryExecNodeShim.scala new file mode 100644 index 0000000000000000000000000000000000000000..8e49f1d388f776a969f917f4d0a6493b1242a320 --- /dev/null +++ b/omnioperator/omniop-spark-extension/spark-extension-shims/spark35-shim/src/main/scala/org/apache/spark/sql/execution/UnaryExecNodeShim.scala @@ -0,0 +1,29 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.expressions.{NamedExpression, SortOrder} + +abstract class UnaryExecNodeShim(sortOrder: Seq[SortOrder], projectList: Seq[NamedExpression]) + extends OrderPreservingUnaryExecNode { + + override def outputExpressions: Seq[NamedExpression] = projectList + + override def orderingExpressions: Seq[SortOrder] = sortOrder +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala index ec03b275363ec70940db54d610ea30b5972906c6..403a1baed983d35df6cb21813ba4ff1201443307 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala @@ -55,108 +55,15 @@ class HiveResourceSuite extends SparkFunSuite { } test("queryBySparkSql-HiveDataSource") { - runner.runQuery("q1", 1) - runner.runQuery("q2", 1) - runner.runQuery("q3", 1) - runner.runQuery("q4", 1) - runner.runQuery("q5", 1) - runner.runQuery("q6", 1) - runner.runQuery("q7", 1) - runner.runQuery("q8", 1) - runner.runQuery("q9", 1) - runner.runQuery("q10", 1) - runner.runQuery("q11", 1) - runner.runQuery("q12", 1) - runner.runQuery("q13", 1) - runner.runQuery("q14a", 1) - runner.runQuery("q14b", 1) - runner.runQuery("q15", 1) - runner.runQuery("q16", 1) - runner.runQuery("q17", 1) - runner.runQuery("q18", 1) runner.runQuery("q19", 1) - runner.runQuery("q20", 1) - runner.runQuery("q21", 1) - runner.runQuery("q22", 1) - runner.runQuery("q23a", 1) - runner.runQuery("q23b", 1) - runner.runQuery("q24a", 1) - runner.runQuery("q24b", 1) - runner.runQuery("q25", 1) - runner.runQuery("q26", 1) - runner.runQuery("q27", 1) - runner.runQuery("q28", 1) - runner.runQuery("q29", 1) - runner.runQuery("q30", 1) - runner.runQuery("q31", 1) - runner.runQuery("q32", 1) - runner.runQuery("q33", 1) - runner.runQuery("q34", 1) - runner.runQuery("q35", 1) - runner.runQuery("q36", 1) - runner.runQuery("q37", 1) - runner.runQuery("q38", 1) - runner.runQuery("q39a", 1) - runner.runQuery("q39b", 1) - runner.runQuery("q40", 1) - runner.runQuery("q41", 1) - runner.runQuery("q42", 1) - runner.runQuery("q43", 1) - runner.runQuery("q44", 1) - runner.runQuery("q45", 1) - runner.runQuery("q46", 1) runner.runQuery("q47", 1) - runner.runQuery("q48", 1) - runner.runQuery("q49", 1) - runner.runQuery("q50", 1) - runner.runQuery("q51", 1) - runner.runQuery("q52", 1) runner.runQuery("q53", 1) - runner.runQuery("q54", 1) - runner.runQuery("q55", 1) - runner.runQuery("q56", 1) - runner.runQuery("q57", 1) - runner.runQuery("q58", 1) - runner.runQuery("q59", 1) - runner.runQuery("q60", 1) - runner.runQuery("q61", 1) - runner.runQuery("q62", 1) runner.runQuery("q63", 1) - runner.runQuery("q64", 1) - runner.runQuery("q65", 1) - runner.runQuery("q66", 1) - runner.runQuery("q67", 1) - runner.runQuery("q68", 1) - runner.runQuery("q69", 1) - runner.runQuery("q70", 1) runner.runQuery("q71", 1) - runner.runQuery("q72", 1) - runner.runQuery("q73", 1) - runner.runQuery("q74", 1) - runner.runQuery("q75", 1) - runner.runQuery("q76", 1) - runner.runQuery("q77", 1) - runner.runQuery("q78", 1) runner.runQuery("q79", 1) - runner.runQuery("q80", 1) - runner.runQuery("q81", 1) runner.runQuery("q82", 1) - runner.runQuery("q83", 1) runner.runQuery("q84", 1) - runner.runQuery("q85", 1) - runner.runQuery("q86", 1) - runner.runQuery("q87", 1) - runner.runQuery("q88", 1) runner.runQuery("q89", 1) - runner.runQuery("q90", 1) - runner.runQuery("q91", 1) - runner.runQuery("q92", 1) - runner.runQuery("q93", 1) - runner.runQuery("q94", 1) - runner.runQuery("q95", 1) - runner.runQuery("q96", 1) - runner.runQuery("q97", 1) - runner.runQuery("q98", 1) runner.runQuery("q99", 1) } diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderPlanTestBase.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderPlanTestBase.scala deleted file mode 100644 index d8d7d0bd97807cb6bdcaad024e54232687b43937..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderPlanTestBase.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Project} -import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.util.sideBySide - -trait HeuristicJoinReorderPlanTestBase extends PlanTest { - - def outputsOf(plans: LogicalPlan*): Seq[Attribute] = { - plans.map(_.output).reduce(_ ++ _) - } - - def assertEqualJoinPlans( - optimizer: RuleExecutor[LogicalPlan], - originalPlan: LogicalPlan, - groundTruthBestPlan: LogicalPlan): Unit = { - val analyzed = originalPlan.analyze - val optimized = optimizer.execute(analyzed) - val expected = EliminateResolvedHint.apply(groundTruthBestPlan.analyze) - - assert(equivalentOutput(analyzed, expected)) - assert(equivalentOutput(analyzed, optimized)) - - compareJoinOrder(optimized, expected) - } - - protected def equivalentOutput(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { - normalizeExprIds(plan1).output == normalizeExprIds(plan2).output - } - - protected def compareJoinOrder(plan1: LogicalPlan, plan2: LogicalPlan): Unit = { - val normalized1 = normalizePlan(normalizeExprIds(plan1)) - val normalized2 = normalizePlan(normalizeExprIds(plan2)) - if (!sameJoinPlan(normalized1, normalized2)) { - fail( - s""" - |== FAIL: Plans do not match === - |${sideBySide( - rewriteNameFromAttrNullability(normalized1).treeString, - rewriteNameFromAttrNullability(normalized2).treeString).mkString("\n")} - """.stripMargin) - } - } - - private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { - (plan1, plan2) match { - case (j1: Join, j2: Join) => - (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right) - && j1.hint.leftHint == j2.hint.leftHint && j1.hint.rightHint == j2.hint.rightHint) || - (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left) - && j1.hint.leftHint == j2.hint.rightHint && j1.hint.rightHint == j2.hint.leftHint) - case (p1: Project, p2: Project) => - p1.projectList == p2.projectList && sameJoinPlan(p1.child, p2.child) - case _ => - plan1 == plan2 - } - } -} diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderSuite.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderSuite.scala deleted file mode 100644 index c7ea9bd95ad953b136ff9f5f196a0e0bace24027..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderSuite.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} -import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.plans.logical.ColumnStat -import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} - -class HeuristicJoinReorderSuite - extends HeuristicJoinReorderPlanTestBase with StatsEstimationTestBase { - - private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( - attr("t1.k-1-2") -> rangeColumnStat(2, 0), - attr("t1.v-1-10") -> rangeColumnStat(10, 0), - attr("t2.k-1-5") -> rangeColumnStat(5, 0), - attr("t3.v-1-100") -> rangeColumnStat(100, 0), - attr("t4.k-1-2") -> rangeColumnStat(2, 0), - attr("t4.v-1-10") -> rangeColumnStat(10, 0), - attr("t5.k-1-5") -> rangeColumnStat(5, 0), - attr("t5.v-1-5") -> rangeColumnStat(5, 0) - )) - - private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) - private val nameToColInfo: Map[String, (Attribute, ColumnStat)] = - columnInfo.map(kv => kv._1.name -> kv) - - private val t1 = StatsTestPlan( - outputList = Seq("t1.k-1-2", "t1.v-1-10").map(nameToAttr), - rowCount = 1000, - size = Some(1000 * (8 + 4 + 4)), - attributeStats = AttributeMap(Seq("t1.k-1-2", "t1.v-1-10").map(nameToColInfo))) - - private val t2 = StatsTestPlan( - outputList = Seq("t2.k-1-5").map(nameToAttr), - rowCount = 20, - size = Some(20 * (8 + 4)), - attributeStats = AttributeMap(Seq("t2.k-1-5").map(nameToColInfo))) - - private val t3 = StatsTestPlan( - outputList = Seq("t3.v-1-100").map(nameToAttr), - rowCount = 100, - size = Some(100 * (8 + 4)), - attributeStats = AttributeMap(Seq("t3.v-1-100").map(nameToColInfo))) - - test("reorder 3 tables") { - val originalPlan = - t1.join(t2).join(t3) - .where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && - (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) - - val analyzed = originalPlan.analyze - val optimized = HeuristicJoinReorder.apply(analyzed).select(outputsOf(t1, t2, t3): _*) - val expected = - t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) - .join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) - .select(outputsOf(t1, t2, t3): _*) - - assert(equivalentOutput(analyzed, expected)) - assert(equivalentOutput(analyzed, optimized)) - - compareJoinOrder(optimized, expected) - } -} diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala index ec03b275363ec70940db54d610ea30b5972906c6..403a1baed983d35df6cb21813ba4ff1201443307 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala @@ -55,108 +55,15 @@ class HiveResourceSuite extends SparkFunSuite { } test("queryBySparkSql-HiveDataSource") { - runner.runQuery("q1", 1) - runner.runQuery("q2", 1) - runner.runQuery("q3", 1) - runner.runQuery("q4", 1) - runner.runQuery("q5", 1) - runner.runQuery("q6", 1) - runner.runQuery("q7", 1) - runner.runQuery("q8", 1) - runner.runQuery("q9", 1) - runner.runQuery("q10", 1) - runner.runQuery("q11", 1) - runner.runQuery("q12", 1) - runner.runQuery("q13", 1) - runner.runQuery("q14a", 1) - runner.runQuery("q14b", 1) - runner.runQuery("q15", 1) - runner.runQuery("q16", 1) - runner.runQuery("q17", 1) - runner.runQuery("q18", 1) runner.runQuery("q19", 1) - runner.runQuery("q20", 1) - runner.runQuery("q21", 1) - runner.runQuery("q22", 1) - runner.runQuery("q23a", 1) - runner.runQuery("q23b", 1) - runner.runQuery("q24a", 1) - runner.runQuery("q24b", 1) - runner.runQuery("q25", 1) - runner.runQuery("q26", 1) - runner.runQuery("q27", 1) - runner.runQuery("q28", 1) - runner.runQuery("q29", 1) - runner.runQuery("q30", 1) - runner.runQuery("q31", 1) - runner.runQuery("q32", 1) - runner.runQuery("q33", 1) - runner.runQuery("q34", 1) - runner.runQuery("q35", 1) - runner.runQuery("q36", 1) - runner.runQuery("q37", 1) - runner.runQuery("q38", 1) - runner.runQuery("q39a", 1) - runner.runQuery("q39b", 1) - runner.runQuery("q40", 1) - runner.runQuery("q41", 1) - runner.runQuery("q42", 1) - runner.runQuery("q43", 1) - runner.runQuery("q44", 1) - runner.runQuery("q45", 1) - runner.runQuery("q46", 1) runner.runQuery("q47", 1) - runner.runQuery("q48", 1) - runner.runQuery("q49", 1) - runner.runQuery("q50", 1) - runner.runQuery("q51", 1) - runner.runQuery("q52", 1) runner.runQuery("q53", 1) - runner.runQuery("q54", 1) - runner.runQuery("q55", 1) - runner.runQuery("q56", 1) - runner.runQuery("q57", 1) - runner.runQuery("q58", 1) - runner.runQuery("q59", 1) - runner.runQuery("q60", 1) - runner.runQuery("q61", 1) - runner.runQuery("q62", 1) runner.runQuery("q63", 1) - runner.runQuery("q64", 1) - runner.runQuery("q65", 1) - runner.runQuery("q66", 1) - runner.runQuery("q67", 1) - runner.runQuery("q68", 1) - runner.runQuery("q69", 1) - runner.runQuery("q70", 1) runner.runQuery("q71", 1) - runner.runQuery("q72", 1) - runner.runQuery("q73", 1) - runner.runQuery("q74", 1) - runner.runQuery("q75", 1) - runner.runQuery("q76", 1) - runner.runQuery("q77", 1) - runner.runQuery("q78", 1) runner.runQuery("q79", 1) - runner.runQuery("q80", 1) - runner.runQuery("q81", 1) runner.runQuery("q82", 1) - runner.runQuery("q83", 1) runner.runQuery("q84", 1) - runner.runQuery("q85", 1) - runner.runQuery("q86", 1) - runner.runQuery("q87", 1) - runner.runQuery("q88", 1) runner.runQuery("q89", 1) - runner.runQuery("q90", 1) - runner.runQuery("q91", 1) - runner.runQuery("q92", 1) - runner.runQuery("q93", 1) - runner.runQuery("q94", 1) - runner.runQuery("q95", 1) - runner.runQuery("q96", 1) - runner.runQuery("q97", 1) - runner.runQuery("q98", 1) runner.runQuery("q99", 1) } diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderPlanTestBase.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderPlanTestBase.scala deleted file mode 100644 index d8d7d0bd97807cb6bdcaad024e54232687b43937..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderPlanTestBase.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Project} -import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.util.sideBySide - -trait HeuristicJoinReorderPlanTestBase extends PlanTest { - - def outputsOf(plans: LogicalPlan*): Seq[Attribute] = { - plans.map(_.output).reduce(_ ++ _) - } - - def assertEqualJoinPlans( - optimizer: RuleExecutor[LogicalPlan], - originalPlan: LogicalPlan, - groundTruthBestPlan: LogicalPlan): Unit = { - val analyzed = originalPlan.analyze - val optimized = optimizer.execute(analyzed) - val expected = EliminateResolvedHint.apply(groundTruthBestPlan.analyze) - - assert(equivalentOutput(analyzed, expected)) - assert(equivalentOutput(analyzed, optimized)) - - compareJoinOrder(optimized, expected) - } - - protected def equivalentOutput(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { - normalizeExprIds(plan1).output == normalizeExprIds(plan2).output - } - - protected def compareJoinOrder(plan1: LogicalPlan, plan2: LogicalPlan): Unit = { - val normalized1 = normalizePlan(normalizeExprIds(plan1)) - val normalized2 = normalizePlan(normalizeExprIds(plan2)) - if (!sameJoinPlan(normalized1, normalized2)) { - fail( - s""" - |== FAIL: Plans do not match === - |${sideBySide( - rewriteNameFromAttrNullability(normalized1).treeString, - rewriteNameFromAttrNullability(normalized2).treeString).mkString("\n")} - """.stripMargin) - } - } - - private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { - (plan1, plan2) match { - case (j1: Join, j2: Join) => - (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right) - && j1.hint.leftHint == j2.hint.leftHint && j1.hint.rightHint == j2.hint.rightHint) || - (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left) - && j1.hint.leftHint == j2.hint.rightHint && j1.hint.rightHint == j2.hint.leftHint) - case (p1: Project, p2: Project) => - p1.projectList == p2.projectList && sameJoinPlan(p1.child, p2.child) - case _ => - plan1 == plan2 - } - } -} diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderSuite.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderSuite.scala deleted file mode 100644 index c7ea9bd95ad953b136ff9f5f196a0e0bace24027..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderSuite.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} -import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.plans.logical.ColumnStat -import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} - -class HeuristicJoinReorderSuite - extends HeuristicJoinReorderPlanTestBase with StatsEstimationTestBase { - - private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( - attr("t1.k-1-2") -> rangeColumnStat(2, 0), - attr("t1.v-1-10") -> rangeColumnStat(10, 0), - attr("t2.k-1-5") -> rangeColumnStat(5, 0), - attr("t3.v-1-100") -> rangeColumnStat(100, 0), - attr("t4.k-1-2") -> rangeColumnStat(2, 0), - attr("t4.v-1-10") -> rangeColumnStat(10, 0), - attr("t5.k-1-5") -> rangeColumnStat(5, 0), - attr("t5.v-1-5") -> rangeColumnStat(5, 0) - )) - - private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) - private val nameToColInfo: Map[String, (Attribute, ColumnStat)] = - columnInfo.map(kv => kv._1.name -> kv) - - private val t1 = StatsTestPlan( - outputList = Seq("t1.k-1-2", "t1.v-1-10").map(nameToAttr), - rowCount = 1000, - size = Some(1000 * (8 + 4 + 4)), - attributeStats = AttributeMap(Seq("t1.k-1-2", "t1.v-1-10").map(nameToColInfo))) - - private val t2 = StatsTestPlan( - outputList = Seq("t2.k-1-5").map(nameToAttr), - rowCount = 20, - size = Some(20 * (8 + 4)), - attributeStats = AttributeMap(Seq("t2.k-1-5").map(nameToColInfo))) - - private val t3 = StatsTestPlan( - outputList = Seq("t3.v-1-100").map(nameToAttr), - rowCount = 100, - size = Some(100 * (8 + 4)), - attributeStats = AttributeMap(Seq("t3.v-1-100").map(nameToColInfo))) - - test("reorder 3 tables") { - val originalPlan = - t1.join(t2).join(t3) - .where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && - (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) - - val analyzed = originalPlan.analyze - val optimized = HeuristicJoinReorder.apply(analyzed).select(outputsOf(t1, t2, t3): _*) - val expected = - t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) - .join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) - .select(outputsOf(t1, t2, t3): _*) - - assert(equivalentOutput(analyzed, expected)) - assert(equivalentOutput(analyzed, optimized)) - - compareJoinOrder(optimized, expected) - } -} diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala index ec03b275363ec70940db54d610ea30b5972906c6..403a1baed983d35df6cb21813ba4ff1201443307 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala @@ -55,108 +55,15 @@ class HiveResourceSuite extends SparkFunSuite { } test("queryBySparkSql-HiveDataSource") { - runner.runQuery("q1", 1) - runner.runQuery("q2", 1) - runner.runQuery("q3", 1) - runner.runQuery("q4", 1) - runner.runQuery("q5", 1) - runner.runQuery("q6", 1) - runner.runQuery("q7", 1) - runner.runQuery("q8", 1) - runner.runQuery("q9", 1) - runner.runQuery("q10", 1) - runner.runQuery("q11", 1) - runner.runQuery("q12", 1) - runner.runQuery("q13", 1) - runner.runQuery("q14a", 1) - runner.runQuery("q14b", 1) - runner.runQuery("q15", 1) - runner.runQuery("q16", 1) - runner.runQuery("q17", 1) - runner.runQuery("q18", 1) runner.runQuery("q19", 1) - runner.runQuery("q20", 1) - runner.runQuery("q21", 1) - runner.runQuery("q22", 1) - runner.runQuery("q23a", 1) - runner.runQuery("q23b", 1) - runner.runQuery("q24a", 1) - runner.runQuery("q24b", 1) - runner.runQuery("q25", 1) - runner.runQuery("q26", 1) - runner.runQuery("q27", 1) - runner.runQuery("q28", 1) - runner.runQuery("q29", 1) - runner.runQuery("q30", 1) - runner.runQuery("q31", 1) - runner.runQuery("q32", 1) - runner.runQuery("q33", 1) - runner.runQuery("q34", 1) - runner.runQuery("q35", 1) - runner.runQuery("q36", 1) - runner.runQuery("q37", 1) - runner.runQuery("q38", 1) - runner.runQuery("q39a", 1) - runner.runQuery("q39b", 1) - runner.runQuery("q40", 1) - runner.runQuery("q41", 1) - runner.runQuery("q42", 1) - runner.runQuery("q43", 1) - runner.runQuery("q44", 1) - runner.runQuery("q45", 1) - runner.runQuery("q46", 1) runner.runQuery("q47", 1) - runner.runQuery("q48", 1) - runner.runQuery("q49", 1) - runner.runQuery("q50", 1) - runner.runQuery("q51", 1) - runner.runQuery("q52", 1) runner.runQuery("q53", 1) - runner.runQuery("q54", 1) - runner.runQuery("q55", 1) - runner.runQuery("q56", 1) - runner.runQuery("q57", 1) - runner.runQuery("q58", 1) - runner.runQuery("q59", 1) - runner.runQuery("q60", 1) - runner.runQuery("q61", 1) - runner.runQuery("q62", 1) runner.runQuery("q63", 1) - runner.runQuery("q64", 1) - runner.runQuery("q65", 1) - runner.runQuery("q66", 1) - runner.runQuery("q67", 1) - runner.runQuery("q68", 1) - runner.runQuery("q69", 1) - runner.runQuery("q70", 1) runner.runQuery("q71", 1) - runner.runQuery("q72", 1) - runner.runQuery("q73", 1) - runner.runQuery("q74", 1) - runner.runQuery("q75", 1) - runner.runQuery("q76", 1) - runner.runQuery("q77", 1) - runner.runQuery("q78", 1) runner.runQuery("q79", 1) - runner.runQuery("q80", 1) - runner.runQuery("q81", 1) runner.runQuery("q82", 1) - runner.runQuery("q83", 1) runner.runQuery("q84", 1) - runner.runQuery("q85", 1) - runner.runQuery("q86", 1) - runner.runQuery("q87", 1) - runner.runQuery("q88", 1) runner.runQuery("q89", 1) - runner.runQuery("q90", 1) - runner.runQuery("q91", 1) - runner.runQuery("q92", 1) - runner.runQuery("q93", 1) - runner.runQuery("q94", 1) - runner.runQuery("q95", 1) - runner.runQuery("q96", 1) - runner.runQuery("q97", 1) - runner.runQuery("q98", 1) runner.runQuery("q99", 1) } diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderPlanTestBase.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderPlanTestBase.scala deleted file mode 100644 index d8d7d0bd97807cb6bdcaad024e54232687b43937..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderPlanTestBase.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Project} -import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.util.sideBySide - -trait HeuristicJoinReorderPlanTestBase extends PlanTest { - - def outputsOf(plans: LogicalPlan*): Seq[Attribute] = { - plans.map(_.output).reduce(_ ++ _) - } - - def assertEqualJoinPlans( - optimizer: RuleExecutor[LogicalPlan], - originalPlan: LogicalPlan, - groundTruthBestPlan: LogicalPlan): Unit = { - val analyzed = originalPlan.analyze - val optimized = optimizer.execute(analyzed) - val expected = EliminateResolvedHint.apply(groundTruthBestPlan.analyze) - - assert(equivalentOutput(analyzed, expected)) - assert(equivalentOutput(analyzed, optimized)) - - compareJoinOrder(optimized, expected) - } - - protected def equivalentOutput(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { - normalizeExprIds(plan1).output == normalizeExprIds(plan2).output - } - - protected def compareJoinOrder(plan1: LogicalPlan, plan2: LogicalPlan): Unit = { - val normalized1 = normalizePlan(normalizeExprIds(plan1)) - val normalized2 = normalizePlan(normalizeExprIds(plan2)) - if (!sameJoinPlan(normalized1, normalized2)) { - fail( - s""" - |== FAIL: Plans do not match === - |${sideBySide( - rewriteNameFromAttrNullability(normalized1).treeString, - rewriteNameFromAttrNullability(normalized2).treeString).mkString("\n")} - """.stripMargin) - } - } - - private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { - (plan1, plan2) match { - case (j1: Join, j2: Join) => - (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right) - && j1.hint.leftHint == j2.hint.leftHint && j1.hint.rightHint == j2.hint.rightHint) || - (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left) - && j1.hint.leftHint == j2.hint.rightHint && j1.hint.rightHint == j2.hint.leftHint) - case (p1: Project, p2: Project) => - p1.projectList == p2.projectList && sameJoinPlan(p1.child, p2.child) - case _ => - plan1 == plan2 - } - } -} diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderSuite.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderSuite.scala deleted file mode 100644 index c7ea9bd95ad953b136ff9f5f196a0e0bace24027..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderSuite.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} -import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.plans.logical.ColumnStat -import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} - -class HeuristicJoinReorderSuite - extends HeuristicJoinReorderPlanTestBase with StatsEstimationTestBase { - - private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( - attr("t1.k-1-2") -> rangeColumnStat(2, 0), - attr("t1.v-1-10") -> rangeColumnStat(10, 0), - attr("t2.k-1-5") -> rangeColumnStat(5, 0), - attr("t3.v-1-100") -> rangeColumnStat(100, 0), - attr("t4.k-1-2") -> rangeColumnStat(2, 0), - attr("t4.v-1-10") -> rangeColumnStat(10, 0), - attr("t5.k-1-5") -> rangeColumnStat(5, 0), - attr("t5.v-1-5") -> rangeColumnStat(5, 0) - )) - - private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) - private val nameToColInfo: Map[String, (Attribute, ColumnStat)] = - columnInfo.map(kv => kv._1.name -> kv) - - private val t1 = StatsTestPlan( - outputList = Seq("t1.k-1-2", "t1.v-1-10").map(nameToAttr), - rowCount = 1000, - size = Some(1000 * (8 + 4 + 4)), - attributeStats = AttributeMap(Seq("t1.k-1-2", "t1.v-1-10").map(nameToColInfo))) - - private val t2 = StatsTestPlan( - outputList = Seq("t2.k-1-5").map(nameToAttr), - rowCount = 20, - size = Some(20 * (8 + 4)), - attributeStats = AttributeMap(Seq("t2.k-1-5").map(nameToColInfo))) - - private val t3 = StatsTestPlan( - outputList = Seq("t3.v-1-100").map(nameToAttr), - rowCount = 100, - size = Some(100 * (8 + 4)), - attributeStats = AttributeMap(Seq("t3.v-1-100").map(nameToColInfo))) - - test("reorder 3 tables") { - val originalPlan = - t1.join(t2).join(t3) - .where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && - (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) - - val analyzed = originalPlan.analyze - val optimized = HeuristicJoinReorder.apply(analyzed).select(outputsOf(t1, t2, t3): _*) - val expected = - t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) - .join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) - .select(outputsOf(t1, t2, t3): _*) - - assert(equivalentOutput(analyzed, expected)) - assert(equivalentOutput(analyzed, optimized)) - - compareJoinOrder(optimized, expected) - } -} diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarLimitExecSuit.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarLimitExecSuit.scala index 5944618785ba67a869cba691d5ef223b7c13c045..aa35668839af8685b615f304b84b11a80220441b 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarLimitExecSuit.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarLimitExecSuit.scala @@ -66,6 +66,21 @@ class ColumnarLimitExecSuit extends ColumnarSparkPlanTest { assert(result.count() == 0) } + test("limit with offset and global limit columnar exec") { + val result = spark.sql("SELECT y FROM right WHERE x in " + + "(SELECT a FROM left WHERE a = 4 LIMIT 1 OFFSET 1)") + val plan = result.queryExecution.executedPlan + assert(plan.find(_.isInstanceOf[ColumnarLocalLimitExec]).isEmpty, + s"not match ColumnarLocalLimitExec, real plan: ${plan}") + assert(plan.find(_.isInstanceOf[LocalLimitExec]).isEmpty, + s"real plan: ${plan}") + assert(plan.find(_.isInstanceOf[ColumnarGlobalLimitExec]).isDefined, + s"not match ColumnarGlobalLimitExec, real plan: ${plan}") + assert(plan.find(_.isInstanceOf[GlobalLimitExec]).isEmpty, + s"real plan: ${plan}") + assert(result.count() == 0) + } + test("limit with rollback global limit to row-based exec") { spark.conf.set("spark.omni.sql.columnar.globalLimit", false) val result = spark.sql("SELECT a FROM left WHERE a in " + diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala index 076a3959df1943881989e1e84d80e19ea25f94bc..b7370cd5dcec7e6df7fcf4fef8a16eff6ef41873 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala @@ -56,108 +56,15 @@ class HiveResourceSuite extends SparkFunSuite { } test("queryBySparkSql-HiveDataSource") { - runner.runQuery("q1", 1) - runner.runQuery("q2", 1) - runner.runQuery("q3", 1) - runner.runQuery("q4", 1) - runner.runQuery("q5", 1) - runner.runQuery("q6", 1) - runner.runQuery("q7", 1) - runner.runQuery("q8", 1) - runner.runQuery("q9", 1) - runner.runQuery("q10", 1) - runner.runQuery("q11", 1) - runner.runQuery("q12", 1) - runner.runQuery("q13", 1) - runner.runQuery("q14a", 1) - runner.runQuery("q14b", 1) - runner.runQuery("q15", 1) - runner.runQuery("q16", 1) - runner.runQuery("q17", 1) - runner.runQuery("q18", 1) runner.runQuery("q19", 1) - runner.runQuery("q20", 1) - runner.runQuery("q21", 1) - runner.runQuery("q22", 1) - runner.runQuery("q23a", 1) - runner.runQuery("q23b", 1) - runner.runQuery("q24a", 1) - runner.runQuery("q24b", 1) - runner.runQuery("q25", 1) - runner.runQuery("q26", 1) - runner.runQuery("q27", 1) - runner.runQuery("q28", 1) - runner.runQuery("q29", 1) - runner.runQuery("q30", 1) - runner.runQuery("q31", 1) - runner.runQuery("q32", 1) - runner.runQuery("q33", 1) - runner.runQuery("q34", 1) - runner.runQuery("q35", 1) - runner.runQuery("q36", 1) - runner.runQuery("q37", 1) - runner.runQuery("q38", 1) - runner.runQuery("q39a", 1) - runner.runQuery("q39b", 1) - runner.runQuery("q40", 1) - runner.runQuery("q41", 1) - runner.runQuery("q42", 1) - runner.runQuery("q43", 1) - runner.runQuery("q44", 1) - runner.runQuery("q45", 1) - runner.runQuery("q46", 1) runner.runQuery("q47", 1) - runner.runQuery("q48", 1) - runner.runQuery("q49", 1) - runner.runQuery("q50", 1) - runner.runQuery("q51", 1) - runner.runQuery("q52", 1) runner.runQuery("q53", 1) - runner.runQuery("q54", 1) - runner.runQuery("q55", 1) - runner.runQuery("q56", 1) - runner.runQuery("q57", 1) - runner.runQuery("q58", 1) - runner.runQuery("q59", 1) - runner.runQuery("q60", 1) - runner.runQuery("q61", 1) - runner.runQuery("q62", 1) runner.runQuery("q63", 1) - runner.runQuery("q64", 1) - runner.runQuery("q65", 1) - runner.runQuery("q66", 1) - runner.runQuery("q67", 1) - runner.runQuery("q68", 1) - runner.runQuery("q69", 1) - runner.runQuery("q70", 1) runner.runQuery("q71", 1) - runner.runQuery("q72", 1) - runner.runQuery("q73", 1) - runner.runQuery("q74", 1) - runner.runQuery("q75", 1) - runner.runQuery("q76", 1) - runner.runQuery("q77", 1) - runner.runQuery("q78", 1) runner.runQuery("q79", 1) - runner.runQuery("q80", 1) - runner.runQuery("q81", 1) runner.runQuery("q82", 1) - runner.runQuery("q83", 1) runner.runQuery("q84", 1) - runner.runQuery("q85", 1) - runner.runQuery("q86", 1) - runner.runQuery("q87", 1) - runner.runQuery("q88", 1) runner.runQuery("q89", 1) - runner.runQuery("q90", 1) - runner.runQuery("q91", 1) - runner.runQuery("q92", 1) - runner.runQuery("q93", 1) - runner.runQuery("q94", 1) - runner.runQuery("q95", 1) - runner.runQuery("q96", 1) - runner.runQuery("q97", 1) - runner.runQuery("q98", 1) runner.runQuery("q99", 1) } diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderPlanTestBase.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderPlanTestBase.scala deleted file mode 100644 index d8d7d0bd97807cb6bdcaad024e54232687b43937..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderPlanTestBase.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Project} -import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.util.sideBySide - -trait HeuristicJoinReorderPlanTestBase extends PlanTest { - - def outputsOf(plans: LogicalPlan*): Seq[Attribute] = { - plans.map(_.output).reduce(_ ++ _) - } - - def assertEqualJoinPlans( - optimizer: RuleExecutor[LogicalPlan], - originalPlan: LogicalPlan, - groundTruthBestPlan: LogicalPlan): Unit = { - val analyzed = originalPlan.analyze - val optimized = optimizer.execute(analyzed) - val expected = EliminateResolvedHint.apply(groundTruthBestPlan.analyze) - - assert(equivalentOutput(analyzed, expected)) - assert(equivalentOutput(analyzed, optimized)) - - compareJoinOrder(optimized, expected) - } - - protected def equivalentOutput(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { - normalizeExprIds(plan1).output == normalizeExprIds(plan2).output - } - - protected def compareJoinOrder(plan1: LogicalPlan, plan2: LogicalPlan): Unit = { - val normalized1 = normalizePlan(normalizeExprIds(plan1)) - val normalized2 = normalizePlan(normalizeExprIds(plan2)) - if (!sameJoinPlan(normalized1, normalized2)) { - fail( - s""" - |== FAIL: Plans do not match === - |${sideBySide( - rewriteNameFromAttrNullability(normalized1).treeString, - rewriteNameFromAttrNullability(normalized2).treeString).mkString("\n")} - """.stripMargin) - } - } - - private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { - (plan1, plan2) match { - case (j1: Join, j2: Join) => - (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right) - && j1.hint.leftHint == j2.hint.leftHint && j1.hint.rightHint == j2.hint.rightHint) || - (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left) - && j1.hint.leftHint == j2.hint.rightHint && j1.hint.rightHint == j2.hint.leftHint) - case (p1: Project, p2: Project) => - p1.projectList == p2.projectList && sameJoinPlan(p1.child, p2.child) - case _ => - plan1 == plan2 - } - } -} diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderSuite.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderSuite.scala deleted file mode 100644 index c7ea9bd95ad953b136ff9f5f196a0e0bace24027..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderSuite.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} -import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.plans.logical.ColumnStat -import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} - -class HeuristicJoinReorderSuite - extends HeuristicJoinReorderPlanTestBase with StatsEstimationTestBase { - - private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( - attr("t1.k-1-2") -> rangeColumnStat(2, 0), - attr("t1.v-1-10") -> rangeColumnStat(10, 0), - attr("t2.k-1-5") -> rangeColumnStat(5, 0), - attr("t3.v-1-100") -> rangeColumnStat(100, 0), - attr("t4.k-1-2") -> rangeColumnStat(2, 0), - attr("t4.v-1-10") -> rangeColumnStat(10, 0), - attr("t5.k-1-5") -> rangeColumnStat(5, 0), - attr("t5.v-1-5") -> rangeColumnStat(5, 0) - )) - - private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) - private val nameToColInfo: Map[String, (Attribute, ColumnStat)] = - columnInfo.map(kv => kv._1.name -> kv) - - private val t1 = StatsTestPlan( - outputList = Seq("t1.k-1-2", "t1.v-1-10").map(nameToAttr), - rowCount = 1000, - size = Some(1000 * (8 + 4 + 4)), - attributeStats = AttributeMap(Seq("t1.k-1-2", "t1.v-1-10").map(nameToColInfo))) - - private val t2 = StatsTestPlan( - outputList = Seq("t2.k-1-5").map(nameToAttr), - rowCount = 20, - size = Some(20 * (8 + 4)), - attributeStats = AttributeMap(Seq("t2.k-1-5").map(nameToColInfo))) - - private val t3 = StatsTestPlan( - outputList = Seq("t3.v-1-100").map(nameToAttr), - rowCount = 100, - size = Some(100 * (8 + 4)), - attributeStats = AttributeMap(Seq("t3.v-1-100").map(nameToColInfo))) - - test("reorder 3 tables") { - val originalPlan = - t1.join(t2).join(t3) - .where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && - (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) - - val analyzed = originalPlan.analyze - val optimized = HeuristicJoinReorder.apply(analyzed).select(outputsOf(t1, t2, t3): _*) - val expected = - t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) - .join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) - .select(outputsOf(t1, t2, t3): _*) - - assert(equivalentOutput(analyzed, expected)) - assert(equivalentOutput(analyzed, optimized)) - - compareJoinOrder(optimized, expected) - } -} diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarLimitExecSuit.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarLimitExecSuit.scala index 5944618785ba67a869cba691d5ef223b7c13c045..aa35668839af8685b615f304b84b11a80220441b 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarLimitExecSuit.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarLimitExecSuit.scala @@ -66,6 +66,21 @@ class ColumnarLimitExecSuit extends ColumnarSparkPlanTest { assert(result.count() == 0) } + test("limit with offset and global limit columnar exec") { + val result = spark.sql("SELECT y FROM right WHERE x in " + + "(SELECT a FROM left WHERE a = 4 LIMIT 1 OFFSET 1)") + val plan = result.queryExecution.executedPlan + assert(plan.find(_.isInstanceOf[ColumnarLocalLimitExec]).isEmpty, + s"not match ColumnarLocalLimitExec, real plan: ${plan}") + assert(plan.find(_.isInstanceOf[LocalLimitExec]).isEmpty, + s"real plan: ${plan}") + assert(plan.find(_.isInstanceOf[ColumnarGlobalLimitExec]).isDefined, + s"not match ColumnarGlobalLimitExec, real plan: ${plan}") + assert(plan.find(_.isInstanceOf[GlobalLimitExec]).isEmpty, + s"real plan: ${plan}") + assert(result.count() == 0) + } + test("limit with rollback global limit to row-based exec") { spark.conf.set("spark.omni.sql.columnar.globalLimit", false) val result = spark.sql("SELECT a FROM left WHERE a in " +