diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewShufflePartition.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewShufflePartition.scala new file mode 100644 index 0000000000000000000000000000000000000000..1af8af40cca58e43b4cad1744e0141c21b745934 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewShufflePartition.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.sql.execution.{CoalescedPartitionSpec, ShufflePartitionSpec, SparkPlan} +import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, EnsureRequirements, ShuffleOrigin, ValidateRequirements} +import org.apache.spark.sql.internal.SQLConf + +case class OptimizeSkewShufflePartition(ensureRequirements: EnsureRequirements, parallelism: Int) + extends AQEShuffleReadRule { + + override val supportedShuffleOrigins: Seq[ShuffleOrigin] = { + Seq(ENSURE_REQUIREMENTS) + } + + private def optimizeSkewedPartitions( + shuffleId: Int, + bytesByPartitionId: Array[Long], + targetSize: Long + ): Seq[ShufflePartitionSpec] = { + logWarning(s"Enter OptimizeSkewShufflePartition optimizeSkewedPartitions") + + val smallPartitionFactor = + conf.getConf(SQLConf.ADAPTIVE_REBALANCE_PARTITIONS_SMALL_PARTITION_FACTOR) + + bytesByPartitionId.indices.flatMap { reduceIndex => + val bytes = bytesByPartitionId(reduceIndex) + if (bytes > targetSize) { + val newPartitionSpec = ShufflePartitionsUtil.createSkewPartitionSpecs( + shuffleId, reduceIndex, targetSize, smallPartitionFactor) + if (newPartitionSpec.isEmpty) { + CoalescedPartitionSpec(reduceIndex, reduceIndex + 1, bytes) :: Nil + } else { + logWarning(s"For shuffle $shuffleId partition $reduceIndex is skewed, " + + s"split it into ${newPartitionSpec.get.size} parts.") + + newPartitionSpec.get + } + } else { + CoalescedPartitionSpec(reduceIndex, reduceIndex + 1, bytes) :: Nil + } + } + } + + private def tryOptimizeSkewedPartitions(shuffle: ShuffleQueryStageExec): SparkPlan = { + logWarning(s"OptimizeSkewShufflePartition shuffle.shuffleOrigin ${shuffle.shuffleOrigin}") + logWarning(s"Enter OptimizeSkewShufflePartition tryOptimizeSkewedPartitions") + + val mapStats = shuffle.mapStats + if (mapStats.isEmpty) { + return shuffle + } + val partitionCount = mapStats.get.bytesByPartitionId.length + val partitionMean = mapStats.get.bytesByPartitionId.sum / partitionCount + val advisorySize = if (partitionCount < parallelism) { + math.max(partitionMean * 2L, 64L * 1024 * 1024) + } else { + 1L * 1024 * 1024 * 1024 + } + if (mapStats.isEmpty || mapStats.get.bytesByPartitionId.forall(_ <= advisorySize)) { + return shuffle + } + + val newPartitionsSpec = optimizeSkewedPartitions( + mapStats.get.shuffleId, mapStats.get.bytesByPartitionId, advisorySize) + + if (newPartitionsSpec.length == mapStats.get.bytesByPartitionId.length) { + shuffle + } else { + AQEShuffleReadExec(shuffle, newPartitionsSpec) + } + } + + private def optimizedSkewPlan(plan: SparkPlan): SparkPlan = plan.transformUp { + case stage: ShuffleQueryStageExec + if (isSupported(stage.shuffle)) => { + if (stage.shuffleOrigin == ENSURE_REQUIREMENTS) { + tryOptimizeSkewedPartitions(stage) + } else { + logWarning(s"optimizedSkewPlan stage.shuffleOrigin ${stage.shuffleOrigin}") + + stage + } + } + } + + override def apply(plan: SparkPlan): SparkPlan = { + logWarning(s"Enter OptimizeSkewShufflePartition") + + if (!conf.getConf(SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED)) { + return plan + } + + val optimized = optimizedSkewPlan(plan) + + logWarning(s"ensureRequirements.requiredDistribution.isDefined" + + s"${ensureRequirements.requiredDistribution.isDefined}") + + val requirementSatisfied = if (ensureRequirements.requiredDistribution.isDefined) { + ValidateRequirements.validate(optimized, ensureRequirements.requiredDistribution.get) + } else { + ValidateRequirements.validate(optimized) + } + if (requirementSatisfied) { + logWarning(s"OptimizeSkewShufflePartition optimized") + + optimized + } else { + logWarning(s"OptimizeSkewShufflePartition ensureRequirements.apply(optimized)") + + ensureRequirements.apply(optimized) + } + } +}