diff --git a/omnioperator/omniop-deploy-tool/command/command_line_tmp b/omnioperator/omniop-deploy-tool/command/command_line_tmp new file mode 100644 index 0000000000000000000000000000000000000000..e1c4d24c4db9228c894d2f64b537d43f64582f1b --- /dev/null +++ b/omnioperator/omniop-deploy-tool/command/command_line_tmp @@ -0,0 +1,22 @@ +spark-sql --archives hdfs://{hostname}:9000/user/{user}/{omni_tar_name}.tar.gz#omni \ +--deploy-mode client \ +--driver-cores 8 \ +--driver-memory 40g \ +--master yarn \ +--executor-cores 12 \ +--executor-memory 5g \ +--conf spark.memory.offHeap.enabled=true \ +--conf spark.memory.offHeap.size=35g \ +--num-executors 24 \ +--conf spark.executor.extraJavaOptions='-XX:+UseG1GC' \ +--conf spark.locality.wait=0 \ +--conf spark.network.timeout=600 \ +--conf spark.serializer=org.apache.spark.serializer.KryoSerializer \ +--conf spark.sql.adaptive.enabled=true \ +--conf spark.sql.adaptive.skewedJoin.enabled=true \ +--conf spark.sql.autoBroadcastJoinThreshold=100M \ +--conf spark.sql.broadcastTimeout=600 \ +--conf spark.sql.shuffle.partitions=600 \ +--conf spark.sql.orc.impl=native \ +--conf spark.task.cpus=1 \ +--properties-file {spark_conf_path}/{conf_file_name} \ No newline at end of file diff --git a/omnioperator/omniop-deploy-tool/conf/config b/omnioperator/omniop-deploy-tool/conf/config new file mode 100644 index 0000000000000000000000000000000000000000..a5023e413f5ba1bcee8e574aaef58f07bdfe1d24 --- /dev/null +++ b/omnioperator/omniop-deploy-tool/conf/config @@ -0,0 +1,13 @@ +#!/usr/bin/env bash +# omnioperator期望的spark版本号,当前支持的spark版本有:3.1.1 3.3.1 3.4.3 3.5.2 +expect_spark_version=3.3.1 +# omnioperator版本号 +omnioperator_version=1.8.0 +# 部署目录 +target_path=/opt/omni-operator +# 是否安装sve版本的omnioperator,默认为false(小写字母) +sve_flag=false +# spark的conf目录 +spark_conf_path=/usr/local/spark/conf +# 执行完脚本后是否进行功能验证 +omni_check=true \ No newline at end of file diff --git a/omnioperator/omniop-deploy-tool/conf/omni.conf b/omnioperator/omniop-deploy-tool/conf/omni.conf new file mode 100644 index 0000000000000000000000000000000000000000..d4392f5de8c45a258365aa6134ed8ac3455f2a98 --- /dev/null +++ b/omnioperator/omniop-deploy-tool/conf/omni.conf @@ -0,0 +1,28 @@ +# <----Spark template----> +#数学运算中小数舍入模式,默认为DOWN。HALF_UP表示向最接近数字方向舍入,如果与两个相邻数字的距离相等,则向上舍入,就是通常讲的四舍五入。DOWN表示截断,即向零方向舍入。 +RoundingRule=DOWN +#Decimal操作结果是否检查溢出,默认为CHECK_RESCALE。CHECK_RESCALE表示检查溢出,NOT_CHECK_RESCALE表示不检查溢出。 +CheckReScaleRule=CHECK_RESCALE +#Replace操作中,对待空字符是否替换,默认为NOT_REPLACE。REPLAEC表示替换,NOT_REPLACE表示不替换。 +#例如,InputStr="apple", ReplaceStr="*", SearchStr="",openLooKeng会将字母中间的空字符替换,得到OutputStr="*a*p*p*l*e*"。Spark则不替换,得到OutputStr="apple"。 +EmptySearchStrReplaceRule=NOT_REPLACE +#Decimal转Double过程中,C++直接转换或先转为字符串再进行转换,默认为CONVERT_WITH_STRING。CAST表示直接转换,CONVERT_WITH_STRING表示先转为字符串再进行转换。 +CastDecimalToDoubleRule=CONVERT_WITH_STRING +#Substr操作中,负数索引超出最小索引,直接返回空串或仍继续取字符串,默认为INTERCEPT_FROM_BEYOND。EMPTY_STRING表示返回空串,INTERCEPT_FROM_BEYOND表示继续取字符串。 +#例如,str="apple", strLength=5, startIndex=-7, subStringLength=3。 字符串长度为5,从索引-7的位置取3个字符。"apple"长度为5,最小负数索引为-4,由于-7小于-4,OLK直接返回空串,Spark则仍从-7的位置取3个字符后仍继续取字符串,直到取到值"a"后返回 +NegativeStartIndexOutOfBoundsRule=INTERCEPT_FROM_BEYOND +#是否支持ContainerVector,默认为NOT_SUPPORT。SUPPORT表示支持,NOT_SUPPORT表示不支持。 +SupportContainerVecRule=NOT_SUPPORT +#字符串转Date过程中,是否支持降低精度,默认为ALLOW_REDUCED_PRECISION。NOT_ALLOW_REDUCED_PRECISION表示不允许降低精度,ALLOW_REDUCED_PRECISION表示允许降低精度。 +#例如,openLooKeng必须完整书写ISO日期扩展格式,不能省略Month和Day,如1996-02-08。Spark支持省略Month和Day,如1996-02-28, 1996-02, 1996都支持。 +StringToDateFormatRule=ALLOW_REDUCED_PRECISION +#VectorBatch是否包含filter column,默认为NO_EXPR。NO_EXPR表示不包含filter column,EXPR_FILTER表示包含filter column。 +SupportExprFilterRule=EXPR_FILTER +#在substr运算时,默认为IS_SUPPORT,为IS_NOT_SUPPORT时,表示不支持startIndex=0时从第一个元素开始取,因为默认起始索引从1开始,若起始索引为0,默认返回空字符串,为IS_SUPPORT时,表示支持substr函数在startIndex=0时支持从第一个元素开始取。 +ZeroStartIndexSupportRule=IS_SUPPORT +#表达式是否校验, +ExpressionVerifyRule=NOT_VERIFY + +# <----Other properties----> +# 是否开启codegen函数批处理,默认关闭 +enableBatchExprEvaluate=false diff --git a/omnioperator/omniop-deploy-tool/conf/omnioperator_tmp.conf b/omnioperator/omniop-deploy-tool/conf/omnioperator_tmp.conf new file mode 100644 index 0000000000000000000000000000000000000000..948d219ddf4b8bbec39b09fd7968457b07155035 --- /dev/null +++ b/omnioperator/omniop-deploy-tool/conf/omnioperator_tmp.conf @@ -0,0 +1,29 @@ +spark.master yarn +spark.eventLog.enabled true +spark.eventLog.dir hdfs://{host_name}:9000/spark2-history +spark.eventLog.compress true +spark.history.fs.logDirectory hdfs://{host_name}:9000/spark2-history +spark.sql.optimizer.runtime.bloomFilter.enabled true +spark.driverEnv.LD_LIBRARY_PATH {target_path}/lib +spark.driverEnv.LD_PRELOAD {target_path}/lib/libjemalloc.so.2 +spark.driverEnv.OMNI_HOME {target_path} +spark.driver.extraClassPath {target_path}/lib/boostkit-omniop-spark-{spark_version}-{omni_version}-aarch64.jar:{target_path}/lib/boostkit-omniop-bindings-{omni_version}-aarch64.jar:{target_path}/lib/dependencies/protobuf-java-3.15.8.jar:{target_path}/lib/dependencies/boostkit-omniop-native-reader-{spark_version}-{omni_version}.jar +spark.driver.extraLibraryPath {target_path}/lib +spark.driver.defaultJavaOptions -Djava.library.path={target_path}/lib +spark.executorEnv.LD_LIBRARY_PATH ${PWD}/omni/{omni_package_name}/lib +spark.executorEnv.LD_PRELOAD ${PWD}/omni/{omni_package_name}/lib/libjemalloc.so.2 +spark.executorEnv.MALLOC_CONF narenas:2 +spark.executorEnv.OMNI_HOME ${PWD}/omni/{omni_package_name} +spark.executor.extraClassPath ${PWD}/omni/{omni_package_name}/lib/boostkit-omniop-spark-{spark_version}-{omni_version}-aarch64.jar:${PWD}/omni/{omni_package_name}/lib/boostkit-omniop-bindings-{omni_version}-aarch64.jar:${PWD}/omni/{omni_package_name}/lib/dependencies/protobuf-java-3.15.8.jar:${PWD}/omni/{omni_package_name}/lib/dependencies/boostkit-omniop-native-reader-{spark_version}-{omni_version}.jar +spark.executor.extraLibraryPath ${PWD}/omni/{omni_package_name}/lib +spark.omni.sql.columnar.fusion false +spark.shuffle.manager org.apache.spark.shuffle.sort.OmniColumnarShuffleManager +spark.sql.codegen.wholeStage false +spark.sql.extensions com.huawei.boostkit.spark.ColumnarPlugin +spark.omni.sql.columnar.RewriteSelfJoinInInPredicate true +spark.sql.execution.filterMerge.enabled true +spark.omni.sql.columnar.dedupLeftSemiJoin true +spark.omni.sql.columnar.radixSort.enabled true +spark.executorEnv.MALLOC_CONF tcache:false +spark.sql.adaptive.coalescePartitions.minPartitionNum 200 +spark.sql.join.columnar.preferShuffledHashJoin true \ No newline at end of file diff --git a/omnioperator/omniop-deploy-tool/deploy.sh b/omnioperator/omniop-deploy-tool/deploy.sh new file mode 100644 index 0000000000000000000000000000000000000000..7e01bffd2bb1dff87d3191d20596e010c229d622 --- /dev/null +++ b/omnioperator/omniop-deploy-tool/deploy.sh @@ -0,0 +1,417 @@ +#!/bin/bash +tool_root_dir=$(cd $(dirname "$0")||exit 1;pwd) +packages_dir="$tool_root_dir/omnioperator" +host_name=$(hostname) +user=$(whoami) + +#环境变量生效 +dos2unix ./conf/config +source ./conf/config + +################################## 函数 ################################################ +function_start() { + echo "" + echo "$1" +} + +function_end() { + echo "$1" + echo "" +} + +check_spark_version(){ + function_start "------Start checking spark version------" + + # 获取当前系统的spark版本 + full_spark_version=$(spark-sql --version 2>&1 | grep 'version' | head -n 1 | sed -E 's/.*version ([0-9]+\.[0-9]+\.[0-9])+.*/\1/') + spark_version=$(spark-sql --version 2>&1 | grep 'version' | head -n 1 | sed -E 's/.*version ([0-9]+\.[0-9]+)\.[0-9]+.*/\1/') + + if [ $spark_version == $(echo ${expect_spark_version} | cut -d'.' -f1,2) ]; then + if [ $full_spark_version == ${expect_spark_version} ]; then + echo "INFO: Spark version is right. Expect Spark version is consistent with the SYSTEM spark version." + else + echo "INFO: Spark version is $full_spark_version." + echo "INFO: Omni expected spark version is ${expect_spark_version}." + fi + else + echo "ERROR: Spark version is wrong! Expect spark version is consistent with the SYSTEM spark version." + echo "Spark version is $full_spark_version." + echo "Omni expected spark version is ${expect_spark_version}." + exit 1 + fi + + function_end "------Finish checking spark version------" +} + +check_cpu_model(){ + function_start "------Start checking cpu model------" + + cpu_model=$(lscpu | grep "Model name" | sed 's/Model name:[[:space:]]*//') + if [ $cpu_model == "Kunpeng-920" ] && [ ${sve_flag} == true ]; then + echo "ERROR: Kunpeng-920 don't support omnioperator-SVE version!" + exit 1 + else + echo "INFO: Check over." + fi + + function_end "------Finish checking cpu model------" +} + +generate_dir(){ + function_start "------Start creating target dir------" + + # 检查用户是否有权限新建该文件夹 + if [ ! -w "$(dirname "$target_path")" ]; then + echo "Error: You do not have permission to create a directory in $(dirname "$target_path")." + exit 1 + fi + + # 提示用户确认输入 + read -p "Target path is: "$target_path", please confirm create it or not. (y/n): " confirm + + # 检查用户确认的输入 + if [[ "$confirm" == "y" || "$confirm" == "Y" ]]; then + # 检查目录是否已存在 + if [ -d "$target_path" ]; then + read -p "dir "$target_path" is existed, do you want to replace it? (y/n): " delete_confirm + if [[ "$delete_confirm" == "y" || "$delete_confirm" == "Y" ]]; then + rm -rf "$target_path" + echo "INFO: The old directory has been replaced: "$target_path"。" + else + echo "INFO: Process exit." + exit 1 + fi + fi + + # 创建目录 + mkdir -p "${target_path}/lib" "${target_path}/conf" + echo "INFO: dir "$target_path" has been created。" + else + echo "INFO: Process exit." + echo "If you want to install OmniOperator in another dir," + echo "Please change configuration in ${tool_root_dir}/conf/config" + exit 1 + fi + + function_end "------Finish creating target dir------" +} + +unzip_package(){ + function_start "------Start unziping package------" + # 检查系统类型 + os_type=$(grep -i "^NAME=" /etc/os-release | awk -F= '{print tolower($2)}' | tr -d '"') + + if [ "$os_type" != "openeuler" ] && [ "$os_type" != "centos" ]; then + echo "Error: do not support: $os_type" + exit 1 + fi + + if [ "$os_type" == "centos" ] && [ ${sve_flag} == true ]; then + echo "Error: CentOS don't support 'SVE' version." + exit 1 + fi + + if [ ${sve_flag} != true ] && [ ${sve_flag} != false ]; then + echo "Error: sve_flag is not a boolean value." + exit 1 + fi + + # 创建unzip_file,检查 是否有权限 及 文件是否存在 + if [ -f "${packages_dir}/unzip_file" ]; then + rm -rf ${packages_dir}/unzip_file + fi + + if [ ! -w "$(dirname "${packages_dir}")" ]; then + echo "Error: You do not have permission to create a directory in $(dirname "${packages_dir}")." + exit 1 + else + mkdir $packages_dir/unzip_file + fi + unzip_file="$packages_dir/unzip_file" + cd $packages_dir + + # 检查文件是否存在 + missing_files=() + if [[ ! -f "${packages_dir}/Dependency_library_${os_type}.zip" ]]; then + missing_files+=("Dependency_library_${os_type}.zip") + fi + + if [[ ! -f "${packages_dir}/BoostKit-omniop_${omnioperator_version}.zip" ]]; then + missing_files+=("BoostKit-omniop_${omnioperator_version}.zip") + fi + + if [[ ! -f "${packages_dir}/boostkit-omniop-spark-${expect_spark_version}-${omnioperator_version}-aarch64.zip" ]]; then + missing_files+=("boostkit-omniop-spark-${expect_spark_version}-${omnioperator_version}-aarch64.zip") + fi + + if [[ ${#missing_files[@]} -gt 0 ]]; then + echo "ERROR: The following packages are missing in ${packages_dir}:" + for file in "${missing_files[@]}"; do + echo "- ${packages_dir}/$file" + done + exit 1 + fi + + #解压Dependency_library + if [ ${sve_flag} == false ]; then + unzip -q Dependency_library_${os_type}.zip #大小写忽略 + mv ${packages_dir}/Dependency_library_${os_type}/* ${unzip_file} + else + unzip -q Dependency_library_${os_type}-sve.zip #大小写忽略 + mv ${packages_dir}/Dependency_library_${os_type}-sve/* ${unzip_file} + fi + find . -type d -name "Dependency_library_${os_type}*" ! -name '*.zip' -exec rm -rf {} + + echo "Info: Dependency_library_${os_type}.zip has been unzipped" + + #解压BoostKit-omniop_${omnioperator_version}.zip + if [ ${sve_flag} == false ]; then + unzip -q BoostKit-omniop_${omnioperator_version}.zip boostkit-omniop-operator-${omnioperator_version}-aarch64-${os_type}.tar.gz + tar -zxf boostkit-omniop-operator-${omnioperator_version}-aarch64-${os_type}.tar.gz -C ${packages_dir} + else + unzip -q BoostKit-omniop_${omnioperator_version}.zip boostkit-omniop-operator-${omnioperator_version}-aarch64-${os_type}-sve.tar.gz + tar -zxf boostkit-omniop-operator-${omnioperator_version}-aarch64-${os_type}-sve.tar.gz -C ${packages_dir} + fi + mv ${packages_dir}/boostkit-omniop-operator-${omnioperator_version}-aarch64/* ${unzip_file} + rm -rf boostkit-omniop-operator-${omnioperator_version}-aarch64* + echo "Info: BoostKit-omniop_${omnioperator_version}.zip has been unzipped" + + #解压boostkit-omniop-spark-${expect_spark_version}-${omnioperator_version}-aarch64.zip + if [ ${sve_flag} == false ]; then + unzip -q boostkit-omniop-spark-${expect_spark_version}-${omnioperator_version}-aarch64.zip boostkit-omniop-spark-${expect_spark_version}-${omnioperator_version}-aarch64-${os_type}.zip + unzip -q boostkit-omniop-spark-${expect_spark_version}-${omnioperator_version}-aarch64-${os_type}.zip -d ${unzip_file} + rm -rf boostkit-omniop-spark-${expect_spark_version}-${omnioperator_version}-aarch64-${os_type}.zip + else + unzip -q boostkit-omniop-spark-${expect_spark_version}-${omnioperator_version}-aarch64.zip boostkit-omniop-spark-${expect_spark_version}-${omnioperator_version}-aarch64-${os_type}-sve.zip + unzip -q boostkit-omniop-spark-${expect_spark_version}-${omnioperator_version}-aarch64-${os_type}-sve.zip -d ${unzip_file} + rm -rf boostkit-omniop-spark-${expect_spark_version}-${omnioperator_version}-aarch64-${os_type}-sve.zip + fi + + cd ${unzip_file} + tar -zxf dependencies.tar.gz + rm -rf dependencies.tar.gz + cd ${packages_dir} + echo "Info: boostkit-omniop-spark-${expect_spark_version}-${omnioperator_version}-aarch64.zip has been unzipped" + + mv ${unzip_file}/* ${target_path}/lib + rm -rf ${unzip_file} + echo "Info: all unzipped files have been moved to ${target_path}/lib" + + function_end "------Finish unziping package------" +} + +generate_omniconf(){ + cp ${tool_root_dir}/conf/omni.conf ${target_path}/conf/ +} + +upload_hdfs(){ + function_start "------Start uploading hdfs------" + username=$(whoami) + hdfs_dir="/user/${username}" + + # 如果hdfs上用户目录不存在,新增一个 + hdfs dfs -test -e ${hdfs_dir} + if [ $? -eq 1 ]; then + echo "Info: hdfs: ${hdfs_dir} is not exist, creating now" + hdfs dfs -mkdir ${hdfs_dir} + fi + + # 打包omnioperator,准备传到hdfs ${hdfs_dir},存在同名旧文件会直接覆盖 + tar -czf ${target_path}.tar.gz -C $(dirname "$target_path") $(basename "$target_path") + + # 检查 HDFS 路径下是否存在同名文件,如存在,则先删除再推送omnioperator包到hdfs + hdfs dfs -test -e ${hdfs_dir}/${target_path##*/}.tar.gz + if [ $? -eq 0 ]; then + hadoop fs -rm ${hdfs_dir}/${target_path##*/}.tar.gz + fi + hadoop fs -put ${target_path}.tar.gz ${hdfs_dir} + + hdfs dfs -test -e ${hdfs_dir}/${target_path##*/}.tar.gz + if [ $? -eq 0 ]; then + echo "Info: successfully upload hdfs" + else + echo "ERROR: ${target_path##*/}.tar.gz didn't exist in hdfs." + exit 1 + fi + + function_end "------Finish uploading hdfs------" +} + +generate_spark_defaults(){ + function_start "------Start generating spark defaults------" + + tmp="${tool_root_dir}/conf/omnioperator_tmp.conf" + target_dir="${spark_conf_path}" + + # 检查 omni.conf 文件是否存在 + if [ ! -f "$tmp" ]; then + echo "Error: $tmp does not exist!" + exit 1 + fi + + conf_file_name="${omnioperator_${expect_spark_version}_${omnioperator_version}.conf}" + conf_file_path="$(dirname "$tmp")/${conf_file_name}" + cp $tmp $conf_file_path + + # 替换文件中的内容 + sed -i "s|{host_name}|${host_name}|g" "$conf_file_path" + sed -i "s|{target_path}|${target_path}|g" "$conf_file_path" + sed -i "s|{spark_version}|${expect_spark_version}|g" "$conf_file_path" + sed -i "s|{omni_version}|${omnioperator_version}|g" "$conf_file_path" + sed -i "s|{omni_package_name}|${target_path##*/}|g" "$conf_file_path" + + if [ -f "$target_dir/${conf_file_name}" ]; then + echo "INFO: File ${conf_file_name} has been existed in target path:${target_dir}. Replace by new one..." + rm "$target_dir/${conf_file_name}" + fi + + # 移动修改后的文件到目标目录 + mv "$conf_file_path" "$target_dir" + echo "File ${conf_file_name} has been moved to $target_dir" + + function_end "------Finish generating spark defaults------" +} + +generate_command_line(){ + function_start "------Start generating command line------" + command_line_tmp="${tool_root_dir}/command/command_line_tmp" + + # 检查 command_line 文件是否存在 + if [ ! -f "$command_line_tmp" ]; then + echo "Error: $command_line_tmp does not exist!" + exit 1 + fi + + # 生成目标文件副本 + command_line=$(dirname "$command_line_tmp")/command_line_${expect_spark_version}_${omnioperator_version} + cp ${command_line_tmp} ${command_line} + + # 替换文件中的内容 + sed -i "s|{hostname}|${host_name}|g" "$command_line" + sed -i "s|{user}|${user}|g" "$command_line" + sed -i "s|{omni_tar_name}|${target_path##*/}|g" "$command_line" + sed -i "s|{conf_file_name}|"omnioperator_${expect_spark_version}_${omnioperator_version}.conf"|g" "$command_line" + sed -i "s|{spark_conf_path}|${spark_conf_path}|g" "$command_line" + + function_end "------Finish generating command line------" + + if [ ${sve_flag} == true ]; then + echo "Deployment Successful. OmniOperator version: ${omnioperator_version}-sve; Spark version: ${expect_spark_version}." + else + echo "Deployment Successful. OmniOperator version: ${omnioperator_version}; Spark version: ${expect_spark_version}." + fi +} + +check_omni_function(){ + function_start "------Start checking omni function------" + + omni_start_command=$(cat "${tool_root_dir}/command/command_line_${expect_spark_version}_${omnioperator_version}") + + # 启动OmniOp后,执行下面sql建表,执行SQL82,验证OmniOp是否生效 + result=$(echo " + CREATE DATABASE IF NOT EXISTS test_db; + + USE test_db; + + CREATE TABLE IF NOT EXISTS item ( + i_item_id INT, + i_item_desc STRING, + i_current_price DECIMAL(10, 2), + i_manufact_id INT, + i_item_sk INT + ); + + CREATE TABLE IF NOT EXISTS inventory ( + inv_item_sk INT, + inv_quantity_on_hand INT, + inv_date_sk INT + ); + + CREATE TABLE IF NOT EXISTS date_dim ( + d_date_sk INT, + d_date STRING + ); + + CREATE TABLE IF NOT EXISTS store_sales ( + ss_item_sk INT, + ss_item_id INT, + ss_quantity INT + ); + + INSERT INTO item (i_item_id, i_item_desc, i_current_price, i_manufact_id, i_item_sk) + VALUES + (1, 'Item A', 80.00, 512, 1), + (2, 'Item B', 90.00, 409, 2), + (3, 'Item Omni', 100.00, 677, 3), + (4, 'Item C', 95.00, 16, 4); + + INSERT INTO inventory (inv_item_sk, inv_quantity_on_hand, inv_date_sk) + VALUES + (1, 200, 1), + (2, 150, 2), + (3, 300, 3), + (4, 250, 4); + + INSERT INTO date_dim (d_date_sk, d_date) + VALUES + (1, '1998-06-29'), + (2, '1998-07-01'), + (3, '1998-08-01'), + (4, '1998-08-29'); + + INSERT INTO store_sales (ss_item_sk, ss_item_id, ss_quantity) + VALUES + (1, 1, 50), + (2, 2, 60), + (3, 3, 70), + (4, 4, 80); + + set spark.sql.adaptive.enabled=false; + EXPLAIN SELECT i_item_id, i_item_desc, i_current_price + FROM item, inventory, date_dim, store_sales + WHERE i_current_price BETWEEN 76 AND 106 + AND inv_item_sk = i_item_sk + AND d_date_sk = inv_date_sk + AND d_date BETWEEN CAST('1998-06-29' AS DATE) AND CAST('1998-08-29' AS DATE) + AND i_manufact_id IN (512, 409, 677, 16) + AND inv_quantity_on_hand BETWEEN 100 AND 500 + AND ss_item_sk = i_item_sk + GROUP BY i_item_id, i_item_desc, i_current_price + ORDER BY i_item_id + LIMIT 100; + + DROP TABLE IF EXISTS store_sales; + DROP TABLE IF EXISTS date_dim; + DROP TABLE IF EXISTS inventory; + DROP TABLE IF EXISTS item; + + DROP DATABASE IF EXISTS test_db; + " | $omni_start_command 2>&1) + + if [ $? -ne 0 ]; then + echo "ERROR: Error occurred during spark-sql execution." + echo "Error details: ${result}" + exit 1 + fi + + if echo "result" | grep -q "Omni"; then + echo "INFO: Omnioperator is effective." + else + echo "ERROR: Omnioperator is NOT effective." + fi + function_end "------Finish checking omni function------" +} + +################################## 执行 ################################################ +check_spark_version +check_cpu_model +generate_dir +unzip_package +generate_omniconf +upload_hdfs +generate_spark_defaults +generate_command_line +if [ ${omni_check,,} == true ]; then + check_omni_function +fi +echo "-----------ALL Finish----------" +exit 0 \ No newline at end of file diff --git a/omnioperator/omniop-deploy-tool/readme.md b/omnioperator/omniop-deploy-tool/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..00dada22ce759ae39c6f44c6efbdaaa89c70b8aa --- /dev/null +++ b/omnioperator/omniop-deploy-tool/readme.md @@ -0,0 +1,34 @@ +# OmniOperator自动部署工具使用说明: + +## 1. 修改配置文件./conf/config +1.1. `expect_spark_version`:需要部署OmniOperator包对应的spark版本,而不是填系统当前的spark版本 +例如:当前系统的spark版本为3.4.1, 但OmniOperator的包对应spark版本是3.4.3,则expect_spark_version应填3.4.3,否则会报错找不到对应的包 + +1.2. `omnioperator_version`:OmniOperator的版本 + +1.3. `target_path`:OmniOperator的部署路径 + +1.4. `sve_flag`:是否选择sve版本(centos系统暂不支持,cpu为鲲鹏920的机器暂不支持) +若机器不支持OmniOperator SVE版本,而选择SVE版本安装,将安装失败 + +## 2. 获取OmniOperator安装包 +获取: + +BoostKit-omniop_{omni_version}.zip +boostkit-omniop-spark-{spark_version}-{omni_version}-aarch64.zip +Dependency_library_{os_type}.zip + +三个包,并放置在脚本的根目录/omnioperator文件夹下 + +## 3. 脚本执行 +在脚本的根目录下,执行: +```shell +bash deploy.sh +``` + +## 4. 获取命令行 +在脚本的根目录下,执行: +```shell +vim ./command/command_line +``` +并复制文件中的内容,粘贴到机器上,执行即可 \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/build.sh b/omnioperator/omniop-native-reader/cpp/build.sh index c21dba905a8ba6ba17c7d0448405ab82ba81d981..280ad87f706ab882f38fbb27c7a0c90f1c56cab3 100644 --- a/omnioperator/omniop-native-reader/cpp/build.sh +++ b/omnioperator/omniop-native-reader/cpp/build.sh @@ -22,7 +22,7 @@ if [ -z "$OMNI_HOME" ]; then OMNI_HOME=/opt fi -export OMNI_INCLUDE_PATH=$OMNI_HOME/lib/include +export OMNI_INCLUDE_PATH=$OMNI_HOME/lib/include:$OMNI_HOME/lib/lib/APL/include export OMNI_INCLUDE_PATH=$OMNI_INCLUDE_PATH:$OMNI_HOME/lib export CPLUS_INCLUDE_PATH=$OMNI_INCLUDE_PATH:$CPLUS_INCLUDE_PATH echo "OMNI_INCLUDE_PATH=$OMNI_INCLUDE_PATH" diff --git a/omnioperator/omniop-native-reader/cpp/src/CMakeLists.txt b/omnioperator/omniop-native-reader/cpp/src/CMakeLists.txt index 346db130b87edc7658b3c5bd1b9acdbc6acb102e..29c23253f5e65b0a00374b58572b9df4025db8fc 100644 --- a/omnioperator/omniop-native-reader/cpp/src/CMakeLists.txt +++ b/omnioperator/omniop-native-reader/cpp/src/CMakeLists.txt @@ -1,18 +1,20 @@ include_directories(SYSTEM "/user/local/include") -set (PROJ_TARGET native_reader) +set(PROJ_TARGET native_reader) -set (SOURCE_FILES +set(SOURCE_FILES jni/OrcColumnarBatchJniWriter.cpp jni/OrcColumnarBatchJniReader.cpp jni/jni_common.cpp jni/ParquetColumnarBatchJniReader.cpp + jni/ParquetColumnarBatchJniWriter.cpp parquet/ParquetReader.cpp parquet/ParquetColumnReader.cpp parquet/ParquetTypedRecordReader.cpp parquet/ParquetDecoder.cpp parquet/ParquetExpression.cpp + parquet/ParquetWriter.cpp common/UriInfo.cc orcfile/OrcFileOverride.cc orcfile/OrcHdfsFileOverride.cc @@ -29,40 +31,44 @@ set (SOURCE_FILES arrowadapter/HdfsAdapter.cc arrowadapter/LocalfsAdapter.cc common/JulianGregorianRebase.cpp - common/TimeRebaseInfo.cpp) + common/TimeRebaseInfo.cpp + common/PredicateUtil.cpp) #Find required protobuf package find_package(Protobuf REQUIRED) -if(PROTOBUF_FOUND) +if (PROTOBUF_FOUND) message(STATUS "protobuf library found") -else() +else () message(FATAL_ERROR "protobuf library is needed but cant be found") -endif() +endif () +include_directories($ENV{OMNI_HOME}/lib/lib/APL/include) include_directories(${Protobuf_INCLUDE_DIRS}) include_directories(${CMAKE_CURRENT_BINARY_DIR}) -add_library (${PROJ_TARGET} SHARED ${SOURCE_FILES} ${PROTO_SRCS} ${PROTO_HDRS} ${PROTO_SRCS_VB} ${PROTO_HDRS_VB}) +add_library(${PROJ_TARGET} SHARED ${SOURCE_FILES} ${PROTO_SRCS} ${PROTO_HDRS} ${PROTO_SRCS_VB} ${PROTO_HDRS_VB}) find_package(Arrow REQUIRED) find_package(Parquet REQUIRED) find_package(ArrowDataset REQUIRED) +find_package(nlohmann_json 3.7.3 REQUIRED) #JNI target_include_directories(${PROJ_TARGET} PUBLIC $ENV{JAVA_HOME}/include) target_include_directories(${PROJ_TARGET} PUBLIC $ENV{JAVA_HOME}/include/linux) target_include_directories(${PROJ_TARGET} PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) -target_link_libraries (${PROJ_TARGET} PUBLIC +target_link_libraries(${PROJ_TARGET} PUBLIC Arrow::arrow_shared ArrowDataset::arrow_dataset_shared Parquet::parquet_shared + nlohmann_json::nlohmann_json orc - boostkit-omniop-vector-1.7.0-aarch64 + boostkit-omniop-vector-1.8.0-aarch64 hdfs - ) +) set_target_properties(${PROJ_TARGET} PROPERTIES - LIBRARY_OUTPUT_DIRECTORY ${root_directory}/releases + LIBRARY_OUTPUT_DIRECTORY ${root_directory}/releases ) install(TARGETS ${PROJ_TARGET} DESTINATION lib) diff --git a/omnioperator/omniop-native-reader/cpp/src/common/PredicateCondition.h b/omnioperator/omniop-native-reader/cpp/src/common/PredicateCondition.h new file mode 100644 index 0000000000000000000000000000000000000000..06751c1fa590160a699d19c03bc56ebe4e8a3788 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/common/PredicateCondition.h @@ -0,0 +1,357 @@ +/** +* Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICATECONDITION_H +#define PREDICATECONDITION_H + +#include +#include +#include +#include +#include "vector/vector_common.h" +#include "util/bit_util.h" +#include "xsimd/xsimd.hpp" +#include "PredicateOperatorType.h" + +using omniruntime::BitUtil; +using omniruntime::mem::AlignedBuffer; +using omniruntime::vec::BaseVector; +using omniruntime::vec::Vector; +using omniruntime::vec::unsafe::UnsafeVector; +using omniruntime::vec::unsafe::UnsafeBaseVector; +using omniruntime::exception::OmniException; + +namespace common { + + template + bool equalTo(const V &left, const V &right) { + return left == right; + } + + template + bool greater(const V &left, const V &right) { + return left > right; + } + + template + bool greaterEqual(const V &left, const V &right) { + return left >= right; + } + + template + bool less(const V &left, const V &right) { + return left < right; + } + + template + bool lessEqual(const V &left, const V &right) { + return left <= right; + } + + template + xsimd::batch_bool batchEqualTo(const xsimd::batch &left, const xsimd::batch &right) { + return xsimd::eq(left, right); + } + + template + xsimd::batch_bool batchGreater(const xsimd::batch &left, const xsimd::batch &right) { + return xsimd::gt(left, right); + } + + template + xsimd::batch_bool batchGreaterEqual(const xsimd::batch &left, const xsimd::batch &right) { + return xsimd::ge(left, right); + } + + template + xsimd::batch_bool batchLess(const xsimd::batch &left, const xsimd::batch &right) { + return xsimd::lt(left, right); + } + + template + xsimd::batch_bool batchLessEqual(const xsimd::batch &left, const xsimd::batch &right) { + return xsimd::le(left, right); + } + + class PredicateCondition { + public: + explicit PredicateCondition(const PredicateOperatorType &opType) { + this->op = opType; + } + + virtual ~PredicateCondition() = default; + + virtual uint8_t *compute(std::vector &vecBatch) = 0; + + virtual void init(const int32_t size) { + bitSize = size; + bitMarkBuf = std::make_unique>(BitUtil::Nbytes(size) + 8); + bitMark = bitMarkBuf->GetBuffer(); + } + + virtual bool isAllNull(int32_t columnIndex) { + return false; + } + + virtual bool isAllNotNull(int32_t columnIndex) { + return false; + } + + void buildNullColumns(int32_t columnCount) { + for (int32_t i = 0; i < columnCount; i++) { + if (isAllNull(i)) { + isAllNullColumns.insert(i); + } + if (isAllNotNull(i)) { + isAllNotNullColumns.insert(i); + } + } + } + + std::set &getIsAllNullColumns() { + return isAllNullColumns; + } + + std::set &getIsAllNotNullColumns() { + return isAllNotNullColumns; + } + + protected: + PredicateOperatorType op; + int32_t bitSize = 0; + std::unique_ptr> bitMarkBuf; + uint8_t *bitMark = nullptr; + std::set isAllNullColumns; + std::set isAllNotNullColumns; + }; + + template + class LeafPredicateCondition final : public PredicateCondition { + public: + LeafPredicateCondition(const PredicateOperatorType &opType, int32_t index, T value) + : PredicateCondition(opType), index(index), value(value) { + } + + template (*BATCH_OP)(const xsimd::batch &, const xsimd::batch &), bool (*OP)(const T &, const T &)> + void computeCompare(BaseVector *vector, int32_t vectorSize) { + xsimd::batch broadcast = xsimd::batch(value); + int32_t step = static_cast(xsimd::batch::size); + Vector *realVector = reinterpret_cast *>(vector); + T *values = UnsafeVector::GetRawValues(realVector); + int32_t index = 0; + for (; index + step <= vectorSize; index += step) { + xsimd::batch valuesBatch = xsimd::load_aligned(values + index); + xsimd::batch_bool result = BATCH_OP(valuesBatch, broadcast); + uint64_t mask = result.mask(); + BitUtil::StoreBits(reinterpret_cast(bitMark), index, mask, 64); + } + for (; index < vectorSize; index++) { + bool result = OP(values[index], value); + BitUtil::SetBit(bitMark, index, result); + } + } + + uint8_t *compute(std::vector &vecBatch) override { + auto vector = vecBatch[index]; + auto vectorSize = vector->GetSize(); + switch (op) { + case TRUE: { + errno_t opTrueRet = memset_s(bitMark, BitUtil::Nbytes(bitSize), -1, BitUtil::Nbytes(bitSize)); + if (UNLIKELY(opTrueRet != EOK)) { + throw OmniException("OPERATOR_RUNTIME_ERROR", "LeafPredicateCondition TRUE memset_s fail."); + } + break; + } + case EQUAL_TO: { + computeCompare(vector, vectorSize); + break; + } + case GREATER_THAN: { + computeCompare(vector, vectorSize); + break; + } + case GREATER_THAN_OR_EQUAL: { + computeCompare(vector, vectorSize); + break; + } + case LESS_THAN: { + computeCompare(vector, vectorSize); + break; + } + case LESS_THAN_OR_EQUAL: { + computeCompare(vector, vectorSize); + break; + } + case IS_NOT_NULL: { + uint8_t *nulls = reinterpret_cast(UnsafeBaseVector::GetNulls(vector)); + int32_t step = static_cast(xsimd::batch::size); + int32_t byteLen = BitUtil::Nbytes(vectorSize); + int32_t index = 0; + for (; index + step <= byteLen; index += step) { + xsimd::batch valuesBatch = xsimd::load_aligned(nulls + index); + xsimd::batch result = xsimd::bitwise_not(valuesBatch); + result.store_aligned(bitMark + index); + } + for (; index < byteLen; index++) { + bitMark[index] = ~nulls[index]; + } + break; + } + case IS_NULL: { + errno_t opIsNullRet = memcpy_s(bitMark, BitUtil::Nbytes(vectorSize), + UnsafeBaseVector::GetNulls(vector), BitUtil::Nbytes(vectorSize)); + if (UNLIKELY(opIsNullRet != EOK)) { + throw OmniException("OPERATOR_RUNTIME_ERROR", "LeafPredicateCondition IS_NULL memcpy_s fail."); + } + break; + } + default: + throw OmniException("OPERATOR_RUNTIME_ERROR", + "LeafPredicateCondition UnSupport OperatorType: " + std::to_string(op)); + } + return bitMark; + } + + bool isAllNull(int32_t columnIndex) override { + return columnIndex == index && op == IS_NULL; + } + + bool isAllNotNull(int32_t columnIndex) override { + return columnIndex == index && op == IS_NOT_NULL; + } + + private: + int32_t index; + T value; + }; + + class NotPredicateCondition : public PredicateCondition { + public: + explicit NotPredicateCondition(std::unique_ptr child) : PredicateCondition(NOT), + child(std::move(child)) { + } + + void init(const int32_t size) override { + PredicateCondition::init(size); + child->init(size); + } + + uint8_t *compute(std::vector &vecBatch) override { + auto vectorSize = vecBatch[0]->GetSize(); + uint8_t *childResult = child->compute(vecBatch); + int32_t step = static_cast(xsimd::batch::size); + int32_t byteLen = BitUtil::Nbytes(vectorSize); + int32_t index = 0; + for (; index + step <= byteLen; index += step) { + xsimd::batch valuesBatch = xsimd::load_aligned(childResult + index); + xsimd::batch result = xsimd::bitwise_not(valuesBatch); + result.store_aligned(bitMark + index); + } + for (; index < byteLen; index++) { + bitMark[index] = ~childResult[index]; + } + return bitMark; + } + + private: + std::unique_ptr child; + }; + + class AndPredicateCondition : public PredicateCondition { + public: + AndPredicateCondition(std::unique_ptr left, std::unique_ptr right) + : PredicateCondition(AND), left(std::move(left)), right(std::move(right)) { + } + + void init(const int32_t size) override { + PredicateCondition::init(size); + left->init(size); + right->init(size); + } + + uint8_t *compute(std::vector &vecBatch) override { + auto vectorSize = vecBatch[0]->GetSize(); + uint8_t *leftResult = left->compute(vecBatch); + uint8_t *rightResult = right->compute(vecBatch); + int32_t step = static_cast(xsimd::batch::size); + int32_t byteLen = BitUtil::Nbytes(vectorSize); + int32_t index = 0; + for (; index + step <= byteLen; index += step) { + xsimd::batch leftBatch = xsimd::load_aligned(leftResult + index); + xsimd::batch rightBatch = xsimd::load_aligned(rightResult + index); + xsimd::batch result = xsimd::bitwise_and(leftBatch, rightBatch); + result.store_aligned(bitMark + index); + } + for (; index < byteLen; index++) { + bitMark[index] = leftResult[index] & rightResult[index]; + } + return bitMark; + } + + bool isAllNull(int32_t columnIndex) override { + return left->isAllNull(columnIndex) | right->isAllNull(columnIndex); + } + + bool isAllNotNull(int32_t columnIndex) override { + return left->isAllNotNull(columnIndex) | right->isAllNotNull(columnIndex); + } + + private: + std::unique_ptr left; + std::unique_ptr right; + }; + + class OrPredicateCondition : public PredicateCondition { + public: + OrPredicateCondition(std::unique_ptr left, std::unique_ptr right) + : PredicateCondition(OR), left(std::move(left)), right(std::move(right)) { + } + + void init(const int32_t size) override { + PredicateCondition::init(size); + left->init(size); + right->init(size); + } + + uint8_t *compute(std::vector &vecBatch) override { + auto vectorSize = vecBatch[0]->GetSize(); + uint8_t *leftResult = left->compute(vecBatch); + uint8_t *rightResult = right->compute(vecBatch); + int32_t step = static_cast(xsimd::batch::size); + int32_t byteLen = BitUtil::Nbytes(vectorSize); + int32_t index = 0; + for (; index + step <= byteLen; index += step) { + xsimd::batch leftBatch = xsimd::load_aligned(leftResult + index); + xsimd::batch rightBatch = xsimd::load_aligned(rightResult + index); + xsimd::batch result = xsimd::bitwise_or(leftBatch, rightBatch); + result.store_aligned(bitMark + index); + } + for (; index < byteLen; index++) { + bitMark[index] = leftResult[index] | rightResult[index]; + } + return bitMark; + } + + private: + std::unique_ptr left; + std::unique_ptr right; + }; +} + +#endif //PREDICATECONDITION_H diff --git a/omnioperator/omniop-native-reader/cpp/src/common/PredicateOperatorType.h b/omnioperator/omniop-native-reader/cpp/src/common/PredicateOperatorType.h new file mode 100644 index 0000000000000000000000000000000000000000..fceb8da119967f0feb7ae3a4a465b1e4afae6e67 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/common/PredicateOperatorType.h @@ -0,0 +1,39 @@ +/** +* Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICATEOPERATORTYPE_H +#define PREDICATEOPERATORTYPE_H + +namespace common { + enum PredicateOperatorType { + TRUE = 0, + EQUAL_TO = 1, + GREATER_THAN = 2, + GREATER_THAN_OR_EQUAL = 3, + LESS_THAN = 4, + LESS_THAN_OR_EQUAL = 5, + IS_NOT_NULL = 6, + IS_NULL = 7, + OR = 8, + AND = 9, + NOT = 10 + }; +} + +#endif //PREDICATEOPERATORTYPE_H diff --git a/omnioperator/omniop-native-reader/cpp/src/common/PredicateUtil.cpp b/omnioperator/omniop-native-reader/cpp/src/common/PredicateUtil.cpp new file mode 100644 index 0000000000000000000000000000000000000000..215ce7ecebdda12a7f1925ccb8b9a14eb49725e8 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/common/PredicateUtil.cpp @@ -0,0 +1,342 @@ +/** +* Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "PredicateUtil.h" +#include +#include + +using namespace omniruntime::vec; + +namespace common { + std::unique_ptr buildLeafPredicateCondition(PredicateOperatorType &op, + nlohmann::json &jsonCondition) { + using namespace omniruntime::type; + int32_t index = jsonCondition["index"].get(); + std::string value = jsonCondition["value"]; + DataTypeId typeId = jsonCondition["dataType"].get(); + switch (typeId) { + case OMNI_SHORT: { + return std::make_unique>(op, index, static_cast(stoi(value))); + } + case OMNI_INT: { + return std::make_unique>(op, index, stoi(value)); + } + case OMNI_LONG: { + return std::make_unique>(op, index, stol(value)); + } + case OMNI_DOUBLE: { + return std::make_unique>(op, index, stod(value)); + } + case OMNI_DATE32: { + return std::make_unique>(op, index, stoi(value)); + } + case OMNI_BOOLEAN: { + return std::make_unique>(op, index, value == "true" ? 1 : 0); + } + default: { + throw OmniException("OPERATOR_RUNTIME_ERROR", "buildLeafPredicateCondition UnSupport DataTypeId: " + std::to_string(typeId)); + } + } + } + + std::unique_ptr buildPredicateCondition(nlohmann::json &jsonCondition) { + PredicateOperatorType op = jsonCondition["op"].get(); + switch (op) { + case TRUE: + case EQUAL_TO: + case GREATER_THAN: + case GREATER_THAN_OR_EQUAL: + case LESS_THAN: + case LESS_THAN_OR_EQUAL: + case IS_NOT_NULL: + case IS_NULL: { + return buildLeafPredicateCondition(op, jsonCondition); + } + case OR: { + auto orLeft = buildPredicateCondition(jsonCondition["left"]); + auto orRight = buildPredicateCondition(jsonCondition["right"]); + return std::make_unique(std::move(orLeft), std::move(orRight)); + } + case AND: { + auto andLeft = buildPredicateCondition(jsonCondition["left"]); + auto andRight = buildPredicateCondition(jsonCondition["right"]); + return std::make_unique(std::move(andLeft), std::move(andRight)); + } + case NOT: { + auto child = buildPredicateCondition(jsonCondition["child"]); + return std::make_unique(std::move(child)); + } + default: { + throw OmniException("OPERATOR_RUNTIME_ERROR", "buildPredicateCondition UnSupport PredicateOperatorType: " + std::to_string(op)); + } + } + } + + std::unique_ptr BuildVecPredicateCondition(JNIEnv *env, jobject jsonObj, int32_t columnCount) { + if (!env->CallBooleanMethod(jsonObj, jsonMethodHas, env->NewStringUTF("vecPredicateCondition"))) { + return nullptr; + } + + auto condition = static_cast(env->CallObjectMethod(jsonObj, jsonMethodString, + env->NewStringUTF("vecPredicateCondition"))); + auto conditionPtr = env->GetStringUTFChars(condition, JNI_FALSE); + std::string conditionStr(conditionPtr); + auto jsonCondition = nlohmann::json::parse(conditionStr); + env->ReleaseStringUTFChars(condition, conditionPtr); + auto predicate= buildPredicateCondition(jsonCondition); + predicate->buildNullColumns(columnCount); + return predicate; + } + + struct BitMaskIndex { + BitMaskIndex() { + for (int i = 0; i < (1 << N); i++) { + int32_t startIndex = i * (N + 1); + int32_t index = startIndex; + for (int bit = 0; bit < N; bit++) { + if (i & (1 << bit)) { + memo_[++index] = bit; + } + } + memo_[startIndex] = index - startIndex; + } + } + + const inline uint8_t* operator[](size_t i) const { + return memo_ + (i * (N + 1)); + } + + private: + static constexpr int N = 8; + uint8_t memo_[(1 << N) * (N + 1)]{0}; + }; + + const BitMaskIndex bitMaskIndex; + const int32_t batchStep = 8; + + template + void SetFlatVectorValue(int32_t rowCount, BaseVector *baseVector, BaseVector *selectedBaseVector, + const uint8_t *bitMark, bool isAllNull, bool isAllNotNull) + { + int32_t index = 0; + int32_t j = 0; + uint8_t mask; + // 该列过滤出来都是全部为空的情况 + if (isAllNull) { + auto *nulls = reinterpret_cast(UnsafeBaseVector::GetNulls(selectedBaseVector)); + memset_s(nulls, BitUtil::Nbytes(selectedBaseVector->GetSize()), -1, BitUtil::Nbytes(selectedBaseVector->GetSize())); + return; + } + // 该列过滤出来都是全部不为空的情况 + if (isAllNotNull) { + auto *srcValues = UnsafeVector::GetRawValues(static_cast(baseVector)); + auto *destValues = UnsafeVector::GetRawValues(static_cast(selectedBaseVector)); + for (; j + batchStep <= rowCount; j += batchStep) { + mask = bitMark[j >> 3]; + if (mask == 0) { + continue; + } + if (mask == 255) { + memcpy_s(destValues + index, batchStep * sizeof(RAW_DATA_TYPE), srcValues + j, batchStep * sizeof(RAW_DATA_TYPE)); + index += batchStep; + continue; + } + const uint8_t *maskArr = bitMaskIndex[mask]; + for (int i = 1; i <= *maskArr; i++) { + auto offset = j + maskArr[i]; + auto value = static_cast(baseVector)->GetValue(offset); + static_cast(selectedBaseVector)->SetValue(index++, value); + } + } + for (; j < rowCount; j++) { + if (BitUtil::IsBitSet(bitMark, j)) { + auto value = static_cast(baseVector)->GetValue(j); + static_cast(selectedBaseVector)->SetValue(index++, value); + } + } + return; + } + // 其他情况 + for (; j + batchStep <= rowCount; j += batchStep) { + mask = bitMark[j >> 3]; + if (mask == 0) { + continue; + } + const uint8_t *maskArr = bitMaskIndex[mask]; + for (int i = 1; i <= *maskArr; i++) { + auto offset = j + maskArr[i]; + if (UNLIKELY(baseVector->IsNull(offset))) { + static_cast(selectedBaseVector)->SetNull(index++); + } else { + auto value = static_cast(baseVector)->GetValue(offset); + static_cast(selectedBaseVector)->SetValue(index++, value); + } + } + } + for (; j < rowCount; j++) { + if (BitUtil::IsBitSet(bitMark, j)) { + if (UNLIKELY(baseVector->IsNull(j))) { + static_cast(selectedBaseVector)->SetNull(index++); + } else { + auto value = static_cast(baseVector)->GetValue(j); + static_cast(selectedBaseVector)->SetValue(index++, value); + } + } + } + } + + void SetStringVectorValue(int32_t rowCount, Vector> *baseVector, + Vector> *selectedBaseVector, const uint8_t *bitMark, bool isAllNull, + bool isAllNotNull) + { + int32_t index = 0; + int32_t j = 0; + uint8_t mask; + // 该列过滤出来都是全部为空的情况 + if (isAllNull) { + for (; j < selectedBaseVector->GetSize(); j++) { + selectedBaseVector->SetNull(j); + } + return; + } + // 该列过滤出来都是全部不为空的情况 + if (isAllNotNull) { + for (; j + batchStep <= rowCount; j += batchStep) { + mask = bitMark[j >> 3]; + if (mask == 0) { + continue; + } + const uint8_t *maskArr = bitMaskIndex[mask]; + for (int i = 1; i <= *maskArr; i++) { + auto offset = j + maskArr[i]; + auto value = baseVector->GetValue(offset); + selectedBaseVector->SetValue(index++, value); + } + } + for (; j < rowCount; j++) { + if (BitUtil::IsBitSet(bitMark, j)) { + auto value = baseVector->GetValue(j); + selectedBaseVector->SetValue(index++, value); + } + } + return; + } + // 其他情况 + for (; j + batchStep <= rowCount; j += batchStep) { + mask = bitMark[j >> 3]; + if (mask == 0) { + continue; + } + const uint8_t *maskArr = bitMaskIndex[mask]; + for (int i = 1; i <= *maskArr; i++) { + auto offset = j + maskArr[i]; + if (UNLIKELY(baseVector->IsNull(offset))) { + selectedBaseVector->SetNull(index++); + } else { + auto value = baseVector->GetValue(offset); + selectedBaseVector->SetValue(index++, value); + } + } + } + for (; j < rowCount; j++) { + if (BitUtil::IsBitSet(bitMark, j)) { + if (UNLIKELY(baseVector->IsNull(j))) { + selectedBaseVector->SetNull(index++); + } else { + auto value = baseVector->GetValue(j); + selectedBaseVector->SetValue(index++, value); + } + } + } + } + + bool GetFlatBaseVectorsFromBitMark(std::vector &baseVectors, std::vector &result, + uint8_t *bitMark, int32_t vectorSize, const std::set& isNullSet, const std::set& isNotNullSet) { + int32_t resultSize = BitUtil::CountBits(reinterpret_cast(bitMark), 0, vectorSize); + if (resultSize == vectorSize) { + // all selected, filtering is not required. + return false; + } + if (UNLIKELY(baseVectors.empty())) { + return false; + } + int32_t rowCount = baseVectors[0]->GetSize(); + int32_t encodingType = baseVectors[0]->GetEncoding(); + if (UNLIKELY(encodingType == OMNI_DICTIONARY)) { + throw omniruntime::exception::OmniException("UNSUPPORTED_ERROR", "OMNI_DICTIONARY is unsupported."); + } + result.resize(baseVectors.size(), nullptr); + for (int32_t i = 0; i < baseVectors.size(); i++) { + auto baseVector = baseVectors[i]; + auto dataType = baseVector->GetTypeId(); + auto selectedBaseVector = VectorHelper::CreateVector(OMNI_FLAT, dataType, static_cast(resultSize)); + auto isAllNull = isNullSet.count(i); + auto isAllNotNull = isNotNullSet.count(i); + switch (dataType) { + case OMNI_INT: + case OMNI_DATE32: { + SetFlatVectorValue, int32_t>(rowCount, baseVector, selectedBaseVector, bitMark, + isAllNull, isAllNotNull); + break; + } + case OMNI_SHORT: { + SetFlatVectorValue, int16_t>(rowCount, baseVector, selectedBaseVector, bitMark, + isAllNull, isAllNotNull); + break; + } + case OMNI_LONG: + case OMNI_TIMESTAMP: + case OMNI_DECIMAL64: { + SetFlatVectorValue, int64_t>(rowCount, baseVector, selectedBaseVector, bitMark, + isAllNull, isAllNotNull); + break; + } + case OMNI_DOUBLE: { + SetFlatVectorValue, double>(rowCount, baseVector, selectedBaseVector, bitMark, + isAllNull, isAllNotNull); + break; + } + case OMNI_BOOLEAN: { + SetFlatVectorValue, bool>(rowCount, baseVector, selectedBaseVector, bitMark, + isAllNull, isAllNotNull); + break; + } + case OMNI_DECIMAL128: { + SetFlatVectorValue, Decimal128>(rowCount, baseVector, selectedBaseVector, + bitMark, isAllNull, isAllNotNull); + break; + } + case OMNI_VARCHAR: + case OMNI_CHAR: { + SetStringVectorValue(rowCount, dynamic_cast> *>(baseVector), + dynamic_cast> *>(selectedBaseVector), bitMark, + isAllNull, isAllNotNull); + break; + } + default: { + LogError("No such %d type support", dataType); + throw omniruntime::exception::OmniException("OPERATOR_RUNTIME_ERROR", + "unsupported selectivity type: " + std::to_string(static_cast(dataType))); + } + } + result[i] = selectedBaseVector; + } + return true; + } +} diff --git a/omnioperator/omniop-native-reader/cpp/src/common/PredicateUtil.h b/omnioperator/omniop-native-reader/cpp/src/common/PredicateUtil.h new file mode 100644 index 0000000000000000000000000000000000000000..550430293a4a19aad0e36e5834b4a610755c23b2 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/common/PredicateUtil.h @@ -0,0 +1,34 @@ +/** +* Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICATEUTIL_H +#define PREDICATEUTIL_H + +#include "PredicateCondition.h" +#include "jni/jni_common.h" +#include + +namespace common { + std::unique_ptr BuildVecPredicateCondition(JNIEnv *env, jobject jsonObj, int32_t columnCount); + + bool GetFlatBaseVectorsFromBitMark(std::vector &baseVectors, std::vector &result, + uint8_t *bitMark, int32_t vectorSize, const std::set& isNullSet, const std::set& isNotNullSet); +} + +#endif //PREDICATEUTIL_H diff --git a/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.cpp b/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.cpp index f2806e63519db543f3f627e4a6de0a5753111f5b..f202a5cbf9a643f9dcb3661c730f69acac4a24b1 100644 --- a/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.cpp +++ b/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.cpp @@ -23,6 +23,7 @@ #include "jni_common.h" #include "common/UriInfo.h" #include "common/JulianGregorianRebase.h" +#include "common/PredicateUtil.h" using namespace omniruntime::vec; using namespace omniruntime::type; @@ -317,8 +318,9 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniRea } std::unique_ptr julianPtr = common::BuildJulianGregorianRebase(env, jsonObj); + std::unique_ptr predicate = common::BuildVecPredicateCondition(env, jsonObj, arrLen); - std::unique_ptr rowReader = readerPtr->createRowReader(rowReaderOpts, julianPtr); + std::unique_ptr rowReader = readerPtr->createRowReader(rowReaderOpts, julianPtr, predicate); return (jlong)(rowReader.release()); JNI_FUNC_END(runtimeExceptionClass) } @@ -328,6 +330,11 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniRea jobject jObj, jlong rowReader, jlong batchSize) { JNI_FUNC_START + omniruntime::reader::OmniRowReaderImpl *rowReaderPtr = (omniruntime::reader::OmniRowReaderImpl*)rowReader; + auto predicate = rowReaderPtr->getPredicatePtr(); + if (predicate != nullptr) { + predicate->init((int32_t)batchSize); + } batchLen = (uint64_t)batchSize; return (jlong)0xffff; JNI_FUNC_END(runtimeExceptionClass) @@ -340,6 +347,62 @@ inline void FindLastNotEmpty(const char *chars, long &len) } } +void clearRecordBatch(std::vector &recordBatch) +{ + for (auto vec : recordBatch) { + delete vec; + } + recordBatch.clear(); +} + +uint64_t filterData(uint8_t *bitMark, std::vector *recordBatch, int32_t vectorSize, + const std::set& isNullSet, const std::set& isNotNullSet) +{ + std::vector resultBatch; + if (common::GetFlatBaseVectorsFromBitMark(*recordBatch, resultBatch, bitMark, vectorSize, isNullSet, isNotNullSet)) { + clearRecordBatch(*recordBatch); + *recordBatch = std::move(resultBatch); + return (*recordBatch)[0]->GetSize(); + } + // 失败返回原始的 + clearRecordBatch(resultBatch); + return vectorSize; +} + +bool readAndFilterData(omniruntime::reader::OmniRowReaderImpl *rowReaderPtr, jint *omniTypeId, + std::vector *recordBatch, uint64_t &batchRowSize) +{ + batchRowSize = rowReaderPtr->next(recordBatch, omniTypeId, batchLen); + auto predicateCondition = rowReaderPtr->getPredicatePtr(); + if (batchRowSize == 0 || predicateCondition == nullptr) { + return false; + } + try { + uint8_t *bitMark = predicateCondition->compute(*recordBatch); + int32_t vectorSize = (*recordBatch)[0]->GetSize(); + // 数据被全部过滤完,需要重新读取 + if (omniruntime::BitUtil::CountBits(reinterpret_cast(bitMark), 0, vectorSize) == 0) { + clearRecordBatch(*recordBatch); + return true; + } + batchRowSize = filterData(bitMark, recordBatch, vectorSize, predicateCondition->getIsAllNullColumns(), + predicateCondition->getIsAllNotNullColumns()); + } catch (const std::exception &e) { + LogError("filterData fail: %s", e.what()); + } + return false; +} + +uint64_t processNext(omniruntime::reader::OmniRowReaderImpl *rowReaderPtr, jint *omniTypeId, + std::vector *recordBatch) { + uint64_t batchRowSize = 0; + bool needReadAgain = readAndFilterData(rowReaderPtr, omniTypeId, recordBatch, batchRowSize); + while (needReadAgain) { + needReadAgain = readAndFilterData(rowReaderPtr, omniTypeId, recordBatch, batchRowSize); + } + return batchRowSize; +} + JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader_recordReaderNext(JNIEnv *env, jobject jObj, jlong rowReader, jlong batch, jintArray typeId, jlongArray vecNativeId) { @@ -354,7 +417,7 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniRea throw std::runtime_error("Types should not be null"); } int32_t arrLen = (int32_t) env->GetArrayLength(typeId); - uint64_t batchRowSize = rowReaderPtr->next(&recordBatch, ptr, batchLen); + uint64_t batchRowSize = processNext(rowReaderPtr, ptr, &recordBatch); if (batchRowSize != 0) { int32_t vecCnt = recordBatch.size(); if (UNLIKELY(vecCnt != arrLen)) { diff --git a/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniWriter.cpp b/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniWriter.cpp index 24a94ac0f6abfb278c2d1229be8588040159ab1e..ea72d6e6a4c4e41f79998554ce67e24ba8a12cda 100644 --- a/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniWriter.cpp +++ b/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniWriter.cpp @@ -157,7 +157,7 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_write_jni_OrcColumnarBatchJniWr } template -void WriteVector(BaseVector *vec, ColumnVectorBatch *fieldBatch, bool isSplitWrite = false, long startPos = 0, +void WriteVector(JNIEnv *env, BaseVector *vec, ColumnVectorBatch *fieldBatch, bool isSplitWrite = false, long startPos = 0, long endPos = 0) { using T = typename NativeType::type; @@ -180,9 +180,19 @@ void WriteVector(BaseVector *vec, ColumnVectorBatch *fieldBatch, bool isSplitWri } } index = 0; - for (long j = startPos; j < endPos; j++) { - values[index] = vector->GetValue(j); - index++; + if(sizeof(T) == sizeof(values[0])){ + errno_t err; + err = memcpy_s(values, sizeof(T) * (endPos - startPos), + omniruntime::vec::unsafe::UnsafeVector::GetRawValues(vector) + startPos, + sizeof(T) * (endPos - startPos)); + if (err != EOK) { + env->ThrowNew(runtimeExceptionClass,"Get values from vector failed"); + } + } else { + for (long j = startPos; j < endPos; j++) { + values[index] = vector->GetValue(j); + index++; + } } } @@ -214,7 +224,7 @@ void WriteDecimal128VectorBatch(BaseVector *vec, ColumnVectorBatch *fieldBatch, } } -void WriteDecimal64VectorBatch(BaseVector *vec, ColumnVectorBatch *fieldBatch, bool isSplitWrite = false, +void WriteDecimal64VectorBatch(JNIEnv *env, BaseVector *vec, ColumnVectorBatch *fieldBatch, bool isSplitWrite = false, long startPos = 0, long endPos = 0) { auto vector = (Vector *)vec; @@ -235,10 +245,12 @@ void WriteDecimal64VectorBatch(BaseVector *vec, ColumnVectorBatch *fieldBatch, b index++; } } - index = 0; - for (long j = startPos; j < endPos; j++) { - values[index] = vector->GetValue(j); - index++; + errno_t err; + err = memcpy_s(values,sizeof(values[0]) * (endPos - startPos), + omniruntime::vec::unsafe::UnsafeVector::GetRawValues(vector) + startPos, + sizeof(vector->GetValue(0)) * (endPos - startPos)); + if (err != EOK) { + env->ThrowNew(runtimeExceptionClass,"Get values from vector failed"); } } @@ -278,17 +290,17 @@ void WriteLongVectorBatch(JNIEnv *env, DataTypeId typeId, BaseVector *baseVector JNI_FUNC_START switch (typeId) { case OMNI_BOOLEAN: - return WriteVector(baseVector, fieldBatch, isSplitWrite, startPos, endPos); + return WriteVector(env, baseVector, fieldBatch, isSplitWrite, startPos, endPos); case OMNI_SHORT: - return WriteVector(baseVector, fieldBatch, isSplitWrite, startPos, endPos); + return WriteVector(env, baseVector, fieldBatch, isSplitWrite, startPos, endPos); case OMNI_INT: - return WriteVector(baseVector, fieldBatch, isSplitWrite, startPos, endPos); + return WriteVector(env, baseVector, fieldBatch, isSplitWrite, startPos, endPos); case OMNI_LONG: - return WriteVector(baseVector, fieldBatch, isSplitWrite, startPos, endPos); + return WriteVector(env, baseVector, fieldBatch, isSplitWrite, startPos, endPos); case OMNI_DATE32: - return WriteVector(baseVector, fieldBatch, isSplitWrite, startPos, endPos); + return WriteVector(env, baseVector, fieldBatch, isSplitWrite, startPos, endPos); case OMNI_DATE64: - return WriteVector(baseVector, fieldBatch, isSplitWrite, startPos, endPos); + return WriteVector(env, baseVector, fieldBatch, isSplitWrite, startPos, endPos); default: env->ThrowNew(runtimeExceptionClass, "DealLongVectorBatch not support for type:" + typeId); } @@ -316,13 +328,13 @@ void WriteVector(JNIEnv *env, long *vecNativeId, int colNums, orc::StructVectorB WriteLongVectorBatch(env, typeId, vec, fieldBatch, isSplitWrite, startPos, endPos); break; case OMNI_DOUBLE: - WriteVector(vec, fieldBatch, isSplitWrite, startPos, endPos); + WriteVector(env, vec, fieldBatch, isSplitWrite, startPos, endPos); break; case OMNI_VARCHAR: WriteVarCharVectorBatch(vec, fieldBatch, isSplitWrite, startPos, endPos); break; case OMNI_DECIMAL64: - WriteDecimal64VectorBatch(vec, fieldBatch, isSplitWrite, startPos, endPos); + WriteDecimal64VectorBatch(env, vec, fieldBatch, isSplitWrite, startPos, endPos); break; case OMNI_DECIMAL128: WriteDecimal128VectorBatch(vec, fieldBatch, isSplitWrite, startPos, endPos); diff --git a/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniWriter.cpp b/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniWriter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ba15149690ce209212c076a15d9dd0b1284026d2 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniWriter.cpp @@ -0,0 +1,203 @@ +/** + * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ParquetColumnarBatchJniWriter.h" +#include "jni_common.h" +#include "parquet/ParquetWriter.h" +#include "common/UriInfo.h" +#include "arrow/status.h" +#include +#include + +using namespace omniruntime::writer; +using namespace arrow; + +static constexpr int32_t DECIMAL_PRECISION_INDEX = 0; +static constexpr int32_t DECIMAL_SCALE_INDEX = 1; + +JNIEXPORT void JNICALL Java_com_huawei_boostkit_write_jni_ParquetColumnarBatchJniWriter_initializeWriter( + JNIEnv *env, jobject jObj, jobject jsonObj, jlong writer) +{ + JNI_FUNC_START + jstring uri = (jstring)env->CallObjectMethod(jsonObj, jsonMethodString, env->NewStringUTF("uri")); + const char *uriStr = env->GetStringUTFChars(uri, JNI_FALSE); + std::string uriString(uriStr); + env->ReleaseStringUTFChars(uri, uriStr); + + jstring ugiTemp = (jstring)env->CallObjectMethod(jsonObj, jsonMethodString, env->NewStringUTF("ugi")); + const char *ugi = env->GetStringUTFChars(ugiTemp, JNI_FALSE); + std::string ugiString(ugi); + env->ReleaseStringUTFChars(ugiTemp, ugi); + + jstring schemaTemp = (jstring)env->CallObjectMethod(jsonObj, jsonMethodString, env->NewStringUTF("scheme")); + const char *schema = env->GetStringUTFChars(schemaTemp, JNI_FALSE); + std::string schemaString(schema); + env->ReleaseStringUTFChars(schemaTemp, schema); + + jstring hostTemp = (jstring)env->CallObjectMethod(jsonObj, jsonMethodString, env->NewStringUTF("host")); + const char *host = env->GetStringUTFChars(hostTemp, JNI_FALSE); + std::string hostString(host); + env->ReleaseStringUTFChars(hostTemp, host); + + jstring pathTemp = (jstring)env->CallObjectMethod(jsonObj, jsonMethodString, env->NewStringUTF("path")); + const char *path = env->GetStringUTFChars(pathTemp, JNI_FALSE); + std::string pathString(path); + env->ReleaseStringUTFChars(pathTemp, path); + + jint port = (jint)env->CallIntMethod(jsonObj, jsonMethodInt, env->NewStringUTF("port")); + + UriInfo uriInfo(uriString, schemaString, pathString, hostString, std::to_string(port)); + ParquetWriter *pWriter = (ParquetWriter *)writer; + if (pWriter == nullptr) { + env->ThrowNew(runtimeExceptionClass, "the pWriter is null"); + } + pWriter->InitRecordWriter(uriInfo, ugiString); + JNI_FUNC_END_VOID(runtimeExceptionClass) +} + +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_write_jni_ParquetColumnarBatchJniWriter_initializeSchema( + JNIEnv *env, jobject JObj, jlong writer, jobjectArray fieldNames, jintArray fieldTypes, + jbooleanArray nullables, jobjectArray decimalParam) +{ + JNI_FUNC_START + auto pWriter = std::make_unique(); + auto fieldTypesPtr = env->GetIntArrayElements(fieldTypes, JNI_FALSE); + auto nullablesPtr = env->GetBooleanArrayElements(nullables, JNI_FALSE); + if (fieldTypesPtr == NULL) { + throw std::runtime_error("Parquet type ids should not be null"); + } + auto schemaLength = (int32_t)env->GetArrayLength(fieldTypes); + FieldVector fieldVector; + for (int i = 0; i < schemaLength; i++) { + jint parquetType = fieldTypesPtr[i]; + jboolean nullable = nullablesPtr[i]; + jstring fieldName = (jstring)env->GetObjectArrayElement(fieldNames, i); + const char *cFieldName = env->GetStringUTFChars(fieldName, nullptr); + std::shared_ptr writeParquetType; + + auto decimalParamArray = (jintArray)env->GetObjectArrayElement(decimalParam, i); + auto decimalParamArrayPtr = env->GetIntArrayElements(decimalParamArray, JNI_FALSE); + auto precision = decimalParamArrayPtr[DECIMAL_PRECISION_INDEX]; + auto scale = decimalParamArrayPtr[DECIMAL_SCALE_INDEX]; + switch (static_cast(parquetType)) { + case Type::type::DECIMAL: + pWriter->precisions.push_back(precision); + pWriter->scales.push_back(scale); + writeParquetType = decimal128(precision, scale); + break; + case Type::type::BOOL: + writeParquetType = arrow::boolean(); + break; + case Type::type::INT16: + writeParquetType = arrow::int16(); + break; + case Type::type::INT32: + writeParquetType = arrow::int32(); + break; + case Type::type::INT64: + writeParquetType = arrow::int64(); + break; + case Type::type::DATE32: + writeParquetType = arrow::date32(); + break; + case Type::type::DATE64: + writeParquetType = arrow::date64(); + break; + case Type::type::DOUBLE: + writeParquetType = arrow::float64(); + break; + case Type::type::STRING: + writeParquetType = arrow::utf8(); + break; + default: + throw std::invalid_argument("Unsupported parquet type: "+std::to_string(parquetType)); + } + auto t = field(cFieldName, writeParquetType, nullable); + fieldVector.emplace_back(t); + env->ReleaseIntArrayElements(decimalParamArray,decimalParamArrayPtr,JNI_ABORT); + env->ReleaseStringUTFChars(fieldName, cFieldName); + } + auto t = std::make_unique(fieldVector); + if (pWriter == nullptr) { + env->ThrowNew(runtimeExceptionClass, "the pWriter is null"); + } + pWriter->schema_ = std::make_shared(fieldVector); + ParquetWriter *pWriterNew= pWriter.release(); + env->ReleaseIntArrayElements(fieldTypes,fieldTypesPtr,JNI_ABORT); + env->ReleaseBooleanArrayElements(nullables,nullablesPtr,JNI_ABORT); + return (jlong)(pWriterNew); + JNI_FUNC_END(runtimeExceptionClass) +} + +JNIEXPORT void JNICALL +Java_com_huawei_boostkit_write_jni_ParquetColumnarBatchJniWriter_write( + JNIEnv *env, jobject jObj, jlong writer, jlongArray vecNativeId, + jintArray omniTypes, jbooleanArray dataColumnsIds, jint numRows) +{ + JNI_FUNC_START + ParquetWriter *pWriter = (ParquetWriter *)writer; + auto vecNativeIdPtr = env->GetLongArrayElements(vecNativeId, JNI_FALSE); + auto colNums = env->GetArrayLength(vecNativeId); + auto omniTypesPtr = env->GetIntArrayElements(omniTypes, JNI_FALSE); + auto dataColumnsIdsPtr = env->GetBooleanArrayElements(dataColumnsIds, JNI_FALSE); + if (pWriter == nullptr) { + env->ThrowNew(runtimeExceptionClass, "the pWriter is null"); + } + pWriter->write(vecNativeIdPtr, colNums, omniTypesPtr, dataColumnsIdsPtr); + env->ReleaseLongArrayElements(vecNativeId, vecNativeIdPtr, 0); + env->ReleaseIntArrayElements(omniTypes, omniTypesPtr, 0); + env->ReleaseBooleanArrayElements(dataColumnsIds, dataColumnsIdsPtr, 0); + JNI_FUNC_END_VOID(runtimeExceptionClass) +} + +JNIEXPORT void JNICALL Java_com_huawei_boostkit_write_jni_ParquetColumnarBatchJniWriter_splitWrite( + JNIEnv *env, jobject jObj, jlong writer, jlongArray vecNativeId, jintArray omniTypes, jbooleanArray dataColumnsIds, + jlong startPos, jlong endPos) +{ + JNI_FUNC_START + auto vecNativeIdPtr = env->GetLongArrayElements(vecNativeId, JNI_FALSE); + auto colNums = env->GetArrayLength(vecNativeId); + auto omniTypesPtr = env->GetIntArrayElements(omniTypes, JNI_FALSE); + auto dataColumnsIdsPtr = env->GetBooleanArrayElements(dataColumnsIds, JNI_FALSE); + auto writeRows = endPos - startPos; + ParquetWriter *pWriter = (ParquetWriter *)writer; + if (pWriter == nullptr) { + env->ThrowNew(runtimeExceptionClass, "the pWriter is null"); + } + pWriter->write(vecNativeIdPtr, colNums, omniTypesPtr, dataColumnsIdsPtr, true, startPos, endPos); + + env->ReleaseLongArrayElements(vecNativeId, vecNativeIdPtr, 0); + env->ReleaseIntArrayElements(omniTypes, omniTypesPtr, 0); + env->ReleaseBooleanArrayElements(dataColumnsIds, dataColumnsIdsPtr, 0); + JNI_FUNC_END_VOID(runtimeExceptionClass) +} + +JNIEXPORT void JNICALL Java_com_huawei_boostkit_write_jni_ParquetColumnarBatchJniWriter_close(JNIEnv *env, jobject jObj, + jlong writer) +{ + JNI_FUNC_START + + ParquetWriter *pWriter = (ParquetWriter *)writer; + if (pWriter == nullptr) { + env->ThrowNew(runtimeExceptionClass, "delete nullptr error for writer"); + } + pWriter->arrow_writer->Close(); + delete pWriter; + JNI_FUNC_END_VOID(runtimeExceptionClass) +} diff --git a/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniWriter.h b/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniWriter.h new file mode 100644 index 0000000000000000000000000000000000000000..8139d51e8c2f9ba5038a8cce46988eb7bb0d69fa --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniWriter.h @@ -0,0 +1,86 @@ +/** + * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef OMNI_RUNTIME_PARQUETCOLUMNARBATCHJNIWRITER_H +#define OMNI_RUNTIME_PARQUETCOLUMNARBATCHJNIWRITER_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "common/debug.h" + +#ifdef __cplusplus +extern "C" +{ +#endif + +/* + * Class: com_huawei_boostkit_writer_jni_ParquetColumnarBatchJniWriter + * Method: initializeWriter + * Signature: + */ +JNIEXPORT void JNICALL Java_com_huawei_boostkit_write_jni_ParquetColumnarBatchJniWriter_initializeWriter + (JNIEnv *env, jobject jObj, jobject job, jlong writer); + +/* + * Class: com_huawei_boostkit_writer_jni_ParquetColumnarBatchJniWriter + * Method: initializeSchema + * Signature: + */ +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_write_jni_ParquetColumnarBatchJniWriter_initializeSchema + (JNIEnv *env, jobject jObj, jlong writer, jobjectArray filedNames, jintArray fieldTypes, + jbooleanArray nullables, jobjectArray decimalParam); + +/* + * Class: com_huawei_boostkit_writer_jni_ParquetColumnarBatchJniWriter + * Method: write + * Signature: + */ +JNIEXPORT void JNICALL Java_com_huawei_boostkit_write_jni_ParquetColumnarBatchJniWriter_write( + JNIEnv *env, jobject jObj, jlong writer, jlongArray vecNativeId, + jintArray omniTypes, jbooleanArray dataColumnsIds, jint numRows); + +/* + * Class: com_huawei_boostkit_writer_jni_ParquetColumnarBatchJniWriter + * Method: splitWrite + * Signature: + */ +JNIEXPORT void JNICALL Java_com_huawei_boostkit_write_jni_ParquetColumnarBatchJniWriter_splitWrite( + JNIEnv *env, jobject jObj, jlong writer, jlongArray vecNativeId, jintArray omniTypes, + jbooleanArray dataColumnsIds, jlong startPos, jlong endPos); + +/* + * Class: com_huawei_boostkit_writer_jni_ParquetColumnarBatchJniWriter + * Method: close + * Signature: + */ +JNIEXPORT void JNICALL Java_com_huawei_boostkit_write_jni_ParquetColumnarBatchJniWriter_close(JNIEnv *env, jobject jObj, + jlong writer); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniByteRLE.cc b/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniByteRLE.cc index 1f3dfbbc90fe9e8a5689c6c76461910723878b43..d39af57b1a21c5f16f2261cd50eb07711edff9bc 100644 --- a/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniByteRLE.cc +++ b/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniByteRLE.cc @@ -19,11 +19,14 @@ #include "OmniRLEv2.hh" #include "OmniColReader.hh" #include "vector/vector_helper.h" +#include "OrcDecodeUtils.hh" + namespace omniruntime::reader { const int MINIMUM_REPEAT = 3; const int MAXIMUM_REPEAT = 127 + MINIMUM_REPEAT; + const uint32_t BITS_OF_BYTE = 8; void OmniBooleanRleDecoder::seek(orc::PositionProvider& location) { OmniByteRleDecoder::seek(location); @@ -35,6 +38,7 @@ namespace omniruntime::reader { if (consumed != 0) { remainingBits = 8 - consumed; OmniByteRleDecoder::next(&lastByte, 1, nullptr); + reversedAndFlipLastByte = bitNotFlip[lastByte]; } } @@ -47,6 +51,7 @@ namespace omniruntime::reader { OmniByteRleDecoder::skip(bytesSkipped); if (numValues % 8 != 0) { OmniByteRleDecoder::next(&lastByte, 1, nullptr); + reversedAndFlipLastByte = bitNotFlip[lastByte]; remainingBits = 8 - (numValues % 8); } else { remainingBits = 0; @@ -54,86 +59,108 @@ namespace omniruntime::reader { } } - void OmniBooleanRleDecoder::next(omniruntime::vec::BaseVector*& omnivec, uint64_t numValues, char* notNull, - const orc::Type* baseTp, int omniTypeId) { - auto dataTypeId = static_cast(omniTypeId); - std::unique_ptr tempOmnivec = makeNewVector(numValues, baseTp, dataTypeId); - auto pushOmniVec = tempOmnivec.get(); - switch (dataTypeId) { - case omniruntime::type::OMNI_BOOLEAN: - nextByType(pushOmniVec, numValues, notNull, baseTp, omniTypeId); - break; - case omniruntime::type::OMNI_SHORT: - throw std::runtime_error("OmniBooleanRleDecoder SHORT not finished!!!"); - break; - case omniruntime::type::OMNI_INT: - throw std::runtime_error("OmniBooleanRleDecoder INT not finished!!!"); - break; - case omniruntime::type::OMNI_LONG: - throw std::runtime_error("OmniBooleanRleDecoder LONG not finished!!!"); - break; - case omniruntime::type::OMNI_TIMESTAMP: - throw std::runtime_error("OmniBooleanRleDecoder TIMESTAMP not finished!!!"); - break; - case omniruntime::type::OMNI_DATE32: - throw std::runtime_error("OmniBooleanRleDecoder DATE32 not finished!!!"); - break; - case omniruntime::type::OMNI_DATE64: - throw std::runtime_error("OmniBooleanRleDecoder DATE64 not finished!!!"); - break; - case omniruntime::type::OMNI_DOUBLE: - throw std::runtime_error("OmniBooleanRleDecoder DOUBLE not finished!!!"); - break; - case omniruntime::type::OMNI_CHAR: - throw std::runtime_error("OmniBooleanRleDecoder CHAR not finished!!!"); - case omniruntime::type::OMNI_VARCHAR: - throw std::runtime_error("OmniBooleanRleDecoder VARCHAR not finished!!!"); - case omniruntime::type::OMNI_DECIMAL64: - throw std::runtime_error("OmniBooleanRleDecoder DECIMAL64 should not in here!!!"); - case omniruntime::type::OMNI_DECIMAL128: - throw std::runtime_error("OmniBooleanRleDecoder DECIMAL64 should not in here!!!"); - default: - printf("OmniBooleanRleDecoder switch no process!!!"); + void OmniBooleanRleDecoder::nextNulls(char *data, uint64_t numValues, uint64_t *nulls) { + if (nulls) { + throw std::runtime_error("Not implemented yet for struct type!"); + } + + uint64_t nonNulls = numValues; + + const uint32_t outputBytes = (numValues + 7) / 8; + if (nonNulls == 0) { + ::memset(data, 1, outputBytes); + return; + } + + if (remainingBits >= nonNulls) { + // handle remaining bits, which can cover this round + data[0] = reversedAndFlipLastByte >> (8 - remainingBits) & 0xff >> (8 - nonNulls); + remainingBits -= nonNulls; + } else { + // put the remaining bits, if any, into previousByte. + uint8_t previousByte{0}; + if (remainingBits > 0) { + previousByte = reversedAndFlipLastByte >> (8 - remainingBits); + } + + // compute byte size that should read + uint64_t bytesRead = (nonNulls - remainingBits + 7) / 8; + OmniByteRleDecoder::next(data, bytesRead, nullptr); + + ReverseAndFlipBytes(reinterpret_cast(data), bytesRead); + reversedAndFlipLastByte = data[bytesRead - 1]; + + // now shift the data in place + if (remainingBits > 0 ) { + uint64_t nonNullDWords = nonNulls / 64; + for (uint64_t i = 0; i < nonNullDWords; i++) { + uint64_t tmp = reinterpret_cast(data)[i]; + reinterpret_cast(data)[i] = + previousByte | tmp << remainingBits; // previousByte is LSB + previousByte = (tmp >> (64 - remainingBits)) & 0xff; + } + + // shift 8 bits a time for the remaining bits + const uint64_t nonNullOutputBytes = (nonNulls + 7) / 8; + for (int32_t i = nonNullDWords * 8; i < nonNullOutputBytes; ++i) { + uint8_t tmp = data[i]; // already reversed + data[i] = previousByte | tmp << remainingBits; // previousByte is LSB + previousByte = tmp >> (8 - remainingBits); + } + } + remainingBits = bytesRead * 8 + remainingBits - nonNulls; } - omnivec = tempOmnivec.release(); + // clear the most significant bits in the last byte which will be processed in the next round + data[outputBytes - 1] &= 0xff >> (outputBytes * 8 - numValues); + } + + + void OmniBooleanRleDecoder::next(omniruntime::vec::BaseVector *omnivec, uint64_t numValues, + uint64_t *nulls, int omniTypeId) { + switch (omniTypeId) { + case omniruntime::type::OMNI_BOOLEAN: { + auto boolValues = omniruntime::vec::unsafe::UnsafeVector::GetRawValues( + static_cast*>(omnivec)); + return next(boolValues, numValues, nulls); + } + default: + throw std::runtime_error("OmniBooleanRleDecoder not support type: " + omniTypeId); + } } - template - void OmniBooleanRleDecoder::nextByType(omniruntime::vec::BaseVector*& omnivec, uint64_t numValues, char* notNull, - const orc::Type* baseTp, int omniTypeId) { - using namespace omniruntime::type; - using T = typename NativeType::type; - auto vec = reinterpret_cast*>(omnivec); + void OmniBooleanRleDecoder::next(char *data, uint64_t numValues, uint64_t *nulls) { + next(data, numValues, nulls); + } + template + void OmniBooleanRleDecoder::next(T *data, uint64_t numValues, uint64_t *nulls) { // next spot to fill in uint64_t position = 0; // use up any remaining bits - if (notNull) { - while(remainingBits > 0 && position < numValues) { - if (notNull[position]) { + if (nulls) { + while (remainingBits > 0 && position < numValues) { + if (!BitUtil::IsBitSet(nulls, position)) { remainingBits -= 1; - vec->SetValue(static_cast(position), static_cast((static_cast(lastByte) >> - remainingBits) & 0x1)); + data[position] = static_cast((static_cast(lastByte) >> remainingBits) & 0x1); } else { - vec->SetNull(static_cast(position)); + data[position] = 0; } position += 1; } } else { - while(remainingBits > 0 && position < numValues) { + while (remainingBits > 0 && position < numValues) { remainingBits -= 1; - vec->SetValue(static_cast(position++), static_cast((static_cast(lastByte) >> - remainingBits) & 0x1)); + data[position++] = static_cast((static_cast(lastByte) >> remainingBits) & 0x1); } } // count the number of nonNulls remaining uint64_t nonNulls = numValues - position; - if (notNull) { - for(uint64_t i = position; i < numValues; ++i) { - if (!notNull[i]) { + if (nulls) { + for (uint64_t i = position; i < numValues; ++i) { + if (BitUtil::IsBitSet(nulls, i)) { nonNulls -= 1; } } @@ -142,35 +169,31 @@ namespace omniruntime::reader { // fill in the remaining values if (nonNulls == 0) { while (position < numValues) { - vec->SetNull(static_cast(position++)); + data[position++] = 0; } } else if (position < numValues) { // read the new bytes into the array uint64_t bytesRead = (nonNulls + 7) / 8; - auto *values = reinterpret_cast(omniruntime::vec::VectorHelper::UnsafeGetValues(omnivec)); - OmniByteRleDecoder::next(values + position, bytesRead, nullptr); - lastByte = static_cast(vec->GetValue(position + bytesRead - 1)); + OmniByteRleDecoder::next(reinterpret_cast(data + position), bytesRead, nullptr); + lastByte = data[position + bytesRead - 1]; remainingBits = bytesRead * 8 - nonNulls; // expand the array backwards so that we don't clobber the data - uint64_t bitsLeft = bytesRead * 8- remainingBits; - if (notNull) { - for (int64_t i =static_cast(numValues) - 1; - i >= static_cast(position); --i) { - if (notNull[i]) { + uint64_t bitsLeft = bytesRead * 8 - remainingBits; + if (nulls) { + for (int64_t i = static_cast(numValues) - 1; i >= static_cast(position); --i) { + if (!BitUtil::IsBitSet(nulls, i)) { uint64_t shiftPosn = (-bitsLeft) % 8; - auto value = static_cast(vec->GetValue(position + (bitsLeft - 1) / 8)) >> shiftPosn; - vec->SetValue(static_cast(i), static_cast(value & 0x1)); + data[i] = static_cast((data[position + (bitsLeft - 1) / 8] >> shiftPosn) & 0x1); bitsLeft -= 1; } else { - vec->SetNull(static_cast(i)); + data[i] = 0; } } } else { - for(int64_t i = static_cast(numValues) - 1; + for (int64_t i = static_cast(numValues) - 1; i >= static_cast(position); --i, --bitsLeft) { uint64_t shiftPosn = (-bitsLeft) % 8; - auto value = static_cast(vec->GetValue(position + (bitsLeft - 1) / 8)) >> shiftPosn; - vec->SetValue(static_cast(i), static_cast(value & 0x1)); + data[i] = static_cast((data[position + (bitsLeft - 1) / 8] >> shiftPosn) & 0x1); } } } @@ -181,6 +204,7 @@ namespace omniruntime::reader { ): OmniByteRleDecoder(std::move(input)) { remainingBits = 0; lastByte = 0; + reversedAndFlipLastByte = 0; } OmniBooleanRleDecoder::~OmniBooleanRleDecoder() { @@ -268,8 +292,7 @@ namespace omniruntime::reader { } } - void OmniByteRleDecoder::next(char* data, uint64_t numValues, - char* notNull) { + void OmniByteRleDecoder::next(char* data, uint64_t numValues, char* notNull) { uint64_t position = 0; // skip over null values while (notNull && position < numValues && !notNull[position]) { @@ -329,129 +352,5 @@ namespace omniruntime::reader { } } - void OmniByteRleDecoder::next(omniruntime::vec::BaseVector*& omnivec, uint64_t numValues, char* notNull, - const orc::Type* baseTp, int omniTypeId) { - auto dataTypeId = static_cast(omniTypeId); - std::unique_ptr tempOmnivec = makeNewVector(numValues, baseTp, dataTypeId); - auto pushOmniVec = tempOmnivec.get(); - switch (dataTypeId) { - case omniruntime::type::OMNI_BOOLEAN: - nextByType - (pushOmniVec, numValues, notNull, baseTp); - break; - case omniruntime::type::OMNI_SHORT: - nextByType - (pushOmniVec, numValues, notNull, baseTp); - break; - case omniruntime::type::OMNI_INT: - nextByType - (pushOmniVec, numValues, notNull, baseTp); - break; - case omniruntime::type::OMNI_LONG: - nextByType - (pushOmniVec, numValues, notNull, baseTp); - break; - case omniruntime::type::OMNI_TIMESTAMP: - nextByType - (pushOmniVec, numValues, notNull, baseTp); - break; - case omniruntime::type::OMNI_DATE32: - nextByType - (pushOmniVec, numValues, notNull, baseTp); - break; - case omniruntime::type::OMNI_DATE64: - nextByType - (pushOmniVec, numValues, notNull, baseTp); - break; - case omniruntime::type::OMNI_DOUBLE: - nextByType - (pushOmniVec, numValues, notNull, baseTp); - break; - case omniruntime::type::OMNI_CHAR: - throw std::runtime_error("OmniByteRleDecoder CHAR not finished!!!"); - case omniruntime::type::OMNI_VARCHAR: - throw std::runtime_error("OmniByteRleDecoder VARCHAR not finished!!!"); - case omniruntime::type::OMNI_DECIMAL64: - throw std::runtime_error("OmniByteRleDecoder DECIMAL64 not finished!!!"); - case omniruntime::type::OMNI_DECIMAL128: - throw std::runtime_error("OmniByteRleDecoder DECIMAL64 not finished!!!"); - default: - printf("OmniByteRleDecoder swtich no process!!!"); - } - - omnivec = tempOmnivec.release(); - } - - template - void OmniByteRleDecoder::nextByType(omniruntime::vec::BaseVector*& omnivec, uint64_t numValues, char* notNull, - const orc::Type* baseTp) { - using namespace omniruntime::type; - using T = typename NativeType::type; - auto vec = reinterpret_cast*>(omnivec); - - uint64_t position = 0; - // skip over null values - while (notNull && position < numValues && !notNull[position]) { - position += 1; - } - while (position < numValues) { - // if we are out of values, read more - if (remainingValues == 0) { - readHeader(); - } - // how many do we read out of this block? - size_t count = std::min(static_cast(numValues - position), - remainingValues); - uint64_t consumed = 0; - if (repeating) { - if (notNull) { - for(uint64_t i=0; i < count; ++i) { - if (notNull[position + i]) { - vec->SetValue(static_cast(position + i), static_cast(value)); - consumed += 1; - } else { - vec->SetNull(static_cast(position + i)); - } - } - } else { - for (uint64_t i = position; i < position + count; ++i) { - vec->SetValue(static_cast(i), static_cast(value)); - } - consumed = count; - } - } else { - if (notNull) { - for(uint64_t i = 0; i < count; ++i) { - if (notNull[position + i]) { - vec->SetValue(static_cast(position + i), static_cast(readByte())); - consumed += 1; - } else { - vec->SetNull(static_cast(position + i)); - } - } - } else { - uint64_t i = 0; - while (i < count) { - if (bufferStart == bufferEnd) { - nextBuffer(); - } - uint64_t copyBytes = - std::min(static_cast(count - i), - static_cast(bufferEnd - bufferStart)); - vec->SetValues(static_cast(position + i), bufferStart, static_cast(copyBytes)); - bufferStart += copyBytes; - i += copyBytes; - } - consumed = count; - } - } - remainingValues -= consumed; - position += count; - // skip over any null values - while (notNull && position < numValues && !notNull[position]) { - position += 1; - } - } - } //OmniByteRleDecoder end } diff --git a/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniByteRLE.hh b/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniByteRLE.hh index 0ed2153472a82cdc040999fab814b88e593c2627..742abfce3a700ab80b28a13badf4d67c9bb36888 100644 --- a/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniByteRLE.hh +++ b/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniByteRLE.hh @@ -43,24 +43,17 @@ namespace omniruntime::reader { */ virtual void next(char* data, uint64_t numValues, char* notNull); - virtual void next(omniruntime::vec::BaseVector*& omnivec, uint64_t numValues, char* notNull, - const orc::Type* baseTp, int omniTypeId); - - template - void nextByType(omniruntime::vec::BaseVector*& omnivec, uint64_t numValues, char* notNull, - const orc::Type* baseTp); - protected: - inline void nextBuffer(); - inline signed char readByte(); - inline void readHeader(); - - std::unique_ptr inputStream; - size_t remainingValues; - char value; - const char* bufferStart; - const char* bufferEnd; - bool repeating; + inline void nextBuffer(); + inline signed char readByte(); + inline void readHeader(); + + std::unique_ptr inputStream; + size_t remainingValues; + char value; + const char *bufferStart; + const char *bufferEnd; + bool repeating; }; class OmniBooleanRleDecoder: public OmniByteRleDecoder { @@ -80,19 +73,25 @@ namespace omniruntime::reader { virtual void skip(uint64_t numValues); /** - * Read a number of values into the batch. + * Read nulls flag to data. + */ + void nextNulls(char *data, uint64_t numValues, uint64_t *nulls); + + /** + * Read a number of values into the batch by nulls. */ + void next(omniruntime::vec::BaseVector *omnivec, uint64_t numValues, + uint64_t *nulls, int omniTypeId); - virtual void next(omniruntime::vec::BaseVector*& omnivec, uint64_t numValues, char* notNull, - const orc::Type* baseTp, int omniTypeId); + template + void next(T *data, uint64_t numValues, uint64_t *nulls); - template - void nextByType(omniruntime::vec::BaseVector*& omnivec, uint64_t numValues, char* notNull, - const orc::Type* baseTp, int omniTypeId); + void next(char *data, uint64_t numValues, uint64_t *nulls); protected: - size_t remainingBits; - char lastByte; + size_t remainingBits; + char lastByte; + char reversedAndFlipLastByte; }; } #endif diff --git a/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniColReader.cc b/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniColReader.cc index 1335118385af6673a8c23b6de2a4a6c324998ac0..38b43fd643047d603233e8e90c6adeffc3bbe2c4 100644 --- a/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniColReader.cc +++ b/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniColReader.cc @@ -23,8 +23,12 @@ #include "orc/Writer.hh" #include "orc/MemoryPool.hh" #include "util/omni_exception.h" +#include "OrcDecodeUtils.hh" using omniruntime::vec::VectorBatch; +using omniruntime::vec::BaseVector; +using omniruntime::exception::OmniException; +using omniruntime::vec::NullsBuffer; using orc::ColumnReader; using orc::ByteRleDecoder; using orc::Type; @@ -43,13 +47,13 @@ namespace omniruntime::reader { switch (static_cast(kind)) { case orc::proto::ColumnEncoding_Kind_DIRECT: case orc::proto::ColumnEncoding_Kind_DICTIONARY: - return orc::RleVersion_1; + return orc::RleVersion_1; case orc::proto::ColumnEncoding_Kind_DIRECT_V2: case orc::proto::ColumnEncoding_Kind_DICTIONARY_V2: - return orc::RleVersion_2; - default: - throw omniruntime::exception::OmniException("EXPRESSION_NOT_SUPPORT", - "Unknown encoding in omniConvertRleVersion"); + return orc::RleVersion_2; + default: + throw omniruntime::exception::OmniException( + "EXPRESSION_NOT_SUPPORT", "Unknown encoding in omniConvertRleVersion"); } } @@ -57,26 +61,63 @@ namespace omniruntime::reader { RleVersion version, MemoryPool& pool) { switch (static_cast(version)) { case orc::RleVersion_1: - // should not use - throw omniruntime::exception::OmniException("EXPRESSION_NOT_SUPPORT", "RleVersion_1 should not use!!!"); + // should not use + throw omniruntime::exception::OmniException( + "EXPRESSION_NOT_SUPPORT", "RleVersion_1 Not supported yet"); case orc::RleVersion_2: - return std::unique_ptr(new OmniRleDecoderV2(std::move(input), - isSigned, pool)); - default: - throw omniruntime::exception::OmniException("EXPRESSION_NOT_SUPPORT", "Not implemented yet"); + return std::unique_ptr(new OmniRleDecoderV2(std::move(input), isSigned, pool)); + default: + throw omniruntime::exception::OmniException("EXPRESSION_NOT_SUPPORT", "Not implemented yet"); } } - std::unique_ptr createOmniBooleanRleDecoder - (std::unique_ptr input) { + std::unique_ptr createOmniBooleanRleDecoder(std::unique_ptr input) { OmniBooleanRleDecoder* decoder = new OmniBooleanRleDecoder(std::move(input)); return std::unique_ptr(reinterpret_cast(decoder)); } - std::unique_ptr createOmniByteRleDecoder - (std::unique_ptr input) { + std::unique_ptr createOmniByteRleDecoder(std::unique_ptr input) { return std::unique_ptr(new OmniByteRleDecoder(std::move(input))); } + + OmniColumnReader::OmniColumnReader(const Type &type, StripeStreams &stripe) + : ColumnReader(type, stripe) { + std::unique_ptr stream = stripe.getStream(columnId, orc::proto::Stream_Kind_PRESENT, true); + if (stream.get()) { + notNullDecoder = std::make_unique(std::move(stream)); + } + } + + uint64_t OmniColumnReader::skip(uint64_t numValues) { + if (notNullDecoder) { + // pass through the values that we want to skip and count how many are non-null + const uint64_t MAX_BUFFER_SIZE = 32768; + uint64_t bufferSize = std::min(MAX_BUFFER_SIZE, numValues); + // buffer, 0: null; 1: non-null + char buffer[MAX_BUFFER_SIZE]; + uint64_t remaining = numValues; + while (remaining > 0) { + uint64_t chunkSize = std::min(remaining, bufferSize); + notNullDecoder->next(buffer, chunkSize, nullptr); + remaining -= chunkSize; + // update non-null count + for (uint64_t i = 0; i < chunkSize; i++) { + if (!buffer[i]) { + // minus null + numValues -= 1; + } + } + } + } + return numValues; + } + + void OmniColumnReader::seekToRowGroup(std::unordered_map &positions) { + if (notNullDecoder) { + notNullDecoder->seek(positions.at(columnId)); + } + } + /** * Create a reader for the given stripe. @@ -88,91 +129,69 @@ namespace omniruntime::reader { case orc::INT: case orc::LONG: case orc::SHORT: - return std::unique_ptr( - new OmniIntegerColumnReader(type, stripe)); + return std::make_unique(type, stripe); case orc::BINARY: case orc::CHAR: case orc::STRING: case orc::VARCHAR: - switch (static_cast(stripe.getEncoding(type.getColumnId()).kind())){ + switch (static_cast(stripe.getEncoding(type.getColumnId()).kind())) { case orc::proto::ColumnEncoding_Kind_DICTIONARY: case orc::proto::ColumnEncoding_Kind_DICTIONARY_V2: - return std::unique_ptr( - new OmniStringDictionaryColumnReader(type, stripe)); + return std::make_unique(type, stripe); case orc::proto::ColumnEncoding_Kind_DIRECT: case orc::proto::ColumnEncoding_Kind_DIRECT_V2: - return std::unique_ptr( - new OmniStringDirectColumnReader(type, stripe)); + return std::make_unique(type, stripe); default: - throw omniruntime::exception::OmniException("EXPRESSION_NOT_SUPPORT", - "omniBuildReader unhandled string encoding"); - } + throw omniruntime::exception::OmniException( + "EXPRESSION_NOT_SUPPORT", "omniBuildReader unhandled string encoding"); + } - case orc::BOOLEAN: - return std::unique_ptr(new OmniBooleanColumnReader(type, stripe)); + case orc::BOOLEAN: + return std::make_unique(type, stripe); - case orc::BYTE: - return std::unique_ptr(new OmniByteColumnReader(type, stripe)); + case orc::BYTE: + return std::make_unique(type, stripe); case orc::STRUCT: - return std::unique_ptr( - new OmniStructColumnReader(type, stripe, julianPtr)); + return std::make_unique(type, stripe, julianPtr); case orc::TIMESTAMP: - return std::unique_ptr - (new OmniTimestampColumnReader(type, stripe, false, julianPtr)); + return std::make_unique(type, stripe, false, julianPtr); case orc::TIMESTAMP_INSTANT: - return std::unique_ptr - (new OmniTimestampColumnReader(type, stripe, true, julianPtr)); + return std::make_unique(type, stripe, true, julianPtr); case orc::DECIMAL: - // Is this a Hive 0.11 or 0.12 file? - if (type.getPrecision() == 0) { - return std::unique_ptr - (new OmniDecimalHive11ColumnReader(type, stripe)); - } else if (type.getPrecision() <= - OmniDecimal64ColumnReader::MAX_PRECISION_64) { - return std::unique_ptr - (new OmniDecimal64ColumnReader(type, stripe)); - } else { - return std::unique_ptr - (new OmniDecimal128ColumnReader(type, stripe)); - } + // Is this a Hive 0.11 or 0.12 file? + if (type.getPrecision() == 0) { + return std::make_unique(type, stripe); + } else if (type.getPrecision() <= OmniDecimal64ColumnReader::MAX_PRECISION_64) { + return std::make_unique(type, stripe); + } else { + return std::make_unique(type, stripe); + } - case orc::FLOAT: - case orc::DOUBLE: - return std::unique_ptr( - new OmniDoubleColumnReader(type, stripe)); + case orc::FLOAT: + case orc::DOUBLE: + return std::make_unique(type, stripe); - default: - throw omniruntime::exception::OmniException("EXPRESSION_NOT_SUPPORT", "omniBuildReader unhandled type"); + default: + throw omniruntime::exception::OmniException("EXPRESSION_NOT_SUPPORT", "omniBuildReader unhandled type"); } } - inline void readNulls(ColumnReader* colReader, uint64_t numValues, char* incomingMask, char* nulls, bool& hasNull) { - ByteRleDecoder* decoder = colReader->notNullDecoder.get(); - // TO do 需要将char*转换为bool*数组, 可以优化 - if (decoder) { - decoder->next(nulls, numValues, incomingMask); + inline void readNulls(OmniColumnReader *colReader, uint64_t numValues, uint64_t *incomingNulls, + uint8_t *nulls) { + if (colReader->notNullDecoder) { + colReader->notNullDecoder->nextNulls(reinterpret_cast(nulls), numValues, incomingNulls); // check to see if there are nulls in this batch - for(uint64_t i=0; i < numValues; ++i) { - auto ptr = nulls; - if (!(ptr[i])) { - // To do hasNull is protected - hasNull = true; - return; - } - } - } else if (incomingMask) { - // if we don't have a notNull stream, copy the incomingMask + } else if (incomingNulls) { + // if we don't have a notNull stream, copy the incomingNulls // To do finished - hasNull = true; - memcpy_s(nulls, numValues, incomingMask, numValues); + memcpy_s(reinterpret_cast(nulls), BitUtil::Nbytes(numValues), incomingNulls, + BitUtil::Nbytes(numValues)); return; } - // To do hasNull is protected - hasNull = false; } void scaleInt128(orc::Int128& value, uint32_t scale, uint32_t currentScale) { @@ -233,28 +252,29 @@ namespace omniruntime::reader { * OmniStructColumnReader funcs */ OmniStructColumnReader::OmniStructColumnReader(const Type& type, StripeStreams& stripe, - common::JulianGregorianRebase *julianPtr): ColumnReader(type, stripe) { + common::JulianGregorianRebase *julianPtr): OmniColumnReader(type, stripe) { // count the number of selected sub-columns const std::vector selectedColumns = stripe.getSelectedColumns(); switch (static_cast(stripe.getEncoding(columnId).kind())) { - case orc::proto::ColumnEncoding_Kind_DIRECT: - for(unsigned int i=0; i < type.getSubtypeCount(); ++i) { + case orc::proto::ColumnEncoding_Kind_DIRECT: + for(unsigned int i = 0; i < type.getSubtypeCount(); ++i) { const Type& child = *type.getSubtype(i); if (selectedColumns[static_cast(child.getColumnId())]) { children.push_back(omniBuildReader(child, stripe, julianPtr)); } } break; - case orc::proto::ColumnEncoding_Kind_DIRECT_V2: - case orc::proto::ColumnEncoding_Kind_DICTIONARY: - case orc::proto::ColumnEncoding_Kind_DICTIONARY_V2: + case orc::proto::ColumnEncoding_Kind_DIRECT_V2: + case orc::proto::ColumnEncoding_Kind_DICTIONARY: + case orc::proto::ColumnEncoding_Kind_DICTIONARY_V2: default: - throw omniruntime::exception::OmniException("EXPRESSION_NOT_SUPPORT", "Unknown encoding for OmniStructColumnReader"); + throw omniruntime::exception::OmniException( + "EXPRESSION_NOT_SUPPORT", "Unknown encoding for OmniStructColumnReader"); } } uint64_t OmniStructColumnReader::skip(uint64_t numValues) { - numValues = ColumnReader::skip(numValues); + numValues = OmniColumnReader::skip(numValues); for(auto& ptr : children) { ptr->skip(numValues); } @@ -262,17 +282,13 @@ namespace omniruntime::reader { } void OmniStructColumnReader::next(void *&batch, uint64_t numValues, char *notNull, - const orc::Type& baseTp, int* omniTypeId) { + const orc::Type& baseTp, int* omniTypeId) { auto vecs = reinterpret_cast*>(batch); - nextInternal(*vecs, numValues, notNull, baseTp, omniTypeId); - } - - void OmniStructColumnReader::nextEncoded(orc::ColumnVectorBatch& rowBatch, uint64_t numValues, char *notNull) { - throw omniruntime::exception::OmniException("EXPRESSION_NOT_SUPPORT", "OmniStructColumnReader::nextEncoded not finished!!!"); + nextInternal(*vecs, numValues, nullptr, baseTp, omniTypeId); } void OmniStructColumnReader::seekToRowGroup(std::unordered_map& positions) { - ColumnReader::seekToRowGroup(positions); + OmniColumnReader::seekToRowGroup(positions); for(auto& ptr : children) { ptr->seekToRowGroup(positions); @@ -281,34 +297,34 @@ namespace omniruntime::reader { template void OmniStructColumnReader::nextInternal(std::vector &vecs, uint64_t numValues, - char *notNull, const orc::Type& baseTp, int* omniTypeId) { + uint64_t *incomingNulls, const orc::Type& baseTp, int* omniTypeId) { + + if (encoded) { + std::string message("OmniStructColumnReader::nextInternal encoded is not finished!"); + throw omniruntime::exception::OmniException("EXPRESSION_NOT_SUPPORT", message); + } bool hasNull = false; - char nulls[numValues]; - readNulls(this, numValues, notNull, nulls, hasNull); + auto nulls = std::make_shared(numValues); + readNulls(this, numValues, incomingNulls, reinterpret_cast(nulls->GetNulls())); + uint64_t i = 0; - uint64_t i=0; - notNull = hasNull ? nulls : nullptr; for(auto iter = children.begin(); iter != children.end(); ++iter, ++i) { - if (encoded) { - std::string message("OmniStructColumnReader::nextInternal encoded is not finished!!!"); - throw omniruntime::exception::OmniException("EXPRESSION_NOT_SUPPORT", message); + const Type* orcType = baseTp.getSubtype(i); + omniruntime::type::DataTypeId dataTypeId; + if (omniTypeId == nullptr) { + dataTypeId = getDefaultOmniType(orcType); } else { - omniruntime::vec::BaseVector* tempVec = nullptr; - const Type* type = baseTp.getSubtype(i); - if (omniTypeId == nullptr) { - int tempOmniTypeId = getOmniTypeByOrcType(type); - (*iter)->next(reinterpret_cast(tempVec), numValues, notNull, *type, &tempOmniTypeId); - vecs.push_back(tempVec); - } else { - (*iter)->next(reinterpret_cast(tempVec), numValues, notNull, *type, &omniTypeId[i]); - vecs.push_back(tempVec); - } + dataTypeId = static_cast(omniTypeId[i]); } + auto omnivector = omniruntime::reader::makeNewVector(numValues, orcType, dataTypeId); + reinterpret_cast(&(*iter->get()))->next(omnivector.get(), numValues, + hasNull ? nulls->GetNulls() : nullptr, dataTypeId); + vecs.push_back(omnivector.release()); } } - omniruntime::type::DataTypeId OmniStructColumnReader::getOmniTypeByOrcType(const Type* type) { - constexpr int32_t OMNI_MAX_DECIMAL64_DIGITS = 18; + omniruntime::type::DataTypeId OmniStructColumnReader::getDefaultOmniType(const Type* type) { + constexpr int32_t OMNI_MAX_DECIMAL64_DIGITS = 18; switch (type->getKind()) { case orc::TypeKind::BOOLEAN: return omniruntime::type::OMNI_BOOLEAN; @@ -337,83 +353,81 @@ namespace omniruntime::reader { return omniruntime::type::OMNI_DECIMAL64; } default: - printf("no getOmniTypeByOrcType type to process!!!"); + throw omniruntime::exception::OmniException( + "EXPRESSION_NOT_SUPPORT", "Not Supported Type: " + type->getKind()); } - - throw omniruntime::exception::OmniException("EXPRESSION_NOT_SUPPORT", - "OmniStructColumnReader::getOmniTypeByOrcType no type!!!"); - return omniruntime::type::OMNI_INVALID; } /** * all next funcs */ - void OmniIntegerColumnReader::next(void*& vec, uint64_t numValues, char *notNull, - const orc::Type& baseTp, int* omniTypeId) { - bool HasNull = false; - char nulls[numValues]; - readNulls(this, numValues, notNull, nulls, HasNull); - rle->next(reinterpret_cast(vec), numValues, HasNull ? nulls : nullptr, &baseTp, - *omniTypeId); - } - - void OmniBooleanColumnReader::next(void*& vec, uint64_t numValues, char *notNull, - const orc::Type& baseTp, int* omniTypeId) { - bool HasNull = false; - char nulls[numValues]; - readNulls(this, numValues, notNull, nulls, HasNull); - rle->next(reinterpret_cast(vec), numValues, HasNull ? nulls : nullptr, &baseTp, - *omniTypeId); - } - - void OmniByteColumnReader::next(void*& vec, uint64_t numValues, char *notNull, - const orc::Type& baseTp, int* omniTypeId) { - bool HasNull = false; - char nulls[numValues]; - readNulls(this, numValues, notNull, nulls, HasNull); - rle->next(reinterpret_cast(vec), numValues, HasNull ? nulls : nullptr, &baseTp, - *omniTypeId); - } - - void OmniTimestampColumnReader::next(void*& omnivec, uint64_t numValues, char *notNull, - const orc::Type& baseTp, int* omniTypeId) { - auto dataTypeId = static_cast(*omniTypeId); + void OmniIntegerColumnReader::next(BaseVector *vec, uint64_t numValues, uint64_t *incomingNulls, + int omniTypeId) { + auto nulls = omniruntime::vec::unsafe::UnsafeBaseVector::GetNulls(vec); + readNulls(this, numValues, incomingNulls, nulls); + bool hasNull = vec->HasNull(); + rle->next(vec, numValues, hasNull ? reinterpret_cast(nulls) : nullptr, omniTypeId); + } + + void OmniBooleanColumnReader::next(BaseVector *vec, uint64_t numValues, uint64_t *incomingNulls, + int omniTypeId) { + auto nulls = omniruntime::vec::unsafe::UnsafeBaseVector::GetNulls(vec); + readNulls(this, numValues, incomingNulls, nulls); + bool hasNull = vec->HasNull(); + if (omniTypeId != omniruntime::type::OMNI_BOOLEAN) { + throw OmniException("EXPRESSION_NOT_SUPPORT", "Not Supported Type: " + omniTypeId); + } + OmniBooleanRleDecoder *boolDecoder = reinterpret_cast(rle.get()); + boolDecoder->next(vec, numValues, hasNull ? reinterpret_cast(nulls) : nullptr, omniTypeId); + } + + void OmniByteColumnReader::next(BaseVector *vec, uint64_t numValues, uint64_t *incomingNulls, + int omniTypeId) { + throw OmniException("EXPRESSION_NOT_SUPPORT", "Not Supported yet."); + } + + void OmniTimestampColumnReader::next(BaseVector *vec, uint64_t numValues, + uint64_t *incomingNulls, int omniTypeId) { + auto nulls = omniruntime::vec::unsafe::UnsafeBaseVector::GetNulls(vec); + readNulls(this, numValues, incomingNulls, nulls); + bool hasNull = vec->HasNull(); + auto dataTypeId = static_cast(omniTypeId); switch (dataTypeId) { - case omniruntime::type::OMNI_DATE32: - return nextByType(omnivec, numValues, notNull, baseTp, omniTypeId); - case omniruntime::type::OMNI_DATE64: - return nextByType(omnivec, numValues, notNull, baseTp, omniTypeId); - case omniruntime::type::OMNI_TIMESTAMP: - return nextByType(omnivec, numValues, notNull, baseTp, omniTypeId); + case omniruntime::type::OMNI_DATE32: { + auto intValues = omniruntime::vec::unsafe::UnsafeVector::GetRawValues( + static_cast*>(vec)); + return nextByType(intValues, numValues, hasNull ? reinterpret_cast(nulls) : nullptr); + } + case omniruntime::type::OMNI_DATE64: { + auto longValues = omniruntime::vec::unsafe::UnsafeVector::GetRawValues( + static_cast*>(vec)); + return nextByType(longValues, numValues, hasNull ? reinterpret_cast(nulls) : nullptr); + } + case omniruntime::type::OMNI_TIMESTAMP: { + auto longValues = omniruntime::vec::unsafe::UnsafeVector::GetRawValues( + static_cast*>(vec)); + return nextByType(longValues, numValues, hasNull ? reinterpret_cast(nulls) : nullptr); + } default: - throw omniruntime::exception::OmniException("EXPRESSION_NOT_SUPPORT", - "OmniTimestampColumnReader type not support!!!"); + throw omniruntime::exception::OmniException( + "EXPRESSION_NOT_SUPPORT", "OmniTimestampColumnReader type not support: " + dataTypeId); } } - template - void OmniTimestampColumnReader::nextByType(void*& omnivec, uint64_t numValues, char* notNull, - const orc::Type& baseTp, int* omniTypeId) { - using namespace omniruntime::type; - using T = typename NativeType::type; - auto vec = std::make_unique>(static_cast(numValues)); - - bool HasNull = false; - char nulls[numValues]; - readNulls(this, numValues, notNull, nulls, HasNull); - notNull = HasNull ? nulls : nullptr; + template + void OmniTimestampColumnReader::nextByType(T *data, uint64_t numValues, uint64_t *nulls) { int64_t secsBuffer[numValues]; - secondsRle->next(secsBuffer, numValues, notNull); + secondsRle->next(secsBuffer, numValues, nulls); int64_t nanoBuffer[numValues]; - nanoRle->next(nanoBuffer, numValues, notNull); + nanoRle->next(nanoBuffer, numValues, nulls); // Construct the values - for(uint64_t i=0; i < numValues; i++) { - if (notNull == nullptr || notNull[i]) { + for(uint64_t i = 0; i < numValues; i++) { + if (nulls == nullptr || !BitUtil::IsBitSet(nulls, i)) { uint64_t zeros = nanoBuffer[i] & 0x7; nanoBuffer[i] >>= 3; if (zeros != 0) { - for(uint64_t j = 0; j <= zeros; ++j) { + for (uint64_t j = 0; j <= zeros; ++j) { nanoBuffer[i] *= 10; } } @@ -437,143 +451,80 @@ namespace omniruntime::reader { } if (julianPtr != nullptr) { - vec->SetValue(static_cast(i), static_cast( - julianPtr->RebaseJulianToGregorianMicros(secsBuffer[i] * 1000000L + nanoBuffer[i] / 1000L))); + data[i] = static_cast( + julianPtr->RebaseJulianToGregorianMicros(secsBuffer[i] * 1000000L + nanoBuffer[i] / 1000L)); } else { - vec->SetValue(static_cast(i), static_cast(secsBuffer[i] * 1000000L + nanoBuffer[i] / 1000L)); + data[i] = static_cast(secsBuffer[i] * 1000000L + nanoBuffer[i] / 1000L); } - } else { - vec->SetNull(static_cast(i)); } } - - omnivec = vec.release(); } - void OmniDoubleColumnReader::next(void*& omnivec, uint64_t numValues, char* notNull, - const orc::Type& baseTp, int* omniTypeId) { - auto dataTypeId = static_cast(*omniTypeId); - std::unique_ptr tempOmnivec = makeNewVector(numValues, &baseTp, dataTypeId); - auto pushOmniVec = tempOmnivec.get(); + void OmniDoubleColumnReader::next(BaseVector *vec, uint64_t numValues, uint64_t *incomingNulls, + int omniTypeId) { + auto nulls = omniruntime::vec::unsafe::UnsafeBaseVector::GetNulls(vec); + readNulls(this, numValues, incomingNulls, nulls); + bool hasNull = vec->HasNull(); + auto dataTypeId = static_cast(omniTypeId); switch (dataTypeId) { - case omniruntime::type::OMNI_BOOLEAN: - nextByType - (pushOmniVec, numValues, notNull, baseTp); - break; - case omniruntime::type::OMNI_SHORT: - nextByType - (pushOmniVec, numValues, notNull, baseTp); - break; - case omniruntime::type::OMNI_INT: - nextByType - (pushOmniVec, numValues, notNull, baseTp); - break; - case omniruntime::type::OMNI_LONG: - nextByType - (pushOmniVec, numValues, notNull, baseTp); - break; - case omniruntime::type::OMNI_TIMESTAMP: - nextByType - (pushOmniVec, numValues, notNull, baseTp); - break; - case omniruntime::type::OMNI_DATE32: - nextByType - (pushOmniVec, numValues, notNull, baseTp); - break; - case omniruntime::type::OMNI_DATE64: - nextByType - (pushOmniVec, numValues, notNull, baseTp); - break; - case omniruntime::type::OMNI_DOUBLE: - nextByType - (pushOmniVec, numValues, notNull, baseTp); - break; - case omniruntime::type::OMNI_CHAR: - throw std::runtime_error("OmniDoubleColumnReader_type CHAR not finished!!!"); - case omniruntime::type::OMNI_VARCHAR: - throw std::runtime_error("OmniDoubleColumnReader_type VARCHAR not finished!!!"); - case omniruntime::type::OMNI_DECIMAL64: - throw std::runtime_error("OmniDoubleColumnReader_type DECIMAL64 not finished!!!"); - case omniruntime::type::OMNI_DECIMAL128: - throw std::runtime_error("OmniDoubleColumnReader_type DECIMAL64 not finished!!!"); + case omniruntime::type::OMNI_DOUBLE: { + auto doubleValues = omniruntime::vec::unsafe::UnsafeVector::GetRawValues( + static_cast*>(vec)); + return nextByType(doubleValues, numValues, hasNull ? reinterpret_cast(nulls) : nullptr); + } default: - printf("OmniDoubleColumnReader_type swtich no process!!!"); + throw omniruntime::exception::OmniException( + "EXPRESSION_NOT_SUPPORT", "OmniDoubleColumnReader type not support: " + dataTypeId); } - - omnivec = tempOmnivec.release(); } - template - void OmniDoubleColumnReader::nextByType(omniruntime::vec::BaseVector*& omnivec, uint64_t numValues, char* notNull, - const orc::Type& baseTp) { - bool HasNull = false; - char nulls[numValues]; - readNulls(this, numValues, notNull, nulls, HasNull); - // update the notNull from the parent class - notNull = HasNull ? nulls : nullptr; - - using namespace omniruntime::type; - using T = typename NativeType::type; - - auto vec = reinterpret_cast*>(omnivec); - + template + void OmniDoubleColumnReader::nextByType(T *data, uint64_t numValues, uint64_t *nulls) { if (columnKind == orc::FLOAT) { - if(notNull) { - for(size_t i=0; i < numValues; ++i) { - if(notNull[i]) { - vec->SetValue(static_cast(i), static_cast(readFloat())); - } else { - vec->SetNull(i); + if(nulls) { + for(size_t i = 0; i < numValues; ++i) { + if(!BitUtil::IsBitSet(nulls, i)) { + data[i] = static_cast(readFloat()); } } } else { - for(size_t i=0; i < numValues; ++i) { - vec->SetValue(static_cast(i), static_cast(readFloat())); + for(size_t i = 0; i < numValues; ++i) { + data[i] = static_cast(readFloat()); } } } else { - if (notNull) { - for(size_t i=0; i < numValues; ++i) { - if (notNull[i]) { - vec->SetValue(static_cast(i), static_cast(readDouble())); - } else { - vec->SetNull(i); + if (nulls) { + for(size_t i = 0; i < numValues; ++i) { + if (!BitUtil::IsBitSet(nulls, i)) { + data[i] = static_cast(readDouble()); } } } else { - for(size_t i=0; i < numValues; ++i) { - vec->SetValue(static_cast(i), static_cast(readDouble())); + for(size_t i = 0; i < numValues; ++i) { + data[i] = static_cast(readDouble()); } } } } - void OmniStringDictionaryColumnReader::next(void*& omnivec, uint64_t numValues, char* notNull, - const orc::Type& baseTp, int* omniTypeId) { - bool HasNull = false; - auto newVector = std::make_unique>>(numValues); - - char nulls[numValues]; - readNulls(this, numValues, notNull, nulls, HasNull); - // update the notNull from the parent class - notNull = HasNull ? nulls : nullptr; - - - bool is_char = false; - if (baseTp.getKind() == orc::TypeKind::CHAR) { - is_char = true; - } + void OmniStringDictionaryColumnReader::next(BaseVector *vec, uint64_t numValues, + uint64_t *incomingNulls, int omniTypeId) { + auto nulls = omniruntime::vec::unsafe::UnsafeBaseVector::GetNulls(vec); + readNulls(this, numValues, incomingNulls, nulls); + bool hasNull = vec->HasNull(); char *blob = dictionary->dictionaryBlob.data(); int64_t *dictionaryOffsets = dictionary->dictionaryOffset.data(); - + auto nullsTrans = reinterpret_cast(nulls); int64_t outputLengths[numValues]; - rle->next(outputLengths, numValues, notNull); + rle->next(outputLengths, numValues, nullsTrans); uint64_t dictionaryCount = dictionary->dictionaryOffset.size() - 1; - if (notNull) { + + auto varcharVector = reinterpret_cast>*>(vec); + if (hasNull) { for(uint64_t i=0; i < numValues; ++i) { - if (notNull[i]) { + if (!BitUtil::IsBitSet(nullsTrans, i)) { int64_t entry = outputLengths[i]; if (entry < 0 || static_cast(entry) >= dictionaryCount ) { throw orc::ParseError("Entry index out of range in StringDictionaryColumn"); @@ -583,13 +534,13 @@ namespace omniruntime::reader { auto len = dictionaryOffsets[entry+1] - dictionaryOffsets[entry]; char* ptr = blob + dictionaryOffsets[entry]; - if (is_char) { + if (isChar) { FindLastNotEmpty(ptr, len); } auto data = std::string_view(ptr, len); - newVector->SetValue(i, data); + varcharVector->SetValue(i, data); } else { - newVector->SetNull(i); + varcharVector->SetNull(i); } } } else { @@ -603,39 +554,28 @@ namespace omniruntime::reader { auto len = dictionaryOffsets[entry+1] - dictionaryOffsets[entry]; char* ptr = blob + dictionaryOffsets[entry]; - if (is_char) { + if (isChar) { FindLastNotEmpty(ptr, len); } auto data = std::string_view(ptr, len); - newVector->SetValue(i, data); + varcharVector->SetValue(i, data); } } - - omnivec = newVector.release(); } - void OmniStringDirectColumnReader::next(void*& omnivec, uint64_t numValues, char* notNull, - const orc::Type& baseTp, int* omniTypeId) { - bool HasNull = false; - auto newVector = std::make_unique>>(numValues); - char nulls[numValues]; - readNulls(this, numValues, notNull, nulls, HasNull); - // update the notNull from the parent class - notNull = HasNull ? nulls : nullptr; + void OmniStringDirectColumnReader::next(BaseVector *vec, uint64_t numValues, + uint64_t *incomingNulls, int omniTypeId) { + auto nulls = omniruntime::vec::unsafe::UnsafeBaseVector::GetNulls(vec); + readNulls(this, numValues, incomingNulls, nulls); + bool hasNull = vec->HasNull(); int64_t lengthPtr[numValues]; - - bool is_char = false; - if (baseTp.getKind() == orc::TypeKind::CHAR) { - is_char = true; - } - + auto nullsTrans = reinterpret_cast(nulls); // read the length vector - lengthRle->next(lengthPtr, numValues, notNull); + lengthRle->next(lengthPtr, numValues, nullsTrans); // figure out the total length of data we need from the blob stream - const size_t totalLength = computeSize(lengthPtr, notNull, numValues); + const size_t totalLength = computeSize(lengthPtr, nullsTrans, numValues); // Load data from the blob stream into our buffer until we have enough // to get the rest directly out of the stream's buffer. @@ -660,22 +600,24 @@ namespace omniruntime::reader { lastBufferLength -= moreBytes; } + auto varcharVector = reinterpret_cast>*>(vec); size_t filledSlots = 0; char* tempPtr = ptr; - if (notNull) { + if (hasNull) { while (filledSlots < numValues) { - if (notNull[filledSlots]) { + if (!BitUtil::IsBitSet(nullsTrans, filledSlots)) { //求出长度,如果为char,则需要去除最后的空格 auto len = lengthPtr[filledSlots]; - if (is_char) { + if (isChar) { FindLastNotEmpty(tempPtr, len); } auto data = std::string_view(tempPtr, len); - newVector->SetValue(filledSlots, data); + varcharVector->SetValue(filledSlots, data); tempPtr += lengthPtr[filledSlots]; } else { - newVector->SetNull(filledSlots); + varcharVector->SetNull(filledSlots); } filledSlots += 1; } @@ -683,106 +625,103 @@ namespace omniruntime::reader { while (filledSlots < numValues) { //求出长度,如果为char,则需要去除最后的空格 auto len = lengthPtr[filledSlots]; - if (is_char) { + if (isChar) { FindLastNotEmpty(tempPtr, len); } auto data = std::string_view(tempPtr, len); - newVector->SetValue(filledSlots, data); + varcharVector->SetValue(filledSlots, data); tempPtr += lengthPtr[filledSlots]; filledSlots += 1; } } - - omnivec = newVector.release(); } - void OmniDecimal64ColumnReader::next(void*& omnivec, uint64_t numValues, char* notNull, - const orc::Type& baseTp, int* omniTypeId) { - auto newVector = std::make_unique>(numValues); - bool HasNull = false; - char nulls[numValues]; - readNulls(this, numValues, notNull, nulls, HasNull); - notNull = HasNull ? nulls : nullptr; - // read the next group of scales - int64_t scaleBuffer[numValues]; - scaleDecoder->next(scaleBuffer, numValues, notNull); + void OmniDecimal64ColumnReader::next(BaseVector *vec, uint64_t numValues, + uint64_t *incomingNulls, int omniTypeId) { + auto nulls = omniruntime::vec::unsafe::UnsafeBaseVector::GetNulls(vec); + readNulls(this, numValues, incomingNulls, nulls); + bool hasNull = vec->HasNull(); - - if (notNull) { - for(size_t i=0; i < numValues; ++i) { - if (notNull[i]) { - int64_t value = 0; - readInt64(value, static_cast(scaleBuffer[i])); - newVector->SetValue(static_cast(i), static_cast(value)); - } else { - newVector->SetNull(static_cast(i)); + // read the next group of scales + auto nonNullNums = numValues - vec->GetNullCount(); + int64_t scaleBuffer[nonNullNums]; + // read dense scales + scaleDecoder->next(scaleBuffer, nonNullNums, nullptr); + auto nullsTrans = reinterpret_cast(nulls); + if (hasNull) { + auto vector = reinterpret_cast*>(vec); + + int64_t values[nonNullNums]; + ReadValuesBatch(values, nonNullNums); + UnZigZagBatchHEFs8p2(reinterpret_cast(values), nonNullNums); + UnScaleBatch(values, scaleBuffer, nonNullNums); + + int scaleIndex = 0; + for(size_t i = 0; i < numValues; ++i) { + if (!BitUtil::IsBitSet(nullsTrans, i)) { + vector->SetValue(i, values[scaleIndex++]); } } } else { - for(size_t i=0; i < numValues; ++i) { - int64_t value = 0; - readInt64(value, static_cast(scaleBuffer[i])); - newVector->SetValue(static_cast(i), static_cast(value)); - } + // special case for non null + auto values = omniruntime::vec::unsafe::UnsafeVector::GetRawValues( + static_cast*>(vec)); + ReadValuesBatch(values, nonNullNums); + UnZigZagBatchHEFs8p2(reinterpret_cast(values), nonNullNums); + UnScaleBatch(values, scaleBuffer, nonNullNums); } - - omnivec = newVector.release(); } - void OmniDecimal128ColumnReader::next(void*& omnivec, uint64_t numValues, char* notNull, - const orc::Type& baseTp, int* omniTypeId) { - auto newVector = std::make_unique>(numValues); - bool HasNull = false; - char nulls[numValues]; - readNulls(this, numValues, notNull, nulls, HasNull); - notNull = HasNull ? nulls : nullptr; + void OmniDecimal128ColumnReader::next(BaseVector *vec, uint64_t numValues, + uint64_t *incomingNulls, int omniTypeId) { + auto nulls = omniruntime::vec::unsafe::UnsafeBaseVector::GetNulls(vec); + readNulls(this, numValues, incomingNulls, nulls); + bool hasNull = vec->HasNull(); + // read the next group of scales int64_t scaleBuffer[numValues]; + auto nullsTrans = reinterpret_cast(nulls); + scaleDecoder->next(scaleBuffer, numValues, nullsTrans); - scaleDecoder->next(scaleBuffer, numValues, notNull); - if (notNull) { - for(size_t i=0; i < numValues; ++i) { - if (notNull[i]) { + auto vector = reinterpret_cast*>(vec); + if (hasNull) { + for(size_t i = 0; i < numValues; ++i) { + if (!BitUtil::IsBitSet(nullsTrans, i)) { orc::Int128 value = 0; readInt128(value, static_cast(scaleBuffer[i])); __int128_t dst = value.getHighBits(); dst <<= 64; dst |= value.getLowBits(); - newVector->SetValue(i, omniruntime::type::Decimal128(dst)); - } else { - newVector->SetNull(i); + vector->SetValue(i, omniruntime::type::Decimal128(dst)); } } } else { - for(size_t i=0; i < numValues; ++i) { + for(size_t i = 0; i < numValues; ++i) { orc::Int128 value = 0; readInt128(value, static_cast(scaleBuffer[i])); __int128_t dst = value.getHighBits(); dst <<= 64; dst |= value.getLowBits(); - newVector->SetValue(i, omniruntime::type::Decimal128(dst)); + vector->SetValue(i, omniruntime::type::Decimal128(dst)); } } - - omnivec = newVector.release(); } - void OmniDecimalHive11ColumnReader::next(void*&vec, uint64_t numValues, char* notNull, - const orc::Type& baseTp, int* omniTypeId) { - auto newVector = std::make_unique>(numValues); - bool HasNull = false; - char nulls[numValues]; - readNulls(this, numValues, notNull, nulls, HasNull); - notNull = HasNull ? nulls : nullptr; + void OmniDecimalHive11ColumnReader::next(BaseVector *vec, uint64_t numValues, + uint64_t *incomingNulls, int omniTypeId) { + auto nulls = omniruntime::vec::unsafe::UnsafeBaseVector::GetNulls(vec); + readNulls(this, numValues, incomingNulls, nulls); + bool hasNull = vec->HasNull(); // read the next group of scales int64_t scaleBuffer[numValues]; + auto nullsTrans = reinterpret_cast(nulls); + scaleDecoder->next(scaleBuffer, numValues, nullsTrans); - scaleDecoder->next(scaleBuffer, numValues, notNull); - - if (notNull) { + auto vector = reinterpret_cast*>(vec); + if (hasNull) { for (size_t i = 0; i < numValues; ++i) { - if (notNull[i]) { + if (!BitUtil::IsBitSet(nullsTrans, i)) { orc::Int128 value = 0; if (!readInt128(value, static_cast(scaleBuffer[i]))) { if (throwOnOverflow) { @@ -791,17 +730,15 @@ namespace omniruntime::reader { *errorStream << "Warning: " << "Hive 0.11 decimal with more than 38 digits " << "replaced by NULL. \n"; - newVector->SetNull(i); + vector->SetNull(i); } } else { __int128_t dst = value.getHighBits(); dst <<= 64; dst |= value.getLowBits(); - newVector->SetValue(i, omniruntime::type::Decimal128(dst)); + vector->SetValue(i, omniruntime::type::Decimal128(dst)); } - } else { - newVector->SetNull(i); - } + } } } else { for (size_t i = 0; i < numValues; ++i) { @@ -813,25 +750,24 @@ namespace omniruntime::reader { *errorStream << "Warning: " << "Hive 0.11 decimal with more than 38 digits " << "replaced by NULL. \n"; - newVector->SetNull(i); + vector->SetNull(i); } } else { __int128_t dst = value.getHighBits(); dst <<= 64; dst |= value.getLowBits(); - newVector->SetValue(i, omniruntime::type::Decimal128(dst)); + vector->SetValue(i, omniruntime::type::Decimal128(dst)); } } } } - OmniIntegerColumnReader::OmniIntegerColumnReader(const Type& type, - StripeStreams& stripe): ColumnReader(type, stripe) { + OmniIntegerColumnReader::OmniIntegerColumnReader(const Type& type, StripeStreams& stripe) + : OmniColumnReader(type, stripe) { RleVersion vers = omniConvertRleVersion(stripe.getEncoding(columnId).kind()); - std::unique_ptr stream = - stripe.getStream(columnId, orc::proto::Stream_Kind_DATA, true); + std::unique_ptr stream = stripe.getStream(columnId, orc::proto::Stream_Kind_DATA, true); if (stream == nullptr) - throw omniruntime::exception::OmniException("EXPRESSION_NOT_SUPPORT", "DATA stream not found in Integer column"); + throw OmniException("EXPRESSION_NOT_SUPPORT", "DATA stream not found in Integer column"); rle = createOmniRleDecoder(std::move(stream), true, vers, memoryPool); } @@ -840,14 +776,14 @@ namespace omniruntime::reader { } uint64_t OmniIntegerColumnReader::skip(uint64_t numValues) { - numValues = ColumnReader::skip(numValues); + numValues = OmniColumnReader::skip(numValues); rle->skip(numValues); return numValues; } void OmniIntegerColumnReader::seekToRowGroup( std::unordered_map& positions) { - ColumnReader::seekToRowGroup(positions); + OmniColumnReader::seekToRowGroup(positions); rle->seek(positions.at(columnId)); } @@ -876,7 +812,7 @@ namespace omniruntime::reader { OmniDecimal64ColumnReader::OmniDecimal64ColumnReader(const Type& type, StripeStreams& stripe - ): ColumnReader(type, stripe) { + ): OmniColumnReader(type, stripe) { scale = static_cast(type.getScale()); precision = static_cast(type.getPrecision()); valueStream = stripe.getStream(columnId, orc::proto::Stream_Kind_DATA, true); @@ -897,7 +833,7 @@ namespace omniruntime::reader { } uint64_t OmniDecimal64ColumnReader::skip(uint64_t numValues) { - numValues = ColumnReader::skip(numValues); + numValues = OmniColumnReader::skip(numValues); uint64_t skipped = 0; while (skipped < numValues) { readBuffer(); @@ -911,7 +847,7 @@ namespace omniruntime::reader { void OmniDecimal64ColumnReader::seekToRowGroup( std::unordered_map& positions) { - ColumnReader::seekToRowGroup(positions); + OmniColumnReader::seekToRowGroup(positions); valueStream->seek(positions.at(columnId)); scaleDecoder->seek(positions.at(columnId)); // clear buffer state after seek @@ -998,7 +934,7 @@ namespace omniruntime::reader { StripeStreams& stripe, bool isInstantType, common::JulianGregorianRebase *julianPtr - ): ColumnReader(type, stripe), + ): OmniColumnReader(type, stripe), writerTimezone(isInstantType ? orc::getTimezoneByName("GMT") : stripe.getWriterTimezone()), @@ -1025,7 +961,7 @@ namespace omniruntime::reader { } uint64_t OmniTimestampColumnReader::skip(uint64_t numValues) { - numValues = ColumnReader::skip(numValues); + numValues = OmniColumnReader::skip(numValues); secondsRle->skip(numValues); nanoRle->skip(numValues); return numValues; @@ -1033,13 +969,13 @@ namespace omniruntime::reader { void OmniTimestampColumnReader::seekToRowGroup( std::unordered_map& positions) { - ColumnReader::seekToRowGroup(positions); + OmniColumnReader::seekToRowGroup(positions); secondsRle->seek(positions.at(columnId)); nanoRle->seek(positions.at(columnId)); } OmniStringDirectColumnReader::OmniStringDirectColumnReader(const Type& type, StripeStreams& stripe) - : ColumnReader(type, stripe) { + : OmniColumnReader(type, stripe) { RleVersion rleVersion = omniConvertRleVersion(stripe.getEncoding(columnId).kind()); std::unique_ptr stream = stripe.getStream(columnId, orc::proto::Stream_Kind_LENGTH, true); @@ -1052,6 +988,9 @@ namespace omniruntime::reader { throw orc::ParseError("DATA stream not found in StringDirectColumn"); lastBuffer = nullptr; lastBufferLength = 0; + if (type.getKind() == orc::TypeKind::CHAR) { + isChar = true; + } } OmniStringDirectColumnReader::~OmniStringDirectColumnReader() { @@ -1060,7 +999,7 @@ namespace omniruntime::reader { uint64_t OmniStringDirectColumnReader::skip(uint64_t numValues) { const size_t BUFFER_SIZE = 1024; - numValues = ColumnReader::skip(numValues); + numValues = OmniColumnReader::skip(numValues); int64_t buffer[BUFFER_SIZE]; uint64_t done = 0; size_t totalBytes = 0; @@ -1091,18 +1030,17 @@ namespace omniruntime::reader { return numValues; } - size_t OmniStringDirectColumnReader::computeSize(const int64_t* lengths, - const char* notNull, - uint64_t numValues) { + size_t OmniStringDirectColumnReader::computeSize(const int64_t* lengths, uint64_t *nulls, + uint64_t numValues) { size_t totalLength = 0; - if (notNull) { - for(size_t i=0; i < numValues; ++i) { - if (notNull[i]) { + if (nulls) { + for(size_t i = 0; i < numValues; ++i) { + if (!BitUtil::IsBitSet(nulls, i)) { totalLength += static_cast(lengths[i]); } } } else { - for(size_t i=0; i < numValues; ++i) { + for(size_t i = 0; i < numValues; ++i) { totalLength += static_cast(lengths[i]); } } @@ -1111,7 +1049,7 @@ namespace omniruntime::reader { void OmniStringDirectColumnReader::seekToRowGroup( std::unordered_map& positions) { - ColumnReader::seekToRowGroup(positions); + OmniColumnReader::seekToRowGroup(positions); blobStream->seek(positions.at(columnId)); lengthRle->seek(positions.at(columnId)); // clear buffer state after seek @@ -1120,8 +1058,7 @@ namespace omniruntime::reader { } OmniStringDictionaryColumnReader::OmniStringDictionaryColumnReader(const Type& type, StripeStreams& stripe) - : ColumnReader(type, stripe), - dictionary(new orc::StringDictionary(stripe.getMemoryPool())) { + : OmniColumnReader(type, stripe), dictionary(new orc::StringDictionary(stripe.getMemoryPool())) { RleVersion rleVersion = omniConvertRleVersion(stripe.getEncoding(columnId) .kind()); uint32_t dictSize = stripe.getEncoding(columnId).dictionarysize(); @@ -1156,6 +1093,9 @@ namespace omniruntime::reader { "DICTIONARY_DATA stream not found in StringDictionaryColumn"); } omniReadFully(dictionary->dictionaryBlob.data(), blobSize, blobStream.get()); + if (type.getKind() == orc::TypeKind::CHAR) { + isChar = true; + } } OmniStringDictionaryColumnReader::~OmniStringDictionaryColumnReader() { @@ -1163,19 +1103,19 @@ namespace omniruntime::reader { } uint64_t OmniStringDictionaryColumnReader::skip(uint64_t numValues) { - numValues = ColumnReader::skip(numValues); + numValues = OmniColumnReader::skip(numValues); rle->skip(numValues); return numValues; } void OmniStringDictionaryColumnReader::seekToRowGroup( std::unordered_map& positions) { - ColumnReader::seekToRowGroup(positions); + OmniColumnReader::seekToRowGroup(positions); rle->seek(positions.at(columnId)); } - OmniBooleanColumnReader::OmniBooleanColumnReader(const orc::Type& type, - orc::StripeStreams& stripe): ColumnReader(type, stripe){ + OmniBooleanColumnReader::OmniBooleanColumnReader(const orc::Type& type, orc::StripeStreams& stripe) + : OmniColumnReader(type, stripe){ std::unique_ptr stream = stripe.getStream(columnId, orc::proto::Stream_Kind_DATA, true); if (stream == nullptr) @@ -1188,7 +1128,7 @@ namespace omniruntime::reader { } uint64_t OmniBooleanColumnReader::skip(uint64_t numValues) { - numValues = ColumnReader::skip(numValues); + numValues = OmniColumnReader::skip(numValues); rle->skip(numValues); return numValues; } @@ -1196,14 +1136,13 @@ namespace omniruntime::reader { void OmniBooleanColumnReader::seekToRowGroup( std::unordered_map& positions) { - ColumnReader::seekToRowGroup(positions); + OmniColumnReader::seekToRowGroup(positions); rle->seek(positions.at(columnId)); } - OmniByteColumnReader::OmniByteColumnReader(const Type& type, - StripeStreams& stripe - ): ColumnReader(type, stripe){ + OmniByteColumnReader::OmniByteColumnReader(const Type& type, StripeStreams& stripe) + : OmniColumnReader(type, stripe){ std::unique_ptr stream = stripe.getStream(columnId, orc::proto::Stream_Kind_DATA, true); if (stream == nullptr) @@ -1216,21 +1155,21 @@ namespace omniruntime::reader { } uint64_t OmniByteColumnReader::skip(uint64_t numValues) { - numValues = ColumnReader::skip(numValues); + numValues = OmniColumnReader::skip(numValues); rle->skip(numValues); return numValues; } void OmniByteColumnReader::seekToRowGroup( std::unordered_map& positions) { - ColumnReader::seekToRowGroup(positions); + OmniColumnReader::seekToRowGroup(positions); rle->seek(positions.at(columnId)); } OmniDoubleColumnReader::OmniDoubleColumnReader(const Type& type, StripeStreams& stripe - ): ColumnReader(type, stripe), + ): OmniColumnReader(type, stripe), columnKind(type.getKind()), bytesPerValue((type.getKind() == orc::FLOAT) ? 4 : 8), @@ -1246,7 +1185,7 @@ namespace omniruntime::reader { } uint64_t OmniDoubleColumnReader::skip(uint64_t numValues) { - numValues = ColumnReader::skip(numValues); + numValues = OmniColumnReader::skip(numValues); if (static_cast(bufferEnd - bufferPointer) >= bytesPerValue * numValues) { @@ -1269,7 +1208,7 @@ namespace omniruntime::reader { void OmniDoubleColumnReader::seekToRowGroup( std::unordered_map& positions) { - ColumnReader::seekToRowGroup(positions); + OmniColumnReader::seekToRowGroup(positions); inputStream->seek(positions.at(columnId)); // clear buffer state after seek bufferEnd = nullptr; diff --git a/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniColReader.hh b/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniColReader.hh index f3dbdeb6822dc2436cc8ed5f0023e8c883d2f9e3..6a08e694351c2aba1c2f7a1c97189b4b4998706b 100644 --- a/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniColReader.hh +++ b/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniColReader.hh @@ -29,50 +29,80 @@ #include "common/JulianGregorianRebase.h" namespace omniruntime::reader { - class OmniStructColumnReader: public orc::ColumnReader { + + class OmniColumnReader: public orc::ColumnReader { + public: + OmniColumnReader(const orc::Type& type, orc::StripeStreams& stripe); + + virtual ~OmniColumnReader() {} + + /** + * Skip number of specified rows. + */ + virtual uint64_t skip(uint64_t numValues); + + /** + * Read OmniVector in OmniTypeId, which contains specified rows. + */ + virtual void next( + omniruntime::vec::BaseVector *omniVector, + uint64_t numValues, + uint64_t *incomingNulls, + int omniTypeId) { + throw std::runtime_error("next() in base class should not be called"); + } + + /** + * Seek to beginning of a row group in the current stripe + * @param positions a list of PositionProviders storing the positions + */ + virtual void seekToRowGroup(std::unordered_map &positions); + + std::unique_ptr notNullDecoder; + }; + + class OmniStructColumnReader: public OmniColumnReader { private: - std::vector> children; + std::vector> children; public: - OmniStructColumnReader(const orc::Type& type, orc::StripeStreams& stipe, common::JulianGregorianRebase *julianPtr); + OmniStructColumnReader(const orc::Type& type, orc::StripeStreams& stipe, + common::JulianGregorianRebase *julianPtr); uint64_t skip(uint64_t numValues) override; /** * direct read VectorBatch in next * @param omniVecBatch the VectorBatch to push - * @param numValues the VectorBatch to push - * @param notNull the VectorBatch to push - * @param baseTp the vectorBatch to push + * @param numValues the numValues of VectorBatch + * @param notNull the notNull array indicates value not null + * @param baseTp the orc type * @param omniTypeId the omniTypeId to push */ void next(void *&omniVecBatch, uint64_t numValues, char *notNull, const orc::Type& baseTp, int* omniTypeId) override; - void nextEncoded(orc::ColumnVectorBatch& rowBatch, - uint64_t numValues, - char *notNull) override; - void seekToRowGroup( std::unordered_map& positions) override; private: /** - * direct read VectorBatch in next + * direct read VectorBatch in next for omni * @param omniVecBatch the VectorBatch to push - * @param numValues the VectorBatch to push - * @param notNull the VectorBatch to push - * @param baseTp the vectorBatch to push + * @param numValues the numValues of VectorBatch + * @param notNull the notNull array indicates value not null + * @param baseTp the orc type * @param omniTypeId the omniTypeId to push */ template - void nextInternal(std::vector &vecs, uint64_t numValues, char *notNull, - const orc::Type& baseTp, int* omniTypeId); + void nextInternal(std::vector &vecs, uint64_t numValues, + uint64_t *incomingNulls, const orc::Type& baseTp, int* omniTypeId); - omniruntime::type::DataTypeId getOmniTypeByOrcType(const orc::Type* type); + // Get default omni type from orc type. + omniruntime::type::DataTypeId getDefaultOmniType(const orc::Type *type); }; - class OmniBooleanColumnReader: public orc::ColumnReader { + class OmniBooleanColumnReader: public OmniColumnReader { protected: std::unique_ptr rle; @@ -82,13 +112,13 @@ namespace omniruntime::reader { uint64_t skip(uint64_t numValues) override; - void next(void*& vec, uint64_t numValues, char *notNull, - const orc::Type& baseTp, int* omniTypeId) override; + void next(omniruntime::vec::BaseVector *vec, uint64_t numValues, + uint64_t *incomingNulls, int omniTypeId) override; void seekToRowGroup(std::unordered_map& positions) override; }; - class OmniByteColumnReader: public orc::ColumnReader { + class OmniByteColumnReader: public OmniColumnReader { protected: std::unique_ptr rle; @@ -98,14 +128,13 @@ namespace omniruntime::reader { uint64_t skip(uint64_t numValues) override; - void next(void*& vec, uint64_t numValues, char *notNull, - const orc::Type& baseTp, int* omniTypeId) override; + void next(omniruntime::vec::BaseVector *vec, uint64_t numValues, + uint64_t *incomingNulls, int omniTypeId) override; - void seekToRowGroup( - std::unordered_map& positions) override; + void seekToRowGroup(std::unordered_map& positions) override; }; - class OmniIntegerColumnReader: public orc::ColumnReader { + class OmniIntegerColumnReader: public OmniColumnReader { protected: std::unique_ptr rle; @@ -115,20 +144,19 @@ namespace omniruntime::reader { uint64_t skip(uint64_t numValues) override; - void next(void*& vec, uint64_t numValues, char *notNull, - const orc::Type& baseTp, int* omniTypeId) override; + void next(omniruntime::vec::BaseVector *vec, uint64_t numValues, + uint64_t *incomingNulls, int omniTypeId) override; - void seekToRowGroup( - std::unordered_map& positions) override; + void seekToRowGroup(std::unordered_map& positions) override; }; - class OmniTimestampColumnReader: public orc::ColumnReader { + class OmniTimestampColumnReader: public OmniColumnReader { public: - void next(void*& vec, uint64_t numValues, char* notNull, - const orc::Type& baseTp, int* omniTypeId) override; + void next(omniruntime::vec::BaseVector *vec, uint64_t numValues, + uint64_t *incomingNulls, int omniTypeId) override; - template - void nextByType(void*& omnivec, uint64_t numValues, char* notNull, const orc::Type& baseTp, int* omniTypeId); + template + void nextByType(T *data, uint64_t numValues, uint64_t *nulls); private: std::unique_ptr secondsRle; @@ -147,23 +175,21 @@ namespace omniruntime::reader { uint64_t skip(uint64_t numValues) override; - void seekToRowGroup( - std::unordered_map& positions) override; + void seekToRowGroup(std::unordered_map& positions) override; }; - class OmniDoubleColumnReader: public orc::ColumnReader { + class OmniDoubleColumnReader: public OmniColumnReader { public: OmniDoubleColumnReader(const orc::Type& type, orc::StripeStreams& stripe); ~OmniDoubleColumnReader() override; uint64_t skip(uint64_t numValues) override; - void next(void*& vec, uint64_t numValues, char* notNull, - const orc::Type& baseTp, int* omniTypeId) override; + void next(omniruntime::vec::BaseVector *vec, uint64_t numValues, + uint64_t *incomingNulls, int omniTypeId) override; - template - void nextByType(omniruntime::vec::BaseVector*& omnivec, uint64_t numValues, char* notNull, - const orc::Type& baseTp); + template + void nextByType(T *data, uint64_t numValues, uint64_t *nulls); void seekToRowGroup( std::unordered_map& positions) override; @@ -207,14 +233,15 @@ namespace omniruntime::reader { }; - class OmniStringDictionaryColumnReader: public orc::ColumnReader { + class OmniStringDictionaryColumnReader: public OmniColumnReader { public: - void next(void*& vec, uint64_t numValues, char* notNull, - const orc::Type& baseTp, int* omniTypeId) override; + void next(omniruntime::vec::BaseVector *vec, uint64_t numValues, + uint64_t *incomingNulls, int omniTypeId) override; private: std::shared_ptr dictionary; std::unique_ptr rle; + bool isChar = false; public: OmniStringDictionaryColumnReader(const orc::Type& type, orc::StripeStreams& stipe); @@ -222,30 +249,30 @@ namespace omniruntime::reader { uint64_t skip(uint64_t numValues) override; - void seekToRowGroup( - std::unordered_map& positions) override; + void seekToRowGroup(std::unordered_map& positions) override; }; - class OmniStringDirectColumnReader: public orc::ColumnReader { + class OmniStringDirectColumnReader: public OmniColumnReader { public: - void next(void*& vec, uint64_t numValues, char* notNull, - const orc::Type& baseTp, int* omniTypeId) override; + void next(omniruntime::vec::BaseVector *vec, uint64_t numValues, + uint64_t *incomingNulls, int omniTypeId) override; private: std::unique_ptr lengthRle; std::unique_ptr blobStream; const char *lastBuffer; size_t lastBufferLength; + bool isChar = false; /** * Compute the total length of the values. * @param lengths the array of lengths - * @param notNull the array of notNull flags + * @param nulls the bits of nulls flags * @param numValues the lengths of the arrays * @return the total number of bytes for the non-null values */ - size_t computeSize(const int64_t *lengths, const char *notNull, - uint64_t numValues); + size_t computeSize(const int64_t *lengths, uint64_t *nulls, + uint64_t numValues); public: OmniStringDirectColumnReader(const orc::Type& type, orc::StripeStreams& stipe); @@ -257,10 +284,10 @@ namespace omniruntime::reader { std::unordered_map& positions) override; }; - class OmniDecimal64ColumnReader: public orc::ColumnReader { + class OmniDecimal64ColumnReader: public OmniColumnReader { public: - void next(void*&vec, uint64_t numValues, char* notNull, - const orc::Type& baseTp, int* omniTypeId) override; + void next(omniruntime::vec::BaseVector *vec, uint64_t numValues, + uint64_t *incomingNulls, int omniTypeId) override; public: static const uint32_t MAX_PRECISION_64 = 18; @@ -315,6 +342,38 @@ namespace omniruntime::reader { } } + void ReadValuesBatch(int64_t *data, int64_t numValues) { + for (int i = 0; i < numValues; ++i) { + int64_t value = 0; + size_t offset = 0; + while (true) { + readBuffer(); + unsigned char ch = static_cast(*(buffer++)); + value |= static_cast(ch & 0x7f) << offset; + offset += 7; + if (!(ch & 0x80)) { + break; + } + } + data[i] = value; + } + } + + void UnScaleBatch(int64_t *data, int64_t *scales, int64_t numValues) { + for (int i = 0; i < numValues; i++) { + int32_t currentScale = static_cast(scales[i]); + if (scale > currentScale && + static_cast(scale - currentScale) <= MAX_PRECISION_64) { + data[i] *= POWERS_OF_TEN[scale - currentScale]; + } else if (scale < currentScale && + static_cast(currentScale - scale) <= MAX_PRECISION_64) { + data[i] /= POWERS_OF_TEN[currentScale - scale]; + } else if (scale != currentScale) { + throw orc::ParseError("Decimal scale out of range"); + } + } + } + public: OmniDecimal64ColumnReader(const orc::Type& type, orc::StripeStreams& stipe); ~OmniDecimal64ColumnReader() override; @@ -327,8 +386,8 @@ namespace omniruntime::reader { class OmniDecimal128ColumnReader : public OmniDecimal64ColumnReader { public: - void next(void*&vec, uint64_t numValues, char* notNull, - const orc::Type& baseTp, int* omniTypeId) override; + void next(omniruntime::vec::BaseVector *vec, uint64_t numValues, + uint64_t *incomingNulls, int omniTypeId) override; public: OmniDecimal128ColumnReader(const orc::Type& type, orc::StripeStreams& stipe); @@ -349,8 +408,8 @@ namespace omniruntime::reader { OmniDecimalHive11ColumnReader(const orc::Type& type, orc::StripeStreams& stipe); ~OmniDecimalHive11ColumnReader() override; - void next(void*&vec, uint64_t numValues, char* notNull, - const orc::Type& baseTp, int* omniTypeId) override; + void next(omniruntime::vec::BaseVector *vec, uint64_t numValues, + uint64_t *incomingNulls, int omniTypeId) override; }; std::unique_ptr omniBuildReader(const orc::Type& type, diff --git a/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniRLEv2.cc b/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniRLEv2.cc index a5c515d3cb11759f7950d4e143c25b70627abe0b..2184d32da4b3bc5b1b323408696e642810bed74d 100644 --- a/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniRLEv2.cc +++ b/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniRLEv2.cc @@ -18,27 +18,27 @@ #include "OmniRLEv2.hh" #include "orc/RLEV2Util.hh" +#include "OrcDecodeUtils.hh" using omniruntime::vec::VectorBatch; namespace omniruntime::reader { - std::unique_ptr makeFixLenthVector(uint64_t numValues, + std::unique_ptr makeFixedLengthVector(uint64_t numValues, omniruntime::type::DataTypeId dataTypeId) { switch (dataTypeId) { - case omniruntime::type::OMNI_BOOLEAN: - return std::make_unique>(numValues); - case omniruntime::type::OMNI_SHORT: - return std::make_unique>(numValues); - case omniruntime::type::OMNI_INT: - return std::make_unique>(numValues); - case omniruntime::type::OMNI_LONG: - return std::make_unique>(numValues); - case omniruntime::type::OMNI_DATE32: - return std::make_unique>(numValues); - case omniruntime::type::OMNI_DATE64: - return std::make_unique>(numValues); - default: - throw std::runtime_error("Not support for this type: " + dataTypeId); + case omniruntime::type::OMNI_BOOLEAN: + return std::make_unique>(numValues); + case omniruntime::type::OMNI_SHORT: + return std::make_unique>(numValues); + case omniruntime::type::OMNI_INT: + case omniruntime::type::OMNI_DATE32: + return std::make_unique>(numValues); + case omniruntime::type::OMNI_LONG: + case omniruntime::type::OMNI_DATE64: + case omniruntime::type::OMNI_TIMESTAMP: + return std::make_unique>(numValues); + default: + throw std::runtime_error("MakeFixedLengthVector Not support for this type: " + dataTypeId); } } @@ -48,7 +48,31 @@ namespace omniruntime::reader { case omniruntime::type::OMNI_DOUBLE: return std::make_unique>(numValues); default: - throw std::runtime_error("Not support double vector for this type: " + dataTypeId); + throw std::runtime_error("MakeDoubleVector Not support double vector for this type: " + dataTypeId); + } + } + + std::unique_ptr makeDecimalVector(uint64_t numValues, + omniruntime::type::DataTypeId dataTypeId) { + switch (dataTypeId) { + case omniruntime::type::OMNI_DECIMAL64: + return std::make_unique>(numValues); + case omniruntime::type::OMNI_DECIMAL128: + return std::make_unique>(numValues); + default: + throw std::runtime_error("makeDecimalVector Not support vector for this type: " + dataTypeId); + } + } + + std::unique_ptr makeVarcharVector(uint64_t numValues, + omniruntime::type::DataTypeId dataTypeId) { + switch (dataTypeId) { + case omniruntime::type::OMNI_CHAR: + case omniruntime::type::OMNI_VARCHAR: + return std::make_unique>>(numValues); + default: + throw std::runtime_error("MakeVarcharVector Not support vector for this type: " + dataTypeId); } } @@ -62,36 +86,75 @@ namespace omniruntime::reader { case orc::TypeKind::TIMESTAMP: case orc::TypeKind::TIMESTAMP_INSTANT: case orc::TypeKind::LONG: - return makeFixLenthVector(numValues, dataTypeId); + return makeFixedLengthVector(numValues, dataTypeId); case orc::TypeKind::DOUBLE: return makeDoubleVector(numValues, dataTypeId); case orc::TypeKind::CHAR: - throw std::runtime_error("CHAR not finished!!!"); case orc::TypeKind::STRING: case orc::TypeKind::VARCHAR: - throw std::runtime_error("VARCHAR not finished!!!"); + return makeVarcharVector(numValues, dataTypeId); case orc::TypeKind::DECIMAL: - throw std::runtime_error("DECIMAL should not in here!!!"); + return makeDecimalVector(numValues, dataTypeId); default: { - throw std::runtime_error("Not support For This Type: " + baseTp->getKind()); + throw std::runtime_error("Not support For This ORC Type: " + baseTp->getKind()); } } } - void OmniRleDecoderV2::next(omniruntime::vec::BaseVector*& omnivec, uint64_t numValues, char* notNull, - const orc::Type* baseTp, int omniTypeId) { - uint64_t nRead = 0; + void OmniRleDecoderV2::next(int64_t *data, uint64_t numValues, uint64_t *nulls) { + next(data, numValues, nulls); + } - auto dataTypeId = static_cast(omniTypeId); - std::unique_ptr tempOmnivec = makeNewVector(numValues, baseTp, dataTypeId); - auto pushOmniVec = tempOmnivec.get(); + void OmniRleDecoderV2::next(int32_t *data, uint64_t numValues, uint64_t *nulls) { + next(data, numValues, nulls); + } + + void OmniRleDecoderV2::next(int16_t *data, uint64_t numValues, uint64_t *nulls) { + next(data, numValues, nulls); + } + + void OmniRleDecoderV2::next(bool *data, uint64_t numValues, uint64_t *nulls) { + next(data, numValues, nulls); + } + + void OmniRleDecoderV2::next(omniruntime::vec::BaseVector *omnivec, uint64_t numValues, + uint64_t *nulls, int omniTypeId) { + switch (omniTypeId) { + case omniruntime::type::OMNI_BOOLEAN: { + auto boolValues = omniruntime::vec::unsafe::UnsafeVector::GetRawValues( + static_cast*>(omnivec)); + return next(boolValues, numValues, nulls); + } + case omniruntime::type::OMNI_SHORT: { + auto shortValues = omniruntime::vec::unsafe::UnsafeVector::GetRawValues( + static_cast*>(omnivec)); + return next(shortValues, numValues, nulls); + } + case omniruntime::type::OMNI_INT: + case omniruntime::type::OMNI_DATE32: { + auto intValues = omniruntime::vec::unsafe::UnsafeVector::GetRawValues( + static_cast*>(omnivec)); + return next(intValues, numValues, nulls); + } + case omniruntime::type::OMNI_LONG: + case omniruntime::type::OMNI_DATE64: { + auto longValues = omniruntime::vec::unsafe::UnsafeVector::GetRawValues( + static_cast*>(omnivec)); + return next(longValues, numValues, nulls); + } + default: + throw std::runtime_error("OmniRleDecoderV2 Not support For This Type: " + omniTypeId); + } + } + + template + void OmniRleDecoderV2::next(T *data, uint64_t numValues, uint64_t *nulls) { + uint64_t nRead = 0; while (nRead < numValues) { // SKip any nulls before attempting to read first byte. - while (notNull && !notNull[nRead]) { - tempOmnivec->SetNull(nRead); + while (nulls && BitUtil::IsBitSet(nulls, nRead)) { if (++nRead == numValues) { - omnivec = tempOmnivec.release(); return; //ended with null values } } @@ -104,30 +167,28 @@ namespace omniruntime::reader { uint64_t offset = nRead, length = numValues - nRead; orc::EncodingType enc = static_cast((firstByte >> 6) & 0x03); - switch (static_cast(enc)) { + switch (static_cast(enc)) { case orc::SHORT_REPEAT: - nRead += nextShortRepeatsByType(pushOmniVec, offset, length, notNull, dataTypeId); + nRead += nextShortRepeats(data, offset, length, nulls); break; case orc::DIRECT: - nRead += nextDirect(pushOmniVec, offset, length, notNull, dataTypeId); + nRead += nextDirect(data, offset, length, nulls); break; case orc::PATCHED_BASE: - nRead += nextPatchedByType(pushOmniVec, offset, length, notNull, dataTypeId); + nRead += nextPatched(data, offset, length, nulls); break; case orc::DELTA: - nRead += nextDeltaByType(pushOmniVec, offset, length, notNull, dataTypeId); + nRead += nextDelta(data, offset, length, nulls); break; default: throw orc::ParseError("unknown encoding"); } } - - omnivec = tempOmnivec.release(); } - uint64_t OmniRleDecoderV2::nextDirect(omniruntime::vec::BaseVector*& OmniVec, - uint64_t offset, uint64_t numValues, const char* const notNull, - omniruntime::type::DataTypeId dataTypeId) { + template + uint64_t OmniRleDecoderV2::nextDirect(T *data, uint64_t offset, uint64_t numValues, + uint64_t *nulls) { if (runRead == runLength) { // extract the number of fixed bits unsigned char fbo = (firstByte >> 1) & 0x1f; @@ -140,65 +201,19 @@ namespace omniruntime::reader { runLength += 1; runRead = 0; - readLongsByType(OmniVec, 0, runLength, offset , numValues, bitSize, dataTypeId, notNull); + readLongs(literals.data(), 0, runLength, bitSize); if (isSigned) { - for (uint64_t i = 0; i < runLength; ++i) { - literals[i] = orc::unZigZag(static_cast(literals[i])); - } + UnZigZagBatchHEFs8p2(reinterpret_cast(literals.data()), runLength); } } - return copyDataFromBufferByType(OmniVec, offset, numValues, notNull, dataTypeId); - } - - uint64_t OmniRleDecoderV2::nextShortRepeatsByType(omniruntime::vec::BaseVector*& OmniVec, - uint64_t offset, uint64_t numValues, const char* const notNull, - omniruntime::type::DataTypeId dataTypeId) { - switch (dataTypeId) { - case omniruntime::type::OMNI_BOOLEAN: - return nextShortRepeats - (OmniVec, offset, numValues, notNull); - case omniruntime::type::OMNI_SHORT: - return nextShortRepeats - (OmniVec, offset, numValues, notNull); - case omniruntime::type::OMNI_INT: - return nextShortRepeats - (OmniVec, offset, numValues, notNull); - case omniruntime::type::OMNI_LONG: - return nextShortRepeatsLongType - (OmniVec, offset, numValues, notNull); - case omniruntime::type::OMNI_DATE32: - return nextShortRepeats - (OmniVec, offset, numValues, notNull); - case omniruntime::type::OMNI_DATE64: - return nextShortRepeatsLongType - (OmniVec, offset, numValues, notNull); - case omniruntime::type::OMNI_DOUBLE: - return nextShortRepeats - (OmniVec, offset, numValues, notNull); - case omniruntime::type::OMNI_CHAR: - throw std::runtime_error("nextShortRepeats_type CHAR not finished!!!"); - case omniruntime::type::OMNI_VARCHAR: - throw std::runtime_error("nextShortRepeats_type VARCHAR not finished!!!"); - case omniruntime::type::OMNI_DECIMAL64: - throw std::runtime_error("nextShortRepeats_type DECIMAL64 should not in here!!!"); - case omniruntime::type::OMNI_DECIMAL128: - throw std::runtime_error("nextShortRepeats_type DECIMAL128 should not in here!!!"); - default: - printf("nextShortRepeats_type switch no process!!!"); - } - - return 0; + return copyDataFromBuffer(data, offset, numValues, nulls); } - template - uint64_t OmniRleDecoderV2::nextShortRepeats(omniruntime::vec::BaseVector*& OmniVec, - uint64_t offset, uint64_t numValues, const char* const notNull) { - using namespace omniruntime::type; - using T = typename NativeType::type; - auto vec = reinterpret_cast*>(OmniVec); - + template + uint64_t OmniRleDecoderV2::nextShortRepeats(T *data, uint64_t offset, uint64_t numValues, + uint64_t *nulls) { if (runRead == runLength) { // extract the number of fixed bytes uint64_t byteSize = (firstByte >> 3) & 0x07; @@ -218,117 +233,98 @@ namespace omniruntime::reader { } uint64_t nRead = std::min(runLength - runRead, numValues); - - if (notNull) { - for(uint64_t pos = offset; pos < offset + nRead; ++pos) { - if (notNull[pos]) { - vec->SetValue(static_cast(pos), static_cast(literals[0])); + if constexpr (std::is_same_v) { + if (nulls) { + uint64_t i = offset; + uint64_t end = offset + nRead; + uint64_t skipNum = std::min(BitUtil::Nbytes(offset) * 8 - offset, nRead); + for (; i < offset + skipNum; i++) { + if (!BitUtil::IsBitSet(nulls, i)) { + data[i] = literals[0]; + runRead++; + } + } + uint8_t mask; + for (; i + 8 <= end; i += 8) { + mask = reinterpret_cast(nulls)[i / 8]; + if (UNLIKELY(mask == 255)) { + continue; + } + if (mask == 0) { + data[i] = literals[0]; + data[i + 1] = literals[0]; + data[i + 2] = literals[0]; + data[i + 3] = literals[0]; + data[i + 4] = literals[0]; + data[i + 5] = literals[0]; + data[i + 6] = literals[0]; + data[i + 7] = literals[0]; + runRead += 8; + continue; + } + auto *maskArr = notNullBitMask[mask]; + for (int j = 1; j <= *maskArr; j++) { + auto notNullIndex = i + maskArr[j]; + data[notNullIndex] = literals[0]; + runRead++; + } + } + for (; i < end; i++) { + if (!BitUtil::IsBitSet(nulls, i)) { + data[i] = literals[0]; + runRead++; + } + } + } else { + for(uint64_t pos = offset; pos < offset + nRead; ++pos) { + data[pos] = literals[0]; ++runRead; - } else { - vec->SetNull(static_cast(pos)); } } } else { - for(uint64_t pos = offset; pos < offset + nRead; ++pos) { - vec->SetValue(static_cast(pos), static_cast(literals[0])); - ++runRead; + if (nulls) { + uint64_t i = offset; + uint64_t end = offset + nRead; + uint64_t skipNum = std::min(BitUtil::Nbytes(offset) * 8 - offset, nRead); + for (; i < offset + skipNum; i++) { + if (!BitUtil::IsBitSet(nulls, i)) { + data[i] = static_cast(literals[0]); + runRead++; + } + } + uint8_t mask; + for (; i + 8 <= end; i += 8) { + mask = reinterpret_cast(nulls)[i / 8]; + if (UNLIKELY(mask == 255)) { + continue; + } + auto *maskArr = notNullBitMask[mask]; + for (int j = 1; j <= *maskArr; j++) { + auto notNullIndex = i + maskArr[j]; + data[notNullIndex] = static_cast(literals[0]); + runRead++; + } + } + for (; i < end; i++) { + if (!BitUtil::IsBitSet(nulls, i)) { + data[i] = static_cast(literals[0]); + runRead++; + } + } + } else { + for(uint64_t pos = offset; pos < offset + nRead; ++pos) { + data[pos] = static_cast(literals[0]); + ++runRead; + } } } return nRead; } - template - uint64_t OmniRleDecoderV2::nextShortRepeatsLongType(omniruntime::vec::BaseVector*& OmniVec, - uint64_t offset, uint64_t numValues, const char* const notNull) { - using namespace omniruntime::type; - using T = typename NativeType::type; - auto vec = reinterpret_cast*>(OmniVec); - - if (runRead == runLength) { - // extract the number of fixed bytes - uint64_t byteSize = (firstByte >> 3) & 0x07; - byteSize += 1; - - runLength = firstByte & 0x07; - // run lengths values are stored only after MIN_REPEAT value is met - runLength += MIN_REPEAT; - runRead = 0; - - // read the repeated value which is store using fixed bytes - literals[0] = readLongBE(byteSize); - - if (isSigned) { - literals[0] = orc::unZigZag(static_cast(literals[0])); - } - } - - uint64_t nRead = std::min(runLength - runRead, numValues); - - if (notNull) { - for(uint64_t pos = offset; pos < offset + nRead; ++pos) { - if (notNull[pos]) { - vec->SetValue(pos, static_cast(literals[0])); - ++runRead; - } else { - vec->SetNull(pos); - } - } - } else { - int64_t values[nRead]; - std::fill(values, values + nRead, literals[0]); - vec->SetValues(offset, values, nRead); - runRead += nRead; - } - - return nRead; - } - - - uint64_t OmniRleDecoderV2::nextPatchedByType(omniruntime::vec::BaseVector*& OmniVec, - uint64_t offset, uint64_t numValues, const char* const notNull, - omniruntime::type::DataTypeId dataTypeId) { - switch (dataTypeId) { - case omniruntime::type::OMNI_BOOLEAN: - return nextPatched - (OmniVec, offset, numValues, notNull, dataTypeId); - case omniruntime::type::OMNI_SHORT: - return nextPatched - (OmniVec, offset, numValues, notNull, dataTypeId); - case omniruntime::type::OMNI_INT: - return nextPatched - (OmniVec, offset, numValues, notNull, dataTypeId); - case omniruntime::type::OMNI_LONG: - return nextPatched - (OmniVec, offset, numValues, notNull, dataTypeId); - case omniruntime::type::OMNI_DATE32: - return nextPatched - (OmniVec, offset, numValues, notNull, dataTypeId); - case omniruntime::type::OMNI_DATE64: - return nextPatched - (OmniVec, offset, numValues, notNull, dataTypeId); - case omniruntime::type::OMNI_DOUBLE: - return nextPatched - (OmniVec, offset, numValues, notNull, dataTypeId); - case omniruntime::type::OMNI_CHAR: - throw std::runtime_error("nextPatched_type CHAR not finished!!!"); - case omniruntime::type::OMNI_VARCHAR: - throw std::runtime_error("nextPatched_type VARCHAR not finished!!!"); - case omniruntime::type::OMNI_DECIMAL64: - throw std::runtime_error("nextPatched_type DECIMAL64 should not in here!!!"); - case omniruntime::type::OMNI_DECIMAL128: - throw std::runtime_error("nextPatched_type DECIMAL128 should not in here!!!"); - default: - printf("nextPatched_type switch no process!!!"); - } - - return 0; - } - - template - uint64_t OmniRleDecoderV2::nextPatched(omniruntime::vec::BaseVector*& OmniVec, uint64_t offset, - uint64_t numValues, const char* const notNull, - omniruntime::type::DataTypeId dataTypeId) { + template + uint64_t OmniRleDecoderV2::nextPatched(T *data, uint64_t offset, uint64_t numValues, + uint64_t *nulls) { if (runRead == runLength) { // extract the number of fixed bits unsigned char fbo = (firstByte >> 1) & 0x1f; @@ -372,7 +368,7 @@ namespace omniruntime::reader { base = -base; } - orc::RleDecoderV2::readLongs(literals.data(), 0, runLength, bitSize); + readLongs(literals.data(), 0, runLength, bitSize); // any remaining bits are thrown out resetReadLongs(); @@ -385,7 +381,7 @@ namespace omniruntime::reader { "(patchBitSize + pgw > 64)!"); } uint32_t cfb = orc::getClosestFixedBits(patchBitSize + pgw); - orc::RleDecoderV2::readLongs(unpackedPatch.data(), 0, pl, cfb); + readLongs(unpackedPatch.data(), 0, pl, cfb); // any remaining bits are thrown out resetReadLongs(); @@ -422,53 +418,12 @@ namespace omniruntime::reader { } } - return copyDataFromBufferByType(OmniVec, offset, numValues, notNull, dataTypeId); + return copyDataFromBuffer(data, offset, numValues, nulls); } - uint64_t OmniRleDecoderV2::nextDeltaByType(omniruntime::vec::BaseVector*& OmniVec, - uint64_t offset, uint64_t numValues, const char* const notNull, - omniruntime::type::DataTypeId dataTypeId) { - switch (dataTypeId) { - case omniruntime::type::OMNI_BOOLEAN: - return nextDelta - (OmniVec, offset, numValues, notNull, dataTypeId); - case omniruntime::type::OMNI_SHORT: - return nextDelta - (OmniVec, offset, numValues, notNull, dataTypeId); - case omniruntime::type::OMNI_INT: - return nextDelta - (OmniVec, offset, numValues, notNull, dataTypeId); - case omniruntime::type::OMNI_LONG: - return nextDelta - (OmniVec, offset, numValues, notNull, dataTypeId); - case omniruntime::type::OMNI_DATE32: - return nextDelta - (OmniVec, offset, numValues, notNull, dataTypeId); - case omniruntime::type::OMNI_DATE64: - return nextDelta - (OmniVec, offset, numValues, notNull, dataTypeId); - case omniruntime::type::OMNI_DOUBLE: - return nextDelta - (OmniVec, offset, numValues, notNull, dataTypeId); - case omniruntime::type::OMNI_CHAR: - throw std::runtime_error("nextShortRepeats_type CHAR not finished!!!"); - case omniruntime::type::OMNI_VARCHAR: - throw std::runtime_error("nextShortRepeats_type VARCHAR not finished!!!"); - case omniruntime::type::OMNI_DECIMAL64: - throw std::runtime_error("nextShortRepeats_type DECIMAL64 should not in here!!!"); - case omniruntime::type::OMNI_DECIMAL128: - throw std::runtime_error("nextShortRepeats_type DECIMAL128 should not in here!!!"); - default: - printf("nextShortRepeats_type switch no process!!!"); - } - - return 0; - } - - template - uint64_t OmniRleDecoderV2::nextDelta(omniruntime::vec::BaseVector*& OmniVec, - uint64_t offset, uint64_t numValues, const char* const notNull, - omniruntime::type::DataTypeId dataTypeId) { + template + uint64_t OmniRleDecoderV2::nextDelta(T *data, uint64_t offset, uint64_t numValues, + uint64_t *nulls) { if (runRead == runLength) { // extract the number of fixed bits unsigned char fbo = (firstByte >> 1) &0x1f; @@ -515,7 +470,7 @@ namespace omniruntime::reader { // value to result buffer. if the delta base value is negative then it // is a decreasing sequence else an increasing sequence. // read deltas using the literals buffer. - orc::RleDecoderV2::readLongs(literals.data(), 2, runLength - 2, bitSize); + readLongs(literals.data(), 2, runLength - 2, bitSize); if (deltaBase < 0) { for (uint64_t i = 2; i < runLength; ++i) { @@ -529,82 +484,36 @@ namespace omniruntime::reader { } } - return copyDataFromBufferByType(OmniVec, offset, numValues, notNull, dataTypeId); - } - - void OmniRleDecoderV2::readLongsByType(omniruntime::vec::BaseVector*& OmniVec, uint64_t offset, - uint64_t len, uint64_t omniOffset, uint64_t numValues, uint64_t fbs, - omniruntime::type::DataTypeId dataTypeId, const char* const notNull) { - switch (dataTypeId) { - case omniruntime::type::OMNI_BOOLEAN: - return readLongs - (OmniVec, literals.data(), offset, len, omniOffset, numValues, fbs, notNull); - case omniruntime::type::OMNI_SHORT: - return readLongs - (OmniVec, literals.data(), offset, len, omniOffset, numValues, fbs, notNull); - case omniruntime::type::OMNI_INT: - return readLongs - (OmniVec, literals.data(), offset, len, omniOffset, numValues, fbs, notNull); - case omniruntime::type::OMNI_LONG: - return readLongs - (OmniVec, literals.data(), offset, len, omniOffset, numValues, fbs, notNull); - case omniruntime::type::OMNI_DATE32: - return readLongs - (OmniVec, literals.data(), offset, len, omniOffset, numValues, fbs, notNull); - case omniruntime::type::OMNI_DATE64: - return readLongs - (OmniVec, literals.data(), offset, len, omniOffset, numValues, fbs, notNull); - case omniruntime::type::OMNI_DOUBLE: - return readLongs - (OmniVec, literals.data(), offset, len, omniOffset, numValues, fbs, notNull); - case omniruntime::type::OMNI_CHAR: - throw std::runtime_error("copyDataFromBuffer_type CHAR not finished!!!"); - case omniruntime::type::OMNI_VARCHAR: - throw std::runtime_error("copyDataFromBuffer_type VARCHAR not finished!!!"); - case omniruntime::type::OMNI_DECIMAL64: - throw std::runtime_error("copyDataFromBuffer_type DECIMAL64 should not in here!!!"); - case omniruntime::type::OMNI_DECIMAL128: - throw std::runtime_error("copyDataFromBuffer_type DECIMAL128 should not in here!!!"); - default: - printf("copyDataFromBuffer switch no process!!!"); - } - - return; + return copyDataFromBuffer(data, offset, numValues, nulls); } - template - void OmniRleDecoderV2::readLongs(omniruntime::vec::BaseVector*& OmniVec, int64_t *data, uint64_t offset, - uint64_t len, uint64_t omniOffset, uint64_t numValues, uint64_t fbs, - const char* const notNull) { + void OmniRleDecoderV2::readLongs(int64_t *data, uint64_t offset, uint64_t len, uint64_t fbs) { switch (fbs) { case 4: - return unrolledUnpack4(OmniVec, data, offset, len, omniOffset, numValues, notNull); + return unrolledUnpack4(data, offset, len); case 8: - return unrolledUnpack8(OmniVec, data, offset, len, omniOffset, numValues, notNull); + return unrolledUnpack8(data, offset, len); case 16: - return unrolledUnpack16(OmniVec, data, offset, len, omniOffset, numValues, notNull); + return unrolledUnpack16(data, offset, len); case 24: - return unrolledUnpack24(OmniVec, data, offset, len, omniOffset, numValues, notNull); + return unrolledUnpack24(data, offset, len); case 32: - return unrolledUnpack32(OmniVec, data, offset, len, omniOffset, numValues, notNull); + return unrolledUnpack32(data, offset, len); case 40: - return unrolledUnpack40(OmniVec, data, offset, len, omniOffset, numValues, notNull); + return unrolledUnpack40(data, offset, len); case 48: - return unrolledUnpack48(OmniVec, data, offset, len, omniOffset, numValues, notNull); + return unrolledUnpack48(data, offset, len); case 56: - return unrolledUnpack56(OmniVec, data, offset, len, omniOffset, numValues, notNull); + return unrolledUnpack56(data, offset, len); case 64: - return unrolledUnpack64(OmniVec, data, offset, len, omniOffset, numValues, notNull); + return unrolledUnpack64(data, offset, len); default: // Fallback to the default implementation for deprecated bit size. - return plainUnpackLongs(OmniVec, data, offset, len, omniOffset, numValues, notNull, fbs); + return plainUnpackLongs(data, offset, len, fbs); } } - template - void OmniRleDecoderV2::plainUnpackLongs(omniruntime::vec::BaseVector*& OmniVec, int64_t* data, uint64_t offset, - uint64_t len, uint64_t omniOffset, uint64_t omniNumValues, - const char* const notNull, uint64_t fbs) { + void OmniRleDecoderV2::plainUnpackLongs(int64_t *data, uint64_t offset, uint64_t len, uint64_t fbs) { for (uint64_t i = offset; i < (offset + len); i++) { uint64_t result = 0; uint64_t bitsLeftToRead = fbs; @@ -629,10 +538,7 @@ namespace omniruntime::reader { return; } - template - void OmniRleDecoderV2::unrolledUnpack64(omniruntime::vec::BaseVector*& OmniVec, int64_t* data, uint64_t offset, - uint64_t len, uint64_t omniOffset, uint64_t omniNumValues, - const char* const notNull) { + void OmniRleDecoderV2::unrolledUnpack64(int64_t *data, uint64_t offset, uint64_t len) { uint64_t curIdx = offset; while (curIdx < offset + len) { // Exhaust the buffer @@ -671,10 +577,7 @@ namespace omniruntime::reader { return; } - template - void OmniRleDecoderV2::unrolledUnpack56(omniruntime::vec::BaseVector*& OmniVec, int64_t* data, uint64_t offset, - uint64_t len, uint64_t omniOffset, uint64_t omniNumValues, - const char* const notNull) { + void OmniRleDecoderV2::unrolledUnpack56(int64_t *data, uint64_t offset, uint64_t len) { uint64_t curIdx = offset; while (curIdx < offset + len) { // Exhaust the buffer @@ -711,10 +614,7 @@ namespace omniruntime::reader { return; } - template - void OmniRleDecoderV2::unrolledUnpack48(omniruntime::vec::BaseVector*& OmniVec, int64_t* data, uint64_t offset, - uint64_t len, uint64_t omniOffset, uint64_t omniNumValues, - const char* const notNull) { + void OmniRleDecoderV2::unrolledUnpack48(int64_t *data, uint64_t offset, uint64_t len) { uint64_t curIdx = offset; while (curIdx < offset + len) { // Exhaust the buffer @@ -749,10 +649,7 @@ namespace omniruntime::reader { return; } - template - void OmniRleDecoderV2::unrolledUnpack40(omniruntime::vec::BaseVector*& OmniVec, int64_t* data, uint64_t offset, - uint64_t len, uint64_t omniOffset, uint64_t omniNumValues, - const char* const notNull) { + void OmniRleDecoderV2::unrolledUnpack40(int64_t *data, uint64_t offset, uint64_t len) { uint64_t curIdx = offset; while (curIdx < offset + len) { // Exhaust the buffer @@ -785,10 +682,7 @@ namespace omniruntime::reader { return; } - template - void OmniRleDecoderV2::unrolledUnpack32(omniruntime::vec::BaseVector*& OmniVec, int64_t* data, uint64_t offset, - uint64_t len, uint64_t omniOffset, uint64_t omniNumValues, - const char* const notNull) { + void OmniRleDecoderV2::unrolledUnpack32(int64_t *data, uint64_t offset, uint64_t len) { uint64_t curIdx = offset; while (curIdx < offset + len) { // Exhaust the buffer @@ -820,10 +714,7 @@ namespace omniruntime::reader { return; } - template - void OmniRleDecoderV2::unrolledUnpack24(omniruntime::vec::BaseVector*& OmniVec, int64_t* data, uint64_t offset, - uint64_t len, uint64_t omniOffset, uint64_t omniNumValues, - const char* const notNull) { + void OmniRleDecoderV2::unrolledUnpack24(int64_t *data, uint64_t offset, uint64_t len) { uint64_t curIdx = offset; while (curIdx < offset + len) { // Exhaust the buffer @@ -853,10 +744,7 @@ namespace omniruntime::reader { return; } - template - void OmniRleDecoderV2::unrolledUnpack16(omniruntime::vec::BaseVector*& OmniVec, int64_t* data, uint64_t offset, - uint64_t len, uint64_t omniOffset, uint64_t omniNumValues, - const char* const notNull) { + void OmniRleDecoderV2::unrolledUnpack16(int64_t *data, uint64_t offset, uint64_t len) { uint64_t curIdx = offset; while (curIdx < offset + len) { // Exhaust the buffer @@ -884,10 +772,7 @@ namespace omniruntime::reader { return; } - template - void OmniRleDecoderV2::unrolledUnpack8(omniruntime::vec::BaseVector*& OmniVec, int64_t* data, uint64_t offset, - uint64_t len, uint64_t omniOffset, uint64_t omniNumValues, - const char* const notNull) { + void OmniRleDecoderV2::unrolledUnpack8(int64_t *data, uint64_t offset, uint64_t len) { uint64_t curIdx = offset; while (curIdx < offset + len) { // Exhaust the buffer @@ -907,10 +792,7 @@ namespace omniruntime::reader { return; } - template - void OmniRleDecoderV2::unrolledUnpack4(omniruntime::vec::BaseVector*& OmniVec, int64_t* data, uint64_t offset, - uint64_t len, uint64_t omniOffset, uint64_t omniNumValues, - const char* const notNull) { + void OmniRleDecoderV2::unrolledUnpack4(int64_t *data, uint64_t offset, uint64_t len) { uint64_t curIdx = offset; while (curIdx < offset + len) { // Make sure bitsLeft is 0 before the loop. bitsLeft can only be 0, 4, or 8. @@ -943,82 +825,87 @@ namespace omniruntime::reader { return; } - uint64_t OmniRleDecoderV2::copyDataFromBufferByType(omniruntime::vec::BaseVector*& tempOmnivec, uint64_t offset, - uint64_t numValues, const char* notNull, - omniruntime::type::DataTypeId dataTypeId) { - switch (dataTypeId) { - case omniruntime::type::OMNI_BOOLEAN: - return copyDataFromBuffer(tempOmnivec, offset, numValues, notNull); - case omniruntime::type::OMNI_SHORT: - return copyDataFromBuffer(tempOmnivec, offset, numValues, notNull); - case omniruntime::type::OMNI_INT: - return copyDataFromBuffer(tempOmnivec, offset, numValues, notNull); - case omniruntime::type::OMNI_LONG: - return copyDataFromBufferTo64bit(tempOmnivec, offset, numValues, notNull); - case omniruntime::type::OMNI_DATE32: - return copyDataFromBuffer(tempOmnivec, offset, numValues, notNull); - case omniruntime::type::OMNI_DATE64: - return copyDataFromBufferTo64bit(tempOmnivec, offset, numValues, - notNull); - case omniruntime::type::OMNI_DOUBLE: - return copyDataFromBuffer(tempOmnivec, offset, numValues, notNull); - case omniruntime::type::OMNI_CHAR: - throw std::runtime_error("copyDataFromBuffer_type CHAR not finished!!!"); - case omniruntime::type::OMNI_VARCHAR: - throw std::runtime_error("copyDataFromBuffer_type VARCHAR not finished!!!"); - case omniruntime::type::OMNI_DECIMAL64: - throw std::runtime_error("copyDataFromBuffer_type DECIMAL64 should not in here!!!"); - case omniruntime::type::OMNI_DECIMAL128: - throw std::runtime_error("copyDataFromBuffer_type DECIMAL128 should not in here!!!"); - default: - printf("copyDataFromBuffer switch no process!!!"); - } - - return 0; - } - - template - uint64_t OmniRleDecoderV2::copyDataFromBuffer(omniruntime::vec::BaseVector*& OmniVec, uint64_t offset, - uint64_t numValues, const char* notNull) { - using namespace omniruntime::type; - using T = typename NativeType::type; - auto vec = reinterpret_cast*>(OmniVec); + template + uint64_t OmniRleDecoderV2::copyDataFromBuffer(T *data, uint64_t offset, uint64_t numValues, + uint64_t *nulls) { uint64_t nRead = std::min(runLength - runRead, numValues); - if (notNull) { - for (uint64_t i = offset; i < (offset + nRead); ++i) { - if (notNull[i]) { - vec->SetValue(static_cast(i), static_cast(literals[runRead++])); - } else { - vec->SetNull(static_cast(i)); - } - } + if constexpr (std::is_same_v) { + if (nulls) { + uint64_t i = offset; + uint64_t end = offset + nRead; + uint64_t skipNum = std::min(BitUtil::Nbytes(offset) * 8 - offset, nRead); + for (; i < offset + skipNum; i++) { + if (!BitUtil::IsBitSet(nulls, i)) { + data[i] = literals[runRead++]; + } + } + uint8_t mask; + for (; i + 8 <= end; i += 8) { + mask = reinterpret_cast(nulls)[i / 8]; + if (UNLIKELY(mask == 255)) { + continue; + } + if (mask == 0) { + data[i] = literals[runRead]; + data[i + 1] = literals[runRead + 1]; + data[i + 2] = literals[runRead + 2]; + data[i + 3] = literals[runRead + 3]; + data[i + 4] = literals[runRead + 4]; + data[i + 5] = literals[runRead + 5]; + data[i + 6] = literals[runRead + 6]; + data[i + 7] = literals[runRead + 7]; + runRead += 8; + continue; + } + auto *maskArr = notNullBitMask[mask]; + for (int j = 1; j <= *maskArr; j++) { + auto notNullIndex = i + maskArr[j]; + data[notNullIndex] = literals[runRead++]; + } + } + for (; i < end; i++) { + if (!BitUtil::IsBitSet(nulls, i)) { + data[i] = literals[runRead++]; + } + } + } else { + memcpy_s(data + offset, nRead * sizeof(int64_t), literals.data() + runRead, nRead * sizeof(int64_t)); + runRead += nRead; + } } else { - for (uint64_t i = offset; i < (offset + nRead); ++i) { - vec->SetValue(static_cast(i), static_cast(literals[runRead++])); - } - } - return nRead; - } - - template - uint64_t OmniRleDecoderV2::copyDataFromBufferTo64bit(omniruntime::vec::BaseVector*& OmniVec, uint64_t offset, - uint64_t numValues, const char* notNull) { - using namespace omniruntime::type; - using T = typename NativeType::type; - auto vec = reinterpret_cast*>(OmniVec); - uint64_t nRead = std::min(runLength - runRead, numValues); - if (notNull) { - for (uint64_t i = offset; i < (offset + nRead); ++i) { - if (notNull[i]) { - vec->SetValue(static_cast(i), static_cast(literals[runRead++])); - } else { - vec->SetNull(static_cast(i)); + if (nulls) { + uint64_t i = offset; + uint64_t end = offset + nRead; + uint64_t skipNum = std::min(BitUtil::Nbytes(offset) * 8 - offset, nRead); + for (; i < offset + skipNum; i++) { + if (!BitUtil::IsBitSet(nulls, i)) { + data[i] = static_cast(literals[runRead++]); + } + } + uint8_t mask; + for (; i + 8 <= end; i += 8) { + mask = reinterpret_cast(nulls)[i / 8]; + if (UNLIKELY(mask == 255)) { + continue; + } + auto *maskArr = notNullBitMask[mask]; + for (int j = 1; j <= *maskArr; j++) { + auto notNullIndex = i + maskArr[j]; + data[notNullIndex] = static_cast(literals[runRead++]); + } + } + for (; i < end; i++) { + if (!BitUtil::IsBitSet(nulls, i)) { + data[i] = static_cast(literals[runRead++]); + } + } + } else { + for (uint64_t i = offset; i < (offset + nRead); ++i) { + data[i] = static_cast(literals[runRead++]); } } - } else { - vec->SetValues(static_cast(offset), literals.data() + runRead, static_cast(nRead)); - runRead += nRead; } + return nRead; } } \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniRLEv2.hh b/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniRLEv2.hh index d7377958ed45519a046721a65e7a9733ffd57585..0965ce31d085a0b4d7cc1fc93a958b1f9d19528f 100644 --- a/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniRLEv2.hh +++ b/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniRLEv2.hh @@ -25,12 +25,18 @@ namespace omniruntime::reader { - std::unique_ptr makeFixLenthVector(uint64_t numValues, + std::unique_ptr makeFixedLengthVector(uint64_t numValues, omniruntime::type::DataTypeId dataTypeId); std::unique_ptr makeDoubleVector(uint64_t numValues, omniruntime::type::DataTypeId dataTypeId); + std::unique_ptr makeVarcharVector(uint64_t numValues, + omniruntime::type::DataTypeId dataTypeId); + + std::unique_ptr makeDecimalVector(uint64_t numValues, + omniruntime::type::DataTypeId dataTypeId); + std::unique_ptr makeNewVector(uint64_t numValues, const orc::Type* baseTp, omniruntime::type::DataTypeId dataTypeId); @@ -43,117 +49,60 @@ namespace omniruntime::reader { * direct read VectorBatch in next * @param omnivec the BaseVector to push * @param numValues the numValues to push - * @param notNull the nullarrays to push + * @param nulls the nulls bits to push * @param baseTp the orcType to push * @param omniTypeId the int* of omniType to push */ - void next(omniruntime::vec::BaseVector*& omnivec, uint64_t numValues, char* notNull, - const orc::Type* baseTp, int omniTypeId); + void next(omniruntime::vec::BaseVector *omnivec, uint64_t numValues, + uint64_t *nulls, int omniTypeId); - void next(int64_t* data, uint64_t numValues, const char* notNull) { - orc::RleDecoderV2::next(data, numValues, notNull); - } + void next(int64_t *data, uint64_t numValues, uint64_t *nulls); + + void next(int32_t *data, uint64_t numValues, uint64_t *nulls); + + void next(int16_t *data, uint64_t numValues, uint64_t *nulls); + + void next(bool *data, uint64_t numValues, uint64_t *nulls); + + template + void next(T *data, uint64_t numValues, uint64_t *nulls); + + template + uint64_t nextShortRepeats(T *data, uint64_t offset, uint64_t numValues, uint64_t *nulls); + + template + uint64_t nextDirect(T *data, uint64_t offset, uint64_t numValues, uint64_t *nulls); + + template + uint64_t nextPatched(T *data, uint64_t offset, uint64_t numValues, uint64_t *nulls); + + template + uint64_t nextDelta(T *data, uint64_t offset, uint64_t numValues, uint64_t *nulls); + + template + uint64_t copyDataFromBuffer(T *data, uint64_t offset, uint64_t numValues, uint64_t *nulls); + + void readLongs(int64_t *data, uint64_t offset, uint64_t len, uint64_t fbs); + + void unrolledUnpack4(int64_t *data, uint64_t offset, uint64_t len); + + void unrolledUnpack8(int64_t *data, uint64_t offset, uint64_t len); + + void unrolledUnpack16(int64_t *data, uint64_t offset, uint64_t len); + + void unrolledUnpack24(int64_t *data, uint64_t offset, uint64_t len); + + void unrolledUnpack32(int64_t *data, uint64_t offset, uint64_t len); + + void unrolledUnpack40(int64_t *data, uint64_t offset, uint64_t len); + + void unrolledUnpack48(int64_t *data, uint64_t offset, uint64_t len); + + void unrolledUnpack56(int64_t *data, uint64_t offset, uint64_t len); + + void unrolledUnpack64(int64_t *data, uint64_t offset, uint64_t len); - uint64_t nextDirect(omniruntime::vec::BaseVector*& OmniVec, uint64_t offset, uint64_t numValues, - const char* const notNull, omniruntime::type::DataTypeId dataTypeId); - - uint64_t nextShortRepeatsByType(omniruntime::vec::BaseVector*& OmniVec, uint64_t offset, uint64_t numValues, - const char* const notNull, omniruntime::type::DataTypeId dataTypeId); - - template - uint64_t nextShortRepeats(omniruntime::vec::BaseVector*& omnivec, uint64_t offset, uint64_t numValues, - const char* notNull); - - template - uint64_t nextShortRepeatsLongType(omniruntime::vec::BaseVector*& OmniVec, - uint64_t offset, uint64_t numValues, const char* const notNull); - - - uint64_t nextPatchedByType(omniruntime::vec::BaseVector*& OmniVec, uint64_t offset, uint64_t numValues, - const char* const notNull, omniruntime::type::DataTypeId dataTypeId); - - template - uint64_t nextPatched(omniruntime::vec::BaseVector*& omnivec, uint64_t offset, uint64_t numValues, - const char* notNull, omniruntime::type::DataTypeId dataTypeId); - - template - uint64_t nextDelta(omniruntime::vec::BaseVector*& omnivec, uint64_t offset, uint64_t numValues, - const char* notNull, omniruntime::type::DataTypeId dataTypeId); - - uint64_t nextDeltaByType(omniruntime::vec::BaseVector*& OmniVec, - uint64_t offset, uint64_t numValues, const char* const notNull, - omniruntime::type::DataTypeId dataTypeId); - - uint64_t copyDataFromBufferByType(omniruntime::vec::BaseVector*& tempOmnivec, uint64_t offset, - uint64_t numValues, const char* notNull, - omniruntime::type::DataTypeId dataTypeId); - - template - uint64_t copyDataFromBufferTo64bit(omniruntime::vec::BaseVector*& OmniVec, uint64_t offset, - uint64_t numValues, const char* notNull); - - template - uint64_t copyDataFromBuffer(omniruntime::vec::BaseVector*& OmniVec, uint64_t offset, - uint64_t numValues, const char* notNull); - - void readLongsByType(omniruntime::vec::BaseVector*& OmniVec, uint64_t offset, - uint64_t len, uint64_t omniOffset, uint64_t numValues, uint64_t fbs, - omniruntime::type::DataTypeId dataTypeId, const char* const notNull); - - template - void readLongs(omniruntime::vec::BaseVector*& OmniVec, int64_t *data, uint64_t offset, - uint64_t len, uint64_t omniOffset, uint64_t numValues, uint64_t fbs, - const char* const notNull); - - template - void unrolledUnpack4(omniruntime::vec::BaseVector*& OmniVec, int64_t* data, uint64_t offset, - uint64_t len, uint64_t omniOffset, uint64_t omniNumValues, - const char* const notNull); - - template - void unrolledUnpack8(omniruntime::vec::BaseVector*& OmniVec, int64_t* data, uint64_t offset, - uint64_t len, uint64_t omniOffset, uint64_t omniNumValues, - const char* const notNull); - - template - void unrolledUnpack16(omniruntime::vec::BaseVector*& OmniVec, int64_t* data, uint64_t offset, - uint64_t len, uint64_t omniOffset, uint64_t omniNumValues, - const char* const notNull); - - template - void unrolledUnpack24(omniruntime::vec::BaseVector*& OmniVec, int64_t* data, uint64_t offset, - uint64_t len, uint64_t omniOffset, uint64_t omniNumValues, - const char* const notNull); - - template - void unrolledUnpack32(omniruntime::vec::BaseVector*& OmniVec, int64_t* data, uint64_t offset, - uint64_t len, uint64_t omniOffset, uint64_t omniNumValues, - const char* const notNull); - - template - void unrolledUnpack40(omniruntime::vec::BaseVector*& OmniVec, int64_t* data, uint64_t offset, - uint64_t len, uint64_t omniOffset, uint64_t omniNumValues, - const char* const notNull); - - template - void unrolledUnpack48(omniruntime::vec::BaseVector*& OmniVec, int64_t* data, uint64_t offset, - uint64_t len, uint64_t omniOffset, uint64_t omniNumValues, - const char* const notNull); - - template - void unrolledUnpack56(omniruntime::vec::BaseVector*& OmniVec, int64_t* data, uint64_t offset, - uint64_t len, uint64_t omniOffset, uint64_t omniNumValues, - const char* const notNull); - - template - void unrolledUnpack64(omniruntime::vec::BaseVector*& OmniVec, int64_t* data, uint64_t offset, - uint64_t len, uint64_t omniOffset, uint64_t omniNumValues, - const char* const notNull); - - template - void plainUnpackLongs(omniruntime::vec::BaseVector*& OmniVec, int64_t* data, uint64_t offset, - uint64_t len, uint64_t omniOffset, uint64_t omniNumValues, - const char* const notNull, uint64_t fbs); + void plainUnpackLongs(int64_t *data, uint64_t offset, uint64_t len, uint64_t fbs); }; } diff --git a/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniRowReaderImpl.cc b/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniRowReaderImpl.cc index 2d7cd8adb0f81e4634d0de3c14daf11164c73b6f..95d3b6b4531c98939bf278354d501fe2182a21cf 100644 --- a/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniRowReaderImpl.cc +++ b/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniRowReaderImpl.cc @@ -98,12 +98,14 @@ namespace omniruntime::reader { std::unique_ptr OmniReaderImpl::createRowReader() const { RowReaderOptions defaultOpts; std::unique_ptr julianPtr; - return createRowReader(defaultOpts, julianPtr); + std::unique_ptr predicate; + return createRowReader(defaultOpts, julianPtr, predicate); } std::unique_ptr OmniReaderImpl::createRowReader( - const RowReaderOptions& opts, std::unique_ptr &julianPtr) const { - return std::unique_ptr(new OmniRowReaderImpl(contents, opts, julianPtr)); + const RowReaderOptions& opts, std::unique_ptr &julianPtr, + std::unique_ptr &predicate) const { + return std::unique_ptr(new OmniRowReaderImpl(contents, opts, julianPtr, predicate)); } void OmniRowReaderImpl::startNextStripe() { @@ -235,4 +237,9 @@ namespace omniruntime::reader { } return rowsToRead; } + + common::PredicateCondition *OmniRowReaderImpl::getPredicatePtr() + { + return predicatePtr == nullptr ? nullptr : predicatePtr.get(); + } } \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniRowReaderImpl.hh b/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniRowReaderImpl.hh index 8abb14d37729abfe651436860db4e48f868b2d2a..62e137f33c14404ea98d44d709fef29dd014c336 100644 --- a/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniRowReaderImpl.hh +++ b/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniRowReaderImpl.hh @@ -22,6 +22,7 @@ #include "orc/RowReader/Reader.hh" #include #include "common/JulianGregorianRebase.h" +#include "common/PredicateUtil.h" namespace omniruntime::reader { class OmniReaderImpl : public orc::ReaderImpl { @@ -33,14 +34,16 @@ namespace omniruntime::reader { std::unique_ptr createRowReader() const override; std::unique_ptr createRowReader(const orc::RowReaderOptions& options, - std::unique_ptr &julianPtr) const; + std::unique_ptr &julianPtr, + std::unique_ptr &predicate) const; }; class OmniRowReaderImpl : public orc::RowReaderImpl { public: OmniRowReaderImpl(std::shared_ptr contents, const orc::RowReaderOptions& options, - std::unique_ptr &julianPtr) - : orc::RowReaderImpl(contents, options), julianPtr(std::move(julianPtr)) + std::unique_ptr &julianPtr, + std::unique_ptr &predicate) + : orc::RowReaderImpl(contents, options), julianPtr(std::move(julianPtr)), predicatePtr(std::move(predicate)) {} /** * direct read VectorBatch in next @@ -53,8 +56,11 @@ namespace omniruntime::reader { void startNextStripe() override; + common::PredicateCondition *getPredicatePtr(); + private: std::unique_ptr julianPtr; + std::unique_ptr predicatePtr; }; std::unique_ptr omniCreateReader(std::unique_ptr stream, diff --git a/omnioperator/omniop-native-reader/cpp/src/orcfile/OrcDecodeUtils.hh b/omnioperator/omniop-native-reader/cpp/src/orcfile/OrcDecodeUtils.hh new file mode 100644 index 0000000000000000000000000000000000000000..1a4b1e9c1196888b35f2c271f1a689decb7ca3ba --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/orcfile/OrcDecodeUtils.hh @@ -0,0 +1,308 @@ +/** + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ORC_DECODE_UTILS +#define ORC_DECODE_UTILS + +#include "arm_neon.h" +namespace omniruntime::reader { + +void inline UnZigZagBatch(uint64_t *data, uint64_t numValues) { + for (uint64_t i = 0; i < numValues; i++) { + data[i] = data[i] >> 1 ^ -(data[i] & 1); + } +} + +void inline UnZigZagBatchHEFs8p2(uint64_t *data, uint64_t numValues) { + uint64_t vec_data_s0_p0; + uint64_t vec_data_s1_p0; + uint64_t vec_data_s2_p0; + uint64_t vec_data_s3_p0; + uint64_t vec_data_s4_p0; + uint64_t vec_data_s5_p0; + uint64_t vec_data_s6_p0; + uint64_t vec_data_s7_p0; + uint64_t vec_data_s0_p1; + uint64_t vec_data_s1_p1; + uint64_t vec_data_s2_p1; + uint64_t vec_data_s3_p1; + uint64_t vec_data_s4_p1; + uint64_t vec_data_s5_p1; + uint64_t vec_data_s6_p1; + uint64_t vec_data_s7_p1; + + uint64_t neg_lsb_s0_p0; + uint64_t neg_lsb_s1_p0; + uint64_t neg_lsb_s2_p0; + uint64_t neg_lsb_s3_p0; + uint64_t neg_lsb_s4_p0; + uint64_t neg_lsb_s5_p0; + uint64_t neg_lsb_s6_p0; + uint64_t neg_lsb_s7_p0; + uint64_t neg_lsb_s0_p1; + uint64_t neg_lsb_s1_p1; + uint64_t neg_lsb_s2_p1; + uint64_t neg_lsb_s3_p1; + uint64_t neg_lsb_s4_p1; + uint64_t neg_lsb_s5_p1; + uint64_t neg_lsb_s6_p1; + uint64_t neg_lsb_s7_p1; + + uint64_t shifted_data_s0_p0; + uint64_t shifted_data_s1_p0; + uint64_t shifted_data_s2_p0; + uint64_t shifted_data_s3_p0; + uint64_t shifted_data_s4_p0; + uint64_t shifted_data_s5_p0; + uint64_t shifted_data_s6_p0; + uint64_t shifted_data_s7_p0; + uint64_t shifted_data_s0_p1; + uint64_t shifted_data_s1_p1; + uint64_t shifted_data_s2_p1; + uint64_t shifted_data_s3_p1; + uint64_t shifted_data_s4_p1; + uint64_t shifted_data_s5_p1; + uint64_t shifted_data_s6_p1; + uint64_t shifted_data_s7_p1; + + uint64_t lsb_mask_s0_p0; + uint64_t lsb_mask_s1_p0; + uint64_t lsb_mask_s2_p0; + uint64_t lsb_mask_s3_p0; + uint64_t lsb_mask_s4_p0; + uint64_t lsb_mask_s5_p0; + uint64_t lsb_mask_s6_p0; + uint64_t lsb_mask_s7_p0; + uint64_t lsb_mask_s0_p1; + uint64_t lsb_mask_s1_p1; + uint64_t lsb_mask_s2_p1; + uint64_t lsb_mask_s3_p1; + uint64_t lsb_mask_s4_p1; + uint64_t lsb_mask_s5_p1; + uint64_t lsb_mask_s6_p1; + uint64_t lsb_mask_s7_p1; + + uint64_t i = 0; + for (; i + 16 <= numValues; i += 16) { + // load data[i] + vec_data_s0_p0 = *(data + i); + vec_data_s1_p0 = *(data + i + 1); + vec_data_s2_p0 = *(data + i + 2); + vec_data_s3_p0 = *(data + i + 3); + vec_data_s4_p0 = *(data + i + 4); + vec_data_s5_p0 = *(data + i + 5); + vec_data_s6_p0 = *(data + i + 6); + vec_data_s7_p0 = *(data + i + 7); + vec_data_s0_p1 = *(data + i + 8); + vec_data_s1_p1 = *(data + i + 9); + vec_data_s2_p1 = *(data + i + 10); + vec_data_s3_p1 = *(data + i + 11); + vec_data_s4_p1 = *(data + i + 12); + vec_data_s5_p1 = *(data + i + 13); + vec_data_s6_p1 = *(data + i + 14); + vec_data_s7_p1 = *(data + i + 15); + + // compute data[i] & 1 + lsb_mask_s0_p0 = vec_data_s0_p0 & 1; + lsb_mask_s1_p0 = vec_data_s1_p0 & 1; + lsb_mask_s2_p0 = vec_data_s2_p0 & 1; + lsb_mask_s3_p0 = vec_data_s3_p0 & 1; + lsb_mask_s4_p0 = vec_data_s4_p0 & 1; + lsb_mask_s5_p0 = vec_data_s5_p0 & 1; + lsb_mask_s6_p0 = vec_data_s6_p0 & 1; + lsb_mask_s7_p0 = vec_data_s7_p0 & 1; + lsb_mask_s0_p1 = vec_data_s0_p1 & 1; + lsb_mask_s1_p1 = vec_data_s1_p1 & 1; + lsb_mask_s2_p1 = vec_data_s2_p1 & 1; + lsb_mask_s3_p1 = vec_data_s3_p1 & 1; + lsb_mask_s4_p1 = vec_data_s4_p1 & 1; + lsb_mask_s5_p1 = vec_data_s5_p1 & 1; + lsb_mask_s6_p1 = vec_data_s6_p1 & 1; + lsb_mask_s7_p1 = vec_data_s7_p1 & 1; + + // compute -(data[i] & 1) + neg_lsb_s0_p0 = -lsb_mask_s0_p0; + neg_lsb_s1_p0 = -lsb_mask_s1_p0; + neg_lsb_s2_p0 = -lsb_mask_s2_p0; + neg_lsb_s3_p0 = -lsb_mask_s3_p0; + neg_lsb_s4_p0 = -lsb_mask_s4_p0; + neg_lsb_s5_p0 = -lsb_mask_s5_p0; + neg_lsb_s6_p0 = -lsb_mask_s6_p0; + neg_lsb_s7_p0 = -lsb_mask_s7_p0; + neg_lsb_s0_p1 = -lsb_mask_s0_p1; + neg_lsb_s1_p1 = -lsb_mask_s1_p1; + neg_lsb_s2_p1 = -lsb_mask_s2_p1; + neg_lsb_s3_p1 = -lsb_mask_s3_p1; + neg_lsb_s4_p1 = -lsb_mask_s4_p1; + neg_lsb_s5_p1 = -lsb_mask_s5_p1; + neg_lsb_s6_p1 = -lsb_mask_s6_p1; + neg_lsb_s7_p1 = -lsb_mask_s7_p1; + + // compute data[i] >> 1 + shifted_data_s0_p0 = vec_data_s0_p0 >> 1; + shifted_data_s1_p0 = vec_data_s1_p0 >> 1; + shifted_data_s2_p0 = vec_data_s2_p0 >> 1; + shifted_data_s3_p0 = vec_data_s3_p0 >> 1; + shifted_data_s4_p0 = vec_data_s4_p0 >> 1; + shifted_data_s5_p0 = vec_data_s5_p0 >> 1; + shifted_data_s6_p0 = vec_data_s6_p0 >> 1; + shifted_data_s7_p0 = vec_data_s7_p0 >> 1; + shifted_data_s0_p1 = vec_data_s0_p1 >> 1; + shifted_data_s1_p1 = vec_data_s1_p1 >> 1; + shifted_data_s2_p1 = vec_data_s2_p1 >> 1; + shifted_data_s3_p1 = vec_data_s3_p1 >> 1; + shifted_data_s4_p1 = vec_data_s4_p1 >> 1; + shifted_data_s5_p1 = vec_data_s5_p1 >> 1; + shifted_data_s6_p1 = vec_data_s6_p1 >> 1; + shifted_data_s7_p1 = vec_data_s7_p1 >> 1; + + // compute and store data[i] + *(data + i) = shifted_data_s0_p0 ^ neg_lsb_s0_p0; + *(data + i + 1) = shifted_data_s1_p0 ^ neg_lsb_s1_p0; + *(data + i + 2) = shifted_data_s2_p0 ^ neg_lsb_s2_p0; + *(data + i + 3) = shifted_data_s3_p0 ^ neg_lsb_s3_p0; + *(data + i + 4) = shifted_data_s4_p0 ^ neg_lsb_s4_p0; + *(data + i + 5) = shifted_data_s5_p0 ^ neg_lsb_s5_p0; + *(data + i + 6) = shifted_data_s6_p0 ^ neg_lsb_s6_p0; + *(data + i + 7) = shifted_data_s7_p0 ^ neg_lsb_s7_p0; + *(data + i + 8) = shifted_data_s0_p1 ^ neg_lsb_s0_p1; + *(data + i + 9) = shifted_data_s1_p1 ^ neg_lsb_s1_p1; + *(data + i + 10) = shifted_data_s2_p1 ^ neg_lsb_s2_p1; + *(data + i + 11) = shifted_data_s3_p1 ^ neg_lsb_s3_p1; + *(data + i + 12) = shifted_data_s4_p1 ^ neg_lsb_s4_p1; + *(data + i + 13) = shifted_data_s5_p1 ^ neg_lsb_s5_p1; + *(data + i + 14) = shifted_data_s6_p1 ^ neg_lsb_s6_p1; + *(data + i + 15) = shifted_data_s7_p1 ^ neg_lsb_s7_p1; + } + + // handle left + for (; i < numValues; i++) { + data[i] = data[i] >> 1 ^ -(data[i] & 1); + } +} + +inline void BitsToBoolsBatch(bool *dest, uint8_t *source, uint64_t bitsLeft) { + uint8_t* data = reinterpret_cast(dest); + for (uint64_t i = 0; i < bitsLeft; i++) { + uint64_t byteIndex = i / 8; + uint64_t bitPosition = 7 - (i % 8); + data[i] = !((source[byteIndex] >> bitPosition) & 0x1); + } +} + +inline void BitsToBoolsBatchHEFs0p1(bool *dest, uint8_t *source, uint64_t bitsLeft) { + uint8_t* data = reinterpret_cast(dest); + uint16_t *read = reinterpret_cast(source); + + // fill all zero in vector lane + uint8x16_t zero = vdupq_n_u8(0); + + // Identify the bit in specified position + uint8_t array[16] = {0x80, 0x40, 0x20, 0x10, 0x08, 0x04, 0x02, 0x01, + 0x80, 0x40, 0x20, 0x10, 0x08, 0x04, 0x02, 0x01}; + uint8x16_t mask = vld1q_u8(array); + + uint64_t i = 0; + // handle 16 elements in one batch + for (; i + 16 <= bitsLeft; i += 16) { + // load two bytes + uint16_t value = read[i/16]; + uint8_t byte0 = (value >> 8) & 0xFF; // first byte + uint8_t byte1 = value & 0xFF; // second byte + + uint8x8_t bits0 = vdup_n_u8(byte0); + uint8x8_t bits1 = vdup_n_u8(byte1); + uint8x16_t bits = vcombine_u8(bits1, bits0); + uint8x16_t result = vmvnq_u8(bits); + + // compute masked result and compare with zero + uint8x16_t expandedBytes = vcgtq_u8(vandq_u8(result, mask), zero); + // move the most significant bit of each byte to the least significant bit + uint8x16_t store = vshrq_n_u8(expandedBytes, 7); + + // store to data + vst1q_u8(data + i, store); + } + + // handle left + for (; i < bitsLeft; i++) { + uint64_t byteIndex = i / 8; + uint64_t bitPosition = 7 - (i % 8); + data[i] = !((source[byteIndex] >> bitPosition) & 0x1); + } +} + +struct BitNotFlip { + BitNotFlip() { + for (int i = 0; i < (1 << N); i++) { + // 取反然后再翻转 + memo_[i] = flip(~i); + } + } + + static uint8_t flip(uint8_t byte) { + return ((byte & 0x01) << 7) | ((byte & 0x02) << 5) | ((byte & 0x04) << 3) | ((byte & 0x08) << 1) | + ((byte & 0x10) >> 1) |((byte & 0x20) >> 3) | ((byte & 0x40) >> 5) | ((byte & 0x80) >> 7); + } + + uint8_t inline operator[](size_t i) const { + return memo_[i]; + } + +private: + static constexpr int N = 8; + uint8_t memo_[1 << N]{0}; +}; + +const static BitNotFlip bitNotFlip; + +inline void ReverseAndFlipBytes(uint8_t *bytes, int length) { + for (int i = 0; i < length; i++) { + bytes[i] = bitNotFlip[bytes[i]]; + } +} + +struct NotNullBitMask { + NotNullBitMask() { + for (int i = 0; i < (1 << N); i++) { + int32_t startIndex = i * (N + 1); + int32_t index = startIndex; + for (int bit = 0; bit < N; bit++) { + if ((i & (1 << bit)) == 0) { + memo_[++index] = bit; + } + } + memo_[startIndex] = index - startIndex; + } + } + + const inline uint8_t* operator[](size_t i) const { + return memo_ + (i * (N + 1)); + } + +private: + static constexpr int N = 8; + uint8_t memo_[(1 << N) * (N + 1)]{0}; +}; + +const static NotNullBitMask notNullBitMask; + +} + +#endif // ORC_DECODE_UTILS \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetDecoder.h b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetDecoder.h index 8a04e338c5e3b54d0bae20d8f71de7558ad9bd60..e94586b6f90f74bf397780e7ec1dee481cb70ba0 100644 --- a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetDecoder.h +++ b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetDecoder.h @@ -59,8 +59,7 @@ namespace omniruntime::reader { // TODO: optimize batch move template - inline int SpacedExpand(T* buffer, int num_values, int null_count, - bool* nulls) { + inline int SpacedExpand(T* buffer, int num_values, int null_count, uint8_t* nulls, int64_t nullsOffset) { int idx_decode = num_values - null_count; memset_s(static_cast(buffer + idx_decode), null_count * sizeof(T), 0, null_count * sizeof(T)); if (idx_decode == 0) { @@ -68,7 +67,7 @@ namespace omniruntime::reader { return num_values; } for (int i = num_values - 1; i >= 0; --i) { - if (!nulls[i]) { + if (!BitUtil::IsBitSet(nulls, nullsOffset + i)) { idx_decode--; memmove_s(buffer + i, sizeof(T), buffer + idx_decode, sizeof(T)); } @@ -82,8 +81,7 @@ namespace omniruntime::reader { public: using T = typename DType::c_type; - virtual int DecodeSpaced(T* buffer, int num_values, int null_count, - bool* nulls) { + virtual int DecodeSpaced(T* buffer, int num_values, int null_count, uint8_t* nulls, int64_t nullsOffset) { if (null_count > 0) { int values_to_read = num_values - null_count; int values_read = Decode(buffer, values_to_read); @@ -91,7 +89,7 @@ namespace omniruntime::reader { throw ::parquet::ParquetException("Number of values / definition_levels read did not match"); } - return SpacedExpand(buffer, num_values, null_count, nulls); + return SpacedExpand(buffer, num_values, null_count, nulls, nullsOffset); } else { return Decode(buffer, num_values); } @@ -105,7 +103,7 @@ namespace omniruntime::reader { ::parquet::ParquetException::NYI("ParquetTypedDecoder for DecodeArrowNonNull"); } - virtual int DecodeArrow(int num_values, int null_count, bool* nulls, + virtual int DecodeArrow(int num_values, int null_count, uint8_t* nulls, int64_t offset, omniruntime::vec::BaseVector** outBaseVec) { ::parquet::ParquetException::NYI("ParquetTypedDecoder for DecodeArrow"); } @@ -228,7 +226,7 @@ namespace omniruntime::reader { ::parquet::ParquetException::NYI("ParquetTypedDecoder for DecodeArrowNonNull"); } - virtual int DecodeArrow(int num_values, int null_count, bool* nulls, + virtual int DecodeArrow(int num_values, int null_count, uint8_t* nulls, int64_t offset, omniruntime::vec::BaseVector** outBaseVec) { ::parquet::ParquetException::NYI("ParquetTypedDecoder for DecodeArrow"); } @@ -275,7 +273,7 @@ namespace omniruntime::reader { return result; } - int DecodeArrow(int num_values, int null_count, bool* nulls, + int DecodeArrow(int num_values, int null_count, uint8_t* nulls, int64_t offset, omniruntime::vec::BaseVector** vec) override { int result = 0; PARQUET_THROW_NOT_OK(DecodeArrowDense(num_values, null_count, nulls, @@ -284,7 +282,7 @@ namespace omniruntime::reader { } private: - arrow::Status DecodeArrowDense(int num_values, int null_count, bool* nulls, + arrow::Status DecodeArrowDense(int num_values, int null_count, uint8_t* nulls, int64_t offset, int* out_num_values, omniruntime::vec::BaseVector** out) { constexpr int32_t kBufferSize = 1024; @@ -299,7 +297,7 @@ namespace omniruntime::reader { int pos_indices = 0; for (int i = 0; i < num_values; i++) { - if (!nulls[offset + i]) { + if (!BitUtil::IsBitSet(nulls, offset + i)) { if (num_indices == pos_indices) { const auto batch_size = std::min(kBufferSize, num_values - null_count - values_decoded); @@ -465,7 +463,7 @@ namespace omniruntime::reader { return result; } - int DecodeArrow(int num_values, int null_count, bool* nulls, + int DecodeArrow(int num_values, int null_count, uint8_t* nulls, int64_t offset, omniruntime::vec::BaseVector** outBaseVec) { int result = 0; PARQUET_THROW_NOT_OK(DecodeArrowDense(num_values, null_count, nulls, @@ -474,7 +472,7 @@ namespace omniruntime::reader { } private: - arrow::Status DecodeArrowDense(int num_values, int null_count, bool* nulls, + arrow::Status DecodeArrowDense(int num_values, int null_count, uint8_t* nulls, int64_t offset, int* out_values_decoded, omniruntime::vec::BaseVector** out) { int values_decoded = 0; @@ -482,7 +480,7 @@ namespace omniruntime::reader { omniruntime::vec::LargeStringContainer>*>(*out); for (int i = 0; i < num_values; i++) { - if (!nulls[offset + i]) { + if (!BitUtil::IsBitSet(nulls, offset + i)) { if (ARROW_PREDICT_FALSE(len_ < 4)) { ::parquet::ParquetException::EofException(); } @@ -632,8 +630,7 @@ namespace omniruntime::reader { using Base = ParquetPlainDecoder<::parquet::FLBAType>; using Base::ParquetPlainDecoder; - int DecodeSpaced(T* buffer, int num_values, int null_count, - bool* nulls) override { + int DecodeSpaced(T* buffer, int num_values, int null_count, uint8_t* nulls, int64_t nullsOffset) override { int values_to_read = num_values - null_count; Decode(buffer, values_to_read); return num_values; diff --git a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetTypedRecordReader.cpp b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetTypedRecordReader.cpp index 262f1498d125c89710a734a81955c4039288b6e1..87ecefe218ada4e80d8ced742ec9ca44130b88cf 100644 --- a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetTypedRecordReader.cpp +++ b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetTypedRecordReader.cpp @@ -501,10 +501,10 @@ Status RawBytesToDecimal64Bytes(const uint8_t* bytes, int32_t length, } void DefLevelsToNullsSIMD(const int16_t* def_levels, int64_t num_def_levels, const int16_t max_def_level, - int64_t* values_read, int64_t* null_count, bool* nulls) { + int64_t* values_read, int64_t* null_count, uint8_t* nulls, int64_t nullsOffset) { for (int64_t i = 0; i < num_def_levels; ++i) { if (def_levels[i] < max_def_level) { - nulls[i] = true; + BitUtil::SetBit(nulls, nullsOffset + i, true); (*null_count)++; } } @@ -512,9 +512,10 @@ void DefLevelsToNullsSIMD(const int16_t* def_levels, int64_t num_def_levels, con } void DefLevelsToNulls(const int16_t* def_levels, int64_t num_def_levels, LevelInfo level_info, - int64_t* values_read, int64_t* null_count, bool* nulls) { + int64_t* values_read, int64_t* null_count, uint8_t* nulls, int64_t nullsOffset) { if (level_info.rep_level == 0) { - DefLevelsToNullsSIMD(def_levels, num_def_levels, level_info.def_level, values_read, null_count, nulls); + DefLevelsToNullsSIMD(def_levels, num_def_levels, level_info.def_level, values_read, null_count, nulls, + nullsOffset); } else { ::ParquetException::NYI("rep_level > 0 NYI"); } diff --git a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetTypedRecordReader.h b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetTypedRecordReader.h index 44cfd75ea7c5790a023e5f37b8ea995127a7a0d6..64c4302f003f3eda34bd1bbb9f0e2b6f1f5ce510 100644 --- a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetTypedRecordReader.h +++ b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetTypedRecordReader.h @@ -52,7 +52,7 @@ namespace omniruntime::reader { ::arrow::Status RawBytesToDecimal64Bytes(const uint8_t* bytes, int32_t length, BaseVector** out_buf, int64_t index); void DefLevelsToNulls(const int16_t* def_levels, int64_t num_def_levels, ::parquet::internal::LevelInfo level_info, - int64_t* values_read, int64_t* null_count, bool* nulls); + int64_t* values_read, int64_t* null_count, uint8_t* nulls, int64_t nullsOffset); template class ParquetColumnReaderBase { @@ -578,7 +578,7 @@ namespace omniruntime::reader { virtual void ReadValuesSpaced(int64_t values_with_nulls, int64_t null_count) { int64_t num_decoded = this->current_decoder_->DecodeSpaced( ValuesHead(), static_cast(values_with_nulls), - static_cast(null_count), nulls_ + values_written_); + static_cast(null_count), nulls_, values_written_); CheckNumberDecoded(num_decoded, values_with_nulls); } @@ -612,7 +612,7 @@ namespace omniruntime::reader { if (leaf_info_.HasNullableValues()) { int64_t values_read = 0; DefLevelsToNulls(def_levels() + start_levels_position, levels_position_ - start_levels_position, leaf_info_, - &values_read, &null_count, nulls_ + start_levels_position); + &values_read, &null_count, nulls_, start_levels_position); values_to_read = values_read - null_count; DCHECK_GE(values_to_read, 0); ReadValuesSpaced(values_read, null_count); @@ -671,7 +671,7 @@ namespace omniruntime::reader { ::parquet::internal::LevelInfo leaf_info_; omniruntime::vec::BaseVector* vec_ = nullptr; uint8_t* parquet_vec_ = nullptr; - bool* nulls_ = nullptr; + uint8_t* nulls_ = nullptr; int32_t byte_width_; }; @@ -752,7 +752,7 @@ namespace omniruntime::reader { } int index = 0; for (int64_t i = 0; i < values_written_; i++) { - if (nulls_ == nullptr || !nulls_[i]) { + if (nulls_ == nullptr || !BitUtil::IsBitSet(nulls_, i)) { PARQUET_THROW_NOT_OK(RawBytesToDecimal64Bytes(GetParquetVecHeadPtr(index++), byte_width_, &vec_, i)); } } @@ -797,7 +797,7 @@ namespace omniruntime::reader { } int index = 0; for (int64_t i = 0; i < values_written_; i++) { - if (nulls_ == nullptr || !nulls_[i]) { + if (nulls_ == nullptr || !BitUtil::IsBitSet(nulls_, i)) { PARQUET_THROW_NOT_OK(RawBytesToDecimal128Bytes(GetParquetVecHeadPtr(index++), byte_width_, &vec_, i)); } } diff --git a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetWriter.cpp b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetWriter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3e53ef5bc579f93a018a4e5e27f774915f2f2a31 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetWriter.cpp @@ -0,0 +1,439 @@ +/** + * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ParquetWriter.h" +#include "ParquetReader.h" +#include "arrow/array/array_base.h" +#include "arrow/array/array_binary.h" +#include "arrow/array/array_primitive.h" +#include "arrow/array/data.h" +#include +#include +#include +#include "arrow/util/bitmap.h" +#include "arrow/chunked_array.h" +#include "arrow/buffer_builder.h" +#include "arrow/table.h" +#include "arrowadapter/FileSystemAdapter.h" +#include "common/UriInfo.h" +#include "jni/jni_common.h" +#include "parquet/arrow/reader.h" +#include "parquet/exception.h" +#include "parquet/properties.h" +//#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace arrow; +using namespace arrow::internal; +using namespace parquet::arrow; +using namespace omniruntime::writer; +using namespace omniruntime::reader; + +static std::mutex mutex_; + +namespace omniruntime::writer +{ + + arrow::Status ParquetWriter::InitRecordWriter(UriInfo &uri, std::string &ugi) + { + parquet::WriterProperties::Builder writer_properties; + parquet::ArrowWriterProperties::Builder arrow_writer_properties; + + arrow::Status result; + mutex_.lock(); + Filesystem *fs = GetFileSystemPtr(uri, ugi, result); + mutex_.unlock(); + if (fs == nullptr || fs->filesys_ptr == nullptr) { + return arrow::Status::IOError(result); + } + + std::string uriPath = uri.ToString(); + std::string path = uri.Path(); + auto res = createDirectories(get_parent_path(path)); + if (res != 0) { + throw std::runtime_error("Create local directories fail"); + } + std::shared_ptr outputStream; + ARROW_ASSIGN_OR_RAISE(outputStream, fs->filesys_ptr->OpenOutputStream(path)); + + writer_properties.disable_dictionary(); + auto fileWriterResult = FileWriter::Open( + *schema_, arrow::default_memory_pool(), outputStream, + writer_properties.build(), parquet::default_arrow_writer_properties()); + if (!fileWriterResult.ok()) { + std::cerr<<"Error opening file writer: "< 0) { + struct stat st = {0}; + if (stat(dir.c_str(), &st) == -1) { + if (mkdir(dir.c_str(), 0755) == -1 && errno != EEXIST) { + return -1; + } + } + } + pos = path.find_first_of("/", pos + 1); + } + struct stat st = {0}; + if (stat(path.c_str(), &st) == -1) { + if (mkdir(path.c_str(), 0755) == -1 && errno != EEXIST) { + return -1; + } + } + return 0; + } + + std::shared_ptr<::arrow::ChunkedArray> buildBooleanChunk(DataTypeId typeId, BaseVector *baseVector, + bool isSplitWrite = false, long startPos = 0, + long endPos = 0) + { + using T = typename NativeType::type; + auto vector = (Vector *)baseVector; + + if (!isSplitWrite) { + startPos = 0; + endPos = vector->GetSize(); + } + + int64_t vectorSize = endPos - startPos; + bool values[vectorSize]; + int64_t index = 0; + auto bitmapBuffer = AllocateBitmap(vectorSize).ValueOrDie(); + arrow::internal::Bitmap bitmap(bitmapBuffer, 0, vectorSize); + bitmap.SetBitsTo(true); + + if (vector->HasNull()) { + for (long j = startPos; j < endPos; j++) { + if (vector->IsNull(j)) { + bitmap.SetBitTo(index, false); + } else if(isSplitWrite) { + values[index] = vector->GetValue(j); + } + index++; + } + } else if (isSplitWrite) { + for (long j = startPos; j < endPos; j++) { + values[index] = vector->GetValue(j); + index++; + } + } + + TypedBufferBuilder builder; + builder.Resize(vectorSize); + + builder.Append(reinterpret_cast(isSplitWrite?values:VectorHelper::UnsafeGetValues(vector)), vectorSize); + auto maybe_buffer = builder.Finish(); + std::shared_ptr databuffer = *maybe_buffer; + + std::vector> buffers; + buffers.emplace_back(bitmapBuffer); + buffers.emplace_back(databuffer); + + auto booleanType = std::make_shared(); + auto arrayData = arrow::ArrayData::Make(booleanType, vectorSize, buffers); + + std::vector> arrayVector; + auto booleanArray = std::make_shared(arrayData); + arrayVector.emplace_back(booleanArray); + + return arrow::ChunkedArray::Make(arrayVector, booleanType).ValueOrDie(); + } + + template + std::shared_ptr<::arrow::ChunkedArray> buildChunk(BaseVector *baseVector, + bool isSplitWrite = false, long startPos = 0, + long endPos = 0) + { + using T=typename NativeType::type; + auto vector =(Vector *)baseVector; + if (!isSplitWrite) { + startPos = 0; + endPos = vector->GetSize(); + } + int64_t vectorSize = endPos - startPos; + ChunkType values[vectorSize]; + int64_t index = 0; + + auto bitmapBuffer = AllocateBitmap(vectorSize).ValueOrDie(); + arrow::internal::Bitmap bitmap(bitmapBuffer, 0, vectorSize); + bitmap.SetBitsTo(true); + if (vector->HasNull()) { + for (long j = startPos; j < endPos; j++) { + if (vector->IsNull(j)) { + bitmap.SetBitTo(index, false); + } else if (isSplitWrite) { + values[index] = vector->GetValue(j); + } + index++; + } + } else if (isSplitWrite) { + for (long j = startPos; j < endPos; j++) { + values[index] = vector->GetValue(j); + index++; + } + } + + TypedBufferBuilder builder; + builder.Resize(vectorSize); + builder.Append(reinterpret_cast(isSplitWrite?values:VectorHelper::UnsafeGetValues(vector)), vectorSize); + auto dataBuffer = *builder.Finish(); + std::vector> buffers; + buffers.emplace_back(bitmapBuffer); + buffers.emplace_back(dataBuffer); + + auto arrowType = std::make_shared(); + auto arrayData = arrow::ArrayData::Make(arrowType, vectorSize, buffers); + std::vector> arrayVector; + auto arrowArray = std::make_shared>(arrayData); + arrayVector.emplace_back(arrowArray); + return ChunkedArray::Make(arrayVector, arrowType).ValueOrDie(); + } + + std::shared_ptr buildVarcharChunk(DataTypeId typeId, BaseVector *baseVector, + bool isSplitWrite = false, long startPos = 0, + long endPos = 0) + { + auto vector = static_cast> *>(baseVector); + + if (!isSplitWrite) { + startPos = 0; + endPos = vector->GetSize(); + } + + int64_t vectorSize = endPos - startPos; + auto bitmapBuffer = AllocateBitmap(vectorSize).ValueOrDie(); + arrow::internal::Bitmap bitmap(bitmapBuffer, 0, vectorSize); + bitmap.SetBitsTo(true); + + TypedBufferBuilder offsetsBuilder; + TypedBufferBuilder valuesBuilder; + int32_t currentOffset = 0; + offsetsBuilder.Append(0); + valuesBuilder.Resize(vectorSize); + + int64_t index = 0; + for (long j = startPos; j < endPos; j++) { + if (vector->IsNull(j)) { + bitmap.SetBitTo(index++, false); + } + index++; + std::string strValue = std::string(vector->GetValue(j)); + size_t length = strValue.length(); + currentOffset += length; + offsetsBuilder.Append(currentOffset); + valuesBuilder.Append(strValue.data(), length); + } + + auto offsetsBuffer = offsetsBuilder.Finish().ValueOrDie(); + auto valuesBuffer = valuesBuilder.Finish().ValueOrDie(); + + std::vector> buffers; + + buffers.emplace_back(bitmapBuffer); + buffers.emplace_back(offsetsBuffer); + buffers.emplace_back(valuesBuffer); + + auto utf8Type = std::make_shared(); + auto arrayData = arrow::ArrayData::Make(utf8Type, vectorSize, buffers); + + std::vector> arrayVector; + auto stringArray = std::make_shared(arrayData); + arrayVector.emplace_back(stringArray); + + return ChunkedArray::Make(arrayVector, utf8Type).ValueOrDie(); + } + + std::shared_ptr buildDecimal64Chunk(DataTypeId typeId, BaseVector *baseVector, + int precision, int scale, bool isSplitWrite = false, + long startPos = 0, long endPos = 0) + { + using T = typename NativeType::type; + auto vector = (Vector *)baseVector; + + if (!isSplitWrite) { + startPos = 0; + endPos = vector->GetSize(); + } + int64_t vectorSize = endPos - startPos; + auto bitmapBuffer = AllocateBitmap(vectorSize).ValueOrDie(); + arrow::internal::Bitmap bitmap(bitmapBuffer, 0, vectorSize); + bitmap.SetBitsTo(true); + BufferBuilder builder; + builder.Resize(vectorSize); + std::vector decimalArray; + + int64_t index = 0; + for (long j = startPos; j < endPos; j++) { + BasicDecimal128 basicDecimal128(0, vector->GetValue(j)); + decimalArray.emplace_back(BasicDecimal128(basicDecimal128)); + if (vector->IsNull(j)) { + bitmap.SetBitTo(index, false); + } + index++; + } + + builder.Append(decimalArray.data(), decimalArray.size() * sizeof(arrow::Decimal128)); + auto dataBuffer = *builder.Finish(); + std::vector> buffers; + buffers.emplace_back(bitmapBuffer); + buffers.emplace_back(dataBuffer); + + auto decimal128Type = std::make_shared(precision, scale); + auto arrayData = arrow::ArrayData::Make(decimal128Type, vectorSize, buffers); + std::vector> arrayVector; + auto decimal128Array = std::make_shared(arrayData); + arrayVector.emplace_back(decimal128Array); + return ChunkedArray::Make(arrayVector, decimal128Type).ValueOrDie(); + } + + std::shared_ptr buildDecimal128Chunk(DataTypeId typeId, BaseVector *baseVector, + int precision, int scale, bool isSplitWrite = false, + long startPos = 0, long endPos = 0) + { + using T = typename NativeType::type; + auto vector = (Vector *)baseVector; + + if (!isSplitWrite) { + startPos = 0; + endPos = vector->GetSize(); + } + int64_t vectorSize = endPos - startPos; + auto bitmapBuffer = AllocateBitmap(vectorSize).ValueOrDie(); + arrow::internal::Bitmap bitmap(bitmapBuffer, 0, vectorSize); + bitmap.SetBitsTo(true); + BufferBuilder builder; + builder.Resize(vectorSize); + std::vector decimalArray; + + int64_t index = 0; + for (long j = startPos; j < endPos; j++) { + auto decimalValue = vector->GetValue(j); + BasicDecimal128 basicDecimal128(vector->GetValue(j).HighBits(), vector->GetValue(j).LowBits()); + decimalArray.emplace_back(BasicDecimal128(basicDecimal128)); + if (vector->IsNull(j)) { + bitmap.SetBitTo(index, false); + } + index++; + } + + builder.Append(decimalArray.data(), decimalArray.size() * sizeof(arrow::Decimal128)); + auto dataBuffer = *builder.Finish(); + std::vector> buffers; + buffers.emplace_back(bitmapBuffer); + buffers.emplace_back(dataBuffer); + + auto decimal128Type = std::make_shared(precision, scale); + auto arrayData = arrow::ArrayData::Make(decimal128Type, vectorSize, buffers); + std::vector> arrayVector; + auto decimal128Array = std::make_shared(arrayData); + arrayVector.emplace_back(decimal128Array); + return ChunkedArray::Make(arrayVector, decimal128Type).ValueOrDie(); + } + + void ParquetWriter::write(long *vecNativeId, int colNums, + const int *omniTypes, + const unsigned char *dataColumnsIds, + bool isSplitWrite, long startPos , long endPos) + { + std::vector> chunks; + int decimalIndex = 0; + int precision = 0; + int scale = 0; + for (int i = 0; i < colNums; ++i) { + if (!dataColumnsIds[i]) { + continue; + } + + auto vec = (BaseVector *)vecNativeId[i]; + auto typeId = static_cast(omniTypes[i]); + switch (typeId) { + case OMNI_BOOLEAN: + chunks.emplace_back(buildBooleanChunk(typeId, vec, isSplitWrite, startPos, endPos)); + break; + case OMNI_SHORT: + chunks.emplace_back(buildChunk(vec, isSplitWrite, startPos, endPos)); + break; + case OMNI_INT: + chunks.emplace_back(buildChunk(vec, isSplitWrite, startPos, endPos)); + break; + case OMNI_LONG: + chunks.emplace_back(buildChunk(vec, isSplitWrite, startPos, endPos)); + break; + case OMNI_DATE32: + chunks.emplace_back(buildChunk(vec, isSplitWrite, startPos, endPos)); + break; + case OMNI_DATE64: + chunks.emplace_back(buildChunk(vec, isSplitWrite, startPos, endPos)); + break; + case OMNI_DOUBLE: + chunks.emplace_back(buildChunk(vec, isSplitWrite, startPos, endPos)); + break; + case OMNI_VARCHAR: + chunks.emplace_back(buildVarcharChunk(typeId, vec, isSplitWrite, startPos, endPos)); + break; + case OMNI_DECIMAL64: + precision = precisions[decimalIndex]; + scale = scales[decimalIndex]; + chunks.emplace_back(buildDecimal64Chunk(typeId, vec, precision, scale, isSplitWrite, startPos, endPos)); + decimalIndex++; + break; + case OMNI_DECIMAL128: + precision = precisions[decimalIndex]; + scale = scales[decimalIndex]; + chunks.emplace_back(buildDecimal128Chunk(typeId, vec, precision, scale, isSplitWrite, startPos, endPos)); + decimalIndex++; + break; + default: + throw std::runtime_error( + "Native columnar write not support for this type: " + std::to_string(typeId)); + } + } + auto numRows = chunks.empty() ? 0 : chunks[0]->length(); + + auto table = arrow::Table::Make(schema_, std::move(chunks), numRows); + if (!arrow_writer) { + throw std::runtime_error("Arrow writer is not initialized"); + } + PARQUET_THROW_NOT_OK(arrow_writer->WriteTable(*table)); + } + +} // namespace omniruntime::writer diff --git a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetWriter.h b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetWriter.h new file mode 100644 index 0000000000000000000000000000000000000000..9c51981d1bfff8ac8c8cdb085c7b7386a8c5c123 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetWriter.h @@ -0,0 +1,54 @@ +/** + * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NATIVE_READER_PARQUETWRITER_H +#define NATIVE_READER_PARQUETWRITER_H + + + +#include +#include +#include "common/UriInfo.h" +#include "parquet/arrow/writer.h" + +using namespace arrow::internal; + +namespace omniruntime::writer +{ + std::string get_parent_path(const std::string& path); + int createDirectories(const std::string &path); + class ParquetWriter + { + public: + ParquetWriter() {} + + arrow::Status InitRecordWriter(UriInfo &uri, std::string &ugi); + std::shared_ptr BuildField(const std::string &name, int typeId, bool nullable); + void write(long *vecNativeId, int colNums, const int *omniTypes, const unsigned char *dataColumnsIds, + bool isSplitWrite = false, long starPos = 0, long endPos = 0); + void write(); + + public: + std::unique_ptr arrow_writer; + std::shared_ptr schema_; + std::vector precisions; + std::vector scales; + }; +} +#endif // NATIVE_READER_PARQUETWRITER_H \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/test/CMakeLists.txt b/omnioperator/omniop-native-reader/cpp/test/CMakeLists.txt index 6bf7e38f1c89d8c1334a8610e4fa3c4cd2adee4a..accbcb728b33a56eb26df9df3d403811ac5d224b 100644 --- a/omnioperator/omniop-native-reader/cpp/test/CMakeLists.txt +++ b/omnioperator/omniop-native-reader/cpp/test/CMakeLists.txt @@ -31,7 +31,7 @@ target_link_libraries(${TP_TEST_TARGET} pthread stdc++ dl - boostkit-omniop-vector-1.7.0-aarch64 + boostkit-omniop-vector-1.8.0-aarch64 securec native_reader) diff --git a/omnioperator/omniop-native-reader/cpp/test/tablescan/orc_scan_test.cpp b/omnioperator/omniop-native-reader/cpp/test/tablescan/orc_scan_test.cpp index 3defb021c1be87252c189bcdc35f18d0e3e570cf..3e23386d514e8171e279b0ee69d56a58d5f08b08 100644 --- a/omnioperator/omniop-native-reader/cpp/test/tablescan/orc_scan_test.cpp +++ b/omnioperator/omniop-native-reader/cpp/test/tablescan/orc_scan_test.cpp @@ -56,8 +56,9 @@ protected: rowReaderOptions.include(includedColumns); std::unique_ptr julianPtr; + std::unique_ptr predicatePtr; auto readerPtr = static_cast(reader.get()); - rowReader = readerPtr->createRowReader(rowReaderOptions, julianPtr).release(); + rowReader = readerPtr->createRowReader(rowReaderOptions, julianPtr, predicatePtr).release(); omniruntime::reader::OmniRowReaderImpl *rowReaderPtr = (omniruntime::reader::OmniRowReaderImpl*) rowReader; rowReaderPtr->next(&recordBatch, nullptr, 4096); } diff --git a/omnioperator/omniop-native-reader/java/pom.xml b/omnioperator/omniop-native-reader/java/pom.xml index 8a40c93eac15ec938a24759a532cdfeef401f894..e11c919e2e5a7d54d0e046e8a30692b101fc2ca6 100644 --- a/omnioperator/omniop-native-reader/java/pom.xml +++ b/omnioperator/omniop-native-reader/java/pom.xml @@ -15,7 +15,7 @@ 2.12 3.3.1 - 1.7.0 + 1.8.0 ${spark.version}-${omniruntime.version} FALSE ../cpp/ @@ -118,6 +118,7 @@ bash ${cpp.dir}/build.sh + release> ${plugin.cpp.test} diff --git a/omnioperator/omniop-native-reader/java/src/main/java/com/huawei/boostkit/write/jni/ParquetColumnarBatchJniWriter.java b/omnioperator/omniop-native-reader/java/src/main/java/com/huawei/boostkit/write/jni/ParquetColumnarBatchJniWriter.java new file mode 100644 index 0000000000000000000000000000000000000000..aa94fc62adeaaf45818f05f1e269fade55f1352f --- /dev/null +++ b/omnioperator/omniop-native-reader/java/src/main/java/com/huawei/boostkit/write/jni/ParquetColumnarBatchJniWriter.java @@ -0,0 +1,39 @@ +/* + * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.write.jni; + +import com.huawei.boostkit.scan.jni.NativeReaderLoader; + +import org.json.JSONObject; + +public class ParquetColumnarBatchJniWriter { + public ParquetColumnarBatchJniWriter() { + NativeReaderLoader.getInstance(); + } + + public native void initializeWriter(JSONObject var1, long writer); + + public native long initializeSchema(long writer, String[] fieldNames, int[] fieldTypes, boolean[] nullables, int[][] decimalParam); + + public native void write(long writer, long[] vecNativeId, int[] omniTypes, boolean[] dataColumnsIds, int rowNums); + + public native void splitWrite(long writer, long[] vecNativeId, int[] omniTypes, boolean[] dataColumnsIds, long starPos, long endPos); + + public native void close(long writer); +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/CMakeLists.txt b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/CMakeLists.txt index 2b8b87a7e11e64b61da1d10b11553d4a168bd4ef..a96a687dd860f22894fa9721e32bde8dda487b03 100644 --- a/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/CMakeLists.txt +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/CMakeLists.txt @@ -38,7 +38,7 @@ target_include_directories(${PROJ_TARGET} PUBLIC /opt/lib/include) target_link_libraries (${PROJ_TARGET} PUBLIC protobuf.a z - boostkit-omniop-vector-1.7.0-aarch64 + boostkit-omniop-vector-1.8.0-aarch64 ock_shuffle gcov ) diff --git a/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/CMakeLists.txt b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/CMakeLists.txt index 59b17d497c8156ce798c1235df173d523c9b9188..1b387c9826294b56873003fbe289b95ddab8f218 100644 --- a/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/CMakeLists.txt +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/CMakeLists.txt @@ -28,7 +28,7 @@ target_link_libraries(${TP_TEST_TARGET} pthread stdc++ dl - boostkit-omniop-vector-1.7.0-aarch64 + boostkit-omniop-vector-1.8.0-aarch64 securec ock_columnar_shuffle) diff --git a/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/pom.xml b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/pom.xml index 092ad1ccc2d1a7e3be01762ad9c81ece27391f84..8e82d5a4eddcb3a63298bb6a028d846193697f80 100644 --- a/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/pom.xml +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/pom.xml @@ -6,7 +6,7 @@ com.huawei.ock omniop-spark-extension-ock - 24.0.0 + 25.0.0 cpp/ @@ -19,7 +19,7 @@ com.huawei.kunpeng boostkit-omniop-spark - 3.3.1-1.7.0 + 3.3.1-1.8.0 compile @@ -27,7 +27,7 @@ ock-omniop-shuffle-manager jar Huawei Open Computing Kit for Spark, shuffle manager - 24.0.0 + 25.0.0 diff --git a/omnioperator/omniop-spark-extension-ock/pom.xml b/omnioperator/omniop-spark-extension-ock/pom.xml index d852d7a3d6a8d7f11da676392610e7e7c88847ad..9165b47dd925e0949db7ec52e14d0855dc67476d 100644 --- a/omnioperator/omniop-spark-extension-ock/pom.xml +++ b/omnioperator/omniop-spark-extension-ock/pom.xml @@ -8,7 +8,7 @@ omniop-spark-extension-ock pom Huawei Open Computing Kit for Spark - 24.0.0 + 25.0.0 3.3.1 @@ -20,7 +20,7 @@ spark-3.3 3.2.0 3.1.1 - 24.0.0 + 25.0.0 @@ -62,13 +62,13 @@ com.huawei.boostkit boostkit-omniop-bindings - 1.7.0 + 1.8.0 aarch64 com.huawei.kunpeng boostkit-omniop-spark - 3.3.1-1.7.0 + 3.3.1-1.8.0 com.huawei.ock diff --git a/omnioperator/omniop-spark-extension/cpp/CMakeLists.txt b/omnioperator/omniop-spark-extension/cpp/CMakeLists.txt index 10f630ad13925922872540fb13b379a0b52e15b3..eb88f6841f013934da589f981c16c942ecbd87d5 100644 --- a/omnioperator/omniop-spark-extension/cpp/CMakeLists.txt +++ b/omnioperator/omniop-spark-extension/cpp/CMakeLists.txt @@ -11,7 +11,7 @@ set(CMAKE_CXX_COMPILER "g++") set(root_directory ${PROJECT_BINARY_DIR}) set(CMAKE_CXX_FLAGS_DEBUG "-pipe -g -Wall -fPIC -fno-common -fno-stack-protector") -set(CMAKE_CXX_FLAGS_RELEASE "-O2 -pipe -Wall -Wtrampolines -D_FORTIFY_SOURCE=2 -O2 -fPIC -finline-functions -fstack-protector-strong -s -Wl,-z,noexecstack -Wl,-z,relro,-z,now") +set(CMAKE_CXX_FLAGS_RELEASE "-O3 -march=armv8-a+crc -pipe -Wall -Wtrampolines -D_FORTIFY_SOURCE=2 -fPIC -finline-functions -fstack-protector-strong -s -Wl,-z,noexecstack -Wl,-z,relro,-z,now") if (DEFINED COVERAGE) if(${COVERAGE} STREQUAL "ON") diff --git a/omnioperator/omniop-spark-extension/cpp/src/CMakeLists.txt b/omnioperator/omniop-spark-extension/cpp/src/CMakeLists.txt index 12fbc44a057a3cf07263db554607845d47d20992..15458706bd159928a9b4bf11fa6b25b31577f44f 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/CMakeLists.txt +++ b/omnioperator/omniop-spark-extension/cpp/src/CMakeLists.txt @@ -14,6 +14,7 @@ set (SOURCE_FILES common/common.cpp jni/SparkJniWrapper.cpp jni/jni_common.cpp + jni/deserializer.cpp ) #Find required protobuf package @@ -42,7 +43,7 @@ target_link_libraries (${PROJ_TARGET} PUBLIC snappy lz4 zstd - boostkit-omniop-vector-1.7.0-aarch64 + boostkit-omniop-vector-1.8.0-aarch64 ) set_target_properties(${PROJ_TARGET} PROPERTIES diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/deserializer.cpp b/omnioperator/omniop-spark-extension/cpp/src/jni/deserializer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..69f6ad31598a8ed35c2f5d9e34867121f98a0cb2 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/jni/deserializer.cpp @@ -0,0 +1,213 @@ +/** + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "jni_common.h" +#include "deserializer.hh" +#include "common/common.h" + +using namespace omniruntime::vec; + +JNIEXPORT jlong JNICALL +Java_com_huawei_boostkit_spark_serialize_ShuffleDataSerializerUtils_columnarShuffleParseInit( + JNIEnv *env, jobject obj, jlong address, jint length) +{ + JNI_FUNC_START + // tranform protobuf bytes to VecBatch + auto *vecBatch = new spark::VecBatch(); + vecBatch->ParseFromArray(reinterpret_cast(address), length); + return (jlong)(vecBatch); + JNI_FUNC_END(runtimeExceptionClass) +} + +JNIEXPORT void JNICALL +Java_com_huawei_boostkit_spark_serialize_ShuffleDataSerializerUtils_columnarShuffleParseClose( + JNIEnv *env, jobject obj, jlong address) +{ + JNI_FUNC_START + spark::VecBatch* vecBatch = reinterpret_cast(address); + delete vecBatch; + JNI_FUNC_END_VOID(runtimeExceptionClass) +} + +JNIEXPORT jint JNICALL +Java_com_huawei_boostkit_spark_serialize_ShuffleDataSerializerUtils_columnarShuffleParseVecCount( + JNIEnv *env, jobject obj, jlong address) +{ + JNI_FUNC_START + spark::VecBatch* vecBatch = reinterpret_cast(address); + return vecBatch->veccnt(); + JNI_FUNC_END(runtimeExceptionClass) +} + +JNIEXPORT jint JNICALL +Java_com_huawei_boostkit_spark_serialize_ShuffleDataSerializerUtils_columnarShuffleParseRowCount( + JNIEnv *env, jobject obj, jlong address) +{ + JNI_FUNC_START + spark::VecBatch* vecBatch = reinterpret_cast(address); + return vecBatch->rowcnt(); + JNI_FUNC_END(runtimeExceptionClass) +} + +JNIEXPORT void JNICALL +Java_com_huawei_boostkit_spark_serialize_ShuffleDataSerializerUtils_columnarShuffleParseBatch( + JNIEnv *env, jobject obj, jlong address, jintArray typeIdArray, jintArray precisionArray, + jintArray scaleArray, jlongArray vecNativeIdArray) +{ + spark::VecBatch* vecBatch = reinterpret_cast(address); + int32_t vecCount = vecBatch->veccnt(); + int32_t rowCount = vecBatch->rowcnt(); + omniruntime::vec::BaseVector* vecs[vecCount]; + + JNI_FUNC_START + jint *typeIdArrayElements = env->GetIntArrayElements(typeIdArray, NULL); + jint *precisionArrayElements = env->GetIntArrayElements(precisionArray, NULL); + jint *scaleArrayElements = env->GetIntArrayElements(scaleArray, NULL); + jlong *vecNativeIdArrayElements = env->GetLongArrayElements(vecNativeIdArray, NULL); + + for (auto i = 0; i < vecCount; ++i) { + const spark::Vec& protoVec = vecBatch->vecs(i); + const spark::VecType& protoTypeId = protoVec.vectype(); + scaleArrayElements[i] = protoTypeId.scale(); + precisionArrayElements[i] = protoTypeId.precision(); + typeIdArrayElements[i] = static_cast(protoTypeId.typeid_()); + + // create native vector + auto vectorDataTypeId = static_cast(protoTypeId.typeid_()); + vecs[i] = VectorHelper::CreateVector(OMNI_FLAT, vectorDataTypeId, rowCount); + vecNativeIdArrayElements[i] = (jlong)(vecs[i]); + + + auto values = protoVec.values().data(); + auto offsets = protoVec.offset().data(); + auto nulls = protoVec.nulls().data(); + + if (vectorDataTypeId == OMNI_CHAR || vectorDataTypeId == OMNI_VARCHAR) { + auto charVec = reinterpret_cast> *>(vecs[i]); + char *valuesAddress = + omniruntime::vec::unsafe::UnsafeStringVector::ExpandStringBuffer(charVec, protoVec.values().size()); + auto offsetsAddress = (uint8_t *)VectorHelper::UnsafeGetOffsetsAddr(vecs[i]); + memcpy_s(valuesAddress, protoVec.values().size(), values, protoVec.values().size()); + memcpy_s(offsetsAddress, protoVec.offset().size(), offsets, protoVec.offset().size()); + } else { + auto *valuesAddress = (uint8_t *)VectorHelper::UnsafeGetValues(vecs[i]); + memcpy_s(valuesAddress, protoVec.values().size(), values, protoVec.values().size()); + } + for (auto j = 0; j < protoVec.nulls().size(); ++j) { + if (int(nulls[j])) { + vecs[i]->SetNull(j); + } + } + } + + env->ReleaseIntArrayElements(typeIdArray, typeIdArrayElements, 0); + env->ReleaseIntArrayElements(precisionArray, precisionArrayElements, 0); + env->ReleaseIntArrayElements(scaleArray, scaleArrayElements, 0); + env->ReleaseLongArrayElements(vecNativeIdArray, vecNativeIdArrayElements, 0); + JNI_FUNC_END_WITH_VECTORS(runtimeExceptionClass, vecs) +} + + +JNIEXPORT jlong JNICALL +Java_com_huawei_boostkit_spark_serialize_ShuffleDataSerializerUtils_rowShuffleParseInit( + JNIEnv *env, jobject obj, jlong address, jint length) +{ + JNI_FUNC_START + // tranform protobuf bytes to ProtoRowBatch + auto *protoRowBatch = new spark::ProtoRowBatch(); + protoRowBatch->ParseFromArray(reinterpret_cast(address), length); + return (jlong)(protoRowBatch); + JNI_FUNC_END(runtimeExceptionClass) +} + +JNIEXPORT void JNICALL +Java_com_huawei_boostkit_spark_serialize_ShuffleDataSerializerUtils_rowShuffleParseClose( + JNIEnv *env, jobject obj, jlong address) +{ + JNI_FUNC_START + spark::ProtoRowBatch* protoRowBatch = reinterpret_cast(address); + delete protoRowBatch; + JNI_FUNC_END_VOID(runtimeExceptionClass) +} + +JNIEXPORT jint JNICALL +Java_com_huawei_boostkit_spark_serialize_ShuffleDataSerializerUtils_rowShuffleParseVecCount( + JNIEnv *env, jobject obj, jlong address) +{ + JNI_FUNC_START + spark::ProtoRowBatch* protoRowBatch = reinterpret_cast(address); + return protoRowBatch->veccnt(); + JNI_FUNC_END(runtimeExceptionClass) +} + +JNIEXPORT jint JNICALL +Java_com_huawei_boostkit_spark_serialize_ShuffleDataSerializerUtils_rowShuffleParseRowCount( + JNIEnv *env, jobject obj, jlong address) +{ + JNI_FUNC_START + spark::ProtoRowBatch* protoRowBatch = reinterpret_cast(address); + return protoRowBatch->rowcnt(); + JNI_FUNC_END(runtimeExceptionClass) +} + +JNIEXPORT void JNICALL +Java_com_huawei_boostkit_spark_serialize_ShuffleDataSerializerUtils_rowShuffleParseBatch( + JNIEnv *env, jobject obj, jlong address, jintArray typeIdArray, jintArray precisionArray, + jintArray scaleArray, jlongArray vecNativeIdArray) +{ + spark::ProtoRowBatch* protoRowBatch = reinterpret_cast(address); + int32_t vecCount = protoRowBatch->veccnt(); + int32_t rowCount = protoRowBatch->rowcnt(); + omniruntime::vec::BaseVector* vecs[vecCount]; + std::vector omniDataTypeIds(vecCount); + + JNI_FUNC_START + jint *typeIdArrayElements = env->GetIntArrayElements(typeIdArray, NULL); + jint *precisionArrayElements = env->GetIntArrayElements(precisionArray, NULL); + jint *scaleArrayElements = env->GetIntArrayElements(scaleArray, NULL); + jlong *vecNativeIdArrayElements = env->GetLongArrayElements(vecNativeIdArray, NULL); + + for (auto i = 0; i < vecCount; ++i) { + const spark::VecType& protoTypeId = protoRowBatch->vectypes(i); + scaleArrayElements[i] = protoTypeId.scale(); + precisionArrayElements[i] = protoTypeId.precision(); + typeIdArrayElements[i] = static_cast(protoTypeId.typeid_()); + omniDataTypeIds[i] = static_cast(protoTypeId.typeid_()); + + // create native vector + auto vectorDataTypeId = static_cast(protoTypeId.typeid_()); + vecs[i] = VectorHelper::CreateVector(OMNI_FLAT, vectorDataTypeId, rowCount); + vecNativeIdArrayElements[i] = (jlong)(vecs[i]); + } + + auto *parser = new RowParser(omniDataTypeIds); + char *rows = const_cast(protoRowBatch->rows().data()); + const int32_t *offsets = reinterpret_cast(protoRowBatch->offsets().data()); + for (auto i = 0; i < rowCount; ++i) { + char *rowPtr = rows + offsets[i]; + parser->ParseOneRow(reinterpret_cast(rowPtr), vecs, i); + } + + env->ReleaseIntArrayElements(typeIdArray, typeIdArrayElements, 0); + env->ReleaseIntArrayElements(precisionArray, precisionArrayElements, 0); + env->ReleaseIntArrayElements(scaleArray, scaleArrayElements, 0); + env->ReleaseLongArrayElements(vecNativeIdArray, vecNativeIdArrayElements, 0); + delete parser; + JNI_FUNC_END_WITH_VECTORS(runtimeExceptionClass, vecs) +} diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/deserializer.hh b/omnioperator/omniop-spark-extension/cpp/src/jni/deserializer.hh new file mode 100644 index 0000000000000000000000000000000000000000..32cc871d1136212dc14b2094d3081450941d7e0d --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/src/jni/deserializer.hh @@ -0,0 +1,82 @@ +/** + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "vec_data.pb.h" + +#ifndef SPARK_JNI_DESERIALIZER +#define SPARK_JNI_DESERIALIZER +#ifdef __cplusplus +extern "C" { +#endif + +JNIEXPORT void JNICALL +Java_com_huawei_boostkit_spark_serialize_ShuffleDataSerializerUtils_columnarShuffleParseBatch( + JNIEnv *env, jobject obj, jlong address, jintArray typeIdArray, jintArray precisionArray, + jintArray scaleArray, jlongArray vecNativeIdArray); + +JNIEXPORT jlong JNICALL +Java_com_huawei_boostkit_spark_serialize_ShuffleDataSerializerUtils_columnarShuffleParseInit( + JNIEnv *env, jobject obj, jlong address, jint length); + +JNIEXPORT void JNICALL +Java_com_huawei_boostkit_spark_serialize_ShuffleDataSerializerUtils_columnarShuffleParseClose( + JNIEnv *env, jobject obj, jlong address); + +JNIEXPORT jint JNICALL +Java_com_huawei_boostkit_spark_serialize_ShuffleDataSerializerUtils_columnarShuffleParseVecCount( + JNIEnv *env, jobject obj, jlong address); + +JNIEXPORT jint JNICALL +Java_com_huawei_boostkit_spark_serialize_ShuffleDataSerializerUtils_columnarShuffleParseRowCount( + JNIEnv *env, jobject obj, jlong address); + +JNIEXPORT void JNICALL +Java_com_huawei_boostkit_spark_serialize_ShuffleDataSerializerUtils_columnarShuffleParseColumnarBatch( + JNIEnv *env, jobject obj, jlong address, jintArray typeIdArray, jintArray precisionArray, + jintArray scaleArray, jlongArray vecNativeIdArray); + +JNIEXPORT jlong JNICALL +Java_com_huawei_boostkit_spark_serialize_ShuffleDataSerializerUtils_rowShuffleParseInit( + JNIEnv *env, jobject obj, jlong address, jint length); + +JNIEXPORT void JNICALL +Java_com_huawei_boostkit_spark_serialize_ShuffleDataSerializerUtils_rowShuffleParseClose( + JNIEnv *env, jobject obj, jlong address); + +JNIEXPORT jint JNICALL +Java_com_huawei_boostkit_spark_serialize_ShuffleDataSerializerUtils_rowShuffleParseVecCount( + JNIEnv *env, jobject obj, jlong address); + +JNIEXPORT jint JNICALL +Java_com_huawei_boostkit_spark_serialize_ShuffleDataSerializerUtils_rowShuffleParseRowCount( + JNIEnv *env, jobject obj, jlong address); + +JNIEXPORT void JNICALL +Java_com_huawei_boostkit_spark_serialize_ShuffleDataSerializerUtils_rowShuffleParseBatch( + JNIEnv *env, jobject obj, jlong address, jintArray typeIdArray, jintArray precisionArray, + jintArray scaleArray, jlongArray vecNativeIdArray); + + +#ifdef __cplusplus +} +#endif +#endif diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.h b/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.h index 964fab6dfc06ac692294fa212f40afc21a4d1041..23fc600398a056b36f06f9a23c730a12336cc900 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.h +++ b/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.h @@ -57,6 +57,16 @@ jmethodID GetMethodID(JNIEnv* env, jclass this_class, const char* name, const ch return 0; \ } +#define JNI_FUNC_END_WITH_VECTORS(exceptionClass, vectors) \ + } catch (const std::exception &e) { \ + for (auto vec : vectors) { \ + delete vec; \ + } \ + env->ThrowNew(runtimeExceptionClass, e.what()); \ + return; \ + } \ + + extern jclass runtimeExceptionClass; extern jclass splitResultClass; extern jclass jsonClass; diff --git a/omnioperator/omniop-spark-extension/cpp/src/proto/vec_data.proto b/omnioperator/omniop-spark-extension/cpp/src/proto/vec_data.proto index 33ee64ec84d12b4bbf76c579645a8cb8a2e9db7d..869105748977fa573853d6f181fe5ccef2a2300a 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/proto/vec_data.proto +++ b/omnioperator/omniop-spark-extension/cpp/src/proto/vec_data.proto @@ -59,14 +59,10 @@ message VecType { TimeUnit timeUnit = 6; } -message ProtoRow { - bytes data = 1; - uint32 length = 2; -} - message ProtoRowBatch { int32 rowCnt = 1; int32 vecCnt = 2; repeated VecType vecTypes = 3; - repeated ProtoRow rows = 4; + bytes rows = 4; + bytes offsets = 5; } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp index ccfa024c1601311703542ed511bd7b26456f6a99..4265d4c03e50949328142f1608e3f82a53b1f8ba 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp +++ b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp @@ -20,6 +20,8 @@ #include "splitter.h" #include "utils.h" +using namespace omniruntime::vec; + SplitOptions SplitOptions::Defaults() { return SplitOptions(); } // 计算分区id,每个batch初始化 @@ -236,12 +238,18 @@ void Splitter::SplitBinaryVector(BaseVector *varcharVector, int col_schema) { if (varcharVector->GetEncoding() == OMNI_DICTIONARY) { auto vc = reinterpret_cast> *>( varcharVector); - cached_vectorbatch_size_ += num_rows * (sizeof(bool) + sizeof(int32_t)); + cached_vectorbatch_size_ += num_rows * (sizeof(bool) + sizeof(int32_t)); for (auto row = 0; row < num_rows; ++row) { auto pid = partition_id_[row]; uint8_t *dst = nullptr; uint32_t str_len = 0; - if (!vc->IsNull(row)) { + if constexpr (hasNull) { + if (!vc->IsNull(row)) { + std::string_view value = vc->GetValue(row); + dst = reinterpret_cast(reinterpret_cast(value.data())); + str_len = static_cast(value.length()); + } + } else { std::string_view value = vc->GetValue(row); dst = reinterpret_cast(reinterpret_cast(value.data())); str_len = static_cast(value.length()); @@ -272,11 +280,17 @@ void Splitter::SplitBinaryVector(BaseVector *varcharVector, int col_schema) { } else { auto vc = reinterpret_cast> *>(varcharVector); cached_vectorbatch_size_ += num_rows * (sizeof(bool) + sizeof(int32_t)) + sizeof(int32_t); - for (auto row = 0; row < num_rows; ++row) { + for (auto row = 0; row < num_rows; ++row) { auto pid = partition_id_[row]; uint8_t *dst = nullptr; uint32_t str_len = 0; - if (!vc->IsNull(row)) { + if constexpr (hasNull) { + if (!vc->IsNull(row)) { + std::string_view value = vc->GetValue(row); + dst = reinterpret_cast(reinterpret_cast(value.data())); + str_len = static_cast(value.length()); + } + } else { std::string_view value = vc->GetValue(row); dst = reinterpret_cast(reinterpret_cast(value.data())); str_len = static_cast(value.length()); @@ -366,7 +380,7 @@ int Splitter::SplitFixedWidthValidityBuffer(VectorBatch& vb){ for (auto row = 0; row < num_rows; ++row) { auto pid = partition_id_[row]; auto dst_offset = partition_buffer_idx_base_[pid] + partition_buffer_idx_offset_[pid]; - dst_addrs[pid][dst_offset] = src_addr[row]; + dst_addrs[pid][dst_offset] = omniruntime::BitUtil::IsBitSet(src_addr, row); partition_buffer_idx_offset_[pid]++; } } @@ -629,7 +643,6 @@ int Splitter::SplitByRow(VectorBatch *vecBatch) { partition_rows[0].emplace_back(rowInfo); total_input_size += rowInfo->length; } - delete vecBatch; } else { auto pidVec = reinterpret_cast *>(vecBatch->Get(0)); auto tmpVectorBatch = new VectorBatch(rowCount); @@ -966,20 +979,28 @@ int32_t Splitter::ProtoWritePartitionByRow(int32_t partition_id, std::unique_ptr } int64_t offset = batchCount * options_.spill_batch_row_num; + std::vector offset_vec(onceCopyRow + 1, 0); auto rowInfoPtr = partition_rows[partition_id].data() + offset; for (uint64_t i = 0; i < onceCopyRow; ++i) { RowInfo *rowInfo = rowInfoPtr[i]; - spark::ProtoRow *protoRow = protoRowBatch->add_rows(); - protoRow->set_data(rowInfo->row, rowInfo->length); - protoRow->set_length(rowInfo->length); - // free row memory + offset_vec[i + 1] = offset_vec[i] + rowInfo->length; + } + std::string rows; + rows.reserve(offset_vec[onceCopyRow]); + for (uint64_t i = 0; i < onceCopyRow; ++i) { + RowInfo *rowInfo = rowInfoPtr[i]; + rows.append(reinterpret_cast(rowInfo->row), rowInfo->length); + // free row memeory delete rowInfo; } + protoRowBatch->set_rows(std::move(rows)); + protoRowBatch->set_offsets(reinterpret_cast(offset_vec.data()), onceCopyRow * sizeof(int32_t)); - if (protoRowBatch->ByteSizeLong() > UINT32_MAX) { + auto byteSizeLong = protoRowBatch->ByteSizeLong(); + if (byteSizeLong > UINT32_MAX) { throw std::runtime_error("Unsafe static_cast long to uint_32t."); } - uint32_t protoRowBatchSize = reversebytes_uint32t(static_cast(protoRowBatch->ByteSizeLong())); + uint32_t protoRowBatchSize = reversebytes_uint32t(static_cast(byteSizeLong)); if (bufferStream->Next(&bufferOut, &sizeOut)) { memcpy_s(bufferOut, sizeof(protoRowBatchSize), &protoRowBatchSize, sizeof(protoRowBatchSize)); if (sizeof(protoRowBatchSize) < static_cast(sizeOut)) { @@ -1118,20 +1139,28 @@ int Splitter::protoSpillPartitionByRow(int32_t partition_id, std::unique_ptr offset_vec(onceCopyRow + 1, 0); auto rowInfoPtr = partition_rows[partition_id].data() + offset; for (uint64_t i = 0; i < onceCopyRow; ++i) { RowInfo *rowInfo = rowInfoPtr[i]; - spark::ProtoRow *protoRow = protoRowBatch->add_rows(); - protoRow->set_data(rowInfo->row, rowInfo->length); - protoRow->set_length(rowInfo->length); - // free row memory + offset_vec[i + 1] = offset_vec[i] + rowInfo->length; + } + std::string rows; + rows.reserve(offset_vec[onceCopyRow]); + for (uint64_t i = 0; i < onceCopyRow; ++i) { + RowInfo *rowInfo = rowInfoPtr[i]; + rows.append(reinterpret_cast(rowInfo->row), rowInfo->length); + // free row memeory delete rowInfo; } + protoRowBatch->set_rows(std::move(rows)); + protoRowBatch->set_offsets(reinterpret_cast(offset_vec.data()), onceCopyRow * sizeof(int32_t)); - if (protoRowBatch->ByteSizeLong() > UINT32_MAX) { + auto byteSizeLong = protoRowBatch->ByteSizeLong(); + if (byteSizeLong > UINT32_MAX) { throw std::runtime_error("Unsafe static_cast long to uint_32t."); } - uint32_t protoRowBatchSize = reversebytes_uint32t(static_cast(protoRowBatch->ByteSizeLong())); + uint32_t protoRowBatchSize = reversebytes_uint32t(static_cast(byteSizeLong)); void *buffer = nullptr; if (!bufferStream->NextNBytes(&buffer, sizeof(protoRowBatchSize))) { throw std::runtime_error("Allocate Memory Failed: Flush Spilled Data, Next failed."); diff --git a/omnioperator/omniop-spark-extension/cpp/test/CMakeLists.txt b/omnioperator/omniop-spark-extension/cpp/test/CMakeLists.txt index c192b33f03af8a38d0462d924e6296cf10ab1434..dd1e37907a289a1646b358a56c2143c1650ed8cf 100644 --- a/omnioperator/omniop-spark-extension/cpp/test/CMakeLists.txt +++ b/omnioperator/omniop-spark-extension/cpp/test/CMakeLists.txt @@ -29,7 +29,7 @@ target_link_libraries(${TP_TEST_TARGET} pthread stdc++ dl - boostkit-omniop-vector-1.7.0-aarch64 + boostkit-omniop-vector-1.8.0-aarch64 securec spark_columnar_plugin) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala new file mode 100644 index 0000000000000000000000000000000000000000..0915b1511f7534206b0c8a8b055b424069f84560 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -0,0 +1,2917 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.io.NotSerializableException +import java.util.Properties +import java.util.concurrent.{ConcurrentHashMap, ScheduledFuture, TimeoutException, TimeUnit } +import java.util.concurrent.atomic.AtomicInteger + +import scala.annotation.tailrec +import scala.collection.Map +import scala.collection.mutable +import scala.collection.mutable.{HashMap, HashSet, ListBuffer} +import scala.concurrent.duration._ +import scala.util.control.NonFatal + +import com.google.common.util.concurrent.{Futures, SettableFuture} + +import org.apache.spark._ +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.errors.SparkCoreErrors +import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config +import org.apache.spark.internal.config.Tests.TEST_NO_STAGE_RETRY +import org.apache.spark.network.shuffle.{BlockStoreClient, MergeFinalizerListener} +import org.apache.spark.network.shuffle.protocol.MergeStatuses +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} +import org.apache.spark.rdd.{RDD, RDDCheckpointData} +import org.apache.spark.resource.ResourceProfile +import org.apache.spark.resource.ResourceProfile.{DEFAULT_RESOURCE_PROFILE_ID, EXECUTOR_CORES_LOCAL_PROPERTY, PYSPARK_MEMORY_LOCAL_PROPERTY} +import org.apache.spark.rpc.RpcTimeout +import org.apache.spark.storage._ +import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat +import org.apache.spark.util._ + +/** + * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of + * stages for each job, keeps track of which RDDs and stage outputs are materialized, and finds a + * minimal schedule to run the job. It then submits stages as TaskSets to an underlying + * TaskScheduler implementation that runs them on the cluster. A TaskSet contains fully independent + * tasks that can run right away based on the data that's already on the cluster (e.g. map output + * files from previous stages), though it may fail if this data becomes unavailable. + * + * Spark stages are created by breaking the RDD graph at shuffle boundaries. RDD operations with + * "narrow" dependencies, like map() and filter(), are pipelined together into one set of tasks + * in each stage, but operations with shuffle dependencies require multiple stages (one to write a + * set of map output files, and another to read those files after a barrier). In the end, every + * stage will have only shuffle dependencies on other stages, and may compute multiple operations + * inside it. The actual pipelining of these operations happens in the RDD.compute() functions of + * various RDDs + * + * In addition to coming up with a DAG of stages, the DAGScheduler also determines the preferred + * locations to run each task on, based on the current cache status, and passes these to the + * low-level TaskScheduler. Furthermore, it handles failures due to shuffle output files being + * lost, in which case old stages may need to be resubmitted. Failures *within* a stage that are + * not caused by shuffle file loss are handled by the TaskScheduler, which will retry each task + * a small number of times before cancelling the whole stage. + * + * When looking through this code, there are several key concepts: + * + * - Jobs (represented by [[ActiveJob]]) are the top-level work items submitted to the scheduler. + * For example, when the user calls an action, like count(), a job will be submitted through + * submitJob. Each Job may require the execution of multiple stages to build intermediate data. + * + * - Stages ([[Stage]]) are sets of tasks that compute intermediate results in jobs, where each + * task computes the same function on partitions of the same RDD. Stages are separated at shuffle + * boundaries, which introduce a barrier (where we must wait for the previous stage to finish to + * fetch outputs). There are two types of stages: [[ResultStage]], for the final stage that + * executes an action, and [[ShuffleMapStage]], which writes map output files for a shuffle. + * Stages are often shared across multiple jobs, if these jobs reuse the same RDDs. + * + * - Tasks are individual units of work, each sent to one machine. + * + * - Cache tracking: the DAGScheduler figures out which RDDs are cached to avoid recomputing them + * and likewise remembers which shuffle map stages have already produced output files to avoid + * redoing the map side of a shuffle. + * + * - Preferred locations: the DAGScheduler also computes where to run each task in a stage based + * on the preferred locations of its underlying RDDs, or the location of cached or shuffle data. + * + * - Cleanup: all data structures are cleared when the running jobs that depend on them finish, + * to prevent memory leaks in a long-running application. + * + * To recover from failures, the same stage might need to run multiple times, which are called + * "attempts". If the TaskScheduler reports that a task failed because a map output file from a + * previous stage was lost, the DAGScheduler resubmits that lost stage. This is detected through a + * CompletionEvent with FetchFailed, or an ExecutorLost event. The DAGScheduler will wait a small + * amount of time to see whether other nodes or tasks fail, then resubmit TaskSets for any lost + * stage(s) that compute the missing tasks. As part of this process, we might also have to create + * Stage objects for old (finished) stages where we previously cleaned up the Stage object. Since + * tasks from the old attempt of a stage could still be running, care must be taken to map any + * events received in the correct Stage object. + * + * Here's a checklist to use when making or reviewing changes to this class: + * + * - All data structures should be cleared when the jobs involving them end to avoid indefinite + * accumulation of state in long-running programs. + * + * - When adding a new data structure, update `DAGSchedulerSuite.assertDataStructuresEmpty` to + * include the new structure. This will help to catch memory leaks. + */ +private[spark] class DAGScheduler( + private[scheduler] val sc: SparkContext, + private[scheduler] val taskScheduler: TaskScheduler, + listenerBus: LiveListenerBus, + mapOutputTracker: MapOutputTrackerMaster, + blockManagerMaster: BlockManagerMaster, + env: SparkEnv, + clock: Clock = new SystemClock()) + extends Logging { + + def this(sc: SparkContext, taskScheduler: TaskScheduler) = { + this( + sc, + taskScheduler, + sc.listenerBus, + sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], + sc.env.blockManager.master, + sc.env) + } + + def this(sc: SparkContext) = this(sc, sc.taskScheduler) + + private[spark] val metricsSource: DAGSchedulerSource = new DAGSchedulerSource(this) + + private[scheduler] val nextJobId = new AtomicInteger(0) + private[scheduler] def numTotalJobs: Int = nextJobId.get() + private val nextStageId = new AtomicInteger(0) + + private[scheduler] val jobIdToStageIds = new HashMap[Int, HashSet[Int]] + private[scheduler] val stageIdToStage = new HashMap[Int, Stage] + /** + * Mapping from shuffle dependency ID to the ShuffleMapStage that will generate the data for + * that dependency. Only includes stages that are part of currently running job (when the job(s) + * that require the shuffle stage complete, the mapping will be removed, and the only record of + * the shuffle data will be in the MapOutputTracker). + */ + private[scheduler] val shuffleIdToMapStage = new HashMap[Int, ShuffleMapStage] + private[scheduler] val jobIdToActiveJob = new HashMap[Int, ActiveJob] + + // Stages we need to run whose parents aren't done + private[scheduler] val waitingStages = new HashSet[Stage] + + // Stages we are running right now + private[scheduler] val runningStages = new HashSet[Stage] + + // Stages that must be resubmitted due to fetch failures + private[scheduler] val failedStages = new HashSet[Stage] + + private[scheduler] val activeJobs = new HashSet[ActiveJob] + + /** + * Contains the locations that each RDD's partitions are cached on. This map's keys are RDD ids + * and its values are arrays indexed by partition numbers. Each array value is the set of + * locations where that RDD partition is cached. + * + * All accesses to this map should be guarded by synchronizing on it (see SPARK-4454). + */ + private val cacheLocs = new HashMap[Int, IndexedSeq[Seq[TaskLocation]]] + + /** + * Tracks the latest epoch of a fully processed error related to the given executor. (We use + * the MapOutputTracker's epoch number, which is sent with every task.) + * + * When an executor fails, it can affect the results of many tasks, and we have to deal with + * all of them consistently. We don't simply ignore all future results from that executor, + * as the failures may have been transient; but we also don't want to "overreact" to follow- + * on errors we receive. Furthermore, we might receive notification of a task success, after + * we find out the executor has actually failed; we'll assume those successes are, in fact, + * simply delayed notifications and the results have been lost, if the tasks started in the + * same or an earlier epoch. In particular, we use this to control when we tell the + * BlockManagerMaster that the BlockManager has been lost. + */ + private val executorFailureEpoch = new HashMap[String, Long] + + /** + * Tracks the latest epoch of a fully processed error where shuffle files have been lost from + * the given executor. + * + * This is closely related to executorFailureEpoch. They only differ for the executor when + * there is an external shuffle service serving shuffle files and we haven't been notified that + * the entire worker has been lost. In that case, when an executor is lost, we do not update + * the shuffleFileLostEpoch; we wait for a fetch failure. This way, if only the executor + * fails, we do not unregister the shuffle data as it can still be served; but if there is + * a failure in the shuffle service (resulting in fetch failure), we unregister the shuffle + * data only once, even if we get many fetch failures. + */ + private val shuffleFileLostEpoch = new HashMap[String, Long] + + private [scheduler] val outputCommitCoordinator = env.outputCommitCoordinator + + // A closure serializer that we reuse. + // This is only safe because DAGScheduler runs in a single thread. + private val closureSerializer = SparkEnv.get.closureSerializer.newInstance() + + /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */ + private val disallowStageRetryForTest = sc.getConf.get(TEST_NO_STAGE_RETRY) + + private val shouldMergeResourceProfiles = sc.getConf.get(config.RESOURCE_PROFILE_MERGE_CONFLICTS) + + /** + * Whether to unregister all the outputs on the host in condition that we receive a FetchFailure, + * this is set default to false, which means, we only unregister the outputs related to the exact + * executor(instead of the host) on a FetchFailure. + */ + private[scheduler] val unRegisterOutputOnHostOnFetchFailure = + sc.getConf.get(config.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE) + + /** + * Number of consecutive stage attempts allowed before a stage is aborted. + */ + private[scheduler] val maxConsecutiveStageAttempts = + sc.getConf.getInt("spark.stage.maxConsecutiveAttempts", + DAGScheduler.DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS) + + /** + * Number of max concurrent tasks check failures for each barrier job. + */ + private[scheduler] val barrierJobIdToNumTasksCheckFailures = new ConcurrentHashMap[Int, Int] + + /** + * Time in seconds to wait between a max concurrent tasks check failure and the next check. + */ + private val timeIntervalNumTasksCheck = sc.getConf + .get(config.BARRIER_MAX_CONCURRENT_TASKS_CHECK_INTERVAL) + + /** + * Max number of max concurrent tasks check failures allowed for a job before fail the job + * submission. + */ + private val maxFailureNumTasksCheck = sc.getConf + .get(config.BARRIER_MAX_CONCURRENT_TASKS_CHECK_MAX_FAILURES) + + private val messageScheduler = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("dag-scheduler-message") + + private[spark] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) + taskScheduler.setDAGScheduler(this) + + private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(sc.getConf, isDriver = true) + + private val blockManagerMasterDriverHeartbeatTimeout = + sc.getConf.get(config.STORAGE_BLOCKMANAGER_MASTER_DRIVER_HEARTBEAT_TIMEOUT).millis + + private val shuffleMergeResultsTimeoutSec = + sc.getConf.get(config.PUSH_BASED_SHUFFLE_MERGE_RESULTS_TIMEOUT) + + private val shuffleMergeFinalizeWaitSec = + sc.getConf.get(config.PUSH_BASED_SHUFFLE_MERGE_FINALIZE_TIMEOUT) + + private val shuffleMergeWaitMinSizeThreshold = + sc.getConf.get(config.PUSH_BASED_SHUFFLE_SIZE_MIN_SHUFFLE_SIZE_TO_WAIT) + + private val shufflePushMinRatio = sc.getConf.get(config.PUSH_BASED_SHUFFLE_MIN_PUSH_RATIO) + + private val shuffleMergeFinalizeNumThreads = + sc.getConf.get(config.PUSH_BASED_SHUFFLE_MERGE_FINALIZE_THREADS) + + // Since SparkEnv gets initialized after DAGScheduler, externalShuffleClient needs to be + // initialized lazily + private lazy val externalShuffleClient: Option[BlockStoreClient] = + if (pushBasedShuffleEnabled) { + Some(env.blockManager.blockStoreClient) + } else { + None + } + + // Use multi-threaded scheduled executor. The merge finalization task could take some time, + // depending on the time to establish connections to mergers, and the time to get MergeStatuses + // from all the mergers. + private val shuffleMergeFinalizeScheduler = + ThreadUtils.newDaemonThreadPoolScheduledExecutor("shuffle-merge-finalizer", + shuffleMergeFinalizeNumThreads) + + /** + * Called by the TaskSetManager to report task's starting. + */ + def taskStarted(task: Task[_], taskInfo: TaskInfo): Unit = { + eventProcessLoop.post(BeginEvent(task, taskInfo)) + } + + /** + * Called by the TaskSetManager to report that a task has completed + * and results are being fetched remotely. + */ + def taskGettingResult(taskInfo: TaskInfo): Unit = { + eventProcessLoop.post(GettingResultEvent(taskInfo)) + } + + /** + * Called by the TaskSetManager to report task completions or failures. + */ + def taskEnded( + task: Task[_], + reason: TaskEndReason, + result: Any, + accumUpdates: Seq[AccumulatorV2[_, _]], + metricPeaks: Array[Long], + taskInfo: TaskInfo): Unit = { + eventProcessLoop.post( + CompletionEvent(task, reason, result, accumUpdates, metricPeaks, taskInfo)) + } + + /** + * Update metrics for in-progress tasks and let the master know that the BlockManager is still + * alive. Return true if the driver knows about the given block manager. Otherwise, return false, + * indicating that the block manager should re-register. + */ + def executorHeartbeatReceived( + execId: String, + // (taskId, stageId, stageAttemptId, accumUpdates) + accumUpdates: Array[(Long, Int, Int, Seq[AccumulableInfo])], + blockManagerId: BlockManagerId, + // (stageId, stageAttemptId) -> metrics + executorUpdates: mutable.Map[(Int, Int), ExecutorMetrics]): Boolean = { + listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, accumUpdates, + executorUpdates)) + blockManagerMaster.driverHeartbeatEndPoint.askSync[Boolean]( + BlockManagerHeartbeat(blockManagerId), + new RpcTimeout(blockManagerMasterDriverHeartbeatTimeout, "BlockManagerHeartbeat")) + } + + /** + * Called by TaskScheduler implementation when an executor fails. + */ + def executorLost(execId: String, reason: ExecutorLossReason): Unit = { + eventProcessLoop.post(ExecutorLost(execId, reason)) + } + + /** + * Called by TaskScheduler implementation when a worker is removed. + */ + def workerRemoved(workerId: String, host: String, message: String): Unit = { + eventProcessLoop.post(WorkerRemoved(workerId, host, message)) + } + + /** + * Called by TaskScheduler implementation when a host is added. + */ + def executorAdded(execId: String, host: String): Unit = { + eventProcessLoop.post(ExecutorAdded(execId, host)) + } + + /** + * Called by the TaskSetManager to cancel an entire TaskSet due to either repeated failures or + * cancellation of the job itself. + */ + def taskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable]): Unit = { + eventProcessLoop.post(TaskSetFailed(taskSet, reason, exception)) + } + + /** + * Called by the TaskSetManager when it decides a speculative task is needed. + */ + def speculativeTaskSubmitted(task: Task[_]): Unit = { + eventProcessLoop.post(SpeculativeTaskSubmitted(task)) + } + + /** + * Called by the TaskSetManager when a taskset becomes unschedulable due to executors being + * excluded because of too many task failures and dynamic allocation is enabled. + */ + def unschedulableTaskSetAdded( + stageId: Int, + stageAttemptId: Int): Unit = { + eventProcessLoop.post(UnschedulableTaskSetAdded(stageId, stageAttemptId)) + } + + /** + * Called by the TaskSetManager when an unschedulable taskset becomes schedulable and dynamic + * allocation is enabled. + */ + def unschedulableTaskSetRemoved( + stageId: Int, + stageAttemptId: Int): Unit = { + eventProcessLoop.post(UnschedulableTaskSetRemoved(stageId, stageAttemptId)) + } + + private[scheduler] + def getCacheLocs(rdd: RDD[_]): IndexedSeq[Seq[TaskLocation]] = cacheLocs.synchronized { + // Note: this doesn't use `getOrElse()` because this method is called O(num tasks) times + if (!cacheLocs.contains(rdd.id)) { + // Note: if the storage level is NONE, we don't need to get locations from block manager. + val locs: IndexedSeq[Seq[TaskLocation]] = if (rdd.getStorageLevel == StorageLevel.NONE) { + IndexedSeq.fill(rdd.partitions.length)(Nil) + } else { + val blockIds = + rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId] + blockManagerMaster.getLocations(blockIds).map { bms => + bms.map(bm => TaskLocation(bm.host, bm.executorId)) + } + } + cacheLocs(rdd.id) = locs + } + cacheLocs(rdd.id) + } + + private def clearCacheLocs(): Unit = cacheLocs.synchronized { + cacheLocs.clear() + } + + /** + * Gets a shuffle map stage if one exists in shuffleIdToMapStage. Otherwise, if the + * shuffle map stage doesn't already exist, this method will create the shuffle map stage in + * addition to any missing ancestor shuffle map stages. + */ + private def getOrCreateShuffleMapStage( + shuffleDep: ShuffleDependency[_, _, _], + firstJobId: Int): ShuffleMapStage = { + shuffleIdToMapStage.get(shuffleDep.shuffleId) match { + case Some(stage) => + stage + + case None => + // Create stages for all missing ancestor shuffle dependencies. + getMissingAncestorShuffleDependencies(shuffleDep.rdd).foreach { dep => + // Even though getMissingAncestorShuffleDependencies only returns shuffle dependencies + // that were not already in shuffleIdToMapStage, it's possible that by the time we + // get to a particular dependency in the foreach loop, it's been added to + // shuffleIdToMapStage by the stage creation process for an earlier dependency. See + // SPARK-13902 for more information. + if (!shuffleIdToMapStage.contains(dep.shuffleId)) { + createShuffleMapStage(dep, firstJobId) + } + } + // Finally, create a stage for the given shuffle dependency. + createShuffleMapStage(shuffleDep, firstJobId) + } + } + + /** + * Check to make sure we don't launch a barrier stage with unsupported RDD chain pattern. The + * following patterns are not supported: + * 1. Ancestor RDDs that have different number of partitions from the resulting RDD (e.g. + * union()/coalesce()/first()/take()/PartitionPruningRDD); + * 2. An RDD that depends on multiple barrier RDDs (e.g. barrierRdd1.zip(barrierRdd2)). + */ + private def checkBarrierStageWithRDDChainPattern(rdd: RDD[_], numTasksInStage: Int): Unit = { + if (rdd.isBarrier() && + !traverseParentRDDsWithinStage(rdd, (r: RDD[_]) => + r.getNumPartitions == numTasksInStage && + r.dependencies.count(_.rdd.isBarrier()) <= 1)) { + throw SparkCoreErrors.barrierStageWithRDDChainPatternError() + } + } + + /** + * Creates a ShuffleMapStage that generates the given shuffle dependency's partitions. If a + * previously run stage generated the same shuffle data, this function will copy the output + * locations that are still available from the previous shuffle to avoid unnecessarily + * regenerating data. + */ + def createShuffleMapStage[K, V, C]( + shuffleDep: ShuffleDependency[K, V, C], jobId: Int): ShuffleMapStage = { + val rdd = shuffleDep.rdd + val (shuffleDeps, resourceProfiles) = getShuffleDependenciesAndResourceProfiles(rdd) + val resourceProfile = mergeResourceProfilesForStage(resourceProfiles) + checkBarrierStageWithDynamicAllocation(rdd) + checkBarrierStageWithNumSlots(rdd, resourceProfile) + checkBarrierStageWithRDDChainPattern(rdd, rdd.getNumPartitions) + val numTasks = rdd.partitions.length + val parents = getOrCreateParentStages(shuffleDeps, jobId) + val id = nextStageId.getAndIncrement() + val stage = new ShuffleMapStage( + id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep, mapOutputTracker, + resourceProfile.id) + + stageIdToStage(id) = stage + shuffleIdToMapStage(shuffleDep.shuffleId) = stage + updateJobIdStageIdMaps(jobId, stage) + + if (!mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { + // Kind of ugly: need to register RDDs with the cache and map output tracker here + // since we can't do it in the RDD constructor because # of partitions is unknown + logInfo(s"Registering RDD ${rdd.id} (${rdd.getCreationSite}) as input to " + + s"shuffle ${shuffleDep.shuffleId}") + mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.length, + shuffleDep.partitioner.numPartitions) + } + stage + } + + /** + * We don't support run a barrier stage with dynamic resource allocation enabled, it shall lead + * to some confusing behaviors (e.g. with dynamic resource allocation enabled, it may happen that + * we acquire some executors (but not enough to launch all the tasks in a barrier stage) and + * later release them due to executor idle time expire, and then acquire again). + * + * We perform the check on job submit and fail fast if running a barrier stage with dynamic + * resource allocation enabled. + * + * TODO SPARK-24942 Improve cluster resource management with jobs containing barrier stage + */ + private def checkBarrierStageWithDynamicAllocation(rdd: RDD[_]): Unit = { + if (rdd.isBarrier() && Utils.isDynamicAllocationEnabled(sc.getConf)) { + throw SparkCoreErrors.barrierStageWithDynamicAllocationError() + } + } + + /** + * Check whether the barrier stage requires more slots (to be able to launch all tasks in the + * barrier stage together) than the total number of active slots currently. Fail current check + * if trying to submit a barrier stage that requires more slots than current total number. If + * the check fails consecutively beyond a configured number for a job, then fail current job + * submission. + */ + private def checkBarrierStageWithNumSlots(rdd: RDD[_], rp: ResourceProfile): Unit = { + if (rdd.isBarrier()) { + val numPartitions = rdd.getNumPartitions + val maxNumConcurrentTasks = sc.maxNumConcurrentTasks(rp) + if (numPartitions > maxNumConcurrentTasks) { + throw SparkCoreErrors.numPartitionsGreaterThanMaxNumConcurrentTasksError(numPartitions, + maxNumConcurrentTasks) + } + } + } + + private[scheduler] def mergeResourceProfilesForStage( + stageResourceProfiles: HashSet[ResourceProfile]): ResourceProfile = { + logDebug(s"Merging stage rdd profiles: $stageResourceProfiles") + val resourceProfile = if (stageResourceProfiles.size > 1) { + if (shouldMergeResourceProfiles) { + val startResourceProfile = stageResourceProfiles.head + val mergedProfile = stageResourceProfiles.drop(1) + .foldLeft(startResourceProfile)((a, b) => mergeResourceProfiles(a, b)) + // compared merged profile with existing ones so we don't add it over and over again + // if the user runs the same operation multiple times + val resProfile = sc.resourceProfileManager.getEquivalentProfile(mergedProfile) + resProfile match { + case Some(existingRp) => existingRp + case None => + // this ResourceProfile could be different if it was merged so we have to add it to + // our ResourceProfileManager + sc.resourceProfileManager.addResourceProfile(mergedProfile) + mergedProfile + } + } else { + throw new IllegalArgumentException("Multiple ResourceProfiles specified in the RDDs for " + + "this stage, either resolve the conflicting ResourceProfiles yourself or enable " + + s"${config.RESOURCE_PROFILE_MERGE_CONFLICTS.key} and understand how Spark handles " + + "the merging them.") + } + } else { + if (stageResourceProfiles.size == 1) { + stageResourceProfiles.head + } else { + sc.resourceProfileManager.defaultResourceProfile + } + } + resourceProfile + } + + // This is a basic function to merge resource profiles that takes the max + // value of the profiles. We may want to make this more complex in the future as + // you may want to sum some resources (like memory). + private[scheduler] def mergeResourceProfiles( + r1: ResourceProfile, + r2: ResourceProfile): ResourceProfile = { + val mergedExecKeys = r1.executorResources ++ r2.executorResources + val mergedExecReq = mergedExecKeys.map { case (k, v) => + val larger = r1.executorResources.get(k).map( x => + if (x.amount > v.amount) x else v).getOrElse(v) + k -> larger + } + val mergedTaskKeys = r1.taskResources ++ r2.taskResources + val mergedTaskReq = mergedTaskKeys.map { case (k, v) => + val larger = r1.taskResources.get(k).map( x => + if (x.amount > v.amount) x else v).getOrElse(v) + k -> larger + } + new ResourceProfile(mergedExecReq, mergedTaskReq) + } + + /** + * Create a ResultStage associated with the provided jobId. + */ + private def createResultStage( + rdd: RDD[_], + func: (TaskContext, Iterator[_]) => _, + partitions: Array[Int], + jobId: Int, + callSite: CallSite): ResultStage = { + val (shuffleDeps, resourceProfiles) = getShuffleDependenciesAndResourceProfiles(rdd) + val resourceProfile = mergeResourceProfilesForStage(resourceProfiles) + checkBarrierStageWithDynamicAllocation(rdd) + checkBarrierStageWithNumSlots(rdd, resourceProfile) + checkBarrierStageWithRDDChainPattern(rdd, partitions.toSet.size) + val parents = getOrCreateParentStages(shuffleDeps, jobId) + val id = nextStageId.getAndIncrement() + val stage = new ResultStage(id, rdd, func, partitions, parents, jobId, + callSite, resourceProfile.id) + stageIdToStage(id) = stage + updateJobIdStageIdMaps(jobId, stage) + stage + } + + /** + * Get or create the list of parent stages for the given shuffle dependencies. The new + * Stages will be created with the provided firstJobId. + */ + private def getOrCreateParentStages(shuffleDeps: HashSet[ShuffleDependency[_, _, _]], + firstJobId: Int): List[Stage] = { + shuffleDeps.map { shuffleDep => + getOrCreateShuffleMapStage(shuffleDep, firstJobId) + }.toList + } + + /** Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet */ + private def getMissingAncestorShuffleDependencies( + rdd: RDD[_]): ListBuffer[ShuffleDependency[_, _, _]] = { + val ancestors = new ListBuffer[ShuffleDependency[_, _, _]] + val visited = new HashSet[RDD[_]] + // We are manually maintaining a stack here to prevent StackOverflowError + // caused by recursively visiting + val waitingForVisit = new ListBuffer[RDD[_]] + waitingForVisit += rdd + while (waitingForVisit.nonEmpty) { + val toVisit = waitingForVisit.remove(0) + if (!visited(toVisit)) { + visited += toVisit + val (shuffleDeps, _) = getShuffleDependenciesAndResourceProfiles(toVisit) + shuffleDeps.foreach { shuffleDep => + if (!shuffleIdToMapStage.contains(shuffleDep.shuffleId)) { + ancestors.prepend(shuffleDep) + waitingForVisit.prepend(shuffleDep.rdd) + } // Otherwise, the dependency and its ancestors have already been registered. + } + } + } + ancestors + } + + /** + * Returns shuffle dependencies that are immediate parents of the given RDD and the + * ResourceProfiles associated with the RDDs for this stage. + * + * This function will not return more distant ancestors for shuffle dependencies. For example, + * if C has a shuffle dependency on B which has a shuffle dependency on A: + * + * A <-- B <-- C + * + * calling this function with rdd C will only return the B <-- C dependency. + * + * This function is scheduler-visible for the purpose of unit testing. + */ + private[scheduler] def getShuffleDependenciesAndResourceProfiles( + rdd: RDD[_]): (HashSet[ShuffleDependency[_, _, _]], HashSet[ResourceProfile]) = { + val parents = new HashSet[ShuffleDependency[_, _, _]] + val resourceProfiles = new HashSet[ResourceProfile] + val visited = new HashSet[RDD[_]] + val waitingForVisit = new ListBuffer[RDD[_]] + waitingForVisit += rdd + while (waitingForVisit.nonEmpty) { + val toVisit = waitingForVisit.remove(0) + if (!visited(toVisit)) { + visited += toVisit + Option(toVisit.getResourceProfile).foreach(resourceProfiles += _) + toVisit.dependencies.foreach { + case shuffleDep: ShuffleDependency[_, _, _] => + parents += shuffleDep + case dependency => + waitingForVisit.prepend(dependency.rdd) + } + } + } + (parents, resourceProfiles) + } + + /** + * Traverses the given RDD and its ancestors within the same stage and checks whether all of the + * RDDs satisfy a given predicate. + */ + private def traverseParentRDDsWithinStage(rdd: RDD[_], predicate: RDD[_] => Boolean): Boolean = { + val visited = new HashSet[RDD[_]] + val waitingForVisit = new ListBuffer[RDD[_]] + waitingForVisit += rdd + while (waitingForVisit.nonEmpty) { + val toVisit = waitingForVisit.remove(0) + if (!visited(toVisit)) { + if (!predicate(toVisit)) { + return false + } + visited += toVisit + toVisit.dependencies.foreach { + case _: ShuffleDependency[_, _, _] => + // Not within the same stage with current rdd, do nothing. + case dependency => + waitingForVisit.prepend(dependency.rdd) + } + } + } + true + } + + private def getMissingParentStages(stage: Stage): List[Stage] = { + val missing = new HashSet[Stage] + val visited = new HashSet[RDD[_]] + // We are manually maintaining a stack here to prevent StackOverflowError + // caused by recursively visiting + val waitingForVisit = new ListBuffer[RDD[_]] + waitingForVisit += stage.rdd + def visit(rdd: RDD[_]): Unit = { + if (!visited(rdd)) { + visited += rdd + val rddHasUncachedPartitions = getCacheLocs(rdd).contains(Nil) + if (rddHasUncachedPartitions) { + for (dep <- rdd.dependencies) { + dep match { + case shufDep: ShuffleDependency[_, _, _] => + val mapStage = getOrCreateShuffleMapStage(shufDep, stage.firstJobId) + // Mark mapStage as available with shuffle outputs only after shuffle merge is + // finalized with push based shuffle. If not, subsequent ShuffleMapStage won't + // read from merged output as the MergeStatuses are not available. + if (!mapStage.isAvailable || !mapStage.shuffleDep.shuffleMergeFinalized) { + missing += mapStage + } else { + // Forward the nextAttemptId if skipped and get visited for the first time. + // Otherwise, once it gets retried, + // 1) the stuffs in stage info become distorting, e.g. task num, input byte, e.t.c + // 2) the first attempt starts from 0-idx, it will not be marked as a retry + mapStage.increaseAttemptIdOnFirstSkip() + } + case narrowDep: NarrowDependency[_] => + waitingForVisit.prepend(narrowDep.rdd) + } + } + } + } + } + while (waitingForVisit.nonEmpty) { + visit(waitingForVisit.remove(0)) + } + missing.toList + } + + /** Invoke `.partitions` on the given RDD and all of its ancestors */ + private def eagerlyComputePartitionsForRddAndAncestors(rdd: RDD[_]): Unit = { + val startTime = System.nanoTime + val visitedRdds = new HashSet[RDD[_]] + // We are manually maintaining a stack here to prevent StackOverflowError + // caused by recursively visiting + val waitingForVisit = new ListBuffer[RDD[_]] + waitingForVisit += rdd + + def visit(rdd: RDD[_]): Unit = { + if (!visitedRdds(rdd)) { + visitedRdds += rdd + + // Eagerly compute: + rdd.partitions + + for (dep <- rdd.dependencies) { + waitingForVisit.prepend(dep.rdd) + } + } + } + + while (waitingForVisit.nonEmpty) { + visit(waitingForVisit.remove(0)) + } + logDebug("eagerlyComputePartitionsForRddAndAncestors for RDD %d took %f seconds" + .format(rdd.id, (System.nanoTime - startTime) / 1e9)) + } + + /** + * Registers the given jobId among the jobs that need the given stage and + * all of that stage's ancestors. + */ + private def updateJobIdStageIdMaps(jobId: Int, stage: Stage): Unit = { + @tailrec + def updateJobIdStageIdMapsList(stages: List[Stage]): Unit = { + if (stages.nonEmpty) { + val s = stages.head + s.jobIds += jobId + jobIdToStageIds.getOrElseUpdate(jobId, new HashSet[Int]()) += s.id + val parentsWithoutThisJobId = s.parents.filter { ! _.jobIds.contains(jobId) } + updateJobIdStageIdMapsList(parentsWithoutThisJobId ++ stages.tail) + } + } + updateJobIdStageIdMapsList(List(stage)) + } + + /** + * Removes state for job and any stages that are not needed by any other job. Does not + * handle cancelling tasks or notifying the SparkListener about finished jobs/stages/tasks. + * + * @param job The job whose state to cleanup. + */ + private def cleanupStateForJobAndIndependentStages(job: ActiveJob): Unit = { + val registeredStages = jobIdToStageIds.get(job.jobId) + if (registeredStages.isEmpty || registeredStages.get.isEmpty) { + logError("No stages registered for job " + job.jobId) + } else { + stageIdToStage.filterKeys(stageId => registeredStages.get.contains(stageId)).foreach { + case (stageId, stage) => + val jobSet = stage.jobIds + if (!jobSet.contains(job.jobId)) { + logError( + "Job %d not registered for stage %d even though that stage was registered for the job" + .format(job.jobId, stageId)) + } else { + def removeStage(stageId: Int): Unit = { + // data structures based on Stage + for (stage <- stageIdToStage.get(stageId)) { + if (runningStages.contains(stage)) { + logDebug("Removing running stage %d".format(stageId)) + runningStages -= stage + } + for ((k, v) <- shuffleIdToMapStage.find(_._2 == stage)) { + shuffleIdToMapStage.remove(k) + } + if (waitingStages.contains(stage)) { + logDebug("Removing stage %d from waiting set.".format(stageId)) + waitingStages -= stage + } + if (failedStages.contains(stage)) { + logDebug("Removing stage %d from failed set.".format(stageId)) + failedStages -= stage + } + } + // data structures based on StageId + stageIdToStage -= stageId + logDebug("After removal of stage %d, remaining stages = %d" + .format(stageId, stageIdToStage.size)) + } + + jobSet -= job.jobId + if (jobSet.isEmpty) { // no other job needs this stage + removeStage(stageId) + } + } + } + } + jobIdToStageIds -= job.jobId + jobIdToActiveJob -= job.jobId + activeJobs -= job + job.finalStage match { + case r: ResultStage => r.removeActiveJob() + case m: ShuffleMapStage => m.removeActiveJob(job) + } + } + + private def eagerPartitions(rdd: RDD[_], visited: mutable.HashSet[RDD[_]] = new mutable.HashSet[RDD[_]]): Unit = try { + rdd.partitions + rdd.dependencies.foreach { + dependencies => { + if (visited.contains(dependencies.rdd)) { + return + } + visited.add(dependencies.rdd) + eagerPartitions(dependencies.rdd, visited) + } + } + } catch { + case t: Throwable => logError("Error is eagerPartitions, ignoring", t) + } + + + /** + * Submit an action job to the scheduler. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param partitions set of partitions to run on; some jobs may not want to compute on all + * partitions of the target RDD, e.g. for operations like first() + * @param callSite where in the user program this job was called + * @param resultHandler callback to pass each result to + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + * + * @return a JobWaiter object that can be used to block until the job finishes executing + * or can be used to cancel the job. + * + * @throws IllegalArgumentException when partitions ids are illegal + */ + def submitJob[T, U]( + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + partitions: Seq[Int], + callSite: CallSite, + resultHandler: (Int, U) => Unit, + properties: Properties): JobWaiter[U] = { + // Check to make sure we are not launching a task on a partition that does not exist. + val maxPartitions = rdd.partitions.length + partitions.find(p => p >= maxPartitions || p < 0).foreach { p => + throw new IllegalArgumentException( + "Attempting to access a non-existent partition: " + p + ". " + + "Total number of partitions: " + maxPartitions) + } + + // SPARK-23626: `RDD.getPartitions()` can be slow, so we eagerly compute + // `.partitions` on every RDD in the DAG to ensure that `getPartitions()` + // is evaluated outside of the DAGScheduler's single-threaded event loop: + eagerlyComputePartitionsForRddAndAncestors(rdd) + + val jobId = nextJobId.getAndIncrement() + if (partitions.isEmpty) { + val clonedProperties = Utils.cloneProperties(properties) + if (sc.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) == null) { + clonedProperties.setProperty(SparkContext.SPARK_JOB_DESCRIPTION, callSite.shortForm) + } + val time = clock.getTimeMillis() + listenerBus.post( + SparkListenerJobStart(jobId, time, Seq.empty, clonedProperties)) + listenerBus.post( + SparkListenerJobEnd(jobId, time, JobSucceeded)) + // Return immediately if the job is running 0 tasks + return new JobWaiter[U](this, jobId, 0, resultHandler) + } + + assert(partitions.nonEmpty) + val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] + val waiter = new JobWaiter[U](this, jobId, partitions.size, resultHandler) + eagerPartitions(rdd) + eventProcessLoop.post(JobSubmitted( + jobId, rdd, func2, partitions.toArray, callSite, waiter, + Utils.cloneProperties(properties))) + waiter + } + + /** + * Run an action job on the given RDD and pass all the results to the resultHandler function as + * they arrive. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param partitions set of partitions to run on; some jobs may not want to compute on all + * partitions of the target RDD, e.g. for operations like first() + * @param callSite where in the user program this job was called + * @param resultHandler callback to pass each result to + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + * + * @note Throws `Exception` when the job fails + */ + def runJob[T, U]( + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + partitions: Seq[Int], + callSite: CallSite, + resultHandler: (Int, U) => Unit, + properties: Properties): Unit = { + val start = System.nanoTime + val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties) + ThreadUtils.awaitReady(waiter.completionFuture, Duration.Inf) + waiter.completionFuture.value.get match { + case scala.util.Success(_) => + logInfo("Job %d finished: %s, took %f s".format + (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) + case scala.util.Failure(exception) => + logInfo("Job %d failed: %s, took %f s".format + (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) + // SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler. + val callerStackTrace = Thread.currentThread().getStackTrace.tail + exception.setStackTrace(exception.getStackTrace ++ callerStackTrace) + throw exception + } + } + + /** + * Run an approximate job on the given RDD and pass all the results to an ApproximateEvaluator + * as they arrive. Returns a partial result object from the evaluator. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param evaluator `ApproximateEvaluator` to receive the partial results + * @param callSite where in the user program this job was called + * @param timeout maximum time to wait for the job, in milliseconds + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + */ + def runApproximateJob[T, U, R]( + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + evaluator: ApproximateEvaluator[U, R], + callSite: CallSite, + timeout: Long, + properties: Properties): PartialResult[R] = { + val jobId = nextJobId.getAndIncrement() + val clonedProperties = Utils.cloneProperties(properties) + if (rdd.partitions.isEmpty) { + // Return immediately if the job is running 0 tasks + val time = clock.getTimeMillis() + listenerBus.post(SparkListenerJobStart(jobId, time, Seq[StageInfo](), clonedProperties)) + listenerBus.post(SparkListenerJobEnd(jobId, time, JobSucceeded)) + return new PartialResult(evaluator.currentResult(), true) + } + + // SPARK-23626: `RDD.getPartitions()` can be slow, so we eagerly compute + // `.partitions` on every RDD in the DAG to ensure that `getPartitions()` + // is evaluated outside of the DAGScheduler's single-threaded event loop: + eagerlyComputePartitionsForRddAndAncestors(rdd) + + val listener = new ApproximateActionListener(rdd, func, evaluator, timeout) + val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] + eagerPartitions(rdd) + eventProcessLoop.post(JobSubmitted( + jobId, rdd, func2, rdd.partitions.indices.toArray, callSite, listener, + clonedProperties)) + listener.awaitResult() // Will throw an exception if the job fails + } + + /** + * Submit a shuffle map stage to run independently and get a JobWaiter object back. The waiter + * can be used to block until the job finishes executing or can be used to cancel the job. + * This method is used for adaptive query planning, to run map stages and look at statistics + * about their outputs before submitting downstream stages. + * + * @param dependency the ShuffleDependency to run a map stage for + * @param callback function called with the result of the job, which in this case will be a + * single MapOutputStatistics object showing how much data was produced for each partition + * @param callSite where in the user program this job was submitted + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + */ + def submitMapStage[K, V, C]( + dependency: ShuffleDependency[K, V, C], + callback: MapOutputStatistics => Unit, + callSite: CallSite, + properties: Properties): JobWaiter[MapOutputStatistics] = { + + val rdd = dependency.rdd + val jobId = nextJobId.getAndIncrement() + if (rdd.partitions.length == 0) { + throw SparkCoreErrors.cannotRunSubmitMapStageOnZeroPartitionRDDError() + } + + // SPARK-23626: `RDD.getPartitions()` can be slow, so we eagerly compute + // `.partitions` on every RDD in the DAG to ensure that `getPartitions()` + // is evaluated outside of the DAGScheduler's single-threaded event loop: + eagerlyComputePartitionsForRddAndAncestors(rdd) + + // We create a JobWaiter with only one "task", which will be marked as complete when the whole + // map stage has completed, and will be passed the MapOutputStatistics for that stage. + // This makes it easier to avoid race conditions between the user code and the map output + // tracker that might result if we told the user the stage had finished, but then they queries + // the map output tracker and some node failures had caused the output statistics to be lost. + val waiter = new JobWaiter[MapOutputStatistics]( + this, jobId, 1, + (_: Int, r: MapOutputStatistics) => callback(r)) + eventProcessLoop.post(MapStageSubmitted( + jobId, dependency, callSite, waiter, Utils.cloneProperties(properties))) + waiter + } + + /** + * Cancel a job that is running or waiting in the queue. + */ + def cancelJob(jobId: Int, reason: Option[String]): Unit = { + logInfo("Asked to cancel job " + jobId) + eventProcessLoop.post(JobCancelled(jobId, reason)) + } + + /** + * Cancel all jobs in the given job group ID. + */ + def cancelJobGroup(groupId: String): Unit = { + logInfo("Asked to cancel job group " + groupId) + eventProcessLoop.post(JobGroupCancelled(groupId)) + } + + /** + * Cancel all jobs that are running or waiting in the queue. + */ + def cancelAllJobs(): Unit = { + eventProcessLoop.post(AllJobsCancelled) + } + + private[scheduler] def doCancelAllJobs(): Unit = { + // Cancel all running jobs. + runningStages.map(_.firstJobId).foreach(handleJobCancellation(_, + Option("as part of cancellation of all jobs"))) + activeJobs.clear() // These should already be empty by this point, + jobIdToActiveJob.clear() // but just in case we lost track of some jobs... + } + + /** + * Cancel all jobs associated with a running or scheduled stage. + */ + def cancelStage(stageId: Int, reason: Option[String]): Unit = { + eventProcessLoop.post(StageCancelled(stageId, reason)) + } + + /** + * Receives notification about shuffle push for a given shuffle from one map + * task has completed + */ + def shufflePushCompleted(shuffleId: Int, shuffleMergeId: Int, mapIndex: Int): Unit = { + eventProcessLoop.post(ShufflePushCompleted(shuffleId, shuffleMergeId, mapIndex)) + } + + /** + * Kill a given task. It will be retried. + * + * @return Whether the task was successfully killed. + */ + def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean = { + taskScheduler.killTaskAttempt(taskId, interruptThread, reason) + } + + /** + * Resubmit any failed stages. Ordinarily called after a small amount of time has passed since + * the last fetch failure. + */ + private[scheduler] def resubmitFailedStages(): Unit = { + if (failedStages.nonEmpty) { + // Failed stages may be removed by job cancellation, so failed might be empty even if + // the ResubmitFailedStages event has been scheduled. + logInfo("Resubmitting failed stages") + clearCacheLocs() + val failedStagesCopy = failedStages.toArray + failedStages.clear() + for (stage <- failedStagesCopy.sortBy(_.firstJobId)) { + submitStage(stage) + } + } + } + + /** + * Check for waiting stages which are now eligible for resubmission. + * Submits stages that depend on the given parent stage. Called when the parent stage completes + * successfully. + */ + private def submitWaitingChildStages(parent: Stage): Unit = { + logTrace(s"Checking if any dependencies of $parent are now runnable") + logTrace("running: " + runningStages) + logTrace("waiting: " + waitingStages) + logTrace("failed: " + failedStages) + val childStages = waitingStages.filter(_.parents.contains(parent)).toArray + waitingStages --= childStages + for (stage <- childStages.sortBy(_.firstJobId)) { + submitStage(stage) + } + } + + /** Finds the earliest-created active job that needs the stage */ + // TODO: Probably should actually find among the active jobs that need this + // stage the one with the highest priority (highest-priority pool, earliest created). + // That should take care of at least part of the priority inversion problem with + // cross-job dependencies. + private def activeJobForStage(stage: Stage): Option[Int] = { + val jobsThatUseStage: Array[Int] = stage.jobIds.toArray.sorted + jobsThatUseStage.find(jobIdToActiveJob.contains) + } + + private[scheduler] def handleJobGroupCancelled(groupId: String): Unit = { + // Cancel all jobs belonging to this job group. + // First finds all active jobs with this group id, and then kill stages for them. + val activeInGroup = activeJobs.filter { activeJob => + Option(activeJob.properties).exists { + _.getProperty(SparkContext.SPARK_JOB_GROUP_ID) == groupId + } + } + val jobIds = activeInGroup.map(_.jobId) + jobIds.foreach(handleJobCancellation(_, + Option("part of cancelled job group %s".format(groupId)))) + } + + private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo): Unit = { + // Note that there is a chance that this task is launched after the stage is cancelled. + // In that case, we wouldn't have the stage anymore in stageIdToStage. + val stageAttemptId = + stageIdToStage.get(task.stageId).map(_.latestInfo.attemptNumber).getOrElse(-1) + listenerBus.post(SparkListenerTaskStart(task.stageId, stageAttemptId, taskInfo)) + } + + private[scheduler] def handleSpeculativeTaskSubmitted(task: Task[_]): Unit = { + listenerBus.post(SparkListenerSpeculativeTaskSubmitted(task.stageId, task.stageAttemptId)) + } + + private[scheduler] def handleUnschedulableTaskSetAdded( + stageId: Int, + stageAttemptId: Int): Unit = { + listenerBus.post(SparkListenerUnschedulableTaskSetAdded(stageId, stageAttemptId)) + } + + private[scheduler] def handleUnschedulableTaskSetRemoved( + stageId: Int, + stageAttemptId: Int): Unit = { + listenerBus.post(SparkListenerUnschedulableTaskSetRemoved(stageId, stageAttemptId)) + } + + private[scheduler] def handleTaskSetFailed( + taskSet: TaskSet, + reason: String, + exception: Option[Throwable]): Unit = { + stageIdToStage.get(taskSet.stageId).foreach { abortStage(_, reason, exception) } + } + + private[scheduler] def cleanUpAfterSchedulerStop(): Unit = { + for (job <- activeJobs) { + val error = + new SparkException(s"Job ${job.jobId} cancelled because SparkContext was shut down") + job.listener.jobFailed(error) + // Tell the listeners that all of the running stages have ended. Don't bother + // cancelling the stages because if the DAG scheduler is stopped, the entire application + // is in the process of getting stopped. + val stageFailedMessage = "Stage cancelled because SparkContext was shut down" + // The `toArray` here is necessary so that we don't iterate over `runningStages` while + // mutating it. + runningStages.toArray.foreach { stage => + markStageAsFinished(stage, Some(stageFailedMessage)) + } + listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobFailed(error))) + } + } + + private[scheduler] def handleGetTaskResult(taskInfo: TaskInfo): Unit = { + listenerBus.post(SparkListenerTaskGettingResult(taskInfo)) + } + + private[scheduler] def handleJobSubmitted(jobId: Int, + finalRDD: RDD[_], + func: (TaskContext, Iterator[_]) => _, + partitions: Array[Int], + callSite: CallSite, + listener: JobListener, + properties: Properties): Unit = { + var finalStage: ResultStage = null + try { + // New stage creation may throw an exception if, for example, jobs are run on a + // HadoopRDD whose underlying HDFS files have been deleted. + finalStage = createResultStage(finalRDD, func, partitions, jobId, callSite) + } catch { + case e: BarrierJobSlotsNumberCheckFailed => + // If jobId doesn't exist in the map, Scala coverts its value null to 0: Int automatically. + val numCheckFailures = barrierJobIdToNumTasksCheckFailures.compute(jobId, + (_: Int, value: Int) => value + 1) + + logWarning(s"Barrier stage in job $jobId requires ${e.requiredConcurrentTasks} slots, " + + s"but only ${e.maxConcurrentTasks} are available. " + + s"Will retry up to ${maxFailureNumTasksCheck - numCheckFailures + 1} more times") + + if (numCheckFailures <= maxFailureNumTasksCheck) { + messageScheduler.schedule( + new Runnable { + override def run(): Unit = eventProcessLoop.post(JobSubmitted(jobId, finalRDD, func, + partitions, callSite, listener, properties)) + }, + timeIntervalNumTasksCheck, + TimeUnit.SECONDS + ) + return + } else { + // Job failed, clear internal data. + barrierJobIdToNumTasksCheckFailures.remove(jobId) + listener.jobFailed(e) + return + } + + case e: Exception => + logWarning("Creating new stage failed due to exception - job: " + jobId, e) + listener.jobFailed(e) + return + } + // Job submitted, clear internal data. + barrierJobIdToNumTasksCheckFailures.remove(jobId) + + val job = new ActiveJob(jobId, finalStage, callSite, listener, properties) + clearCacheLocs() + logInfo("Got job %s (%s) with %d output partitions".format( + job.jobId, callSite.shortForm, partitions.length)) + logInfo("Final stage: " + finalStage + " (" + finalStage.name + ")") + logInfo("Parents of final stage: " + finalStage.parents) + logInfo("Missing parents: " + getMissingParentStages(finalStage)) + + val jobSubmissionTime = clock.getTimeMillis() + jobIdToActiveJob(jobId) = job + activeJobs += job + finalStage.setActiveJob(job) + val stageIds = jobIdToStageIds(jobId).toArray + val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) + listenerBus.post( + SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, + Utils.cloneProperties(properties))) + submitStage(finalStage) + } + + private[scheduler] def handleMapStageSubmitted(jobId: Int, + dependency: ShuffleDependency[_, _, _], + callSite: CallSite, + listener: JobListener, + properties: Properties): Unit = { + // Submitting this map stage might still require the creation of some parent stages, so make + // sure that happens. + var finalStage: ShuffleMapStage = null + try { + // New stage creation may throw an exception if, for example, jobs are run on a + // HadoopRDD whose underlying HDFS files have been deleted. + finalStage = getOrCreateShuffleMapStage(dependency, jobId) + } catch { + case e: Exception => + logWarning("Creating new stage failed due to exception - job: " + jobId, e) + listener.jobFailed(e) + return + } + + val job = new ActiveJob(jobId, finalStage, callSite, listener, properties) + clearCacheLocs() + logInfo("Got map stage job %s (%s) with %d output partitions".format( + jobId, callSite.shortForm, dependency.rdd.partitions.length)) + logInfo("Final stage: " + finalStage + " (" + finalStage.name + ")") + logInfo("Parents of final stage: " + finalStage.parents) + logInfo("Missing parents: " + getMissingParentStages(finalStage)) + + val jobSubmissionTime = clock.getTimeMillis() + jobIdToActiveJob(jobId) = job + activeJobs += job + finalStage.addActiveJob(job) + val stageIds = jobIdToStageIds(jobId).toArray + val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) + listenerBus.post( + SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, + Utils.cloneProperties(properties))) + submitStage(finalStage) + + // If the whole stage has already finished, tell the listener and remove it + if (finalStage.isAvailable) { + markMapStageJobAsFinished(job, mapOutputTracker.getStatistics(dependency)) + } + } + + /** Submits stage, but first recursively submits any missing parents. */ + private def submitStage(stage: Stage): Unit = { + val jobId = activeJobForStage(stage) + if (jobId.isDefined) { + logDebug(s"submitStage($stage (name=${stage.name};" + + s"jobs=${stage.jobIds.toSeq.sorted.mkString(",")}))") + if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) { + val missing = getMissingParentStages(stage).sortBy(_.id) + logDebug("missing: " + missing) + if (missing.isEmpty) { + logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents") + submitMissingTasks(stage, jobId.get) + } else { + for (parent <- missing) { + submitStage(parent) + } + waitingStages += stage + } + } + } else { + abortStage(stage, "No active job for stage " + stage.id, None) + } + } + + /** + * `PythonRunner` needs to know what the pyspark memory and cores settings are for the profile + * being run. Pass them in the local properties of the task if it's set for the stage profile. + */ + private def addPySparkConfigsToProperties(stage: Stage, properties: Properties): Unit = { + val rp = sc.resourceProfileManager.resourceProfileFromId(stage.resourceProfileId) + val pysparkMem = rp.getPySparkMemory + // use the getOption on EXECUTOR_CORES.key instead of using the EXECUTOR_CORES config reader + // because the default for this config isn't correct for standalone mode. Here we want + // to know if it was explicitly set or not. The default profile always has it set to either + // what user specified or default so special case it here. + val execCores = if (rp.id == DEFAULT_RESOURCE_PROFILE_ID) { + sc.conf.getOption(config.EXECUTOR_CORES.key) + } else { + val profCores = rp.getExecutorCores.map(_.toString) + if (profCores.isEmpty) sc.conf.getOption(config.EXECUTOR_CORES.key) else profCores + } + pysparkMem.map(mem => properties.setProperty(PYSPARK_MEMORY_LOCAL_PROPERTY, mem.toString)) + execCores.map(cores => properties.setProperty(EXECUTOR_CORES_LOCAL_PROPERTY, cores)) + } + + /** + * If push based shuffle is enabled, set the shuffle services to be used for the given + * shuffle map stage for block push/merge. + * + * Even with dynamic resource allocation kicking in and significantly reducing the number + * of available active executors, we would still be able to get sufficient shuffle service + * locations for block push/merge by getting the historical locations of past executors. + */ + private def prepareShuffleServicesForShuffleMapStage(stage: ShuffleMapStage): Unit = { + assert(stage.shuffleDep.shuffleMergeAllowed && !stage.shuffleDep.isShuffleMergeFinalizedMarked) + if (stage.shuffleDep.getMergerLocs.isEmpty) { + getAndSetShufflePushMergerLocations(stage) + } + + val shuffleId = stage.shuffleDep.shuffleId + val shuffleMergeId = stage.shuffleDep.shuffleMergeId + if (stage.shuffleDep.shuffleMergeEnabled) { + logInfo(s"Shuffle merge enabled before starting the stage for $stage with shuffle" + + s" $shuffleId and shuffle merge $shuffleMergeId with" + + s" ${stage.shuffleDep.getMergerLocs.size} merger locations") + } else { + logInfo(s"Shuffle merge disabled for $stage with shuffle $shuffleId" + + s" and shuffle merge $shuffleMergeId, but can get enabled later adaptively" + + s" once enough mergers are available") + } + } + + private def getAndSetShufflePushMergerLocations(stage: ShuffleMapStage): Seq[BlockManagerId] = { + val mergerLocs = sc.schedulerBackend.getShufflePushMergerLocations( + stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId) + if (mergerLocs.nonEmpty) { + stage.shuffleDep.setMergerLocs(mergerLocs) + } + + logDebug(s"Shuffle merge locations for shuffle ${stage.shuffleDep.shuffleId} with" + + s" shuffle merge ${stage.shuffleDep.shuffleMergeId} is" + + s" ${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}") + mergerLocs + } + + /** Called when stage's parents are available and we can now do its task. */ + private def submitMissingTasks(stage: Stage, jobId: Int): Unit = { + logDebug("submitMissingTasks(" + stage + ")") + + // Before find missing partition, do the intermediate state clean work first. + // The operation here can make sure for the partially completed intermediate stage, + // `findMissingPartitions()` returns all partitions every time. + stage match { + case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable => + mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId) + sms.shuffleDep.newShuffleMergeState() + case _ => + } + + // Figure out the indexes of partition ids to compute. + val partitionsToCompute: Seq[Int] = stage.findMissingPartitions() + + // Use the scheduling pool, job group, description, etc. from an ActiveJob associated + // with this Stage + val properties = jobIdToActiveJob(jobId).properties + addPySparkConfigsToProperties(stage, properties) + + runningStages += stage + // SparkListenerStageSubmitted should be posted before testing whether tasks are + // serializable. If tasks are not serializable, a SparkListenerStageCompleted event + // will be posted, which should always come after a corresponding SparkListenerStageSubmitted + // event. + stage match { + case s: ShuffleMapStage => + outputCommitCoordinator.stageStart(stage = s.id, maxPartitionId = s.numPartitions - 1) + // Only generate merger location for a given shuffle dependency once. + if (s.shuffleDep.shuffleMergeAllowed) { + if (!s.shuffleDep.isShuffleMergeFinalizedMarked) { + prepareShuffleServicesForShuffleMapStage(s) + } else { + // Disable Shuffle merge for the retry/reuse of the same shuffle dependency if it has + // already been merge finalized. If the shuffle dependency was previously assigned + // merger locations but the corresponding shuffle map stage did not complete + // successfully, we would still enable push for its retry. + s.shuffleDep.setShuffleMergeAllowed(false) + logInfo(s"Push-based shuffle disabled for $stage (${stage.name}) since it" + + " is already shuffle merge finalized") + } + } + case s: ResultStage => + outputCommitCoordinator.stageStart( + stage = s.id, maxPartitionId = s.rdd.partitions.length - 1) + } + val taskIdToLocations: Map[Int, Seq[TaskLocation]] = try { + stage match { + case s: ShuffleMapStage => + partitionsToCompute.map { id => (id, getPreferredLocs(stage.rdd, id))}.toMap + case s: ResultStage => + partitionsToCompute.map { id => + val p = s.partitions(id) + (id, getPreferredLocs(stage.rdd, p)) + }.toMap + } + } catch { + case NonFatal(e) => + stage.makeNewStageAttempt(partitionsToCompute.size) + listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, + Utils.cloneProperties(properties))) + abortStage(stage, s"Task creation failed: $e\n${Utils.exceptionString(e)}", Some(e)) + runningStages -= stage + return + } + + stage.makeNewStageAttempt(partitionsToCompute.size, taskIdToLocations.values.toSeq) + + // If there are tasks to execute, record the submission time of the stage. Otherwise, + // post the even without the submission time, which indicates that this stage was + // skipped. + if (partitionsToCompute.nonEmpty) { + stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) + } + listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, + Utils.cloneProperties(properties))) + + // TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times. + // Broadcasted binary for the task, used to dispatch tasks to executors. Note that we broadcast + // the serialized copy of the RDD and for each task we will deserialize it, which means each + // task gets a different copy of the RDD. This provides stronger isolation between tasks that + // might modify state of objects referenced in their closures. This is necessary in Hadoop + // where the JobConf/Configuration object is not thread-safe. + var taskBinary: Broadcast[Array[Byte]] = null + var partitions: Array[Partition] = null + try { + // For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep). + // For ResultTask, serialize and broadcast (rdd, func). + var taskBinaryBytes: Array[Byte] = null + // taskBinaryBytes and partitions are both effected by the checkpoint status. We need + // this synchronization in case another concurrent job is checkpointing this RDD, so we get a + // consistent view of both variables. + RDDCheckpointData.synchronized { + taskBinaryBytes = stage match { + case stage: ShuffleMapStage => + JavaUtils.bufferToArray( + closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef)) + case stage: ResultStage => + JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef)) + } + + partitions = stage.rdd.partitions + } + + if (taskBinaryBytes.length > TaskSetManager.TASK_SIZE_TO_WARN_KIB * 1024) { + logWarning(s"Broadcasting large task binary with size " + + s"${Utils.bytesToString(taskBinaryBytes.length)}") + } + taskBinary = sc.broadcast(taskBinaryBytes) + } catch { + // In the case of a failure during serialization, abort the stage. + case e: NotSerializableException => + abortStage(stage, "Task not serializable: " + e.toString, Some(e)) + runningStages -= stage + + // Abort execution + return + case e: Throwable => + abortStage(stage, s"Task serialization failed: $e\n${Utils.exceptionString(e)}", Some(e)) + runningStages -= stage + + // Abort execution + return + } + + val tasks: Seq[Task[_]] = try { + val serializedTaskMetrics = closureSerializer.serialize(stage.latestInfo.taskMetrics).array() + stage match { + case stage: ShuffleMapStage => + stage.pendingPartitions.clear() + partitionsToCompute.map { id => + val locs = taskIdToLocations(id) + val part = partitions(id) + stage.pendingPartitions += id + new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber, + taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId), + Option(sc.applicationId), sc.applicationAttemptId, stage.rdd.isBarrier()) + } + + case stage: ResultStage => + partitionsToCompute.map { id => + val p: Int = stage.partitions(id) + val part = partitions(p) + val locs = taskIdToLocations(id) + new ResultTask(stage.id, stage.latestInfo.attemptNumber, + taskBinary, part, locs, id, properties, serializedTaskMetrics, + Option(jobId), Option(sc.applicationId), sc.applicationAttemptId, + stage.rdd.isBarrier()) + } + } + } catch { + case NonFatal(e) => + abortStage(stage, s"Task creation failed: $e\n${Utils.exceptionString(e)}", Some(e)) + runningStages -= stage + return + } + + if (tasks.nonEmpty) { + logInfo(s"Submitting ${tasks.size} missing tasks from $stage (${stage.rdd}) (first 15 " + + s"tasks are for partitions ${tasks.take(15).map(_.partitionId)})") + taskScheduler.submitTasks(new TaskSet( + tasks.toArray, stage.id, stage.latestInfo.attemptNumber, jobId, properties, + stage.resourceProfileId)) + } else { + // Because we posted SparkListenerStageSubmitted earlier, we should mark + // the stage as completed here in case there are no tasks to run + markStageAsFinished(stage, None) + + stage match { + case stage: ShuffleMapStage => + logDebug(s"Stage ${stage} is actually done; " + + s"(available: ${stage.isAvailable}," + + s"available outputs: ${stage.numAvailableOutputs}," + + s"partitions: ${stage.numPartitions})") + markMapStageJobsAsFinished(stage) + case stage : ResultStage => + logDebug(s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})") + } + submitWaitingChildStages(stage) + } + } + + /** + * Merge local values from a task into the corresponding accumulators previously registered + * here on the driver. + * + * Although accumulators themselves are not thread-safe, this method is called only from one + * thread, the one that runs the scheduling loop. This means we only handle one task + * completion event at a time so we don't need to worry about locking the accumulators. + * This still doesn't stop the caller from updating the accumulator outside the scheduler, + * but that's not our problem since there's nothing we can do about that. + */ + private def updateAccumulators(event: CompletionEvent): Unit = { + val task = event.task + val stage = stageIdToStage(task.stageId) + + event.accumUpdates.foreach { updates => + val id = updates.id + try { + // Find the corresponding accumulator on the driver and update it + val acc: AccumulatorV2[Any, Any] = AccumulatorContext.get(id) match { + case Some(accum) => accum.asInstanceOf[AccumulatorV2[Any, Any]] + case None => + throw SparkCoreErrors.accessNonExistentAccumulatorError(id) + } + acc.merge(updates.asInstanceOf[AccumulatorV2[Any, Any]]) + // To avoid UI cruft, ignore cases where value wasn't updated + if (acc.name.isDefined && !updates.isZero) { + stage.latestInfo.accumulables(id) = acc.toInfo(None, Some(acc.value)) + event.taskInfo.setAccumulables( + acc.toInfo(Some(updates.value), Some(acc.value)) +: event.taskInfo.accumulables) + } + } catch { + case NonFatal(e) => + // Log the class name to make it easy to find the bad implementation + val accumClassName = AccumulatorContext.get(id) match { + case Some(accum) => accum.getClass.getName + case None => "Unknown class" + } + logError( + s"Failed to update accumulator $id ($accumClassName) for task ${task.partitionId}", + e) + } + } + } + + private def postTaskEnd(event: CompletionEvent): Unit = { + val taskMetrics: TaskMetrics = + if (event.accumUpdates.nonEmpty) { + try { + TaskMetrics.fromAccumulators(event.accumUpdates) + } catch { + case NonFatal(e) => + val taskId = event.taskInfo.taskId + logError(s"Error when attempting to reconstruct metrics for task $taskId", e) + null + } + } else { + null + } + + listenerBus.post(SparkListenerTaskEnd(event.task.stageId, event.task.stageAttemptId, + Utils.getFormattedClassName(event.task), event.reason, event.taskInfo, + new ExecutorMetrics(event.metricPeaks), taskMetrics)) + } + + /** + * Check [[SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL]] in job properties to see if we should + * interrupt running tasks. Returns `false` if the property value is not a boolean value + */ + private def shouldInterruptTaskThread(job: ActiveJob): Boolean = { + if (job.properties == null) { + false + } else { + val shouldInterruptThread = + job.properties.getProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false") + try { + shouldInterruptThread.toBoolean + } catch { + case e: IllegalArgumentException => + logWarning(s"${SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL} in Job ${job.jobId} " + + s"is invalid: $shouldInterruptThread. Using 'false' instead", e) + false + } + } + } + + private[scheduler] def checkAndScheduleShuffleMergeFinalize( + shuffleStage: ShuffleMapStage): Unit = { + // Check if a finalize task has already been scheduled. This is to prevent scenarios + // where we don't schedule multiple shuffle merge finalization which can happen due to + // stage retry or shufflePushMinRatio is already hit etc. + if (shuffleStage.shuffleDep.getFinalizeTask.isEmpty) { + // 1. Stage indeterminate and some map outputs are not available - finalize + // immediately without registering shuffle merge results. + // 2. Stage determinate and some map outputs are not available - decide to + // register merge results based on map outputs size available and + // shuffleMergeWaitMinSizeThreshold. + // 3. All shuffle outputs available - decide to register merge results based + // on map outputs size available and shuffleMergeWaitMinSizeThreshold. + val totalSize = { + lazy val computedTotalSize = + mapOutputTracker.getStatistics(shuffleStage.shuffleDep). + bytesByPartitionId.filter(_ > 0).sum + if (shuffleStage.isAvailable) { + computedTotalSize + } else { + if (shuffleStage.isIndeterminate) { + 0L + } else { + computedTotalSize + } + } + } + + if (totalSize < shuffleMergeWaitMinSizeThreshold) { + scheduleShuffleMergeFinalize(shuffleStage, delay = 0, registerMergeResults = false) + } else { + scheduleShuffleMergeFinalize(shuffleStage, shuffleMergeFinalizeWaitSec) + } + } + } + + /** + * Responds to a task finishing. This is called inside the event loop so it assumes that it can + * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside. + */ + private[scheduler] def handleTaskCompletion(event: CompletionEvent): Unit = { + val task = event.task + val stageId = task.stageId + + outputCommitCoordinator.taskCompleted( + stageId, + task.stageAttemptId, + task.partitionId, + event.taskInfo.attemptNumber, // this is a task attempt number + event.reason) + + if (!stageIdToStage.contains(task.stageId)) { + // The stage may have already finished when we get this event -- e.g. maybe it was a + // speculative task. It is important that we send the TaskEnd event in any case, so listeners + // are properly notified and can chose to handle it. For instance, some listeners are + // doing their own accounting and if they don't get the task end event they think + // tasks are still running when they really aren't. + postTaskEnd(event) + + // Skip all the actions if the stage has been cancelled. + return + } + + val stage = stageIdToStage(task.stageId) + + // Make sure the task's accumulators are updated before any other processing happens, so that + // we can post a task end event before any jobs or stages are updated. The accumulators are + // only updated in certain cases. + event.reason match { + case Success => + task match { + case rt: ResultTask[_, _] => + val resultStage = stage.asInstanceOf[ResultStage] + resultStage.activeJob match { + case Some(job) => + // Only update the accumulator once for each result task. + if (!job.finished(rt.outputId)) { + updateAccumulators(event) + } + case None => // Ignore update if task's job has finished. + } + case _ => + updateAccumulators(event) + } + case _: ExceptionFailure | _: TaskKilled => updateAccumulators(event) + case _ => + } + postTaskEnd(event) + + event.reason match { + case Success => + // An earlier attempt of a stage (which is zombie) may still have running tasks. If these + // tasks complete, they still count and we can mark the corresponding partitions as + // finished. Here we notify the task scheduler to skip running tasks for the same partition, + // to save resource. + if (task.stageAttemptId < stage.latestInfo.attemptNumber()) { + taskScheduler.notifyPartitionCompletion(stageId, task.partitionId) + } + + task match { + case rt: ResultTask[_, _] => + // Cast to ResultStage here because it's part of the ResultTask + // TODO Refactor this out to a function that accepts a ResultStage + val resultStage = stage.asInstanceOf[ResultStage] + resultStage.activeJob match { + case Some(job) => + if (!job.finished(rt.outputId)) { + job.finished(rt.outputId) = true + job.numFinished += 1 + // If the whole job has finished, remove it + if (job.numFinished == job.numPartitions) { + markStageAsFinished(resultStage) + cancelRunningIndependentStages(job, s"Job ${job.jobId} is finished.") + cleanupStateForJobAndIndependentStages(job) + try { + // killAllTaskAttempts will fail if a SchedulerBackend does not implement + // killTask. + logInfo(s"Job ${job.jobId} is finished. Cancelling potential speculative " + + "or zombie tasks for this job") + // ResultStage is only used by this job. It's safe to kill speculative or + // zombie tasks in this stage. + taskScheduler.killAllTaskAttempts( + stageId, + shouldInterruptTaskThread(job), + reason = "Stage finished") + } catch { + case e: UnsupportedOperationException => + logWarning(s"Could not cancel tasks for stage $stageId", e) + } + listenerBus.post( + SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded)) + } + + // taskSucceeded runs some user code that might throw an exception. Make sure + // we are resilient against that. + try { + job.listener.taskSucceeded(rt.outputId, event.result) + } catch { + case e: Throwable if !Utils.isFatalError(e) => + // TODO: Perhaps we want to mark the resultStage as failed? + job.listener.jobFailed(new SparkDriverExecutionException(e)) + } + } + case None => + logInfo("Ignoring result from " + rt + " because its job has finished") + } + + case smt: ShuffleMapTask => + val shuffleStage = stage.asInstanceOf[ShuffleMapStage] + shuffleStage.pendingPartitions -= task.partitionId + val status = event.result.asInstanceOf[MapStatus] + val execId = status.location.executorId + logDebug("ShuffleMapTask finished on " + execId) + if (executorFailureEpoch.contains(execId) && + smt.epoch <= executorFailureEpoch(execId)) { + logInfo(s"Ignoring possibly bogus $smt completion from executor $execId") + } else { + // The epoch of the task is acceptable (i.e., the task was launched after the most + // recent failure we're aware of for the executor), so mark the task's output as + // available. + mapOutputTracker.registerMapOutput( + shuffleStage.shuffleDep.shuffleId, smt.partitionId, status) + } + + if (runningStages.contains(shuffleStage) && shuffleStage.pendingPartitions.isEmpty) { + if (!shuffleStage.shuffleDep.isShuffleMergeFinalizedMarked && + shuffleStage.shuffleDep.getMergerLocs.nonEmpty) { + checkAndScheduleShuffleMergeFinalize(shuffleStage) + } else { + processShuffleMapStageCompletion(shuffleStage) + } + } + } + + case FetchFailed(bmAddress, shuffleId, _, mapIndex, reduceId, failureMessage) => + val failedStage = stageIdToStage(task.stageId) + val mapStage = shuffleIdToMapStage(shuffleId) + + if (failedStage.latestInfo.attemptNumber != task.stageAttemptId) { + logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" + + s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + + s"(attempt ${failedStage.latestInfo.attemptNumber}) running") + } else { + failedStage.failedAttemptIds.add(task.stageAttemptId) + val shouldAbortStage = + failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || + disallowStageRetryForTest + + // It is likely that we receive multiple FetchFailed for a single stage (because we have + // multiple tasks running concurrently on different executors). In that case, it is + // possible the fetch failure has already been handled by the scheduler. + if (runningStages.contains(failedStage)) { + logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + + s"due to a fetch failure from $mapStage (${mapStage.name})") + markStageAsFinished(failedStage, errorMessage = Some(failureMessage), + willRetry = !shouldAbortStage) + } else { + logDebug(s"Received fetch failure from $task, but it's from $failedStage which is no " + + "longer running") + } + + if (mapStage.rdd.isBarrier()) { + // Mark all the map as broken in the map stage, to ensure retry all the tasks on + // resubmitted stage attempt. + // TODO: SPARK-35547: Clean all push-based shuffle metadata like merge enabled and + // TODO: finalized as we are clearing all the merge results. + mapOutputTracker.unregisterAllMapAndMergeOutput(shuffleId) + } else if (mapIndex != -1) { + // Mark the map whose fetch failed as broken in the map stage + mapOutputTracker.unregisterMapOutput(shuffleId, mapIndex, bmAddress) + if (pushBasedShuffleEnabled) { + // Possibly unregister the merge result , if the FetchFailed + // mapIndex is part of the merge result of + mapOutputTracker. + unregisterMergeResult(shuffleId, reduceId, bmAddress, Option(mapIndex)) + } + } + + if (failedStage.rdd.isBarrier()) { + failedStage match { + case failedMapStage: ShuffleMapStage => + // Mark all the map as broken in the map stage, to ensure retry all the tasks on + // resubmitted stage attempt. + mapOutputTracker.unregisterAllMapAndMergeOutput(failedMapStage.shuffleDep.shuffleId) + + case failedResultStage: ResultStage => + // Abort the failed result stage since we may have committed output for some + // partitions. + val reason = "Could not recover from a failed barrier ResultStage. Most recent " + + s"failure reason: $failureMessage" + abortStage(failedResultStage, reason, None) + } + } + + if (shouldAbortStage) { + val abortMessage = if (disallowStageRetryForTest) { + "Fetch failure will not retry stage due to testing config" + } else { + s"$failedStage (${failedStage.name}) has failed the maximum allowable number of " + + s"times: $maxConsecutiveStageAttempts. Most recent failure reason:\n" + + failureMessage + } + abortStage(failedStage, abortMessage, None) + } else { // update failedStages and make sure a ResubmitFailedStages event is enqueued + // TODO: Cancel running tasks in the failed stage -- cf. SPARK-17064 + val noResubmitEnqueued = !failedStages.contains(failedStage) + failedStages += failedStage + failedStages += mapStage + if (noResubmitEnqueued) { + // If the map stage is INDETERMINATE, which means the map tasks may return + // different result when re-try, we need to re-try all the tasks of the failed + // stage and its succeeding stages, because the input data will be changed after the + // map tasks are re-tried. + // Note that, if map stage is UNORDERED, we are fine. The shuffle partitioner is + // guaranteed to be determinate, so the input data of the reducers will not change + // even if the map tasks are re-tried. + if (mapStage.isIndeterminate) { + // It's a little tricky to find all the succeeding stages of `mapStage`, because + // each stage only know its parents not children. Here we traverse the stages from + // the leaf nodes (the result stages of active jobs), and rollback all the stages + // in the stage chains that connect to the `mapStage`. To speed up the stage + // traversing, we collect the stages to rollback first. If a stage needs to + // rollback, all its succeeding stages need to rollback to. + val stagesToRollback = HashSet[Stage](mapStage) + + def collectStagesToRollback(stageChain: List[Stage]): Unit = { + if (stagesToRollback.contains(stageChain.head)) { + stageChain.drop(1).foreach(s => stagesToRollback += s) + } else { + stageChain.head.parents.foreach { s => + collectStagesToRollback(s :: stageChain) + } + } + } + + def generateErrorMessage(stage: Stage): String = { + "A shuffle map stage with indeterminate output was failed and retried. " + + s"However, Spark cannot rollback the $stage to re-process the input data, " + + "and has to fail this job. Please eliminate the indeterminacy by " + + "checkpointing the RDD before repartition and try again." + } + + activeJobs.foreach(job => collectStagesToRollback(job.finalStage :: Nil)) + + // The stages will be rolled back after checking + val rollingBackStages = HashSet[Stage](mapStage) + stagesToRollback.foreach { + case mapStage: ShuffleMapStage => + val numMissingPartitions = mapStage.findMissingPartitions().length + if (numMissingPartitions < mapStage.numTasks) { + if (sc.getConf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)) { + val reason = "A shuffle map stage with indeterminate output was failed " + + "and retried. However, Spark can only do this while using the new " + + "shuffle block fetching protocol. Please check the config " + + "'spark.shuffle.useOldFetchProtocol', see more detail in " + + "SPARK-27665 and SPARK-25341." + abortStage(mapStage, reason, None) + } else { + rollingBackStages += mapStage + } + } + + case resultStage: ResultStage if resultStage.activeJob.isDefined => + val numMissingPartitions = resultStage.findMissingPartitions().length + if (numMissingPartitions < resultStage.numTasks) { + // TODO: support to rollback result tasks. + abortStage(resultStage, generateErrorMessage(resultStage), None) + } + + case _ => + } + logInfo(s"The shuffle map stage $mapStage with indeterminate output was failed, " + + s"we will roll back and rerun below stages which include itself and all its " + + s"indeterminate child stages: $rollingBackStages") + } + + // We expect one executor failure to trigger many FetchFailures in rapid succession, + // but all of those task failures can typically be handled by a single resubmission of + // the failed stage. We avoid flooding the scheduler's event queue with resubmit + // messages by checking whether a resubmit is already in the event queue for the + // failed stage. If there is already a resubmit enqueued for a different failed + // stage, that event would also be sufficient to handle the current failed stage, but + // producing a resubmit for each failed stage makes debugging and logging a little + // simpler while not producing an overwhelming number of scheduler events. + logInfo( + s"Resubmitting $mapStage (${mapStage.name}) and " + + s"$failedStage (${failedStage.name}) due to fetch failure" + ) + messageScheduler.schedule( + new Runnable { + override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) + }, + DAGScheduler.RESUBMIT_TIMEOUT, + TimeUnit.MILLISECONDS + ) + } + } + + // TODO: mark the executor as failed only if there were lots of fetch failures on it + if (bmAddress != null) { + val externalShuffleServiceEnabled = env.blockManager.externalShuffleServiceEnabled + val isHostDecommissioned = taskScheduler + .getExecutorDecommissionState(bmAddress.executorId) + .exists(_.workerHost.isDefined) + + // Shuffle output of all executors on host `bmAddress.host` may be lost if: + // - External shuffle service is enabled, so we assume that all shuffle data on node is + // bad. + // - Host is decommissioned, thus all executors on that host will die. + val shuffleOutputOfEntireHostLost = externalShuffleServiceEnabled || + isHostDecommissioned + val hostToUnregisterOutputs = if (shuffleOutputOfEntireHostLost + && unRegisterOutputOnHostOnFetchFailure) { + Some(bmAddress.host) + } else { + // Unregister shuffle data just for one executor (we don't have any + // reason to believe shuffle data has been lost for the entire host). + None + } + removeExecutorAndUnregisterOutputs( + execId = bmAddress.executorId, + fileLost = true, + hostToUnregisterOutputs = hostToUnregisterOutputs, + maybeEpoch = Some(task.epoch), + // shuffleFileLostEpoch is ignored when a host is decommissioned because some + // decommissioned executors on that host might have been removed before this fetch + // failure and might have bumped up the shuffleFileLostEpoch. We ignore that, and + // proceed with unconditional removal of shuffle outputs from all executors on that + // host, including from those that we still haven't confirmed as lost due to heartbeat + // delays. + ignoreShuffleFileLostEpoch = isHostDecommissioned) + } + } + + case failure: TaskFailedReason if task.isBarrier => + // Also handle the task failed reasons here. + failure match { + case Resubmitted => + handleResubmittedFailure(task, stage) + + case _ => // Do nothing. + } + + // Always fail the current stage and retry all the tasks when a barrier task fail. + val failedStage = stageIdToStage(task.stageId) + if (failedStage.latestInfo.attemptNumber != task.stageAttemptId) { + logInfo(s"Ignoring task failure from $task as it's from $failedStage attempt" + + s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + + s"(attempt ${failedStage.latestInfo.attemptNumber}) running") + } else { + logInfo(s"Marking $failedStage (${failedStage.name}) as failed due to a barrier task " + + "failed.") + val message = s"Stage failed because barrier task $task finished unsuccessfully.\n" + + failure.toErrorString + try { + // killAllTaskAttempts will fail if a SchedulerBackend does not implement killTask. + val reason = s"Task $task from barrier stage $failedStage (${failedStage.name}) " + + "failed." + val job = jobIdToActiveJob.get(failedStage.firstJobId) + val shouldInterrupt = job.exists(j => shouldInterruptTaskThread(j)) + taskScheduler.killAllTaskAttempts(stageId, shouldInterrupt, reason) + } catch { + case e: UnsupportedOperationException => + // Cannot continue with barrier stage if failed to cancel zombie barrier tasks. + // TODO SPARK-24877 leave the zombie tasks and ignore their completion events. + logWarning(s"Could not kill all tasks for stage $stageId", e) + abortStage(failedStage, "Could not kill zombie barrier tasks for stage " + + s"$failedStage (${failedStage.name})", Some(e)) + } + markStageAsFinished(failedStage, Some(message)) + + failedStage.failedAttemptIds.add(task.stageAttemptId) + // TODO Refactor the failure handling logic to combine similar code with that of + // FetchFailed. + val shouldAbortStage = + failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || + disallowStageRetryForTest + + if (shouldAbortStage) { + val abortMessage = if (disallowStageRetryForTest) { + "Barrier stage will not retry stage due to testing config. Most recent failure " + + s"reason: $message" + } else { + s"$failedStage (${failedStage.name}) has failed the maximum allowable number of " + + s"times: $maxConsecutiveStageAttempts. Most recent failure reason: $message" + } + abortStage(failedStage, abortMessage, None) + } else { + failedStage match { + case failedMapStage: ShuffleMapStage => + // Mark all the map as broken in the map stage, to ensure retry all the tasks on + // resubmitted stage attempt. + mapOutputTracker.unregisterAllMapAndMergeOutput(failedMapStage.shuffleDep.shuffleId) + + case failedResultStage: ResultStage => + // Abort the failed result stage since we may have committed output for some + // partitions. + val reason = "Could not recover from a failed barrier ResultStage. Most recent " + + s"failure reason: $message" + abortStage(failedResultStage, reason, None) + } + // In case multiple task failures triggered for a single stage attempt, ensure we only + // resubmit the failed stage once. + val noResubmitEnqueued = !failedStages.contains(failedStage) + failedStages += failedStage + if (noResubmitEnqueued) { + logInfo(s"Resubmitting $failedStage (${failedStage.name}) due to barrier stage " + + "failure.") + messageScheduler.schedule(new Runnable { + override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) + }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) + } + } + } + + case Resubmitted => + handleResubmittedFailure(task, stage) + + case _: TaskCommitDenied => + // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits + + case _: ExceptionFailure | _: TaskKilled => + // Nothing left to do, already handled above for accumulator updates. + + case TaskResultLost => + // Do nothing here; the TaskScheduler handles these failures and resubmits the task. + + case _: ExecutorLostFailure | UnknownReason => + // Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler + // will abort the job. + } + } + + /** + * + * Schedules shuffle merge finalization. + * + * @param stage the stage to finalize shuffle merge + * @param delay how long to wait before finalizing shuffle merge + * @param registerMergeResults indicate whether DAGScheduler would register the received + * MergeStatus with MapOutputTracker and wait to schedule the reduce + * stage until MergeStatus have been received from all mergers or + * reaches timeout. For very small shuffle, this could be set to + * false to avoid impact to job runtime. + */ + private[scheduler] def scheduleShuffleMergeFinalize( + stage: ShuffleMapStage, + delay: Long, + registerMergeResults: Boolean = true): Unit = { + val shuffleDep = stage.shuffleDep + val scheduledTask: Option[ScheduledFuture[_]] = shuffleDep.getFinalizeTask + scheduledTask match { + case Some(task) => + // If we find an already scheduled task, check if the task has been triggered yet. + // If it's already triggered, do nothing. Otherwise, cancel it and schedule a new + // one for immediate execution. Note that we should get here only when + // handleShufflePushCompleted schedules a finalize task after the shuffle map stage + // completed earlier and scheduled a task with default delay. + // The current task should be coming from handleShufflePushCompleted, thus the + // delay should be 0 and registerMergeResults should be true. + assert(delay == 0 && registerMergeResults) + if (task.getDelay(TimeUnit.NANOSECONDS) > 0 && task.cancel(false)) { + logInfo(s"$stage (${stage.name}) scheduled for finalizing shuffle merge immediately " + + s"after cancelling previously scheduled task.") + shuffleDep.setFinalizeTask( + shuffleMergeFinalizeScheduler.schedule( + new Runnable { + override def run(): Unit = finalizeShuffleMerge(stage, registerMergeResults) + }, + 0, + TimeUnit.SECONDS + ) + ) + } else { + logInfo(s"$stage (${stage.name}) existing scheduled task for finalizing shuffle merge" + + s"would either be in-progress or finished. No need to schedule shuffle merge" + + s" finalization again.") + } + case None => + // If no previous finalization task is scheduled, schedule the finalization task. + logInfo(s"$stage (${stage.name}) scheduled for finalizing shuffle merge in $delay s") + shuffleDep.setFinalizeTask( + shuffleMergeFinalizeScheduler.schedule( + new Runnable { + override def run(): Unit = finalizeShuffleMerge(stage, registerMergeResults) + }, + delay, + TimeUnit.SECONDS + ) + ) + } + } + + /** + * DAGScheduler notifies all the remote shuffle services chosen to serve shuffle merge request for + * the given shuffle map stage to finalize the shuffle merge process for this shuffle. This is + * invoked in a separate thread to reduce the impact on the DAGScheduler main thread, as the + * scheduler might need to talk to 1000s of shuffle services to finalize shuffle merge. + * + * @param stage ShuffleMapStage to finalize shuffle merge for + * @param registerMergeResults indicate whether DAGScheduler would register the received + * MergeStatus with MapOutputTracker and wait to schedule the reduce + * stage until MergeStatus have been received from all mergers or + * reaches timeout. For very small shuffle, this could be set to + * false to avoid impact to job runtime. + */ + private[scheduler] def finalizeShuffleMerge( + stage: ShuffleMapStage, + registerMergeResults: Boolean = true): Unit = { + logInfo(s"$stage (${stage.name}) finalizing the shuffle merge with registering merge " + + s"results set to $registerMergeResults") + val shuffleId = stage.shuffleDep.shuffleId + val shuffleMergeId = stage.shuffleDep.shuffleMergeId + val numMergers = stage.shuffleDep.getMergerLocs.length + val results = (0 until numMergers).map(_ => SettableFuture.create[Boolean]()) + externalShuffleClient.foreach { shuffleClient => + if (!registerMergeResults) { + results.foreach(_.set(true)) + // Finalize in separate thread as shuffle merge is a no-op in this case + shuffleMergeFinalizeScheduler.schedule(new Runnable { + override def run(): Unit = { + stage.shuffleDep.getMergerLocs.foreach { + case shuffleServiceLoc => + // Sends async request to shuffle service to finalize shuffle merge on that host. + // Since merge statuses will not be registered in this case, + // we pass a no-op listener. + shuffleClient.finalizeShuffleMerge(shuffleServiceLoc.host, + shuffleServiceLoc.port, shuffleId, shuffleMergeId, + new MergeFinalizerListener { + override def onShuffleMergeSuccess(statuses: MergeStatuses): Unit = { + } + + override def onShuffleMergeFailure(e: Throwable): Unit = { + } + }) + } + } + }, 0, TimeUnit.SECONDS) + } else { + stage.shuffleDep.getMergerLocs.zipWithIndex.foreach { + case (shuffleServiceLoc, index) => + // Sends async request to shuffle service to finalize shuffle merge on that host + // TODO: SPARK-35536: Cancel finalizeShuffleMerge if the stage is cancelled + // TODO: during shuffleMergeFinalizeWaitSec + shuffleClient.finalizeShuffleMerge(shuffleServiceLoc.host, + shuffleServiceLoc.port, shuffleId, shuffleMergeId, + new MergeFinalizerListener { + override def onShuffleMergeSuccess(statuses: MergeStatuses): Unit = { + assert(shuffleId == statuses.shuffleId) + eventProcessLoop.post(RegisterMergeStatuses(stage, MergeStatus. + convertMergeStatusesToMergeStatusArr(statuses, shuffleServiceLoc))) + results(index).set(true) + } + + override def onShuffleMergeFailure(e: Throwable): Unit = { + logWarning(s"Exception encountered when trying to finalize shuffle " + + s"merge on ${shuffleServiceLoc.host} for shuffle $shuffleId", e) + // Do not fail the future as this would cause dag scheduler to prematurely + // give up on waiting for merge results from the remaining shuffle services + // if one fails + results(index).set(false) + } + }) + } + } + // DAGScheduler only waits for a limited amount of time for the merge results. + // It will attempt to submit the next stage(s) irrespective of whether merge results + // from all shuffle services are received or not. + try { + Futures.allAsList(results: _*).get(shuffleMergeResultsTimeoutSec, TimeUnit.SECONDS) + } catch { + case _: TimeoutException => + logInfo(s"Timed out on waiting for merge results from all " + + s"$numMergers mergers for shuffle $shuffleId") + } finally { + eventProcessLoop.post(ShuffleMergeFinalized(stage)) + } + } + } + + private def processShuffleMapStageCompletion(shuffleStage: ShuffleMapStage): Unit = { + markStageAsFinished(shuffleStage) + logInfo("looking for newly runnable stages") + logInfo("running: " + runningStages) + logInfo("waiting: " + waitingStages) + logInfo("failed: " + failedStages) + + // This call to increment the epoch may not be strictly necessary, but it is retained + // for now in order to minimize the changes in behavior from an earlier version of the + // code. This existing behavior of always incrementing the epoch following any + // successful shuffle map stage completion may have benefits by causing unneeded + // cached map outputs to be cleaned up earlier on executors. In the future we can + // consider removing this call, but this will require some extra investigation. + // See https://github.com/apache/spark/pull/17955/files#r117385673 for more details. + mapOutputTracker.incrementEpoch() + + clearCacheLocs() + + if (!shuffleStage.isAvailable) { + // Some tasks had failed; let's resubmit this shuffleStage. + // TODO: Lower-level scheduler should also deal with this + logInfo("Resubmitting " + shuffleStage + " (" + shuffleStage.name + + ") because some of its tasks had failed: " + + shuffleStage.findMissingPartitions().mkString(", ")) + submitStage(shuffleStage) + } else { + markMapStageJobsAsFinished(shuffleStage) + submitWaitingChildStages(shuffleStage) + } + } + + private[scheduler] def handleRegisterMergeStatuses( + stage: ShuffleMapStage, + mergeStatuses: Seq[(Int, MergeStatus)]): Unit = { + // Register merge statuses if the stage is still running and shuffle merge is not finalized yet. + // TODO: SPARK-35549: Currently merge statuses results which come after shuffle merge + // TODO: is finalized is not registered. + if (runningStages.contains(stage) && !stage.shuffleDep.isShuffleMergeFinalizedMarked) { + mapOutputTracker.registerMergeResults(stage.shuffleDep.shuffleId, mergeStatuses) + } + } + + private[scheduler] def handleShuffleMergeFinalized(stage: ShuffleMapStage, + shuffleMergeId: Int): Unit = { + // Check if update is for the same merge id - finalization might have completed for an earlier + // adaptive attempt while the stage might have failed/killed and shuffle id is getting + // re-executing now. + if (stage.shuffleDep.shuffleMergeId == shuffleMergeId) { + // When it reaches here, there is a possibility that the stage will be resubmitted again + // because of various reasons. Some of these could be: + // a) Stage results are not available. All the tasks completed once so the + // pendingPartitions is empty but due to an executor failure some of the map outputs are not + // available any more, so the stage will be re-submitted. + // b) Stage failed due to a task failure. + // We should mark the stage as merged finalized irrespective of what state it is in. + // This will prevent the push being enabled for the re-attempt. + // Note: for indeterminate stages, this doesn't matter at all, since the merge finalization + // related state is reset during the stage submission. + stage.shuffleDep.markShuffleMergeFinalized() + if (stage.pendingPartitions.isEmpty) + if (runningStages.contains(stage)) { + processShuffleMapStageCompletion(stage) + } else if (stage.isIndeterminate) { + // There are 2 possibilities here - stage is either cancelled or it will be resubmitted. + // If this is an indeterminate stage which is cancelled, we unregister all its merge + // results here just to free up some memory. If the indeterminate stage is resubmitted, + // merge results are cleared again when the newer attempt is submitted. + mapOutputTracker.unregisterAllMergeResult(stage.shuffleDep.shuffleId) + // For determinate stages, which have completed merge finalization, we don't need to + // unregister merge results - since the stage retry, or any other stage computing the + // same shuffle id, can use it. + } + } + } + + private[scheduler] def handleShufflePushCompleted( + shuffleId: Int, shuffleMergeId: Int, mapIndex: Int): Unit = { + shuffleIdToMapStage.get(shuffleId) match { + case Some(mapStage) => + val shuffleDep = mapStage.shuffleDep + // Only update shufflePushCompleted events for the current active stage map tasks. + // This is required to prevent shuffle merge finalization by dangling tasks of a + // previous attempt in the case of indeterminate stage. + if (shuffleDep.shuffleMergeId == shuffleMergeId) { + if (!shuffleDep.isShuffleMergeFinalizedMarked && + shuffleDep.incPushCompleted(mapIndex).toDouble / shuffleDep.rdd.partitions.length + >= shufflePushMinRatio) { + scheduleShuffleMergeFinalize(mapStage, delay = 0) + } + } + case None => + } + } + + private def handleResubmittedFailure(task: Task[_], stage: Stage): Unit = { + logInfo(s"Resubmitted $task, so marking it as still running.") + stage match { + case sms: ShuffleMapStage => + sms.pendingPartitions += task.partitionId + + case _ => + throw SparkCoreErrors.sendResubmittedTaskStatusForShuffleMapStagesOnlyError() + } + } + + private[scheduler] def markMapStageJobsAsFinished(shuffleStage: ShuffleMapStage): Unit = { + // Mark any map-stage jobs waiting on this stage as finished + if (shuffleStage.isAvailable && shuffleStage.mapStageJobs.nonEmpty) { + val stats = mapOutputTracker.getStatistics(shuffleStage.shuffleDep) + for (job <- shuffleStage.mapStageJobs) { + markMapStageJobAsFinished(job, stats) + } + } + } + + /** + * Responds to an executor being lost. This is called inside the event loop, so it assumes it can + * modify the scheduler's internal state. Use executorLost() to post a loss event from outside. + * + * We will also assume that we've lost all shuffle blocks associated with the executor if the + * executor serves its own blocks (i.e., we're not using an external shuffle service), or the + * entire Standalone worker is lost. + */ + private[scheduler] def handleExecutorLost( + execId: String, + workerHost: Option[String]): Unit = { + // if the cluster manager explicitly tells us that the entire worker was lost, then + // we know to unregister shuffle output. (Note that "worker" specifically refers to the process + // from a Standalone cluster, where the shuffle service lives in the Worker.) + val fileLost = workerHost.isDefined || !env.blockManager.externalShuffleServiceEnabled + removeExecutorAndUnregisterOutputs( + execId = execId, + fileLost = fileLost, + hostToUnregisterOutputs = workerHost, + maybeEpoch = None) + } + + /** + * Handles removing an executor from the BlockManagerMaster as well as unregistering shuffle + * outputs for the executor or optionally its host. + * + * @param execId executor to be removed + * @param fileLost If true, indicates that we assume we've lost all shuffle blocks associated + * with the executor; this happens if the executor serves its own blocks (i.e., we're not + * using an external shuffle service), the entire Standalone worker is lost, or a FetchFailed + * occurred (in which case we presume all shuffle data related to this executor to be lost). + * @param hostToUnregisterOutputs (optional) executor host if we're unregistering all the + * outputs on the host + * @param maybeEpoch (optional) the epoch during which the failure was caught (this prevents + * reprocessing for follow-on fetch failures) + */ + private def removeExecutorAndUnregisterOutputs( + execId: String, + fileLost: Boolean, + hostToUnregisterOutputs: Option[String], + maybeEpoch: Option[Long] = None, + ignoreShuffleFileLostEpoch: Boolean = false): Unit = { + val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch) + logDebug(s"Considering removal of executor $execId; " + + s"fileLost: $fileLost, currentEpoch: $currentEpoch") + if (!executorFailureEpoch.contains(execId) || executorFailureEpoch(execId) < currentEpoch) { + executorFailureEpoch(execId) = currentEpoch + logInfo(s"Executor lost: $execId (epoch $currentEpoch)") + if (pushBasedShuffleEnabled) { + // Remove fetchFailed host in the shuffle push merger list for push based shuffle + hostToUnregisterOutputs.foreach( + host => blockManagerMaster.removeShufflePushMergerLocation(host)) + } + blockManagerMaster.removeExecutor(execId) + clearCacheLocs() + } + if (fileLost) { + val remove = if (ignoreShuffleFileLostEpoch) { + true + } else if (!shuffleFileLostEpoch.contains(execId) || + shuffleFileLostEpoch(execId) < currentEpoch) { + shuffleFileLostEpoch(execId) = currentEpoch + true + } else { + false + } + if (remove) { + hostToUnregisterOutputs match { + case Some(host) => + logInfo(s"Shuffle files lost for host: $host (epoch $currentEpoch)") + mapOutputTracker.removeOutputsOnHost(host) + case None => + logInfo(s"Shuffle files lost for executor: $execId (epoch $currentEpoch)") + mapOutputTracker.removeOutputsOnExecutor(execId) + } + } + } + } + + /** + * Responds to a worker being removed. This is called inside the event loop, so it assumes it can + * modify the scheduler's internal state. Use workerRemoved() to post a loss event from outside. + * + * We will assume that we've lost all shuffle blocks associated with the host if a worker is + * removed, so we will remove them all from MapStatus. + * + * @param workerId identifier of the worker that is removed. + * @param host host of the worker that is removed. + * @param message the reason why the worker is removed. + */ + private[scheduler] def handleWorkerRemoved( + workerId: String, + host: String, + message: String): Unit = { + logInfo("Shuffle files lost for worker %s on host %s".format(workerId, host)) + mapOutputTracker.removeOutputsOnHost(host) + clearCacheLocs() + } + + private[scheduler] def handleExecutorAdded(execId: String, host: String): Unit = { + // remove from executorFailureEpoch(execId) ? + if (executorFailureEpoch.contains(execId)) { + logInfo("Host added was in lost list earlier: " + host) + executorFailureEpoch -= execId + } + shuffleFileLostEpoch -= execId + + if (pushBasedShuffleEnabled) { + // Only set merger locations for stages that are not yet finished and have empty mergers + shuffleIdToMapStage.filter { case (_, stage) => + stage.shuffleDep.shuffleMergeAllowed && stage.shuffleDep.getMergerLocs.isEmpty && + runningStages.contains(stage) + }.foreach { case(_, stage: ShuffleMapStage) => + if (getAndSetShufflePushMergerLocations(stage).nonEmpty) { + logInfo(s"Shuffle merge enabled adaptively for $stage with shuffle" + + s" ${stage.shuffleDep.shuffleId} and shuffle merge" + + s" ${stage.shuffleDep.shuffleMergeId} with ${stage.shuffleDep.getMergerLocs.size}" + + s" merger locations") + mapOutputTracker.registerShufflePushMergerLocations(stage.shuffleDep.shuffleId, + stage.shuffleDep.getMergerLocs) + } + } + } + } + + private[scheduler] def handleStageCancellation(stageId: Int, reason: Option[String]): Unit = { + stageIdToStage.get(stageId) match { + case Some(stage) => + val jobsThatUseStage: Array[Int] = stage.jobIds.toArray + jobsThatUseStage.foreach { jobId => + val reasonStr = reason match { + case Some(originalReason) => + s"because $originalReason" + case None => + s"because Stage $stageId was cancelled" + } + handleJobCancellation(jobId, Option(reasonStr)) + } + case None => + logInfo("No active jobs to kill for Stage " + stageId) + } + } + + private[scheduler] def handleJobCancellation(jobId: Int, reason: Option[String]): Unit = { + if (!jobIdToStageIds.contains(jobId)) { + logDebug("Trying to cancel unregistered job " + jobId) + } else { + failJobAndIndependentStages( + jobIdToActiveJob(jobId), "Job %d cancelled %s".format(jobId, reason.getOrElse(""))) + } + } + + /** + * Marks a stage as finished and removes it from the list of running stages. + */ + private def markStageAsFinished( + stage: Stage, + errorMessage: Option[String] = None, + willRetry: Boolean = false): Unit = { + val serviceTime = stage.latestInfo.submissionTime match { + case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0) + case _ => "Unknown" + } + if (errorMessage.isEmpty) { + logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) + stage.latestInfo.completionTime = Some(clock.getTimeMillis()) + + // Clear failure count for this stage, now that it's succeeded. + // We only limit consecutive failures of stage attempts,so that if a stage is + // re-used many times in a long-running job, unrelated failures don't eventually cause the + // stage to be aborted. + stage.clearFailures() + } else { + stage.latestInfo.stageFailed(errorMessage.get) + logInfo(s"$stage (${stage.name}) failed in $serviceTime s due to ${errorMessage.get}") + } + updateStageInfoForPushBasedShuffle(stage) + if (!willRetry) { + outputCommitCoordinator.stageEnd(stage.id) + } + listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) + runningStages -= stage + } + + /** + * Aborts all jobs depending on a particular Stage. This is called in response to a task set + * being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. + */ + private[scheduler] def abortStage( + failedStage: Stage, + reason: String, + exception: Option[Throwable]): Unit = { + if (!stageIdToStage.contains(failedStage.id)) { + // Skip all the actions if the stage has been removed. + return + } + val dependentJobs: Seq[ActiveJob] = + activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq + failedStage.latestInfo.completionTime = Some(clock.getTimeMillis()) + updateStageInfoForPushBasedShuffle(failedStage) + for (job <- dependentJobs) { + failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason", exception) + } + if (dependentJobs.isEmpty) { + logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") + } + } + + private def updateStageInfoForPushBasedShuffle(stage: Stage): Unit = { + // With adaptive shuffle mergers, StageInfo's + // isPushBasedShuffleEnabled and shuffleMergers need to be updated at the end. + stage match { + case s: ShuffleMapStage => + stage.latestInfo.setPushBasedShuffleEnabled(s.shuffleDep.shuffleMergeEnabled) + if (s.shuffleDep.shuffleMergeEnabled) { + stage.latestInfo.setShuffleMergerCount(s.shuffleDep.getMergerLocs.size) + } + case _ => + } + } + + /** Cancel all independent, running stages that are only used by this job. */ + private def cancelRunningIndependentStages(job: ActiveJob, reason: String): Boolean = { + var ableToCancelStages = true + val stages = jobIdToStageIds(job.jobId) + if (stages.isEmpty) { + logError(s"No stages registered for job ${job.jobId}") + } + stages.foreach { stageId => + val jobsForStage: Option[HashSet[Int]] = stageIdToStage.get(stageId).map(_.jobIds) + if (jobsForStage.isEmpty || !jobsForStage.get.contains(job.jobId)) { + logError( + "Job %d not registered for stage %d even though that stage was registered for the job" + .format(job.jobId, stageId)) + } else if (jobsForStage.get.size == 1) { + if (!stageIdToStage.contains(stageId)) { + logError(s"Missing Stage for stage with id $stageId") + } else { + // This stage is only used by the job, so finish the stage if it is running. + val stage = stageIdToStage(stageId) + if (runningStages.contains(stage)) { + try { // cancelTasks will fail if a SchedulerBackend does not implement killTask + taskScheduler.cancelTasks(stageId, shouldInterruptTaskThread(job)) + markStageAsFinished(stage, Some(reason)) + } catch { + case e: UnsupportedOperationException => + logWarning(s"Could not cancel tasks for stage $stageId", e) + ableToCancelStages = false + } + } + } + } + } + ableToCancelStages + } + + /** Fails a job and all stages that are only used by that job, and cleans up relevant state. */ + private def failJobAndIndependentStages( + job: ActiveJob, + failureReason: String, + exception: Option[Throwable] = None): Unit = { + if (cancelRunningIndependentStages(job, failureReason)) { + // SPARK-15783 important to cleanup state first, just for tests where we have some asserts + // against the state. Otherwise we have a *little* bit of flakiness in the tests. + cleanupStateForJobAndIndependentStages(job) + val error = new SparkException(failureReason, exception.orNull) + job.listener.jobFailed(error) + listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobFailed(error))) + } + } + + /** Return true if one of stage's ancestors is target. */ + private def stageDependsOn(stage: Stage, target: Stage): Boolean = { + if (stage == target) { + return true + } + val visitedRdds = new HashSet[RDD[_]] + // We are manually maintaining a stack here to prevent StackOverflowError + // caused by recursively visiting + val waitingForVisit = new ListBuffer[RDD[_]] + waitingForVisit += stage.rdd + def visit(rdd: RDD[_]): Unit = { + if (!visitedRdds(rdd)) { + visitedRdds += rdd + for (dep <- rdd.dependencies) { + dep match { + case shufDep: ShuffleDependency[_, _, _] => + val mapStage = getOrCreateShuffleMapStage(shufDep, stage.firstJobId) + if (!mapStage.isAvailable) { + waitingForVisit.prepend(mapStage.rdd) + } // Otherwise there's no need to follow the dependency back + case narrowDep: NarrowDependency[_] => + waitingForVisit.prepend(narrowDep.rdd) + } + } + } + } + while (waitingForVisit.nonEmpty) { + visit(waitingForVisit.remove(0)) + } + visitedRdds.contains(target.rdd) + } + + /** + * Gets the locality information associated with a partition of a particular RDD. + * + * This method is thread-safe and is called from both DAGScheduler and SparkContext. + * + * @param rdd whose partitions are to be looked at + * @param partition to lookup locality information for + * @return list of machines that are preferred by the partition + */ + private[spark] + def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = { + getPreferredLocsInternal(rdd, partition, new HashSet) + } + + /** + * Recursive implementation for getPreferredLocs. + * + * This method is thread-safe because it only accesses DAGScheduler state through thread-safe + * methods (getCacheLocs()); please be careful when modifying this method, because any new + * DAGScheduler state accessed by it may require additional synchronization. + */ + private def getPreferredLocsInternal( + rdd: RDD[_], + partition: Int, + visited: HashSet[(RDD[_], Int)]): Seq[TaskLocation] = { + // If the partition has already been visited, no need to re-visit. + // This avoids exponential path exploration. SPARK-695 + if (!visited.add((rdd, partition))) { + // Nil has already been returned for previously visited partitions. + return Nil + } + // If the partition is cached, return the cache locations + val cached = getCacheLocs(rdd)(partition) + if (cached.nonEmpty) { + return cached + } + // If the RDD has some placement preferences (as is the case for input RDDs), get those + val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList + if (rddPrefs.nonEmpty) { + return rddPrefs.filter(_ != null).map(TaskLocation(_)) + } + + // If the RDD has narrow dependencies, pick the first partition of the first narrow dependency + // that has any placement preferences. Ideally we would choose based on transfer sizes, + // but this will do for now. + rdd.dependencies.foreach { + case n: NarrowDependency[_] => + for (inPart <- n.getParents(partition)) { + val locs = getPreferredLocsInternal(n.rdd, inPart, visited) + if (locs != Nil) { + return locs + } + } + + case _ => + } + + Nil + } + + /** Mark a map stage job as finished with the given output stats, and report to its listener. */ + def markMapStageJobAsFinished(job: ActiveJob, stats: MapOutputStatistics): Unit = { + // In map stage jobs, we only create a single "task", which is to finish all of the stage + // (including reusing any previous map outputs, etc); so we just mark task 0 as done + job.finished(0) = true + job.numFinished += 1 + job.listener.taskSucceeded(0, stats) + cleanupStateForJobAndIndependentStages(job) + listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded)) + } + + def stop(): Unit = { + messageScheduler.shutdownNow() + shuffleMergeFinalizeScheduler.shutdownNow() + eventProcessLoop.stop() + taskScheduler.stop() + } + + eventProcessLoop.start() +} + +private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler) + extends EventLoop[DAGSchedulerEvent]("dag-scheduler-event-loop") with Logging { + + private[this] val timer = dagScheduler.metricsSource.messageProcessingTimer + + /** + * The main event loop of the DAG scheduler. + */ + override def onReceive(event: DAGSchedulerEvent): Unit = { + val timerContext = timer.time() + try { + doOnReceive(event) + } finally { + timerContext.stop() + } + } + + private def doOnReceive(event: DAGSchedulerEvent): Unit = event match { + case JobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) => + dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) + + case MapStageSubmitted(jobId, dependency, callSite, listener, properties) => + dagScheduler.handleMapStageSubmitted(jobId, dependency, callSite, listener, properties) + + case StageCancelled(stageId, reason) => + dagScheduler.handleStageCancellation(stageId, reason) + + case JobCancelled(jobId, reason) => + dagScheduler.handleJobCancellation(jobId, reason) + + case JobGroupCancelled(groupId) => + dagScheduler.handleJobGroupCancelled(groupId) + + case AllJobsCancelled => + dagScheduler.doCancelAllJobs() + + case ExecutorAdded(execId, host) => + dagScheduler.handleExecutorAdded(execId, host) + + case ExecutorLost(execId, reason) => + val workerHost = reason match { + case ExecutorProcessLost(_, workerHost, _) => workerHost + case ExecutorDecommission(workerHost) => workerHost + case _ => None + } + dagScheduler.handleExecutorLost(execId, workerHost) + + case WorkerRemoved(workerId, host, message) => + dagScheduler.handleWorkerRemoved(workerId, host, message) + + case BeginEvent(task, taskInfo) => + dagScheduler.handleBeginEvent(task, taskInfo) + + case SpeculativeTaskSubmitted(task) => + dagScheduler.handleSpeculativeTaskSubmitted(task) + + case UnschedulableTaskSetAdded(stageId, stageAttemptId) => + dagScheduler.handleUnschedulableTaskSetAdded(stageId, stageAttemptId) + + case UnschedulableTaskSetRemoved(stageId, stageAttemptId) => + dagScheduler.handleUnschedulableTaskSetRemoved(stageId, stageAttemptId) + + case GettingResultEvent(taskInfo) => + dagScheduler.handleGetTaskResult(taskInfo) + + case completion: CompletionEvent => + dagScheduler.handleTaskCompletion(completion) + + case TaskSetFailed(taskSet, reason, exception) => + dagScheduler.handleTaskSetFailed(taskSet, reason, exception) + + case ResubmitFailedStages => + dagScheduler.resubmitFailedStages() + + case RegisterMergeStatuses(stage, mergeStatuses) => + dagScheduler.handleRegisterMergeStatuses(stage, mergeStatuses) + + case ShuffleMergeFinalized(stage) => + dagScheduler.handleShuffleMergeFinalized(stage, stage.shuffleDep.shuffleMergeId) + + case ShufflePushCompleted(shuffleId, shuffleMergeId, mapIndex) => + dagScheduler.handleShufflePushCompleted(shuffleId, shuffleMergeId, mapIndex) + } + + override def onError(e: Throwable): Unit = { + logError("DAGSchedulerEventProcessLoop failed; shutting down SparkContext", e) + try { + dagScheduler.doCancelAllJobs() + } catch { + case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t) + } + dagScheduler.sc.stopInNewThread() + } + + override def onStop(): Unit = { + // Cancel any active jobs in postStop hook + dagScheduler.cleanUpAfterSchedulerStop() + } +} + +private[spark] object DAGScheduler { + // The time, in millis, to wait for fetch failure events to stop coming in after one is detected; + // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one + // as more failure events come in + val RESUBMIT_TIMEOUT = 200 + + // Number of consecutive stage attempts allowed before a stage is aborted + val DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS = 4 +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/pom.xml b/omnioperator/omniop-spark-extension/pom.xml index 4f3241582730ff5b7ac2ce2b8acc6b3c02d3572d..4c93cff7fce6d9a1baf63ecb320f4611ec83580a 100644 --- a/omnioperator/omniop-spark-extension/pom.xml +++ b/omnioperator/omniop-spark-extension/pom.xml @@ -20,7 +20,7 @@ 2.12.10 2.12 3.3.1 - 1.7.0 + 1.8.0 3.2.2 3.13.0-h19 0.6.1 diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java index 1aab7425b4d703e3f20310277db3ca22db566dee..3f8b2db2cba5e7347d8ff5117f8980f6cb062258 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java @@ -19,6 +19,7 @@ package com.huawei.boostkit.spark.jni; import com.huawei.boostkit.scan.jni.OrcColumnarBatchJniReader; +import com.huawei.boostkit.spark.predicate.*; import com.huawei.boostkit.spark.timestamp.JulianGregorianRebase; import com.huawei.boostkit.spark.timestamp.TimestampUtil; import nova.hetu.omniruntime.type.DataType; @@ -62,9 +63,12 @@ import java.text.ParseException; import java.text.SimpleDateFormat; import java.util.ArrayList; import java.util.Arrays; +import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.TimeZone; +import java.util.stream.Collectors; +import java.util.stream.IntStream; public class OrcColumnarBatchScanReader { private static final Logger LOGGER = LoggerFactory.getLogger(OrcColumnarBatchScanReader.class); @@ -176,7 +180,8 @@ public class OrcColumnarBatchScanReader { * @param pushedFilter the filter push down to native * @param requiredSchema the columns read from native */ - public long initializeRecordReaderJava(long offset, long length, Filter pushedFilter, StructType requiredSchema) { + public long initializeRecordReaderJava(long offset, long length, Filter pushedFilter, StructType requiredSchema, + Boolean vecPredicateFilter, Boolean filterPushDown) { this.requiredSchema = requiredSchema; JSONObject job = new JSONObject(); @@ -184,12 +189,20 @@ public class OrcColumnarBatchScanReader { job.put("length", length); if (pushedFilter != null) { - JSONObject jsonExpressionTree = new JSONObject(); - JSONObject jsonLeaves = new JSONObject(); - boolean flag = canPushDown(pushedFilter, jsonExpressionTree, jsonLeaves); - if (flag) { - job.put("expressionTree", jsonExpressionTree); - job.put("leaves", jsonLeaves); + if (filterPushDown != null && filterPushDown) { + JSONObject jsonExpressionTree = new JSONObject(); + JSONObject jsonLeaves = new JSONObject(); + boolean flag = canPushDown(pushedFilter, jsonExpressionTree, jsonLeaves); + if (flag) { + job.put("expressionTree", jsonExpressionTree); + job.put("leaves", jsonLeaves); + } + } + if (vecPredicateFilter != null && vecPredicateFilter) { + String vecPredicateCondition = buildVecPredicateCondition(pushedFilter); + if (vecPredicateCondition != null) { + job.put("vecPredicateCondition", vecPredicateCondition); + } } } @@ -525,4 +538,87 @@ public class OrcColumnarBatchScanReader { throw new UnsupportedOperationException("Unsupported orc push down filter date type: " + literal.getClass().getSimpleName()); } + + private String buildVecPredicateCondition(Filter filterPredicate) { + try { + Map nameToIndex = IntStream.range(0, includedColumns.size()) + .boxed() + .collect(Collectors.toMap(includedColumns::get, i -> i)); + return buildPredicateCondition(filterPredicate, nameToIndex).reduce().toString(); + } catch (Exception e) { + LOGGER.info("Unable to build vec predicate because " + e.getMessage()); + return null; + } + } + + private PredicateCondition buildPredicateCondition(Filter filterPredicate, Map nameToIndex) { + if (filterPredicate instanceof And) { + return new AndPredicateCondition(buildPredicateCondition(((And) filterPredicate).left(), nameToIndex), + buildPredicateCondition(((And) filterPredicate).right(), nameToIndex)); + } else if (filterPredicate instanceof Or) { + return new OrPredicateCondition(buildPredicateCondition(((Or) filterPredicate).left(), nameToIndex), + buildPredicateCondition(((Or) filterPredicate).right(), nameToIndex)); + } else if (filterPredicate instanceof Not) { + return new NotPredicateCondition(buildPredicateCondition(((Not) filterPredicate).child(), nameToIndex)); + } else if (filterPredicate instanceof EqualTo) { + return buildLeafPredicateCondition(PredicateOperatorType.EQUAL_TO, ((EqualTo) filterPredicate).attribute(), + ((EqualTo) filterPredicate).value(), nameToIndex); + } else if (filterPredicate instanceof GreaterThan) { + return buildLeafPredicateCondition(PredicateOperatorType.GREATER_THAN, + ((GreaterThan) filterPredicate).attribute(), ((GreaterThan) filterPredicate).value(), nameToIndex); + } else if (filterPredicate instanceof GreaterThanOrEqual) { + return buildLeafPredicateCondition(PredicateOperatorType.GREATER_THAN_OR_EQUAL, + ((GreaterThanOrEqual) filterPredicate).attribute(), ((GreaterThanOrEqual) filterPredicate).value(), nameToIndex); + } else if (filterPredicate instanceof LessThan) { + return buildLeafPredicateCondition(PredicateOperatorType.LESS_THAN, + ((LessThan) filterPredicate).attribute(), ((LessThan) filterPredicate).value(), nameToIndex); + } else if (filterPredicate instanceof LessThanOrEqual) { + return buildLeafPredicateCondition(PredicateOperatorType.LESS_THAN_OR_EQUAL, + ((LessThanOrEqual) filterPredicate).attribute(), ((LessThanOrEqual) filterPredicate).value(), nameToIndex); + } else if (filterPredicate instanceof IsNotNull) { + return buildLeafPredicateCondition(PredicateOperatorType.IS_NOT_NULL, + ((IsNotNull) filterPredicate).attribute(), "-1", nameToIndex); + } else if (filterPredicate instanceof IsNull) { + return buildLeafPredicateCondition(PredicateOperatorType.IS_NULL, + ((IsNull) filterPredicate).attribute(), "-1", nameToIndex); + } else { + throw new UnsupportedOperationException("Unsupported orc vec predicate operation: " + + filterPredicate.getClass().getSimpleName()); + } + } + + private PredicateCondition buildLeafPredicateCondition(PredicateOperatorType op, String attribute, Object literal, + Map nameToIndex) { + Integer index = nameToIndex.get(attribute); + if (index == null) { + throw new UnsupportedOperationException("Attribute is not found in nameToIndex. attribute: " + attribute); + } + if (op == PredicateOperatorType.IS_NOT_NULL || op == PredicateOperatorType.IS_NULL) { + return new LeafPredicateCondition(op, index, DataType.DataTypeId.OMNI_INT, "-1"); + } + DataType.DataTypeId dataType = getSupportPredicateDataType(attribute); + String value = getLiteralValue(literal); + return new LeafPredicateCondition(op, index, dataType, value); + } + + private DataType.DataTypeId getSupportPredicateDataType(String attribute) { + StructField field = requiredSchema.apply(attribute); + org.apache.spark.sql.types.DataType dataType = field.dataType(); + if (dataType instanceof ShortType) { + return DataType.DataTypeId.OMNI_SHORT; + } else if (dataType instanceof IntegerType) { + return DataType.DataTypeId.OMNI_INT; + } else if (dataType instanceof LongType) { + return DataType.DataTypeId.OMNI_LONG; + } else if (dataType instanceof DoubleType) { + return DataType.DataTypeId.OMNI_DOUBLE; + } else if (dataType instanceof DateType) { + return DataType.DataTypeId.OMNI_DATE32; + } else if (dataType instanceof BooleanType) { + return DataType.DataTypeId.OMNI_BOOLEAN; + } else { + throw new UnsupportedOperationException("Unsupported orc vec predicate data type: " + + dataType.getClass().getSimpleName()); + } + } } diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchScanReader.java b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchScanReader.java index 40fcb06d9203e2b7ac0a95ea1c79b02959cf511b..96d2f56dd76485ba6f43c5ede8bbf970f40c094a 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchScanReader.java +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchScanReader.java @@ -31,6 +31,8 @@ import org.apache.parquet.schema.Type; import org.apache.spark.sql.catalyst.util.RebaseDateTime; import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec; import org.apache.spark.sql.execution.datasources.DataSourceUtils; + +import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.sql.sources.And; import org.apache.spark.sql.sources.EqualTo; import org.apache.spark.sql.sources.Filter; @@ -80,6 +82,10 @@ public class ParquetColumnarBatchScanReader { private final Function1 int96RebaseFunc; + private boolean nativeSupportDateRebase; + + private final Function1 dateRebaseFunc; + private final List parquetTypes; private ArrayList allFieldsNames; @@ -93,12 +99,15 @@ public class ParquetColumnarBatchScanReader { : DataSourceUtils.createTimestampRebaseFuncInRead(datetimeRebaseSpec, "Parquet"); this.int96RebaseFunc = int96RebaseSpec.mode() == null ? micros -> micros : DataSourceUtils.createTimestampRebaseFuncInRead(int96RebaseSpec, "Parquet INT96"); + this.dateRebaseFunc = datetimeRebaseSpec.mode() == null ? date-> date + : DataSourceUtils.createDateRebaseFuncInRead(datetimeRebaseSpec.mode(),"Parquet"); this.parquetTypes = parquetTypes; jniReader = new ParquetColumnarBatchJniReader(); } private void addJulianGregorianInfo(JSONObject job) { - if (Arrays.stream(requiredSchema.fields()).noneMatch(field -> field.dataType() instanceof TimestampType)) { + if (Arrays.stream(requiredSchema.fields()).noneMatch(field -> field.dataType() instanceof TimestampType) + && Arrays.stream(requiredSchema.fields()).noneMatch(field -> field.dataType() instanceof DateType)) { return; } TimestampUtil instance = TimestampUtil.getInstance(); @@ -384,6 +393,20 @@ public class ParquetColumnarBatchScanReader { } } + private void tryToAdjustDateVec(IntVec intVec, long rowNumber, int index) { + Type parquetType = parquetTypes.get(index); + if (parquetType == null) { + throw new RuntimeException("parquetType is null. index: " + index); + } + if (parquetType.getLogicalTypeAnnotation() instanceof LogicalTypeAnnotation.DateLogicalTypeAnnotation) { + int adjustValue; + for (int rowIndex = 0; rowIndex < rowNumber; rowIndex++) { + adjustValue = (int) dateRebaseFunc.apply(intVec.get(rowIndex)); + intVec.set(rowIndex, adjustValue); + } + } + } + public int next(Vec[] vecList, boolean[] missingColumns, List types) { int colsCount = missingColumns.length; long[] vecNativeIds = new long[types.size()]; @@ -417,6 +440,7 @@ public class ParquetColumnarBatchScanReader { vecList[i] = new VarcharVec(vecNativeIds[nativeGetId]); } else if (type instanceof DateType) { vecList[i] = new IntVec(vecNativeIds[nativeGetId]); + tryToAdjustDateVec((IntVec) vecList[i], rtn, i); } else if (type instanceof ByteType) { vecList[i] = new VarcharVec(vecNativeIds[nativeGetId]); } else if (type instanceof TimestampType) { diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchWriter.java b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchWriter.java new file mode 100644 index 0000000000000000000000000000000000000000..a15d8b0e32de38492fe790df6a9a2e9f481799e8 --- /dev/null +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchWriter.java @@ -0,0 +1,330 @@ +/* + * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.jni; + +import com.huawei.boostkit.scan.jni.ParquetColumnarBatchJniReader; +import com.huawei.boostkit.write.jni.OrcColumnarBatchJniWriter; +import com.huawei.boostkit.write.jni.ParquetColumnarBatchJniWriter; + +import nova.hetu.omniruntime.vector.IntVec; +import nova.hetu.omniruntime.vector.*; + +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.security.UserGroupInformation; +import org.apache.orc.OrcFile; +import org.apache.spark.sql.catalyst.util.RebaseDateTime; +import org.apache.spark.sql.execution.vectorized.OmniColumnVector; +import org.apache.spark.sql.types.BooleanType; +import org.apache.spark.sql.types.CharType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.VarcharType; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.json.JSONObject; + +import java.io.IOException; +import java.net.URI; + +public class ParquetColumnarBatchWriter { + public long writer; + + public long schema; + + public ParquetColumnarBatchJniWriter jniWriter; + + public ParquetColumnarBatchWriter() { + jniWriter = new ParquetColumnarBatchJniWriter(); + } + + public enum ParquetLibTypeKind { + NA, + /// Boolean as 1 bit, LSB bit-packed ordering + BOOL, + + /// Unsigned 8-bit little-endian integer + UINT8, + + /// Signed 8-bit little-endian integer + INT8, + + /// Unsigned 16-bit little-endian integer + UINT16, + + /// Signed 16-bit little-endian integer + INT16, + + /// Unsigned 32-bit little-endian integer + UINT32, + + /// Signed 32-bit little-endian integer + INT32, + + /// Unsigned 64-bit little-endian integer + UINT64, + + /// Signed 64-bit little-endian integer + INT64, + + /// 2-byte floating point value + HALF_FLOAT, + + /// 4-byte floating point value + FLOAT, + + /// 8-byte floating point value + DOUBLE, + + /// UTF8 variable-length string as List + STRING, + + /// Variable-length bytes (no guarantee of UTF8-ness) + BINARY, + + /// Fixed-size binary. Each value occupies the same number of bytes + FIXED_SIZE_BINARY, + + /// int32_t days since the UNIX epoch + DATE32, + + /// int64_t milliseconds since the UNIX epoch + DATE64, + + /// Exact timestamp encoded with int64 since UNIX epoch + /// Default unit millisecond + TIMESTAMP, + + /// Time as signed 32-bit integer, representing either seconds or + /// milliseconds since midnight + TIME32, + + /// Time as signed 64-bit integer, representing either microseconds or + /// nanoseconds since midnight + TIME64, + + /// YEAR_MONTH interval in SQL style + INTERVAL_MONTHS, + + /// DAY_TIME interval in SQL style + INTERVAL_DAY_TIME, + + /// Precision- and scale-based decimal type with 128 bits. + DECIMAL128, + + /// Defined for backward-compatibility. + // DECIMAL = DECIMAL128, + + /// Precision- and scale-based decimal type with 256 bits. + DECIMAL256, + + /// A list of some logical data type + LIST, + + /// Struct of logical types + STRUCT, + + /// Sparse unions of logical types + SPARSE_UNION, + + /// Dense unions of logical types + DENSE_UNION, + + /// Dictionary-encoded type, also called "categorical" or "factor" + /// in other programming languages. Holds the dictionary value + /// type but not the dictionary itself, which is part of the + /// ArrayData struct + DICTIONARY, + + /// Map, a repeated struct logical type + MAP, + + /// Custom data type, implemented by user + EXTENSION, + + /// Fixed size list of some logical type + FIXED_SIZE_LIST, + + /// Measure of elapsed time in either seconds, milliseconds, microseconds + /// or nanoseconds. + DURATION, + + /// Like STRING, but with 64-bit offsets + LARGE_STRING, + + /// Like BINARY, but with 64-bit offsets + LARGE_BINARY, + + /// Like LIST, but with 64-bit offsets + LARGE_LIST, + + /// Calendar interval type with three fields. + INTERVAL_MONTH_DAY_NANO, + + // Leave this at the end + MAX_ID + } + + public void initializeWriterJava(Path path) throws IOException { + JSONObject writerOptionsJson = new JSONObject(); + String ugi = UserGroupInformation.getCurrentUser().toString(); + + URI uri = path.toUri(); + + writerOptionsJson.put("uri", path.toString()); + writerOptionsJson.put("ugi", ugi); + + writerOptionsJson.put("host", uri.getHost() == null ? "" : uri.getHost()); + writerOptionsJson.put("scheme", uri.getScheme() == null ? "" : uri.getScheme()); + writerOptionsJson.put("port", uri.getPort()); + writerOptionsJson.put("path", uri.getPath() == null ? "" : uri.getPath()); + + // writer = jniWriter.initializeWriter(writerOptionsJson); + jniWriter.initializeWriter(writerOptionsJson, writer); + } + + public void convertGreGorianToJulian(IntVec intVec, int startPos, int endPos) { + int julianValue; + for (int rowIndex = startPos; rowIndex < endPos; rowIndex++) { + julianValue = RebaseDateTime.rebaseGregorianToJulianDays(intVec.get(rowIndex)); + intVec.set(rowIndex, julianValue); + } + } + + public void initializeSchemaJava(StructType dataSchema) { + int schemaLength = dataSchema.length(); + String[] fieldNames = new String[schemaLength]; + int[] fieldTypes = new int[schemaLength]; + boolean[] nullables = new boolean[schemaLength]; + for (int i = 0; i < schemaLength; i++) { + StructField field = dataSchema.fields()[i]; + fieldNames[i] = field.name(); + fieldTypes[i] = sparkTypeToParquetLibType(field.dataType()); + nullables[i] = field.nullable(); + } + writer = jniWriter.initializeSchema(writer, fieldNames, fieldTypes, nullables, extractDecimalParam(dataSchema)); + } + + public int sparkTypeToParquetLibType(DataType dataType) { + int parquetType; + if (dataType instanceof BooleanType) { + parquetType = ParquetLibTypeKind.BOOL.ordinal(); + } else if (dataType instanceof ShortType) { + parquetType = ParquetLibTypeKind.INT16.ordinal(); + } else if (dataType instanceof IntegerType) { + parquetType = ParquetLibTypeKind.INT32.ordinal(); + } else if (dataType instanceof LongType) { + parquetType = ParquetLibTypeKind.INT64.ordinal(); + } else if (dataType instanceof DateType) { + DateType dateType = (DateType) dataType; + switch (dateType.defaultSize()) { + case 4: + parquetType = ParquetLibTypeKind.DATE32.ordinal(); + break; + case 8: + parquetType = ParquetLibTypeKind.DATE64.ordinal(); + break; + default: + throw new RuntimeException( + "UnSupport size " + dateType.defaultSize() + " of date type"); + } + } else if (dataType instanceof DoubleType) { + parquetType = ParquetLibTypeKind.DOUBLE.ordinal(); + } else if (dataType instanceof VarcharType) { + parquetType = ParquetLibTypeKind.STRING.ordinal(); + } else if (dataType instanceof StringType) { + parquetType = ParquetLibTypeKind.STRING.ordinal(); + } else if (dataType instanceof CharType) { + parquetType = ParquetLibTypeKind.STRING.ordinal(); + } else if (dataType instanceof DecimalType) { + DecimalType decimalType = (DecimalType) dataType; + switch (decimalType.defaultSize()) { + case 8: + case 16: + parquetType = ParquetLibTypeKind.DECIMAL128.ordinal(); + break; + default: + throw new RuntimeException( + "UnSupport size " + decimalType.defaultSize() + " of decimal type"); + } + } else { + throw new RuntimeException( + "UnSupport type convert spark type " + dataType.simpleString() + " to parquet lib type"); + } + return parquetType; + } + + public int[][] extractDecimalParam(StructType dataSchema) { + int paramNum = 2; + int precisionIndex = 0; + int scaleIndex = 1; + int[][] decimalParams = new int[dataSchema.length()][paramNum]; + for (int i = 0; i < dataSchema.length(); i++) { + DataType dataType = dataSchema.fields()[i].dataType(); + if (dataType instanceof DecimalType) { + DecimalType decimal = (DecimalType) dataType; + decimalParams[i][precisionIndex] = decimal.precision(); + decimalParams[i][scaleIndex] = decimal.scale(); + } + } + return decimalParams; + } + + public void write(int[] omniTypes, boolean[] dataColumnsIds, ColumnarBatch batch) { + + long[] vecNativeIds = new long[batch.numCols()]; + for (int i = 0; i < batch.numCols(); i++) { + OmniColumnVector omniVec = (OmniColumnVector) batch.column(i); + Vec vec = omniVec.getVec(); + vecNativeIds[i] = vec.getNativeVector(); + boolean isDateType = (omniTypes[i] == 8); + if (isDateType) { + convertGreGorianToJulian((IntVec) vec, 0, batch.numRows()); + } + } + + jniWriter.write(writer, vecNativeIds, omniTypes, dataColumnsIds, batch.numRows()); + } + + public void splitWrite(int[] omniTypes, int[] allOmniTypes, boolean[] dataColumnsIds, ColumnarBatch inputBatch, long startPos, long endPos) { + long[] vecNativeIds = new long[inputBatch.numCols()]; + for (int i = 0; i < inputBatch.numCols(); i++) { + OmniColumnVector omniVec = (OmniColumnVector) inputBatch.column(i); + Vec vec = omniVec.getVec(); + vecNativeIds[i] = vec.getNativeVector(); + boolean isDateType = (allOmniTypes[i] == 8); + if (isDateType) { + convertGreGorianToJulian((IntVec) vec, (int) startPos, (int) endPos); + } + } + + jniWriter.splitWrite(writer, vecNativeIds, omniTypes, dataColumnsIds, startPos, endPos); + } + + public void close() { + jniWriter.close(writer); + } + +} diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/predicate/AndPredicateCondition.java b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/predicate/AndPredicateCondition.java new file mode 100644 index 0000000000000000000000000000000000000000..00dd0691b3257023f54851183cf4328c8adef121 --- /dev/null +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/predicate/AndPredicateCondition.java @@ -0,0 +1,47 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.predicate; + +public class AndPredicateCondition extends BinaryPredicateCondition { + + public AndPredicateCondition(PredicateCondition left, PredicateCondition right) { + super(PredicateOperatorType.AND, left, right); + } + + @Override + protected PredicateCondition reduce(boolean upExistsNotOp) { + PredicateCondition leftReduce = left.reduce(upExistsNotOp); + PredicateCondition rightReduce = right.reduce(upExistsNotOp); + if (leftReduce == LeafPredicateCondition.TRUE_PREDICATE_CONDITION && + rightReduce == LeafPredicateCondition.TRUE_PREDICATE_CONDITION) { + return LeafPredicateCondition.TRUE_PREDICATE_CONDITION; + } + if (upExistsNotOp && (leftReduce == LeafPredicateCondition.TRUE_PREDICATE_CONDITION || + rightReduce == LeafPredicateCondition.TRUE_PREDICATE_CONDITION)) { + return LeafPredicateCondition.TRUE_PREDICATE_CONDITION; + } + if (leftReduce == LeafPredicateCondition.TRUE_PREDICATE_CONDITION) { + return rightReduce; + } + if (rightReduce == LeafPredicateCondition.TRUE_PREDICATE_CONDITION) { + return leftReduce; + } + return new AndPredicateCondition(leftReduce, rightReduce); + } +} diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/predicate/BinaryPredicateCondition.java b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/predicate/BinaryPredicateCondition.java new file mode 100644 index 0000000000000000000000000000000000000000..e772fb69c40e9f73a0fc61fb1de38953906a56fc --- /dev/null +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/predicate/BinaryPredicateCondition.java @@ -0,0 +1,38 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.predicate; + +public abstract class BinaryPredicateCondition extends PredicateCondition { + + protected PredicateCondition left; + + protected PredicateCondition right; + + public BinaryPredicateCondition(PredicateOperatorType op, PredicateCondition left, PredicateCondition right) { + super(op); + this.left = left; + this.right = right; + } + + @Override + public String toString() { + return String.format("{\"op\":%d,\"left\":%s,\"right\":%s}", + op.ordinal(), left.toString(), right.toString()); + } +} diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/predicate/LeafPredicateCondition.java b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/predicate/LeafPredicateCondition.java new file mode 100644 index 0000000000000000000000000000000000000000..88f699c34957bfd9d2443919ac322150ec219d4b --- /dev/null +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/predicate/LeafPredicateCondition.java @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.predicate; + +import nova.hetu.omniruntime.type.DataType.DataTypeId; + +public class LeafPredicateCondition extends PredicateCondition { + + public static final PredicateCondition TRUE_PREDICATE_CONDITION = + new LeafPredicateCondition(PredicateOperatorType.TRUE, -1, DataTypeId.OMNI_INT, "-1"); + + private int index; + + private DataTypeId dataType; + + private String value; + + public LeafPredicateCondition(PredicateOperatorType op, int index, DataTypeId dataType, String value) { + super(op); + this.index = index; + this.dataType = dataType; + this.value = value; + } + + @Override + protected PredicateCondition reduce(boolean upExistsNotOp) { + return this; + } + + @Override + public String toString() { + return String.format("{\"op\":%d,\"index\":%d,\"dataType\":%d,\"value\":\"%s\"}", + op.ordinal(), index, dataType.toValue(), value); + } +} diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/predicate/NotPredicateCondition.java b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/predicate/NotPredicateCondition.java new file mode 100644 index 0000000000000000000000000000000000000000..ef14ea3bb9588d3899ceaafa45288674e0e8f44c --- /dev/null +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/predicate/NotPredicateCondition.java @@ -0,0 +1,43 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.predicate; + +public class NotPredicateCondition extends PredicateCondition { + + PredicateCondition child; + + public NotPredicateCondition(PredicateCondition child) { + super(PredicateOperatorType.NOT); + this.child = child; + } + + @Override + protected PredicateCondition reduce(boolean upExistsNotOp) { + PredicateCondition childReduce = child.reduce(true); + if (childReduce == LeafPredicateCondition.TRUE_PREDICATE_CONDITION) { + return LeafPredicateCondition.TRUE_PREDICATE_CONDITION; + } + return new NotPredicateCondition(childReduce); + } + + @Override + public String toString() { + return String.format("{\"op\":%d,\"child\":%s}", op.ordinal(), child.toString()); + } +} diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/predicate/OrPredicateCondition.java b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/predicate/OrPredicateCondition.java new file mode 100644 index 0000000000000000000000000000000000000000..5c9f527d675561443bdf702ccf21b4244ea262df --- /dev/null +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/predicate/OrPredicateCondition.java @@ -0,0 +1,37 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.predicate; + +public class OrPredicateCondition extends BinaryPredicateCondition { + + public OrPredicateCondition(PredicateCondition left, PredicateCondition right) { + super(PredicateOperatorType.OR, left, right); + } + + @Override + protected PredicateCondition reduce(boolean upExistsNotOp) { + PredicateCondition leftReduce = left.reduce(upExistsNotOp); + PredicateCondition rightReduce = right.reduce(upExistsNotOp); + if (leftReduce == LeafPredicateCondition.TRUE_PREDICATE_CONDITION || + rightReduce == LeafPredicateCondition.TRUE_PREDICATE_CONDITION) { + return LeafPredicateCondition.TRUE_PREDICATE_CONDITION; + } + return new OrPredicateCondition(leftReduce, rightReduce); + } +} diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/predicate/PredicateCondition.java b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/predicate/PredicateCondition.java new file mode 100644 index 0000000000000000000000000000000000000000..5296337918a8efb5595b24e15b23f6e18da6215a --- /dev/null +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/predicate/PredicateCondition.java @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.predicate; + +public abstract class PredicateCondition { + protected PredicateOperatorType op; + + public PredicateCondition(PredicateOperatorType op) { + this.op = op; + } + + protected abstract PredicateCondition reduce(boolean upExistsNotOp); + + public PredicateCondition reduce() { + return reduce(false); + } +} diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/predicate/PredicateOperatorType.java b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/predicate/PredicateOperatorType.java new file mode 100644 index 0000000000000000000000000000000000000000..88a28c55f37af124ec861235901fa90a2b535e7e --- /dev/null +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/predicate/PredicateOperatorType.java @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.predicate; + +public enum PredicateOperatorType { + TRUE, + EQUAL_TO, + GREATER_THAN, + GREATER_THAN_OR_EQUAL, + LESS_THAN, + LESS_THAN_OR_EQUAL, + IS_NOT_NULL, + IS_NULL, + OR, + AND, + NOT +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java index 0859173ca18d6f8ae4730b57c6be0f474d15f94f..1471d6047db63283e260f19a5f91dad06d5e332b 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java @@ -19,12 +19,13 @@ package com.huawei.boostkit.spark.serialize; +import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; import com.huawei.boostkit.spark.jni.NativeLoader; import nova.hetu.omniruntime.type.*; import nova.hetu.omniruntime.utils.OmniRuntimeException; import nova.hetu.omniruntime.vector.*; -import nova.hetu.omniruntime.vector.serialize.OmniRowDeserializer; +import nova.hetu.omniruntime.type.DataType.DataTypeId; import org.apache.spark.sql.execution.vectorized.OmniColumnVector; import org.apache.spark.sql.types.DataType; @@ -34,30 +35,52 @@ import org.apache.spark.sql.vectorized.ColumnarBatch; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.lang.reflect.Field; +import java.util.Arrays; + +import sun.misc.Unsafe; + public class ShuffleDataSerializer { private static final Logger LOG = LoggerFactory.getLogger(NativeLoader.class); + private static final Unsafe unsafe; + private static final long BYTE_ARRAY_BASE_OFFSET; - public static ColumnarBatch deserialize(boolean isRowShuffle, byte[] bytes) { - if (!isRowShuffle) { - return deserializeByColumn(bytes); - } else { - return deserializeByRow(bytes); + static { + try { + Field field = Unsafe.class.getDeclaredField("theUnsafe"); + field.setAccessible(true); + unsafe = (Unsafe) field.get(null); + BYTE_ARRAY_BASE_OFFSET = unsafe.arrayBaseOffset(byte[].class); + } catch (NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException("create unsafe object failed. errmsg:" + e.getMessage()); } } - public static ColumnarBatch deserializeByColumn(byte[] bytes) { + public static ColumnarBatch deserialize(boolean isRowShuffle, byte[] bytes, int readSize) { ColumnVector[] vecs = null; + long address = -1; + ShuffleDataSerializerUtils deserializer = null; try { - VecData.VecBatch vecBatch = VecData.VecBatch.parseFrom(bytes); - int vecCount = vecBatch.getVecCnt(); - int rowCount = vecBatch.getRowCnt(); + address = unsafe.allocateMemory(readSize); + unsafe.copyMemory(bytes, BYTE_ARRAY_BASE_OFFSET, null, address, readSize); + + deserializer = new ShuffleDataSerializerUtils(); + deserializer.init(address, readSize, isRowShuffle); + int vecCount = deserializer.getVecCount(); + int rowCount = deserializer.getRowCount(); + + int[] typeIdArray = new int[vecCount]; + int[] precisionArray = new int[vecCount]; + int[] scaleArray = new int[vecCount]; + long[] vecNativeIdArray = new long[vecCount]; + deserializer.parse(typeIdArray, precisionArray, scaleArray, vecNativeIdArray); vecs = new ColumnVector[vecCount]; for (int i = 0; i < vecCount; i++) { - vecs[i] = buildVec(vecBatch.getVecs(i), rowCount); + vecs[i] = buildVec(typeIdArray[i], vecNativeIdArray[i], rowCount, precisionArray[i], scaleArray[i]); } + deserializer.close(); + unsafe.freeMemory(address); return new ColumnarBatch(vecs, rowCount); - } catch (InvalidProtocolBufferException e) { - throw new RuntimeException("deserialize failed. errmsg:" + e.getMessage()); } catch (OmniRuntimeException e) { if (vecs != null) { for (int i = 0; i < vecs.length; i++) { @@ -67,184 +90,74 @@ public class ShuffleDataSerializer { } } } - throw new RuntimeException("deserialize failed. errmsg:" + e.getMessage()); - } - } - - public static ColumnarBatch deserializeByRow(byte[] bytes) { - try { - VecData.ProtoRowBatch rowBatch = VecData.ProtoRowBatch.parseFrom(bytes); - int vecCount = rowBatch.getVecCnt(); - int rowCount = rowBatch.getRowCnt(); - OmniColumnVector[] columnarVecs = new OmniColumnVector[vecCount]; - long[] omniVecs = new long[vecCount]; - int[] omniTypes = new int[vecCount]; - createEmptyVec(rowBatch, omniTypes, omniVecs, columnarVecs, vecCount, rowCount); - OmniRowDeserializer deserializer = new OmniRowDeserializer(omniTypes, omniVecs); - - for (int rowIdx = 0; rowIdx < rowCount; rowIdx++) { - VecData.ProtoRow protoRow = rowBatch.getRows(rowIdx); - byte[] array = protoRow.getData().toByteArray(); - deserializer.parse(array, rowIdx); + if (deserializer != null) { + deserializer.close(); } - - // update initial varchar vector because it's capacity might have been expanded. - for (int i = 0; i < vecCount; i++) { - if (omniTypes[i] == VarcharDataType.VARCHAR.getId().toValue()) { - Vec varcharVec = VecFactory.create(omniVecs[i], VecEncoding.OMNI_VEC_ENCODING_FLAT, VarcharDataType.VARCHAR); - columnarVecs[i].setVec(varcharVec); - } + if (address != -1) { + unsafe.freeMemory(address); } - - deserializer.close(); - return new ColumnarBatch(columnarVecs, rowCount); - } catch (InvalidProtocolBufferException e) { throw new RuntimeException("deserialize failed. errmsg:" + e.getMessage()); } } - private static ColumnVector buildVec(VecData.Vec protoVec, int vecSize) { - VecData.VecType protoTypeId = protoVec.getVecType(); + private static ColumnVector buildVec(int typeId, long vecNativeId, int vecSize, int precision, int scale) { Vec vec; DataType type; - switch (protoTypeId.getTypeId()) { - case VEC_TYPE_INT: + switch (DataTypeId.values()[typeId]) { + case OMNI_INT: type = DataTypes.IntegerType; - vec = new IntVec(vecSize); + vec = new IntVec(vecNativeId); break; - case VEC_TYPE_DATE32: + case OMNI_DATE32: type = DataTypes.DateType; - vec = new IntVec(vecSize); + vec = new IntVec(vecNativeId); break; - case VEC_TYPE_LONG: + case OMNI_LONG: type = DataTypes.LongType; - vec = new LongVec(vecSize); + vec = new LongVec(vecNativeId); break; - case VEC_TYPE_TIMESTAMP: + case OMNI_TIMESTAMP: type = DataTypes.TimestampType; - vec = new LongVec(vecSize); + vec = new LongVec(vecNativeId); break; - case VEC_TYPE_DATE64: + case OMNI_DATE64: type = DataTypes.DateType; - vec = new LongVec(vecSize); + vec = new LongVec(vecNativeId); break; - case VEC_TYPE_DECIMAL64: - type = DataTypes.createDecimalType(protoTypeId.getPrecision(), protoTypeId.getScale()); - vec = new LongVec(vecSize); + case OMNI_DECIMAL64: + type = DataTypes.createDecimalType(precision, scale); + vec = new LongVec(vecNativeId); break; - case VEC_TYPE_SHORT: + case OMNI_SHORT: type = DataTypes.ShortType; - vec = new ShortVec(vecSize); + vec = new ShortVec(vecNativeId); break; - case VEC_TYPE_BOOLEAN: + case OMNI_BOOLEAN: type = DataTypes.BooleanType; - vec = new BooleanVec(vecSize); + vec = new BooleanVec(vecNativeId); break; - case VEC_TYPE_DOUBLE: + case OMNI_DOUBLE: type = DataTypes.DoubleType; - vec = new DoubleVec(vecSize); + vec = new DoubleVec(vecNativeId); break; - case VEC_TYPE_VARCHAR: - case VEC_TYPE_CHAR: + case OMNI_VARCHAR: + case OMNI_CHAR: type = DataTypes.StringType; - vec = new VarcharVec(protoVec.getValues().size(), vecSize); - if (vec instanceof VarcharVec) { - ((VarcharVec) vec).setOffsetsBuf(protoVec.getOffset().toByteArray()); - } + vec = new VarcharVec(vecNativeId); break; - case VEC_TYPE_DECIMAL128: - type = DataTypes.createDecimalType(protoTypeId.getPrecision(), protoTypeId.getScale()); - vec = new Decimal128Vec(vecSize); + case OMNI_DECIMAL128: + type = DataTypes.createDecimalType(precision, scale); + vec = new Decimal128Vec(vecNativeId); break; - case VEC_TYPE_TIME32: - case VEC_TYPE_TIME64: - case VEC_TYPE_INTERVAL_DAY_TIME: - case VEC_TYPE_INTERVAL_MONTHS: + case OMNI_TIME32: + case OMNI_TIME64: + case OMNI_INTERVAL_DAY_TIME: + case OMNI_INTERVAL_MONTHS: default: - throw new IllegalStateException("Unexpected value: " + protoTypeId.getTypeId()); - } - vec.setValuesBuf(protoVec.getValues().toByteArray()); - if(protoVec.getNulls().size() != 0) { - vec.setNullsBuf(protoVec.getNulls().toByteArray()); + throw new IllegalStateException("Unexpected value: " + typeId); } OmniColumnVector vecTmp = new OmniColumnVector(vecSize, type, false); vecTmp.setVec(vec); return vecTmp; } - - public static void createEmptyVec(VecData.ProtoRowBatch rowBatch, int[] omniTypes, long[] omniVecs, OmniColumnVector[] columnarVectors, int vecCount, int rowCount) { - for (int i = 0; i < vecCount; i++) { - VecData.VecType protoTypeId = rowBatch.getVecTypes(i); - DataType sparkType; - Vec omniVec; - switch (protoTypeId.getTypeId()) { - case VEC_TYPE_INT: - sparkType = DataTypes.IntegerType; - omniTypes[i] = IntDataType.INTEGER.getId().toValue(); - omniVec = new IntVec(rowCount); - break; - case VEC_TYPE_DATE32: - sparkType = DataTypes.DateType; - omniTypes[i] = Date32DataType.DATE32.getId().toValue(); - omniVec = new IntVec(rowCount); - break; - case VEC_TYPE_LONG: - sparkType = DataTypes.LongType; - omniTypes[i] = LongDataType.LONG.getId().toValue(); - omniVec = new LongVec(rowCount); - break; - case VEC_TYPE_TIMESTAMP: - sparkType = DataTypes.TimestampType; - omniTypes[i] = TimestampDataType.TIMESTAMP.getId().toValue(); - omniVec = new LongVec(rowCount); - break; - case VEC_TYPE_DATE64: - sparkType = DataTypes.DateType; - omniTypes[i] = Date64DataType.DATE64.getId().toValue(); - omniVec = new LongVec(rowCount); - break; - case VEC_TYPE_DECIMAL64: - sparkType = DataTypes.createDecimalType(protoTypeId.getPrecision(), protoTypeId.getScale()); - omniTypes[i] = new Decimal64DataType(protoTypeId.getPrecision(), protoTypeId.getScale()).getId().toValue(); - omniVec = new LongVec(rowCount); - break; - case VEC_TYPE_SHORT: - sparkType = DataTypes.ShortType; - omniTypes[i] = ShortDataType.SHORT.getId().toValue(); - omniVec = new ShortVec(rowCount); - break; - case VEC_TYPE_BOOLEAN: - sparkType = DataTypes.BooleanType; - omniTypes[i] = BooleanDataType.BOOLEAN.getId().toValue(); - omniVec = new BooleanVec(rowCount); - break; - case VEC_TYPE_DOUBLE: - sparkType = DataTypes.DoubleType; - omniTypes[i] = DoubleDataType.DOUBLE.getId().toValue(); - omniVec = new DoubleVec(rowCount); - break; - case VEC_TYPE_VARCHAR: - case VEC_TYPE_CHAR: - sparkType = DataTypes.StringType; - omniTypes[i] = VarcharDataType.VARCHAR.getId().toValue(); - omniVec = new VarcharVec(rowCount); // it's capacity may be expanded. - break; - case VEC_TYPE_DECIMAL128: - sparkType = DataTypes.createDecimalType(protoTypeId.getPrecision(), protoTypeId.getScale()); - omniTypes[i] = new Decimal128DataType(protoTypeId.getPrecision(), protoTypeId.getScale()).getId().toValue(); - omniVec = new Decimal128Vec(rowCount); - break; - case VEC_TYPE_TIME32: - case VEC_TYPE_TIME64: - case VEC_TYPE_INTERVAL_DAY_TIME: - case VEC_TYPE_INTERVAL_MONTHS: - default: - throw new IllegalStateException("Unexpected value: " + protoTypeId.getTypeId()); - } - - omniVecs[i] = omniVec.getNativeVector(); - columnarVectors[i] = new OmniColumnVector(rowCount, sparkType, false); - columnarVectors[i].setVec(omniVec); - } - } } diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializerUtils.java b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializerUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..e8507b3ac4c7466d3dbed3ca067e3948c9e93bac --- /dev/null +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializerUtils.java @@ -0,0 +1,91 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.serialize; + +import com.huawei.boostkit.spark.jni.NativeLoader; + +public class ShuffleDataSerializerUtils { + public long vecBatchAddress; + public boolean isRowShuffle; + + public ShuffleDataSerializerUtils() { + NativeLoader.getInstance(); + } + + public void init(long address, int length, boolean isRowShuffle) { + this.isRowShuffle = isRowShuffle; + if (isRowShuffle) { + this.vecBatchAddress = rowShuffleParseInit(address, length); + } else { + this.vecBatchAddress = columnarShuffleParseInit(address, length); + } + } + + public void close() { + if (isRowShuffle) { + rowShuffleParseClose(vecBatchAddress); + } else { + columnarShuffleParseClose(vecBatchAddress); + } + } + + public int getVecCount() { + if (isRowShuffle) { + return rowShuffleParseVecCount(vecBatchAddress); + } else { + return columnarShuffleParseVecCount(vecBatchAddress); + } + } + + public int getRowCount() { + if (isRowShuffle) { + return rowShuffleParseRowCount(vecBatchAddress); + } else { + return columnarShuffleParseRowCount(vecBatchAddress); + } + } + + public void parse(int[] typeIdArray, int[] precisionArray, int[] scaleArray, long[] vecNativeIdArray) { + if (isRowShuffle) { + rowShuffleParseBatch(vecBatchAddress, typeIdArray, precisionArray, scaleArray, vecNativeIdArray); + } else { + columnarShuffleParseBatch(vecBatchAddress, typeIdArray, precisionArray, scaleArray, vecNativeIdArray); + } + } + + private native long rowShuffleParseInit(long address, int length); + + private native int rowShuffleParseVecCount(long vecBatchAddress); + + private native int rowShuffleParseRowCount(long vecBatchAddress); + + private native void rowShuffleParseClose(long vecBatchAddress); + + private native void rowShuffleParseBatch(long vecBatchAddress, int[] typeIdArray, int[] precisionArray, int[] scaleArray, long[] vecNativeIdArray); + + private native long columnarShuffleParseInit(long address, int length); + + private native int columnarShuffleParseVecCount(long vecBatchAddress); + + private native int columnarShuffleParseRowCount(long vecBatchAddress); + + private native void columnarShuffleParseClose(long vecBatchAddress); + + private native void columnarShuffleParseBatch(long vecBatchAddress, int[] typeIdArray, int[] precisionArray, int[] scaleArray, long[] vecNativeIdArray); +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java index bd0b42463daff40bd7696706c92de1554c5f6ea3..1f3c1a8ef01a4b7d493f77b4ce1ab1e70af93157 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java @@ -66,11 +66,16 @@ public class OmniOrcColumnarBatchReader extends RecordReader true + private val SUPPORTED_DATA_TYPES = Set(ShortType, IntegerType, LongType, DoubleType, BooleanType, DateType) + + private def checkBhjRightChild(plan: Any): Boolean = + plan match { + case _: ColumnarFilterExec => true + case _: ColumnarConditionProjectExec => true case _ => false } + + def isTopNExpression(expr: Expression): Boolean = expr match { + case Alias(child, _) => isTopNExpression(child) + case WindowExpression(_: Rank, _) => true + case _ => false + } + + def isStrictTopN(expr: Expression): Boolean = expr match { + case Alias(child, _) => isStrictTopN(child) + case WindowExpression(_: RowNumber, _) => true + case _ => false + } + + private def splitConjunctivePredicates(condition: Expression): Seq[Expression] = { + condition match { + case And(cond1, cond2) => + splitConjunctivePredicates(cond1) ++ splitConjunctivePredicates(cond2) + case other => other :: Nil + } + } + + // condition :`hashagg(partial)->hashagg(final)` + // If certain conditions are met, it will transformed into ` hashagg(complete)`. + private def combineHashAgg(plan: HashAggregateExec, child: SparkPlan): SparkPlan = { + var newChild: SparkPlan = new ColumnarHashAggregateExec( + plan.requiredChildDistributionExpressions, + plan.isStreaming, + plan.numShufflePartitions, + plan.groupingExpressions, + plan.aggregateExpressions, + plan.aggregateAttributes, + plan.initialInputBufferOffset, + plan.resultExpressions, + child) + child match { + case exec: ColumnarHashAggregateExec => + if (canCombine(plan, exec)) { + newChild = new ColumnarHashAggregateExec( + exec.requiredChildDistributionExpressions, + exec.isStreaming, + exec.numShufflePartitions, + exec.groupingExpressions, + exec.aggregateExpressions.map( + aggregate => new AggregateExpression(aggregate.aggregateFunction, + Complete, + aggregate.isDistinct, + aggregate.filter, + aggregate.resultId)), + plan.aggregateAttributes, + exec.initialInputBufferOffset, + plan.resultExpressions, + exec.child) + } + case _ => + } + newChild + } + + private def canCombine(plan: HashAggregateExec, child: ColumnarHashAggregateExec): Boolean = { + plan.groupingExpressions.length == child.groupingExpressions.length && + plan.groupingExpressions.diff(child.groupingExpressions).isEmpty && + child.aggregateExpressions.forall(_.mode == Partial) && + plan.aggregateExpressions.forall(_.mode == Final) + } + + private def isAllSupport(condition: Expression): Boolean = { + def isSingleExprDataTypeSupported(expr: Expression): Boolean = { + SUPPORTED_DATA_TYPES.exists(_.equals(expr.dataType)) + } + + def isBinaryExprDataTypeSupported(left: Expression, right: Expression): Boolean = { + isSingleExprDataTypeSupported(left) && isSingleExprDataTypeSupported(right) + } + + def isBinaryExprAttributeSupported(left: Expression, right: Expression): Boolean = { + (isAttribute(left) && isLiteral(right)) || (isAttribute(right) && isLiteral(left)) + } + + def isBinaryExprSupported(left: Expression, right: Expression): Boolean = { + isBinaryExprDataTypeSupported(left, right) && isBinaryExprAttributeSupported(left, right) + } + + def isAttribute(expr: Expression): Boolean = { + expr.isInstanceOf[Attribute] + } + + def isLiteral(expr: Expression): Boolean = { + expr.isInstanceOf[Literal] + } + + condition match { + case equalTo: EqualTo => + isBinaryExprSupported(equalTo.left, equalTo.right) + case greaterThan: GreaterThan => + isBinaryExprSupported(greaterThan.left, greaterThan.right) + case greaterThanOrEqual: GreaterThanOrEqual => + isBinaryExprSupported(greaterThanOrEqual.left, greaterThanOrEqual.right) + case lessThan: LessThan => + isBinaryExprSupported(lessThan.left, lessThan.right) + case lessThanOrEqual: LessThanOrEqual => + isBinaryExprSupported(lessThanOrEqual.left, lessThanOrEqual.right) + case isNotNull: IsNotNull => + isAttribute(isNotNull.child) + case isNull: IsNull => + isAttribute(isNull.child) + case and: And => + isAllSupport(and.left) && isAllSupport(and.right) + case or: Or => + isAllSupport(or.left) && isAllSupport(or.right) + case not: Not => + isAllSupport(not.child) + case _ => false + } + } + + private def tryToCombineFilterAndFileSource(plan: FilterExec, child: SparkPlan): SparkPlan = { + if (!enableVecPredicateFilter) { + return ColumnarFilterExec(plan.condition, child) + } + child match { + case ColumnarFileSourceScanExec(relation, _, _, _, _, _, dataFilters, _, _) => + relation.fileFormat match { + case orcFormat: OrcFileFormat => + if (isAllSupport(plan.condition)) { + child + } else { + ColumnarFilterExec(plan.condition, child) + } + case _ => + ColumnarFilterExec(plan.condition, child) + } + case _ => + ColumnarFilterExec(plan.condition, child) + } } def replaceWithColumnarPlan(plan: SparkPlan): SparkPlan = { @@ -119,7 +273,7 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) join.leftKeys, join.rightKeys, join.joinType, - ShimUtil.buildBuildSide(join.buildSide, join.joinType), + join.buildSide, join.condition, join.left, join.right, @@ -158,9 +312,75 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) ColumnarProjectExec(plan.projectList, child) } case plan: FilterExec => + if(enableColumnarTopNSort) { + plan.transform { + case f@FilterExec(condition, w@WindowExec(Seq(windowExpression), _, orderSpec, child: SparkPlan)) => + if (orderSpec.nonEmpty && isTopNExpression(windowExpression) && child.isInstanceOf[SortExec]) { + val sort: SortExec = child.asInstanceOf[SortExec] + var topn = Int.MaxValue + val nonTopNConditions = splitConjunctivePredicates(condition).filter { + case LessThan(e: NamedExpression, IntegerLiteral(n)) + if e.exprId == windowExpression.exprId => + topn = Math.min(topn, n - 1) + false + case GreaterThan(IntegerLiteral(n), e: NamedExpression) + if e.exprId == windowExpression.exprId => + topn = Math.min(topn, n - 1) + false + case LessThanOrEqual(e: NamedExpression, IntegerLiteral(n)) + if e.exprId == windowExpression.exprId => + topn = Math.min(topn, n) + false + case EqualTo(e: NamedExpression, IntegerLiteral(n)) + if n == 1 && e.exprId == windowExpression.exprId => + topn = 1 + false + case EqualTo(IntegerLiteral(n), e: NamedExpression) + if n == 1 && e.exprId == windowExpression.exprId => + topn = 1 + false + case GreaterThanOrEqual(IntegerLiteral(n), e: NamedExpression) + if e.exprId == windowExpression.exprId => + topn = Math.min(topn, n) + false + case _ => true + } + // topn <= SQLConf.get.topNPushDownForWindowThreshold 100. + val strictTopN = isStrictTopN(windowExpression) + val omniSupport: Boolean = try { + ColumnarTopNSortExec(topn, strictTopN, w.partitionSpec, w.orderSpec, sort.global, sort.child).buildCheck() + true + } catch { + case _: Throwable => false + } + if (topn > 0 && topn <= topNSortThreshold && omniSupport) { + val topNSortExec = ColumnarTopNSortExec( + topn, strictTopN, w.partitionSpec, w.orderSpec, sort.global, replaceWithColumnarPlan(sort.child)) + logInfo(s"Columnar Processing for ${topNSortExec.getClass} is currently supported.") + val newCondition = if (nonTopNConditions.isEmpty) { + Literal.TrueLiteral + } else { + nonTopNConditions.reduce(And) + } + val window = ColumnarWindowExec(w.windowExpression, w.partitionSpec, w.orderSpec, topNSortExec) + return ColumnarFilterExec(newCondition, window) + } else { + logInfo{s"topn: ${topn} is bigger than topNSortThreshold: ${topNSortThreshold}."} + val child = replaceWithColumnarPlan(f.child) + return ColumnarFilterExec(f.condition, child) + } + } else { + val child = replaceWithColumnarPlan(f.child) + return ColumnarFilterExec(f.condition, child) + } + case _ => + val child = replaceWithColumnarPlan(plan.child) + return tryToCombineFilterAndFileSource(plan, child) + } + } val child = replaceWithColumnarPlan(plan.child) logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarFilterExec(plan.condition, child) + tryToCombineFilterAndFileSource(plan, child) case plan: ExpandExec => val child = replaceWithColumnarPlan(plan.child) logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") @@ -276,16 +496,7 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) child) } } else { - new ColumnarHashAggregateExec( - plan.requiredChildDistributionExpressions, - plan.isStreaming, - plan.numShufflePartitions, - plan.groupingExpressions, - plan.aggregateExpressions, - plan.aggregateAttributes, - plan.initialInputBufferOffset, - plan.resultExpressions, - child) + combineHashAgg(plan, child) } } else { if (child.isInstanceOf[ColumnarExpandExec]) { @@ -363,16 +574,7 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) child) } } else { - new ColumnarHashAggregateExec( - plan.requiredChildDistributionExpressions, - plan.isStreaming, - plan.numShufflePartitions, - plan.groupingExpressions, - plan.aggregateExpressions, - plan.aggregateAttributes, - plan.initialInputBufferOffset, - plan.resultExpressions, - child) + combineHashAgg(plan, child) } } @@ -443,7 +645,7 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) plan.leftKeys, plan.rightKeys, plan.joinType, - ShimUtil.buildBuildSide(plan.buildSide, plan.joinType), + plan.buildSide, plan.condition, left, newHashAgg, @@ -456,7 +658,7 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) plan.leftKeys, plan.rightKeys, plan.joinType, - ShimUtil.buildBuildSide(plan.buildSide, plan.joinType), + plan.buildSide, plan.condition, left, right, @@ -471,7 +673,7 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) plan.leftKeys, plan.rightKeys, plan.joinType, - ShimUtil.buildBuildSide(plan.buildSide, plan.joinType), + plan.buildSide, plan.condition, left, right, @@ -487,7 +689,7 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) plan.leftKeys, plan.rightKeys, plan.joinType, - ShimUtil.buildBuildSide(plan.buildSide, plan.joinType), + plan.buildSide, plan.condition, left, right, @@ -760,7 +962,7 @@ case class ColumnarOverrideRules(session: SparkSession) extends ColumnarRule wit maybe(session, plan) { val planWithFallBackPolicy = transformPlan(fallbackPolicy(), plan, "fallback") - val finalPlan = planWithFallBackPolicy match { + var finalPlan = planWithFallBackPolicy match { case FallbackNode(fallbackPlan) => // skip c2r and r2c replaceWithColumnarPlan fallbackPlan @@ -769,8 +971,35 @@ case class ColumnarOverrideRules(session: SparkSession) extends ColumnarRule wit } resetOriginalPlan() resetAdaptiveContext() + if (ShimUtil.isNeedModifyBuildSide) { + finalPlan = modifyBuildSide(finalPlan) + } transformPlan(finallyRules(), finalPlan, "final") } + + def modifyBuildSide(plan: SparkPlan): SparkPlan = { + plan.transform { + case join@ShuffledHashJoinExec( + leftKeys, + rightKeys, + joinType, + buildSide, + condition, + left, + right, + isSkewJoin) => + TransformHints.getHint(join) match { + case _: TRANSFORM_UNSUPPORTED => + if (joinType == LeftOuter && buildSide.equals(BuildLeft)) { + join.copy(buildSide = BuildRight) + } else { + join + } + case _ => + join + } + } + } } case class RemoveTransformHintRule() extends Rule[SparkPlan] { @@ -788,16 +1017,83 @@ class ColumnarPlugin extends (SparkSessionExtensions => Unit) with Logging { extensions.injectPlannerStrategy(_ => ShuffleJoinStrategy) extensions.injectOptimizerRule(_ => RewriteSelfJoinInInPredicate) extensions.injectOptimizerRule(_ => DelayCartesianProduct) - extensions.injectOptimizerRule(_ => HeuristicJoinReorder) + extensions.injectOptimizerRule(_ => ReorderJoinEnhances) extensions.injectQueryStagePrepRule(session => DedupLeftSemiJoinAQE(session)) - extensions.injectQueryStagePrepRule(_ => TopNPushDownForWindow) extensions.injectQueryStagePrepRule(session => FallbackBroadcastExchange(session)) extensions.injectQueryStagePrepRule(session => PushOrderedLimitThroughAgg(session)) ModifyUtilAdaptor.injectRule(extensions) } } -private class OmniTaskStartExecutorPlugin extends ExecutorPlugin { +private class OmniTaskStartExecutorPlugin extends ExecutorPlugin with Logging { + + private def initTunedOptimization(): Unit = { + val enableTunedOptimization = SparkEnv.get.conf.get(ENABLE_OMNI_TUNED.key, "false").toBoolean + if (enableTunedOptimization) { + // Initialize the process ID variable, -1 indicates that the process ID has not been successfully obtained + var pid: Long = -1 + try { + // Get the runtime name of teh JVM, which usually contains the process ID and hostname. + val jvmName = ManagementFactory.getRuntimeMXBean().getName() + val index = jvmName.indexOf('@') + if (index > 0) { + pid = java.lang.Long.parseLong(jvmName.substring(0, index)) + logInfo(s"pid process is ${pid}") + } else { + logError("Unable to parse the process ID from the runtime name.") + return + } + } catch { + case e: Exception => + logError(s"An exception occurred while parsing the process ID: ${e.getMessage}") + return + } + + try { + + val command = s"/usr/sbin/tuned-adm profile spark-omni --pid ${pid}" + logInfo(s"command is ${command}") + val process = Runtime.getRuntime.exec(command) + + var reader = new BufferedReader(new InputStreamReader(process.getInputStream())) + var line: String = null + //read the output of the command line by line and print it + while ({line = reader.readLine(); line != null}) { + logInfo(line) + } + + // wait for the command to finish executing and get teh exit status code. + val exitCode = process.waitFor + logInfo(s"The command has finish executing Exit status code : $exitCode") + + } catch { + case e: Exception => + logError("An error occurred while executing the command", e) + } + } + } + + private def initNumaBind(): Unit = { + val enableNumaBinding: Boolean = SparkEnv.get.conf.get(ENABLED_NUMA_BINDING.key, "false").toBoolean + val bindingInfo = if (!enableNumaBinding) { + NumaBindingInfo(enableNumaBinding = false) + } else { + val tmp = SparkEnv.get.conf.get(NUMA_BINDING_CORE_RANGE.key, "") + if (tmp.isEmpty) { + NumaBindingInfo(enableNumaBinding = false) + } else { + val coreRangeList: Array[String] = tmp.split('|').map(_.trim) + NumaBindingInfo(enableNumaBinding = true, coreRangeList) + } + } + ExecutorManager.tryTaskSet(bindingInfo) + } + + override def init(ctx: PluginContext, extraConf: java.util.Map[String, String]): Unit = { + initTunedOptimization() + initNumaBind() + } + override def onTaskStart(): Unit = { addLeakSafeTaskCompletionListener[Unit](_ => { MemoryManager.reclaimMemory() diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala index aae1d2c86acae75d37a37f27e73d6491eb16caa8..a5f5b599dc4b8f77daad1daafa0147f638494315 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala @@ -24,6 +24,8 @@ import org.apache.spark.internal.Logging import org.apache.spark.shuffle.sort.ColumnarShuffleManager import org.apache.spark.sql.internal.SQLConf +case class NumaBindingInfo(enableNumaBinding: Boolean, totalCoreRange: Array[String] = null, numCoresPerExecutor: Int = -1) {} + class ColumnarPluginConfig(conf: SQLConf) extends Logging { def columnarShuffleStr: String = conf .getConfString("spark.shuffle.manager", "sort") @@ -49,6 +51,8 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { def enableColumnarTopNSort: Boolean = conf.getConf(ENABLE_COLUMNAR_TOP_N_SORT) + def enableColumnarWindowGroupLimit: Boolean = conf.getConf(ENABLE_COLUMNAR_WINDOW_GROUP_LIMIT) + def enableColumnarUnion: Boolean = conf.getConf(ENABLE_COLUMNAR_UNION) def enableColumnarWindow: Boolean = conf.getConf(ENABLE_COLUMNAR_WINDOW) @@ -65,7 +69,7 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { def enableShareBroadcastJoinNestedTable: Boolean = conf.getConf(ENABLE_SHARE_BROADCAST_JOIN_NESTED_TABLE) - def enableHeuristicJoinReorder: Boolean = conf.getConf(ENABLE_HEURISTIC_JOIN_REORDER) + def enableJoinReorderEnhance: Boolean = conf.getConf(ENABLE_JOIN_REORDER_ENHANCE) def enableDelayCartesianProduct: Boolean = conf.getConf(ENABLE_DELAY_CARTESIAN_PRODUCT) @@ -147,9 +151,7 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { def enableGlobalColumnarLimit : Boolean = conf.getConf(ENABLE_GLOBAL_COLUMNAR_LIMIT) - def topNPushDownForWindowThreshold: Int = conf.getConf(TOP_N_PUSH_DOWN_FOR_WINDOW_THRESHOLD) - - def topNPushDownForWindowEnable: Boolean = conf.getConf(TOP_N_PUSH_DOWN_FOR_WINDOW_ENABLE) + def topNSortThreshold: Int = conf.getConf(TOP_N_THRESHOLD) def pushOrderedLimitThroughAggEnable: Boolean = conf.getConf(PUSH_ORDERED_LIMIT_THROUGH_AGG_ENABLE) @@ -190,6 +192,14 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { def timeParserPolicy: String = conf.getConfString("spark.sql.legacy.timeParserPolicy") def enableOmniUnixTimeFunc: Boolean = conf.getConf(ENABLE_OMNI_UNIXTIME_FUNCTION) + + def enableVecPredicateFilter: Boolean = conf.getConf(ENABLE_VEC_PREDICATE_FILTER) + + def catalogCacheSize: Int = conf.getConf(CATALOG_CACHE_SIZE) + + def catalogCacheExpireTime: Int = conf.getConf(CATALOG_CACHE_EXPIRE_TIME) + + def joinOutputStringTypeCost: Int = conf.getConf(JOIN_OUTPUT_STRING_COST_ESTIMATE) } @@ -261,6 +271,12 @@ object ColumnarPluginConfig { .booleanConf .createWithDefault(true) + val ENABLE_COLUMNAR_WINDOW_GROUP_LIMIT = buildConf("spark.omni.sql.columnar.windowGroupLimit") + .internal() + .doc("enable or disable columnar WindowGroupLimit") + .booleanConf + .createWithDefault(true) + val ENABLE_COLUMNAR_UNION = buildConf("spark.omni.sql.columnar.union") .internal() .doc("enable or disable columnar union") @@ -309,9 +325,9 @@ object ColumnarPluginConfig { .booleanConf .createWithDefault(true) - val ENABLE_HEURISTIC_JOIN_REORDER = buildConf("spark.sql.heuristicJoinReorder.enabled") + val ENABLE_JOIN_REORDER_ENHANCE = buildConf("spark.omni.sql.columnar.JoinReorderEnhance") .internal() - .doc("enable or disable heuristic join reorder") + .doc("enable or disable join reorder enhance") .booleanConf .createWithDefault(true) @@ -538,16 +554,11 @@ object ColumnarPluginConfig { .booleanConf .createWithDefault(true) - val TOP_N_PUSH_DOWN_FOR_WINDOW_THRESHOLD = buildConf("spark.sql.execution.topNPushDownForWindow.threshold") + val TOP_N_THRESHOLD = buildConf("spark.omni.sql.columnar.topN.threshold") .internal() .intConf .createWithDefault(100) - val TOP_N_PUSH_DOWN_FOR_WINDOW_ENABLE = buildConf("spark.sql.execution.topNPushDownForWindow.enabled") - .internal() - .booleanConf - .createWithDefault(true) - val PUSH_ORDERED_LIMIT_THROUGH_AGG_ENABLE = buildConf("spark.omni.sql.columnar.pushOrderedLimitThroughAggEnable.enabled") .internal() .booleanConf @@ -667,4 +678,45 @@ object ColumnarPluginConfig { .booleanConf .createWithDefault(true) + val ENABLE_VEC_PREDICATE_FILTER = buildConf("spark.omni.sql.columnar.vec.predicate.enabled") + .internal() + .doc("enable vectorized predicate filtering") + .booleanConf + .createWithDefault(false) + + val CATALOG_CACHE_SIZE = buildConf("spark.omni.sql.columnar.catalog.cache.size") + .internal() + .doc("set catalog cache size, value <= 0 presents no cache") + .intConf + .createWithDefault(128) + + val CATALOG_CACHE_EXPIRE_TIME = buildConf("spark.omni.sql.columnar.catalog.cache.expire.time") + .internal() + .doc("set catalog cache expire time in seconds") + .intConf + .createWithDefault(600) + + val JOIN_OUTPUT_STRING_COST_ESTIMATE = buildConf("spark.omni.sql.columnar.join.reorder.stringtype.cost") + .internal() + .doc("set string type in join output cost") + .intConf + .createWithDefault(9) + + val ENABLE_OMNI_TUNED = buildConf("spark.omni.sql.tuned.enabled") + .internal() + .doc("enable tuned-adm optimization") + .booleanConf + .createWithDefault(false) + + val ENABLED_NUMA_BINDING = + buildConf("spark.omni.sql.columnar.numaBinding") + .internal() + .booleanConf + .createWithDefault(false) + + val NUMA_BINDING_CORE_RANGE = + buildConf("spark.omni.sql.columnar.coreRange") + .internal() + .stringConf + .createOptional } diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala index 8f2c94a866bceab91d644311b4996395ee71ab79..d10d60e15a06f61ab177096c4e3519bc0d2e4036 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala @@ -18,6 +18,7 @@ package com.huawei.boostkit.spark +import com.huawei.boostkit.spark.util.ModifyUtilAdaptor import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.rules.Rule @@ -407,8 +408,9 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { } ColumnarDataWritingCommandExec(plan.cmd, plan.child).buildCheck() TransformHints.tagTransformable(plan) - case _ => TransformHints.tagTransformable(plan) - + case _ => + TransformHints.tagTransformable(plan) + ModifyUtilAdaptor.addTransFormableTag(plan) } } catch { case e: UnsupportedOperationException => diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala index f4b856b37ac90a85976836eaf0130f0c01faab07..d19ee14710d9c69d3b02da31e79946f984049fe3 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala @@ -24,16 +24,18 @@ import nova.hetu.omniruntime.`type`.{BooleanDataType, DataTypeSerializer, Date32 import nova.hetu.omniruntime.constants.FunctionType import nova.hetu.omniruntime.constants.FunctionType.{OMNI_AGGREGATION_TYPE_AVG, OMNI_AGGREGATION_TYPE_COUNT_ALL, OMNI_AGGREGATION_TYPE_COUNT_COLUMN, OMNI_AGGREGATION_TYPE_FIRST_IGNORENULL, OMNI_AGGREGATION_TYPE_FIRST_INCLUDENULL, OMNI_AGGREGATION_TYPE_MAX, OMNI_AGGREGATION_TYPE_MIN, OMNI_AGGREGATION_TYPE_SAMP, OMNI_AGGREGATION_TYPE_SUM, OMNI_WINDOW_TYPE_RANK, OMNI_WINDOW_TYPE_ROW_NUMBER} import nova.hetu.omniruntime.constants.JoinType._ +import nova.hetu.omniruntime.constants.BuildSide._ import nova.hetu.omniruntime.operator.OmniExprVerify import com.huawei.boostkit.spark.ColumnarPluginConfig import com.huawei.boostkit.spark.util.ModifyUtilAdaptor import com.google.gson.{JsonArray, JsonElement, JsonObject, JsonParser} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Expression, _} import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero -import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} +import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.util.CharVarcharUtils.getRawTypeString import org.apache.spark.sql.execution import org.apache.spark.sql.hive.HiveUdfAdaptorUtil @@ -55,11 +57,10 @@ object OmniExpressionAdaptor extends Logging { } def getExprIdMap(inputAttrs: Seq[Attribute]): Map[ExprId, Int] = { - var attrMap: Map[ExprId, Int] = Map() - inputAttrs.zipWithIndex.foreach { case (inputAttr, i) => - attrMap += (inputAttr.exprId -> i) - } - attrMap + inputAttrs.iterator + .zipWithIndex + .map { case (attr, idx) => attr.exprId -> idx} + .toMap } def checkOmniJsonWhiteList(filterExpr: String, projections: Array[AnyRef]): Unit = { @@ -165,43 +166,38 @@ object OmniExpressionAdaptor extends Logging { } case sub: Subtract => - ShimUtil.unsupportedEvalModeCheck(sub) val (left, right) = ShimUtil.binaryOperatorAdjust(sub, returnDatatype) new JsonObject().put("exprType", "BINARY") .addOmniExpJsonType("returnType", returnDatatype) - .put("operator", "SUBTRACT") + .put("operator", ShimUtil.transformExpressionByEvalMode(sub)) .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(left, exprsIndexMap)) .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(right, exprsIndexMap)) case add: Add => - ShimUtil.unsupportedEvalModeCheck(add) val (left, right) = ShimUtil.binaryOperatorAdjust(add, returnDatatype) new JsonObject().put("exprType", "BINARY") .addOmniExpJsonType("returnType", returnDatatype) - .put("operator", "ADD") + .put("operator", ShimUtil.transformExpressionByEvalMode(add)) .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(left, exprsIndexMap)) .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(right, exprsIndexMap)) case mult: Multiply => - ShimUtil.unsupportedEvalModeCheck(mult) val (left, right) = ShimUtil.binaryOperatorAdjust(mult, returnDatatype) new JsonObject().put("exprType", "BINARY") .addOmniExpJsonType("returnType", returnDatatype) - .put("operator", "MULTIPLY") + .put("operator", ShimUtil.transformExpressionByEvalMode(mult)) .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(left, exprsIndexMap)) .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(right, exprsIndexMap)) case divide: Divide => - ShimUtil.unsupportedEvalModeCheck(divide) val (left, right) = ShimUtil.binaryOperatorAdjust(divide, returnDatatype) new JsonObject().put("exprType", "BINARY") .addOmniExpJsonType("returnType", returnDatatype) - .put("operator", "DIVIDE") + .put("operator", ShimUtil.transformExpressionByEvalMode(divide)) .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(left, exprsIndexMap)) .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(right, exprsIndexMap)) case mod: Remainder => - ShimUtil.unsupportedEvalModeCheck(mod) val (left, right) = ShimUtil.binaryOperatorAdjust(mod, returnDatatype) new JsonObject().put("exprType", "BINARY") .addOmniExpJsonType("returnType", returnDatatype) @@ -768,12 +764,10 @@ object OmniExpressionAdaptor extends Logging { def toOmniAggFunType(agg: AggregateExpression, isHashAgg: Boolean = false, isMergeCount: Boolean = false): FunctionType = { agg.aggregateFunction match { case sum: Sum => - ShimUtil.unsupportedEvalModeCheck(sum) - OMNI_AGGREGATION_TYPE_SUM + ShimUtil.transformFuncTypeByEvalMode(sum) case Max(_) => OMNI_AGGREGATION_TYPE_MAX case avg: Average => - ShimUtil.unsupportedEvalModeCheck(avg) - OMNI_AGGREGATION_TYPE_AVG + ShimUtil.transformFuncTypeByEvalMode(avg) case Min(_) => OMNI_AGGREGATION_TYPE_MIN case StddevSamp(_, _) => OMNI_AGGREGATION_TYPE_SAMP case Count(Literal(1, IntegerType) :: Nil) | Count(ArrayBuffer(Literal(1, IntegerType))) => @@ -1020,11 +1014,24 @@ object OmniExpressionAdaptor extends Logging { OMNI_JOIN_TYPE_LEFT_SEMI case LeftAnti => OMNI_JOIN_TYPE_LEFT_ANTI + case ExistenceJoin(_) => + OMNI_JOIN_TYPE_EXISTENCE case _ => throw new UnsupportedOperationException(s"Join-type[$joinType] is not supported.") } } + def toOmniBuildSide(buildSide: BuildSide): nova.hetu.omniruntime.constants.BuildSide = { + buildSide match { + case BuildLeft => + BUILD_LEFT + case BuildRight => + BUILD_RIGHT + case _ => + throw new UnsupportedOperationException(s"Build-Side[$buildSide] is not supported.") + } + } + def isSimpleColumn(expr: String): Boolean = { val indexOfExprType = expr.indexOf("exprType") val lastIndexOfExprType = expr.lastIndexOf("exprType") @@ -1054,4 +1061,43 @@ object OmniExpressionAdaptor extends Logging { false } } + + def isSimpleExpression(expr: Expression): Boolean = { + expr match { + case attribute: AttributeReference => + true + case alias: Alias => + alias.child.isInstanceOf[AttributeReference] + case _ => + false + } + } + + def isAllSimpleExpression(exprs: Seq[Expression]): Boolean = { + for (expr <- exprs) { + if (!isSimpleExpression(expr)){ + return false + } + } + true + } + + def getExprIdFromExpressions(exprs: Seq[Expression]): ArrayBuffer[ExprId] = { + var ret = ArrayBuffer[ExprId]() + for (expr <- exprs) { + expr match { + case attributeReference: AttributeReference => + ret+=attributeReference.exprId + case alias: Alias => + alias.child match { + case attr: AttributeReference => ret += attr.exprId + case _ => + throw new UnsupportedOperationException(s"Unsupported Expression") + } + case _ => + throw new UnsupportedOperationException(s"Unsupported Expression") + } + } + ret + } } diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/serialize/ColumnarBatchSerializer.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/serialize/ColumnarBatchSerializer.scala index 26e2b7a3e057662fd9abee0db0345beeeb63f523..676b465fd41ebbf94735c07c4f58bbfd13d61726 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/serialize/ColumnarBatchSerializer.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/serialize/ColumnarBatchSerializer.scala @@ -87,7 +87,7 @@ private class ColumnarBatchSerializerInstance( } ByteStreams.readFully(dIn, columnarBuffer, 0, dataSize) // protobuf serialize - val columnarBatch: ColumnarBatch = ShuffleDataSerializer.deserialize(isRowShuffle, columnarBuffer.slice(0, dataSize)) + val columnarBatch: ColumnarBatch = ShuffleDataSerializer.deserialize(isRowShuffle, columnarBuffer, dataSize) dataSize = readSize() if (dataSize == EOF) { dIn.close() @@ -116,7 +116,7 @@ private class ColumnarBatchSerializerInstance( } ByteStreams.readFully(dIn, columnarBuffer, 0, dataSize) // protobuf serialize - val columnarBatch: ColumnarBatch = ShuffleDataSerializer.deserialize(isRowShuffle, columnarBuffer.slice(0, dataSize)) + val columnarBatch: ColumnarBatch = ShuffleDataSerializer.deserialize(isRowShuffle, columnarBuffer, dataSize) numBatchesTotal += 1 numRowsTotal += columnarBatch.numRows() columnarBatch.asInstanceOf[T] diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/util/ModifyUtilAdaptor.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/util/ModifyUtilAdaptor.scala index adb6883d9d1addec30434181a78718f0dc7ec81a..c066171a2a4443a2d93616de102227887be680d0 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/util/ModifyUtilAdaptor.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/util/ModifyUtilAdaptor.scala @@ -33,6 +33,8 @@ object ModifyUtilAdaptor { private var injectRuleFunc: SparkSessionExtensions => Unit = _ + private var addTransFormableFunc: SparkPlan => Unit = _ + def configRewriteJsonFunc(func: (Expression, Map[ExprId, Int], DataType, (Expression, Map[ExprId, Int], DataType) => JsonObject) => JsonObject): Unit = { rewriteJsonFunc = func } @@ -49,6 +51,10 @@ object ModifyUtilAdaptor { injectRuleFunc = func } + def configAddTransformableFunc(func: SparkPlan => Unit): Unit = { + addTransFormableFunc = func + } + val registration: Unit = { val ModifyUtilClazz = Thread.currentThread().getContextClassLoader.loadClass("org.apache.spark.sql.util.ModifyUtil") val method = ModifyUtilClazz.getMethod("registerFunc") @@ -73,4 +79,8 @@ object ModifyUtilAdaptor { def injectRule(extensions: SparkSessionExtensions): Unit = { injectRuleFunc(extensions) } + + def addTransFormableTag(plan: SparkPlan): Unit = { + addTransFormableFunc(plan) + } } diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorder.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DelayCartesianProduct.scala similarity index 51% rename from omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorder.scala rename to omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DelayCartesianProduct.scala index f0dd04487fff7420a86e501c64a5345d0794738b..835faa5dd32d6b5fdb8f1cf4bc9abd8ef5be35ee 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorder.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DelayCartesianProduct.scala @@ -17,22 +17,15 @@ package org.apache.spark.sql.catalyst.optimizer -import scala.annotation.tailrec -import scala.collection.mutable - import com.huawei.boostkit.spark.ColumnarPluginConfig - import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, EqualNullSafe, EqualTo, Expression, IsNotNull, PredicateHelper} +import org.apache.spark.sql.catalyst.expressions.{And, Expression, PredicateHelper} import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.util.sideBySide - - - /** * Move all cartesian products to the root of the plan */ @@ -162,165 +155,7 @@ object DelayCartesianProduct extends Rule[LogicalPlan] with PredicateHelper { } } -/** - * Firstly, Heuristic reorder join need to execute small joins with filters - * , which can reduce intermediate results - */ -object HeuristicJoinReorder extends Rule[LogicalPlan] - with PredicateHelper with JoinSelectionHelper { - - /** - * Join a list of plans together and push down the conditions into them. - * The joined plan are picked from left to right, thus the final result is a left-deep tree. - * - * @param input a list of LogicalPlans to inner join and the type of inner join. - * @param conditions a list of condition for join. - */ - @tailrec - final def createReorderJoin(input: Seq[(LogicalPlan, InnerLike)], conditions: Seq[Expression]) - : LogicalPlan = { - assert(input.size >= 2) - if (input.size == 2) { - val (joinConditions, others) = conditions.partition(canEvaluateWithinJoin) - val ((leftPlan, leftJoinType), (rightPlan, rightJoinType)) = (input(0), input(1)) - val innerJoinType = (leftJoinType, rightJoinType) match { - case (Inner, Inner) => Inner - case (_, _) => Cross - } - // Set the join node ordered so that we don't need to transform them again. - val orderJoin = OrderedJoin(leftPlan, rightPlan, innerJoinType, joinConditions.reduceLeftOption(And)) - if (others.nonEmpty) { - Filter(others.reduceLeft(And), orderJoin) - } else { - orderJoin - } - } else { - val (left, _) :: rest = input.toList - val candidates = rest.filter { planJoinPair => - val plan = planJoinPair._1 - // 1. it has join conditions with the left node - // 2. it has a filter - // 3. it can be broadcast - val isEqualJoinCondition = conditions.flatMap { - case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => None - case EqualNullSafe(l, r) if l.references.isEmpty || r.references.isEmpty => None - case e@EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, plan) => Some(e) - case e@EqualTo(l, r) if canEvaluate(l, plan) && canEvaluate(r, left) => Some(e) - case e@EqualNullSafe(l, r) if canEvaluate(l, left) && canEvaluate(r, plan) => Some(e) - case e@EqualNullSafe(l, r) if canEvaluate(l, plan) && canEvaluate(r, left) => Some(e) - case _ => None - }.nonEmpty - - val hasFilter = plan match { - case f: Filter if hasValuableCondition(f.condition) => true - case Project(_, f: Filter) if hasValuableCondition(f.condition) => true - case _ => false - } - isEqualJoinCondition && hasFilter - } - val (right, innerJoinType) = if (candidates.nonEmpty) { - candidates.minBy(_._1.stats.sizeInBytes) - } else { - rest.head - } - - val joinedRefs = left.outputSet ++ right.outputSet - val selectedJoinConditions = mutable.HashSet.empty[Expression] - val (joinConditions, others) = conditions.partition { e => - // If there are semantically equal conditions, they should come from two different joins. - // So we should not put them into one join. - if (!selectedJoinConditions.contains(e.canonicalized) && e.references.subsetOf(joinedRefs) - && canEvaluateWithinJoin(e)) { - selectedJoinConditions.add(e.canonicalized) - true - } else { - false - } - } - // Set the join node ordered so that we don't need to transform them again. - val joined = OrderedJoin(left, right, innerJoinType, joinConditions.reduceLeftOption(And)) - - // should not have reference to same logical plan - createReorderJoin(Seq((joined, Inner)) ++ rest.filterNot(_._1 eq right), others) - } - } - - private def hasValuableCondition(condition: Expression): Boolean = { - val conditions = splitConjunctivePredicates(condition) - !conditions.forall(_.isInstanceOf[IsNotNull]) - } - - def apply(plan: LogicalPlan): LogicalPlan = { - if (ColumnarPluginConfig.getSessionConf.enableHeuristicJoinReorder) { - val newPlan = plan.transform { - case p@ExtractFiltersAndInnerJoinsByIgnoreProjects(input, conditions) - if input.size > 2 && conditions.nonEmpty => - val reordered = createReorderJoin(input, conditions) - if (p.sameOutput(reordered)) { - reordered - } else { - // Reordering the joins have changed the order of the columns. - // Inject a projection to make sure we restore to the expected ordering. - Project(p.output, reordered) - } - } - - // After reordering is finished, convert OrderedJoin back to Join - val result = newPlan.transformDown { - case OrderedJoin(left, right, jt, cond) => Join(left, right, jt, cond, JoinHint.NONE) - } - if (!result.resolved) { - // In some special cases related to subqueries, we find that after reordering, - val comparedPlans = sideBySide(plan.treeString, result.treeString).mkString("\n") - logWarning("The structural integrity of the plan is broken, falling back to the " + - s"original plan. == Comparing two plans ===\n$comparedPlans") - plan - } else { - result - } - } else { - plan - } - } -} - -/** - * This is different from [[ExtractFiltersAndInnerJoins]] in that it can collect filters and - * inner joins by ignoring projects on top of joins, which are produced by column pruning. - */ -private object ExtractFiltersAndInnerJoinsByIgnoreProjects extends PredicateHelper { - - /** - * Flatten all inner joins, which are next to each other. - * Return a list of logical plans to be joined with a boolean for each plan indicating if it - * was involved in an explicit cross join. Also returns the entire list of join conditions for - * the left-deep tree. - */ - def flattenJoin(plan: LogicalPlan, parentJoinType: InnerLike = Inner) - : (Seq[(LogicalPlan, InnerLike)], Seq[Expression]) = plan match { - case Join(left, right, joinType: InnerLike, cond, hint) if hint == JoinHint.NONE => - val (plans, conditions) = flattenJoin(left, joinType) - (plans ++ Seq((right, joinType)), conditions ++ - cond.toSeq.flatMap(splitConjunctivePredicates)) - case Filter(filterCondition, j@Join(_, _, _: InnerLike, _, hint)) if hint == JoinHint.NONE => - val (plans, conditions) = flattenJoin(j) - (plans, conditions ++ splitConjunctivePredicates(filterCondition)) - case Project(projectList, child) - if projectList.forall(_.isInstanceOf[Attribute]) => flattenJoin(child) - - case _ => (Seq((plan, parentJoinType)), Seq.empty) - } - - def unapply(plan: LogicalPlan): Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression])] - = plan match { - case f@Filter(_, Join(_, _, _: InnerLike, _, _)) => - Some(flattenJoin(f)) - case j@Join(_, _, _, _, hint) if hint == JoinHint.NONE => - Some(flattenJoin(j)) - case _ => None - } -} private object ExtractFiltersAndInnerJoinsForBushy extends PredicateHelper { diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReorderJoinEnhances.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReorderJoinEnhances.scala new file mode 100644 index 0000000000000000000000000000000000000000..3ffd21f0b4229305964393e1320ad41174bb1465 --- /dev/null +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReorderJoinEnhances.scala @@ -0,0 +1,233 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import com.huawei.boostkit.spark.ColumnarPluginConfig + +import scala.annotation.tailrec +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.trees.TreePattern._ +import org.apache.spark.sql.types.StringType + +/** + * Reorder the joins and push all the conditions into join, so that the bottom ones have at least + * one condition. + * + * The order of joins will not be changed if all of them already have at least one condition. + * + * If star schema detection is enabled, reorder the star join plans based on heuristics. + */ +object ReorderJoinEnhances extends Rule[LogicalPlan] with PredicateHelper { + /** + * Join a list of plans together and push down the conditions into them. + * + * The joined plan are picked from left to right, prefer those has at least one join condition. + * + * @param input a list of LogicalPlans to inner join and the type of inner join. + * @param conditions a list of condition for join. + */ + @tailrec + final def createOrderedJoin( + input: Seq[(LogicalPlan, InnerLike)], + conditions: Seq[Expression]): LogicalPlan = { + assert(input.size >= 2) + if (input.size == 2) { + val (joinConditions, others) = conditions.partition(canEvaluateWithinJoin) + val ((left, leftJoinType), (right, rightJoinType)) = (input(0), input(1)) + val innerJoinType = (leftJoinType, rightJoinType) match { + case (Inner, Inner) => Inner + case (_, _) => Cross + } + val join = Join(left, right, innerJoinType, + joinConditions.reduceLeftOption(And), JoinHint.NONE) + if (others.nonEmpty) { + Filter(others.reduceLeft(And), join) + } else { + join + } + } else { + val (left, _) :: rest = input.toList + val hasEqualJoinConditionCandidates = rest.filter { + planJoinPair => + planJoinPair._1 match { + case p: Project => p.output.count(a => a.dataType == StringType) < columnarPluginConfig.joinOutputStringTypeCost + case _ => true + } + }.filter { + planJoinPair => + val candidateRight = planJoinPair._1 + conditions.exists { + case EqualTo(l, r) => checkLeftRightExpression(l, r, left, candidateRight) + case EqualNullSafe(l, r) => checkLeftRightExpression(l, r, left, candidateRight) + case _ => false + } + }.filter{ + planJoinPair => + val candidateRight = planJoinPair._1 + val joinedRefs = left.outputSet ++ candidateRight.outputSet + conditions.exists { + e => + e.references.subsetOf(joinedRefs) && canEvaluateWithinJoin(e) + } + } + + val candidateJoins = hasEqualJoinConditionCandidates.filter { + planJoinPair => + planJoinPair._1 match { + case f: Filter if checkFilterCondition(f.condition) => true + case Project(_, f: Filter) if checkFilterCondition(f.condition) => true + case _ => false + } + } + + // pick the min size one if candidateJoins left + val (right, innerJoinType) = if (candidateJoins.nonEmpty) { + candidateJoins.minBy(_._1.stats.sizeInBytes) + } else { + val noHasNotEqualCandidates = hasEqualJoinConditionCandidates.filter{ + planJoinPair => + val candidateRight = planJoinPair._1 + !conditions.exists { + case LessThan(l, r) => checkLeftRightExpression(l, r, left, candidateRight) + case LessThanOrEqual(l, r) => checkLeftRightExpression(l, r, left, candidateRight) + case GreaterThan(l, r) => checkLeftRightExpression(l, r, left, candidateRight) + case GreaterThanOrEqual(l, r) => checkLeftRightExpression(l, r, left, candidateRight) + case _ => false + } + } + if(noHasNotEqualCandidates.nonEmpty) { + noHasNotEqualCandidates.head + } else { + rest.head + } + } + + val joinedRefs = left.outputSet ++ right.outputSet + val dupJoinConditionChecks = mutable.HashSet.empty[Expression] + val (joinConditions, others) = conditions.partition { e => + if (!dupJoinConditionChecks.contains(e.canonicalized) && e.references.subsetOf(joinedRefs) + && canEvaluateWithinJoin(e)) { + dupJoinConditionChecks.add(e.canonicalized) + true + } else { + false + } + } + + val joined = Join(left, right, innerJoinType, + joinConditions.reduceLeftOption(And), JoinHint.NONE) + // should not have reference to same logical plan + createOrderedJoin(Seq((joined, Inner)) ++ rest.filterNot(_._1 eq right), others) + } + } + + private def checkFilterCondition(condition: Expression): Boolean = { + val conditions = splitConjunctivePredicates(condition) + conf.constraintPropagationEnabled && !conditions.forall(_.isInstanceOf[IsNotNull]) + } + + private def checkLeftRightExpression(l: Expression, r: Expression, left: LogicalPlan, candidateRight: LogicalPlan): Boolean = l.references.nonEmpty && + r.references.nonEmpty && + ((canEvaluate(l, left) && canEvaluate(r, candidateRight)) + || (canEvaluate(r, left) && canEvaluate(l, candidateRight))) + + val columnarPluginConfig = ColumnarPluginConfig.getSessionConf + + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + columnarPluginConfig.enableJoinReorderEnhance && + _.containsPattern(INNER_LIKE_JOIN)) { + case p @ ExtractFiltersAndInnerJoinsEnhances(input, conditions) + if input.size > 2 && conditions.nonEmpty => + val reordered = if (conf.starSchemaDetection && !conf.cboEnabled) { + val starJoinPlan = StarSchemaDetection.reorderStarJoins(input, conditions) + if (starJoinPlan.nonEmpty) { + val rest = input.filterNot(starJoinPlan.contains(_)) + createOrderedJoin(starJoinPlan ++ rest, conditions) + } else { + createOrderedJoin(input, conditions) + } + } else { + createOrderedJoin(input, conditions) + } + + if (p.sameOutput(reordered)) { + reordered + } else { + // Reordering the joins have changed the order of the columns. + // Inject a projection to make sure we restore to the expected ordering. + Project(p.output, reordered) + } + } +} + +/** + * A pattern that collects the filter and inner joins. + * + * Filter + * | + * inner Join + * / \ ----> (Seq(plan0, plan1, plan2), conditions) + * Filter plan2 + * | + * inner join + * / \ + * plan0 plan1 + * + * Note: This pattern currently only works for left-deep trees. + */ +object ExtractFiltersAndInnerJoinsEnhances extends PredicateHelper { + + /** + * Flatten all inner joins, which are next to each other. + * Return a list of logical plans to be joined with a boolean for each plan indicating if it + * was involved in an explicit cross join. Also returns the entire list of join conditions for + * the left-deep tree. + */ + def flattenJoin(plan: LogicalPlan, parentJoinType: InnerLike = Inner) + : (Seq[(LogicalPlan, InnerLike)], Seq[Expression]) = plan match { + case Join(left, right, joinType: InnerLike, cond, hint) if hint == JoinHint.NONE => + val (plans, conditions) = flattenJoin(left, joinType) + (plans ++ Seq((right, joinType)), conditions ++ + cond.toSeq.flatMap(splitConjunctivePredicates)) + case Filter(filterCondition, j @ Join(_, _, _: InnerLike, _, hint)) if hint == JoinHint.NONE => + val (plans, conditions) = flattenJoin(j) + (plans, conditions ++ splitConjunctivePredicates(filterCondition)) + case Project(projectList, child) + if projectList.forall(_.isInstanceOf[Attribute]) => flattenJoin(child) + + case _ => (Seq((plan, parentJoinType)), Seq.empty) + } + + def unapply(plan: LogicalPlan) + : Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression])] + = plan match { + case f @ Filter(filterCondition, j @ Join(_, _, joinType: InnerLike, _, hint)) + if hint == JoinHint.NONE => + Some(flattenJoin(f)) + case j @ Join(_, _, joinType, _, hint) if hint == JoinHint.NONE => + Some(flattenJoin(j)) + case _ => None + } +} diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/AbstractUnsafeRowSorter.java b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/AbstractUnsafeRowSorter.java deleted file mode 100644 index 9ddbd2bd135d97ef215c67802497dfb78788ab16..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/AbstractUnsafeRowSorter.java +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution; - -import java.io.IOException; - -import scala.collection.Iterator; -import scala.math.Ordering; - -import com.google.common.annotations.VisibleForTesting; - -import org.apache.spark.sql.types.StructType; -import org.apache.spark.util.collection.unsafe.sort.RecordComparator; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.UnsafeRow; - -public abstract class AbstractUnsafeRowSorter -{ - protected final StructType schema; - - /** - * If positive, forces records to be spilled to disk at the give frequency (measured in numbers of records). - * This is only intended to be used in tests. - * */ - protected int testSpillFrequency = 0; - - AbstractUnsafeRowSorter(final StructType schema) { - this.schema = schema; - } - - // This flag makes sure the cleanupResource() has been called. - // After the cleanup work, iterator.next should always return false. - // Downstream operator triggers the resource cleanup while they found there's no need to keep the iterator anymore. - // See more detail in SPARK-21492. - boolean isReleased = false; - - public abstract void insertRow(UnsafeRow row) throws IOException; - - public abstract Iterator sort() throws IOException; - - public abstract Iterator sort(Iterator inputIterator) throws IOException; - - /** - * @return the peak memory used so far, in bytes. - * */ - public abstract long getPeakMemoryUsage(); - - /** - * @return the total amount of time spent sorting data (in-memory only). - * */ - public abstract long getSortTimeNanos(); - - public abstract void cleanupResources(); - - /** - * Foreces spills to occur every 'frequency' records. Only for use in tests. - * */ - @VisibleForTesting - void setTestSpillFrequency(int frequency) { - assert frequency > 0 : "Frequency must be positive"; - testSpillFrequency = frequency; - } - - static final class RowComparator extends RecordComparator { - private final Ordering ordering; - private final UnsafeRow row1; - private final UnsafeRow row2; - - RowComparator(Ordering ordering, int numFields) { - this.row1 = new UnsafeRow(numFields); - this.row2 = new UnsafeRow(numFields); - this.ordering = ordering; - } - - @Override - public int compare( - Object baseObj1, - long baseOff1, - int baseLen1, - Object baseObj2, - long baseOff2, - int baseLen2) { - // Note that since ordering doesn't need the total length of the record, we just pass 0 int the row. - row1.pointTo(baseObj1, baseOff1, 0); - row2.pointTo(baseObj2, baseOff2, 0); - return ordering.compare(row1, row2); - } - } -} diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala index 9f6d57a1fcd75c90f24217c25612fcf2bef9a758..486d7bbc8ba369cf82cbe04fdcc1357addcadcea 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala @@ -375,6 +375,7 @@ case class ColumnarConditionProjectExec(projectList: Seq[NamedExpression], |$formattedNodeName |${ExplainUtils.generateFieldString("Output", projectList)} |${ExplainUtils.generateFieldString("Input", child.output)} + |Condition : ${condition} |""".stripMargin } } diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala index d302ef4d123bf84e6e497a99f73b8c938336394c..6ea29e5b640e85e6a7fc36eb95c85e401c31edfb 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala @@ -94,6 +94,8 @@ case class ColumnarHashAggregateExec( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "numOutputVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatches"), "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"), + "totalMemSize" -> SQLMetrics.createSizeMetric(sparkContext, "total mem size"), + "usedMemSize" -> SQLMetrics.createSizeMetric(sparkContext, "used mem size"), "numSkippedRows" -> SQLMetrics.createMetric(sparkContext, "number of skipped rows")) protected override def needHashTable: Boolean = true @@ -229,6 +231,8 @@ case class ColumnarHashAggregateExec( val numOutputRows = longMetric("numOutputRows") val numOutputVecBatches= longMetric("numOutputVecBatches") val spillSize = longMetric("spillSize") + val usedMemSize = longMetric("usedMemSize") + val totalMemSize = longMetric("totalMemSize") val numSkippedRows = longMetric("numSkippedRows") val attrExpsIdMap = getExprIdMap(child.output) @@ -254,6 +258,9 @@ case class ColumnarHashAggregateExec( val finalStep = (aggregateExpressions.count(_.mode == Final) == aggregateExpressions.size) + val completeStep = + (aggregateExpressions.count(_.mode == Complete) == aggregateExpressions.size) + var index = 0 for (exp <- aggregateExpressions) { if (exp.filter.isDefined) { @@ -296,6 +303,21 @@ case class ColumnarHashAggregateExec( } case _ => throw new UnsupportedOperationException(s"Unsupported aggregate aggregateFunction: ${exp}") } + } else if (exp.mode == Complete) { + exp.aggregateFunction match { + case Sum(_, _) | Min(_) | Max(_) | Count(_) | Average(_, _) | First(_, _) => + omniAggFunctionTypes(index) = toOmniAggFunType(exp, true, true) + omniAggOutputTypes(index) = + toOmniAggInOutType(exp.aggregateFunction.dataType) + omniAggChannels(index) = + toOmniAggInOutJSonExp(exp.aggregateFunction.children, attrExpsIdMap) + omniInputRaws(index) = true + omniOutputPartials(index) = false + if (omniAggFunctionTypes(index) == OMNI_AGGREGATION_TYPE_COUNT_ALL) { + omniAggChannels(index) = null + } + case _ => throw new UnsupportedOperationException(s"Unsupported aggregate aggregateFunction: ${exp}") + } } else { throw new UnsupportedOperationException(s"Unsupported aggregate mode: ${exp.mode}") } @@ -339,7 +361,11 @@ case class ColumnarHashAggregateExec( // close operator SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { - spillSize += operator.getSpilledBytes() + val metrics = operator.getMetricsInfo() + // The first 100 metrics are general metrics, the rest are difined differently by every operator. + spillSize += metrics(0) + usedMemSize += metrics(100) + totalMemSize += metrics(101) operator.close() if (hashAggregationWithExprOperatorFactory != null) { hashAggregationWithExprOperatorFactory.close() @@ -385,7 +411,7 @@ case class ColumnarHashAggregateExec( getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) var localSchema = this.schema - if (finalStep) { + if (finalStep || completeStep) { // for final step resultExpressions's inputs from omni-final aggregator val omnifinalOutSchema = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes localSchema = StructType(omnifinalOutSchema.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) @@ -448,7 +474,7 @@ case class ColumnarHashAggregateExec( } } } - if (finalStep) { + if (finalStep || completeStep) { val finalOut = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes val finalAttrExprsIdMap = getExprIdMap(finalOut) val projectInputTypes = finalOut.map( diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/ColumnarLimit.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/ColumnarLimit.scala index 410ce4127fe5038682e5195a7ec5e4f586439ab1..675c00af1eb531f54bd0cb08ebb84c3a8e40cb5e 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/ColumnarLimit.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/ColumnarLimit.scala @@ -155,11 +155,74 @@ case class ColumnarGlobalLimitExec(limit: Int, child: SparkPlan, offset: Int = 0 copy(child = newChild) def buildCheck(): Unit = { - if (offset > 0) { - throw new UnsupportedOperationException("ColumnarGlobalLimitExec doesn't support offset greater than 0.") - } child.output.foreach(attr => sparkTypeToOmniType(attr.dataType, attr.metadata)) } + + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val addInputTime = longMetric("addInputTime") + val omniCodegenTime = longMetric("omniCodegenTime") + val getOutputTime = longMetric("getOutputTime") + val numOutputRows = longMetric("numOutputRows") + val numOutputVecBatches = longMetric("numOutputVecBatches") + + child.executeColumnar().mapPartitions { iter => + + val startCodegen = System.nanoTime() + val limitOperatorFactory = new OmniLimitOperatorFactory(limit, offset) + val limitOperator = limitOperatorFactory.createOperator + omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) + + // close operator + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + limitOperator.close() + limitOperatorFactory.close() + }) + + val localSchema = this.schema + new Iterator[ColumnarBatch] { + private var results: java.util.Iterator[VecBatch] = _ + + override def hasNext: Boolean = { + while ((results == null || !results.hasNext) && iter.hasNext) { + val batch = iter.next() + val input = transColBatchToOmniVecs(batch) + val vecBatch = new VecBatch(input, batch.numRows()) + val startInput = System.nanoTime() + limitOperator.addInput(vecBatch) + addInputTime += NANOSECONDS.toMillis(System.nanoTime() - startInput) + + val startGetOp = System.nanoTime() + results = limitOperator.getOutput + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + } + if (results == null) { + false + } else { + val startGetOp: Long = System.nanoTime() + val hasNext = results.hasNext + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + hasNext + } + } + + override def next(): ColumnarBatch = { + val startGetOp = System.nanoTime() + val vecBatch = results.next() + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + val vectors: Seq[OmniColumnVector] = OmniColumnVector.allocateColumns( + vecBatch.getRowCount, localSchema, false) + vectors.zipWithIndex.foreach { case (vector, i) => + vector.reset() + vector.setVec(vecBatch.getVectors()(i)) + } + numOutputRows += vecBatch.getRowCount + numOutputVecBatches += 1 + vecBatch.close() + new ColumnarBatch(vectors.toArray, vecBatch.getRowCount) + } + } + } + } } case class ColumnarTakeOrderedAndProjectExec( @@ -168,7 +231,7 @@ case class ColumnarTakeOrderedAndProjectExec( projectList: Seq[NamedExpression], child: SparkPlan, offset: Int = 0) - extends UnaryExecNode { + extends UnaryExecNodeShim(sortOrder, projectList) { override def supportsColumnar: Boolean = true @@ -218,9 +281,6 @@ case class ColumnarTakeOrderedAndProjectExec( } def buildCheck(): Unit = { - if (offset > 0) { - throw new UnsupportedOperationException("ColumnarTakeOrderedAndProjectExec doesn't support offset greater than 0.") - } genSortParam(child.output, sortOrder) val projectEqualChildOutput = projectList == child.output var omniInputTypes: Array[DataType] = null @@ -245,9 +305,9 @@ case class ColumnarTakeOrderedAndProjectExec( } else { val (sourceTypes, ascending, nullFirsts, sortColsExp) = genSortParam(child.output, sortOrder) - def computeTopN(iter: Iterator[ColumnarBatch], schema: StructType): Iterator[ColumnarBatch] = { + def computeTopN(iter: Iterator[ColumnarBatch], schema: StructType, offset: Int): Iterator[ColumnarBatch] = { val startCodegen = System.nanoTime() - val topNOperatorFactory = new OmniTopNWithExprOperatorFactory(sourceTypes, limit, sortColsExp, ascending, nullFirsts, + val topNOperatorFactory = new OmniTopNWithExprOperatorFactory(sourceTypes, limit, offset, sortColsExp, ascending, nullFirsts, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val topNOperator = topNOperatorFactory.createOperator longMetric("omniCodegenTime") += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) @@ -265,7 +325,7 @@ case class ColumnarTakeOrderedAndProjectExec( } else { val localTopK: RDD[ColumnarBatch] = { child.executeColumnar().mapPartitionsWithIndexInternal { (_, iter) => - computeTopN(iter, this.child.schema) + computeTopN(iter, this.child.schema, 0) } } @@ -302,7 +362,7 @@ case class ColumnarTakeOrderedAndProjectExec( } singlePartitionRDD.mapPartitions { iter => // TopN = omni-top-n + omni-project - val topN: Iterator[ColumnarBatch] = computeTopN(iter, this.child.schema) + val topN: Iterator[ColumnarBatch] = computeTopN(iter, this.child.schema, offset) if (!projectEqualChildOutput) { dealPartitionData(null, null, addInputTime, omniCodegenTime, getOutputTime, omniInputTypes, omniExpressions, topN, this.schema) @@ -313,8 +373,6 @@ case class ColumnarTakeOrderedAndProjectExec( } } - override def outputOrdering: Seq[SortOrder] = sortOrder - override def outputPartitioning: Partitioning = SinglePartition override def simpleString(maxFields: Int): String = { diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala index 1e60256fcc78fc3f1fbb99ebb377a9f4ac6038fa..7959c88b03582dce4e3aaf508aa5b084f178e388 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala @@ -27,10 +27,12 @@ import com.huawei.boostkit.spark.serialize.ColumnarBatchSerializer import com.huawei.boostkit.spark.util.OmniAdaptorUtil import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs import com.huawei.boostkit.spark.vectorized.PartitionInfo +import nova.hetu.omniruntime.`type`.DataType.DataTypeId import nova.hetu.omniruntime.`type`.{DataType, DataTypeSerializer} import nova.hetu.omniruntime.operator.config.{OperatorConfig, OverflowConfig, SpillConfig} import nova.hetu.omniruntime.operator.project.OmniProjectOperatorFactory -import nova.hetu.omniruntime.vector.{IntVec, VecBatch} +import nova.hetu.omniruntime.utils.ShuffleHashHelper +import nova.hetu.omniruntime.vector.{IntVec, Vec, VecBatch} import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD @@ -315,6 +317,7 @@ object ColumnarShuffleExchangeExec extends Logging { } } + val rddWithPartitionId: RDD[Product2[Int, ColumnarBatch]] = newPartitioning match { case RoundRobinPartitioning(numPartitions) => // 按随机数分区 @@ -347,58 +350,59 @@ object ColumnarShuffleExchangeExec extends Logging { newIter }, isOrderSensitive = isOrderSensitive) case h@HashPartitioning(expressions, numPartitions) => - //containsRollUp(expressions): Avoid data skew caused by rollup expressions. - //expressions.length > 6: Avoid q11 data skew - //expressions.length == 3: Avoid q28 data skew when the resin rule is enabled. - if (containsRollUp(expressions) || expressions.length > 6 || expressions.length == 3) { - rdd.mapPartitionsWithIndexInternal((_, cbIter) => { - val partitionKeyExtractor: InternalRow => Any = { - val projection = - UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes) - row => projection(row).getInt(0) + rdd.mapPartitionsWithIndexInternal((_, cbIter) => { + val addPid2ColumnBatch = addPidToColumnBatch() + val isAllExpressionSimple = isAllSimpleExpression(expressions) + if (isAllExpressionSimple) { + // 创建一个映射:将 output 中的 Expression 与对应的索引关联起来 + val outputIndexMap = getExprIdMap(outputAttributes) + val exprIds = getExprIdFromExpressions(expressions) + cbIter.map { cb => + val vecs = transColBatchToOmniVecs(cb) + // 提取对应的向量 + val projectedVecs:Array[Long] = exprIds.map(exprId => { + vecs(outputIndexMap(exprId)).getNativeVector + }).toArray + val nativeIntVecAddr = ShuffleHashHelper.computePartitionIds(projectedVecs, numPartitions, cb.numRows()) + addPid2ColumnBatch(new IntVec(nativeIntVecAddr), cb) } - val newIter = computePartitionId(cbIter, partitionKeyExtractor) - newIter - }, isOrderSensitive = isOrderSensitive) - } else { - rdd.mapPartitionsWithIndexInternal((_, cbIter) => { - val addPid2ColumnBatch = addPidToColumnBatch() - // omni project - val genHashExpression = genHashExpr() - val omniExpr: String = genHashExpression(expressions, numPartitions, defaultMm3HashSeed, outputAttributes) - val factory = new OmniProjectOperatorFactory(Array(omniExpr), inputTypes, 1, - new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) - val op = factory.createOperator() - // close operator - addLeakSafeTaskCompletionListener[Unit](_ => { - op.close() - }) - - cbIter.map { cb => - var pidVec: IntVec = null - try { - val vecs = transColBatchToOmniVecs(cb, true) - op.addInput(new VecBatch(vecs, cb.numRows())) - val res = op.getOutput - if (res.hasNext) { - val retBatch = res.next() - pidVec = retBatch.getVectors()(0).asInstanceOf[IntVec] - // close return VecBatch - retBatch.close() - addPid2ColumnBatch(pidVec, cb) - } else { - throw new Exception("Empty Project Operator Result...") - } - } catch { - case e: Exception => - if (pidVec != null) { - pidVec.close() - } - throw e + } else { + // omni project + val genHashExpression = genHashExpr() + val omniExpr: String = genHashExpression(expressions, numPartitions, defaultMm3HashSeed, outputAttributes) + val factory = new OmniProjectOperatorFactory(Array(omniExpr), inputTypes, 1, + new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) + val op = factory.createOperator() + // close operator + addLeakSafeTaskCompletionListener[Unit](_ => { + op.close() + factory.close() + }) + cbIter.map { cb => + var pidVec: IntVec = null + try { + val vecs = transColBatchToOmniVecs(cb, true) + op.addInput(new VecBatch(vecs, cb.numRows())) + val res = op.getOutput + if (res.hasNext) { + val retBatch = res.next() + pidVec = retBatch.getVectors()(0).asInstanceOf[IntVec] + // close return VecBatch + retBatch.close() + addPid2ColumnBatch(pidVec.asInstanceOf[IntVec], cb) + } else { + throw new Exception("Empty Project Operator Result...") } + } catch { + case e: Exception => + if (pidVec != null) { + pidVec.close() + } + throw e } - }, isOrderSensitive = isOrderSensitive) - } + } + } + }, isOrderSensitive = isOrderSensitive) case SinglePartition => rdd.mapPartitionsWithIndexInternal((_, cbIter) => { cbIter.map { cb => (0, cb) } diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala deleted file mode 100644 index 0ddf89b8c1c3d36b63e7bebd2f9b12e0b1a7f385..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ /dev/null @@ -1,307 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.util.concurrent.TimeUnit._ -import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.executor.TaskMetrics -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode} -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS -import org.apache.spark.sql.execution.UnsafeExternalRowSorter.PrefixComputer -import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.util.collection.unsafe.sort.PrefixComparator - - -/** - * Base class of [[SortExec]] and [[TopNSortExec]]. All subclasses of this class need to override - * their own sorter which inherits from [[org.apache.spark.sql.execution.AbstractUnsafeRowSorter]] - * to perform corresponding sorting. - * - * @param global when true performs a global sort of all partitions by shuffling the data first - * if necessary. - * @param testSpillFrequency Method for configuring periodic spilling in unit tests. - * If set, will spill every 'frequency' records. - * */ -abstract class SortExecBase( - sortOrder: Seq[SortOrder], - global: Boolean, - child: SparkPlan, - testSpillFrequency: Int = 0) - extends UnaryExecNode with BlockingOperatorWithCodegen { - - override def output: Seq[Attribute] = child.output - - override def outputOrdering: Seq[SortOrder] = sortOrder - - // sort performed is local within a given partition so will retain - // child operator's partitioning - override def outputPartitioning: Partitioning = child.outputPartitioning - - override def requiredChildDistribution: Seq[Distribution] = - if (global) OrderedDistribution(sortOrder):: Nil else UnspecifiedDistribution :: Nil - - private val enableRadixSort = conf.enableRadixSort - - override lazy val metrics = Map( - "sortTime" -> SQLMetrics.createTimingMetric(sparkContext, "sort time"), - "peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"), - "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size") - ) - - protected val sorterClassName: String - - protected def newSorterInstance( - ordering: Ordering[InternalRow], - prefixComparator: PrefixComparator, - prefixComputer: PrefixComputer, - pageSize: Long, - canSortFullyWIthPrefix: Boolean): AbstractUnsafeRowSorter - - private[sql] var rowSorter: AbstractUnsafeRowSorter = _ - - /** - * This method gets invoked only once for each SortExec instance to initialize - * an AbstractUnsafeRowSorter, both 'plan.execute' and code generation are using it. - * In the code generation code path, we need to call this function outside the class - * so we should make it public - * */ - def createSorter(): AbstractUnsafeRowSorter = { - val ordering = RowOrdering.create(sortOrder, output) - - // THe comparator for comparing prefix - val boundSortExpression = BindReferences.bindReference(sortOrder.head, output) - val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) - - val canSortFullyWIthPrefix = sortOrder.length == 1 && - SortPrefixUtils.canSortFullyWithPrefix(boundSortExpression) - - // The generator for prefix - val prefixExpr = SortPrefix(boundSortExpression) - val prefixProjection = UnsafeProjection.create(Seq(prefixExpr)) - val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { - private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix - override def computePrefix(row: InternalRow): - UnsafeExternalRowSorter.PrefixComputer.Prefix = { - val prefix = prefixProjection.apply(row) - result.isNull = prefix.isNullAt(0) - result.value = if (result.isNull) prefixExpr.nullValue else prefix.getLong(0) - result - } - } - - val pageSize = SparkEnv.get.memoryManager.pageSizeBytes - rowSorter = newSorterInstance(ordering, prefixComparator, prefixComputer, - pageSize, canSortFullyWIthPrefix) - - if (testSpillFrequency > 0) { - rowSorter.setTestSpillFrequency(testSpillFrequency) - } - rowSorter - } - - protected override def doExecute(): RDD[InternalRow] = { - val peakMemory = longMetric("peakMemory") - val spillSize = longMetric("spillSize") - val sortTime = longMetric("sortTime") - - child.execute().mapPartitionsInternal { iter => - val sorter = createSorter() - val metrics = TaskContext.get().taskMetrics() - - // Remember spill data size of this task before execute this operator, - // so that we can figure out how many bytes we spilled for this operator. - val spillSizeBefore = metrics.memoryBytesSpilled - val sortedIterator = sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) - sortTime += NANOSECONDS.toMillis(sorter.getSortTimeNanos) - peakMemory += sorter.getPeakMemoryUsage - spillSize += metrics.memoryBytesSpilled - spillSizeBefore - metrics.incPeakExecutionMemory(sorter.getPeakMemoryUsage) - - sortedIterator - } - } - - override def usedInputs: AttributeSet = AttributeSet(Seq.empty) - - override def inputRDDs(): Seq[RDD[InternalRow]] = { - child.asInstanceOf[CodegenSupport].inputRDDs - } - - // Name of sorter variable used in codegen - private var sorterVariable: String = _ - - override protected def doProduce(ctx: CodegenContext): String = { - val needToSort = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, - "needToSort", v => s"$v = true;") - - // Initalize the class member variables. This includes the instance of the Sorter - // and the iterator to return sorted rows. - val thisPlan = ctx.addReferenceObj("plan", this) - // Inline mutable state since not many Sort operations in a task - sorterVariable = ctx.addMutableState(sorterClassName, "sorter", - v => s"$v = $thisPlan.createSorter();", forceInline = true) - val metrics = ctx.addMutableState(classOf[TaskMetrics].getName, "metrics", - v => s"$v = org.apache.spark.TaskContext.get().taskMetrics();", forceInline = true) - val sortedIterator = ctx.addMutableState("scala.collection.Iterator", - "sortedIter", forceInline = true) - - val addToSorter = ctx.freshName("addToSorter") - val addToSorterFuncName = ctx.addNewFunction(addToSorter, - s""" - | private void $addToSorter() throws java.io.IOException { - | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} - | } - """.stripMargin.trim) - - val outputRow = ctx.freshName("outputRow") - val peakMemory = metricTerm(ctx, "peakMemory") - val spillSize = metricTerm(ctx, "spillSize") - val spillSizeBefore = ctx.freshName("spillSizeBefore") - val sortTime = metricTerm(ctx, "sortTime") - s""" - | if ($needToSort) { - | long $spillSizeBefore = $metrics.memoryBytesSpilled(); - | $addToSorterFuncName(); - | $sortedIterator = $sorterVariable.sort(); - | $sortTime.add($sorterVariable.getSortTimeNanos() / $NANOS_PER_MILLIS); - | $peakMemory.add($sorterVariable.getPeakMemoryUsage()); - | $spillSize.add($metrics.memoryBytesSpilled() - $spillSizeBefore); - | $metrics.incPeakExecutionMemory($sorterVariable.getPeakMemoryUsage()); - | $needToSort = false; - | } - | - | while ($limitNotReachedCond $sortedIterator.hasNext()) { - | UnsafeRow $outputRow = (UnsafeRow)$sortedIterator.next(); - | ${consume(ctx, null, outputRow)} - | if (shouldStop()) return; - | } - """.stripMargin.trim - } - - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - s""" - | ${row.code} - | $sorterVariable.insertRow((UnsafeRow)${row.value}); - """.stripMargin - } - - /** - * In BaseSortExec, we overwrites cleanupResources to close AbstractUnsafeRowSorter. - * */ - - override protected[sql] def cleanupResources(): Unit = { - if (rowSorter != null) { - // There's possible for rowSorter is null here, for example, in the scenario of empty - // iterator in the current task, the downstream physical node(like SortMergeJoinExec) will - // trigger cleanupResources before rowSorter initialized in createSorter - rowSorter.cleanupResources() - } - super.cleanupResources() - } -} - - -/** - * Performs (external) sorting - * */ -case class SortExec( - sortOrder: Seq[SortOrder], - global: Boolean, - child: SparkPlan, - testSpillFrequency: Int = 0) - extends SortExecBase(sortOrder, global, child, testSpillFrequency) { - private val enableRadixSort = conf.enableRadixSort - - - override val sorterClassName: String = classOf[UnsafeExternalRowSorter].getName - - override def newSorterInstance( - ordering: Ordering[InternalRow], - prefixComparator: PrefixComparator, - prefixComputer: PrefixComputer, - pageSize: Long, - canSortFullyWIthPrefix: Boolean): UnsafeExternalRowSorter = { - UnsafeExternalRowSorter.create( - schema, - ordering, - prefixComparator, - prefixComputer, - pageSize, - enableRadixSort && canSortFullyWIthPrefix) - } - - override def createSorter(): UnsafeExternalRowSorter = { - super.createSorter().asInstanceOf[UnsafeExternalRowSorter] - } - - override protected def withNewChildInternal(newChild: SparkPlan): SortExec = { - copy(child = newChild) - } -} - -/** - * Performs topN sort - * - * @param strictTopN when true it strictly returns n results. This param distinguishes - * [[RowNumber]] from [[Rank]]. [[RowNumber]] corresponds to true - * and [[Rank]] corresponds to false. - * @param partitionSpec partitionSpec of [[org.apache.spark.sql.execution.window.WindowExec]] - * @param sortOrder orderSpec of [[org.apache.spark.sql.execution.window.WindowExec]] - * */ -case class TopNSortExec( - n: Int, - strictTopN: Boolean, - partitionSpec: Seq[Expression], - sortOrder: Seq[SortOrder], - global: Boolean, - child: SparkPlan) - extends SortExecBase(sortOrder, global, child, 0) { - - override val sorterClassName: String = classOf[UnsafeTopNRowSorter].getName - - override def newSorterInstance( - ordering: Ordering[InternalRow], - prefixComparator: PrefixComparator, - prefixComputer: PrefixComputer, - pageSize: Long, - canSortFullyWIthPrefix: Boolean): UnsafeTopNRowSorter = { - val partitionSpecProjection = UnsafeProjection.create(partitionSpec, output) - UnsafeTopNRowSorter.create( - n, - strictTopN, - schema, - partitionSpecProjection, - ordering, - prefixComparator, - prefixComputer, - pageSize, - canSortFullyWIthPrefix) - } - - override def createSorter(): UnsafeTopNRowSorter = { - super.createSorter().asInstanceOf[UnsafeTopNRowSorter] - } - - override protected def withNewChildInternal(newChild: SparkPlan): TopNSortExec = { - copy(child = newChild) - } -} diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/TopNSortExec.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/TopNSortExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..ece8ac9d35c58d28760473d11805e7c41b8c3041 --- /dev/null +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/TopNSortExec.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ + +/** + * Performs topN sort + * + * @param strictTopN when true it strictly returns n results. This param distinguishes + * [[RowNumber]] from [[Rank]]. [[RowNumber]] corresponds to true + * and [[Rank]] corresponds to false. + * @param partitionSpec partitionSpec of [[org.apache.spark.sql.execution.window.WindowExec]] + * @param sortOrder orderSpec of [[org.apache.spark.sql.execution.window.WindowExec]] + * */ +case class TopNSortExec( + n: Int, + strictTopN: Boolean, + partitionSpec: Seq[Expression], + sortOrder: Seq[SortOrder], + global: Boolean, + child: SparkPlan) extends UnaryExecNode { + override protected def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException("unsupported topn sort exec") + } + + override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = { + copy(child = newChild) + } +} diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java deleted file mode 100644 index b36a424d22f54fa629e8dfd774c7d503ee75362c..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ /dev/null @@ -1,198 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution; - -import java.io.IOException; -import java.util.function.Supplier; - -import scala.collection.Iterator; -import scala.math.Ordering; - -import org.apache.spark.SparkEnv; -import org.apache.spark.TaskContext; -import org.apache.spark.internal.config.package$; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.UnsafeRow; -import org.apache.spark.sql.types.StructType; -import org.apache.spark.unsafe.Platform; -import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; -import org.apache.spark.util.collection.unsafe.sort.RecordComparator; -import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter; -import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator; - -public final class UnsafeExternalRowSorter extends AbstractUnsafeRowSorter { - private long numRowsInserted = 0; - private final UnsafeExternalRowSorter.PrefixComputer prefixComputer; - private final UnsafeExternalSorter sorter; - - public abstract static class PrefixComputer { - public static class Prefix { - // Key prefix value, or the null prefix value if isNull = true - public long value; - - // Whether the key is null - public boolean isNull; - } - - /** - * Computes prefix for the given row. For efficiency, the object may be reused in - * further calls to a given PrefixComputer. - * */ - public abstract Prefix computePrefix(InternalRow row); - } - - public static UnsafeExternalRowSorter createWithRecordComparator( - StructType schema, - Supplier recordComparatorSupplier, - PrefixComparator prefixComparator, - UnsafeExternalRowSorter.PrefixComputer prefixComputer, - long pageSizeBytes, - boolean canUseRadixSort) throws IOException { - return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, prefixComparator, - prefixComputer, pageSizeBytes, canUseRadixSort); - } - - public static UnsafeExternalRowSorter create( - StructType schema, - Ordering ordering, - PrefixComparator prefixComparator, - UnsafeExternalRowSorter.PrefixComputer prefixComputer, - long pageSizeBytes, - boolean canUseRadixSort) throws IOException { - Supplier recordComparatorSupplier = () -> new RowComparator(ordering, schema.length()); - return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, prefixComparator, - prefixComputer, pageSizeBytes, canUseRadixSort); - } - - private UnsafeExternalRowSorter( - StructType schema, - Supplier recordComparatorSupplier, - PrefixComparator prefixComparator, - UnsafeExternalRowSorter.PrefixComputer prefixComputer, - long pageSizeBytes, - boolean canUseRadixSort) { - super(schema); - this.prefixComputer = prefixComputer; - final SparkEnv sparkEnv = SparkEnv.get(); - final TaskContext taskContext = TaskContext.get(); - sorter = UnsafeExternalSorter.create( - taskContext.taskMemoryManager(), - sparkEnv.blockManager(), - sparkEnv.serializerManager(), - taskContext, - recordComparatorSupplier, - prefixComparator, - (int) (long) sparkEnv.conf().get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()), - pageSizeBytes, - (int) sparkEnv.conf().get( - package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()), - canUseRadixSort); - } - - @Override - public void insertRow(UnsafeRow row) throws IOException { - final PrefixComputer.Prefix prefix = prefixComputer.computePrefix(row); - sorter.insertRecord( - row.getBaseObject(), - row.getBaseOffset(), - row.getSizeInBytes(), - prefix.value, - prefix.isNull); - numRowsInserted++; - if (testSpillFrequency > 0 && (numRowsInserted % testSpillFrequency) == 0) { - sorter.spill(); - } - } - - @Override - public long getPeakMemoryUsage() { - return sorter.getPeakMemoryUsedBytes(); - } - - @Override - public long getSortTimeNanos() { - return sorter.getSortTimeNanos(); - } - - @Override - public void cleanupResources() { - isReleased = true; - sorter.cleanupResources(); - } - - @Override - public Iterator sort() throws IOException { - try { - final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator(); - if (!sortedIterator.hasNext()) { - // Since we won't ever call next() on an empty iterator, we need to clean up resources - // here in order to prevent memory leaks. - cleanupResources(); - } - return new RowIterator() { - private final int numFields = schema.length(); - private UnsafeRow row = new UnsafeRow(numFields); - - @Override - public boolean advanceNext() { - try { - if (!isReleased && sortedIterator.hasNext()) { - sortedIterator.loadNext(); - row.pointTo( - sortedIterator.getBaseObject(), - sortedIterator.getBaseOffset(), - sortedIterator.getRecordLength()); - // Here is the initial buf ifx in SPARK-9364: the bug fix of use-after-free bug - // when returning the last row from an iterator. For example, in - // [[GroupedIterator]], we still use the last row after traversing the iterator - // in 'fetchNextGroupIterator' - if (!sortedIterator.hasNext()) { - row = row.copy(); // so that we don't have dangling pointers to freed page - cleanupResources(); - } - return true; - } else { - row = null; // so that we don't keep reference to the base object - return false; - } - } catch (IOException e) { - cleanupResources(); - // Scala iterators don't declare any checked exceptions, so we need to use this hack - // to re-throw the exception. - Platform.throwException(e); - } - throw new RuntimeException("Exception should have been re-thrown in next()"); - } - - @Override - public UnsafeRow getRow() { return row; } - }.toScala(); - } catch (IOException e) { - cleanupResources(); - throw e; - } - } - - @Override - public Iterator sort(Iterator inputIterator) throws IOException { - while (inputIterator.hasNext()) { - insertRow(inputIterator.next()); - } - return sort(); - } -} diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/UnsafeTopNRowSorter.java b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/UnsafeTopNRowSorter.java deleted file mode 100644 index 6a27c8edfa16042201f37addc0d0e0783fa81d5c..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/UnsafeTopNRowSorter.java +++ /dev/null @@ -1,256 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution; - -import java.io.IOException; -import java.util.*; -import java.util.function.Supplier; - -import scala.collection.Iterator; -import scala.math.Ordering; - -import org.apache.spark.TaskContext; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; -import org.apache.spark.sql.catalyst.expressions.UnsafeRow; -import org.apache.spark.sql.execution.topnsort.UnsafeInMemoryTopNSorter; -import org.apache.spark.sql.execution.topnsort.UnsafePartitionedTopNSorter; -import org.apache.spark.sql.types.StructType; -import org.apache.spark.unsafe.Platform; -import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; -import org.apache.spark.util.collection.unsafe.sort.RecordComparator; -import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator; - -public final class UnsafeTopNRowSorter extends AbstractUnsafeRowSorter { - - private final UnsafePartitionedTopNSorter partitionedTopNSorter; - - // partition key - private final UnsafeProjection partitionSpecProjection; - - // order(rank) key - private final UnsafeExternalRowSorter.PrefixComputer prefixComputer; - - private long totalSortTimeNanos = 0L; - private final long timeNanosBeforeInsertRow; - - public static UnsafeTopNRowSorter create( - int n, - boolean strictTopN, - StructType schema, - UnsafeProjection partitionSpecProjection, - Ordering orderingOfRankKey, - PrefixComparator prefixComparator, - UnsafeExternalRowSorter.PrefixComputer prefixComputer, - long pageSizeBytes, - boolean canSortFullyWithPrefix) { - Supplier recordComparatorSupplier = - () -> new RowComparator(orderingOfRankKey, schema.length()); - return new UnsafeTopNRowSorter( - n, strictTopN, schema, partitionSpecProjection, recordComparatorSupplier, - prefixComparator, prefixComputer, pageSizeBytes, canSortFullyWithPrefix); - } - - private UnsafeTopNRowSorter( - int n, - boolean strictTopN, - StructType schema, - UnsafeProjection partitionSpecProjection, - Supplier recordComparatorSupplier, - PrefixComparator prefixComparator, - UnsafeExternalRowSorter.PrefixComputer prefixComputer, - long pageSizeBytes, - boolean canSortFullyWithPrefix) { - super(schema); - this.prefixComputer = prefixComputer; - final TaskContext taskContext = TaskContext.get(); - this.partitionSpecProjection = partitionSpecProjection; - this.partitionedTopNSorter = UnsafePartitionedTopNSorter.create( - n, - strictTopN, - taskContext.taskMemoryManager(), - taskContext, - recordComparatorSupplier, - prefixComparator, - pageSizeBytes, - canSortFullyWithPrefix); - timeNanosBeforeInsertRow = System.nanoTime(); - } - - @Override - public void insertRow(UnsafeRow row) throws IOException { - final UnsafeExternalRowSorter.PrefixComputer.Prefix prefix = prefixComputer.computePrefix(row); - UnsafeRow partKey = partitionSpecProjection.apply(row); - partitionedTopNSorter.insertRow(partKey, row, prefix.value); - } - - /** - * Return the peak memory used so far, in bytes. - * */ - @Override - public long getPeakMemoryUsage() { - return partitionedTopNSorter.getPeakMemoryUsedBytes(); - } - - /** - * @return the total amount of time spent sorting data (in-memory only). - * */ - @Override - public long getSortTimeNanos() { - return totalSortTimeNanos; - } - - @Override - public Iterator sort() throws IOException - { - try { - Map partKeyToSorter = - partitionedTopNSorter.getPartKeyToSorter(); - if (partKeyToSorter.isEmpty()) { - // Since we won't ever call next() on an empty iterator, we need to clean up resources - // here in order to prevent memory leaks. - cleanupResources(); - return emptySortedIterator(); - } - - Queue sortedIteratorsForPartitions = new LinkedList<>(); - for (Map.Entry entry : partKeyToSorter.entrySet()) { - final UnsafeInMemoryTopNSorter topNSorter = entry.getValue(); - final UnsafeSorterIterator unsafeSorterIterator = topNSorter.getSortedIterator(); - - sortedIteratorsForPartitions.add(new RowIterator() - { - private final int numFields = schema.length(); - private UnsafeRow row = new UnsafeRow(numFields); - - @Override - public boolean advanceNext() - { - try { - if (!isReleased && unsafeSorterIterator.hasNext()) { - unsafeSorterIterator.loadNext(); - row.pointTo( - unsafeSorterIterator.getBaseObject(), - unsafeSorterIterator.getBaseOffset(), - unsafeSorterIterator.getRecordLength()); - // Here is the initial buf ifx in SPARK-9364: the bug fix of use-after-free bug - // when returning the last row from an iterator. For example, in - // [[GroupedIterator]], we still use the last row after traversing the iterator - // in 'fetchNextGroupIterator' - if (!unsafeSorterIterator.hasNext()) { - row = row.copy(); // so that we don't have dangling pointers to freed page - topNSorter.freeMemory(); - } - return true; - } - else { - row = null; // so that we don't keep reference to the base object - return false; - } - } catch (IOException e) { - topNSorter.freeMemory(); - // Scala iterators don't declare any checked exceptions, so we need to use this hack - // to re-throw the exception. - Platform.throwException(e); - } - throw new RuntimeException("Exception should have been re-thrown in next()"); - } - - @Override - public UnsafeRow getRow() - { - return row; - } - }); - } - - // Update total sort time. - if (totalSortTimeNanos == 0L) { - totalSortTimeNanos = System.nanoTime() - timeNanosBeforeInsertRow; - } - final ChainedIterator chainedIterator = new ChainedIterator(sortedIteratorsForPartitions); - return chainedIterator.toScala(); - } catch (Exception e) { - cleanupResources(); - throw e; - } - } - - private Iterator emptySortedIterator() { - return new RowIterator() { - @Override - public boolean advanceNext() { - return false; - } - - @Override - public UnsafeRow getRow() { - return null; - } - }.toScala(); - } - - /** - * Chain multiple UnsafeSorterIterators from PartSorterMap as single one. - * */ - private static final class ChainedIterator extends RowIterator { - private final Queue iterators; - private RowIterator current; - private UnsafeRow row; - - ChainedIterator(Queue iterators) { - assert iterators.size() > 0; - this.iterators = iterators; - this.current = iterators.remove(); - } - - @Override - public boolean advanceNext() { - boolean result = this.current.advanceNext(); - while(!result && !this.iterators.isEmpty()) { - this.current = iterators.remove(); - result = this.current.advanceNext(); - } - if (!result) { - this.row = null; - } else { - this.row = (UnsafeRow) this.current.getRow(); - } - return result; - } - - @Override - public UnsafeRow getRow() { - return row; - } - } - - @Override - public Iterator sort(Iterator inputIterator) throws IOException { - while (inputIterator.hasNext()) { - insertRow(inputIterator.next()); - } - return sort(); - } - - @Override - public void cleanupResources() { - isReleased = true; - partitionedTopNSorter.cleanupResources(); - } -} diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala index 65c5bac350c241dd1d94bc4c78b8002ca07b7cb0..0db345e92a34c99e7540061162a6078ca4bbc427 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.orc +import com.huawei.boostkit.spark.ColumnarPluginConfig.ENABLE_VEC_PREDICATE_FILTER import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce._ @@ -72,6 +73,7 @@ class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializ sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis val orcFilterPushDown = sparkSession.sessionState.conf.orcFilterPushDown + val vecPredicateFilter = sparkSession.sessionState.conf.getConf(ENABLE_VEC_PREDICATE_FILTER) (file: PartitionedFile) => { val conf = broadcastedConf.value.value @@ -79,11 +81,7 @@ class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializ val filePath = ShimUtil.getPartitionedFilePath(file) // ORC predicate pushdown - val pushed = if (orcFilterPushDown) { - filters.reduceOption(And(_, _)) - } else { - None - } + val pushed = filters.reduceOption(And(_, _)) val taskConf = new Configuration(conf) val fileSplit = new FileSplit(filePath, file.start, file.length, Array.empty) @@ -91,7 +89,8 @@ class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializ val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId) // read data from vectorized reader - val batchReader = new OmniOrcColumnarBatchReader(capacity, requiredSchema, pushed.orNull) + val batchReader = new OmniOrcColumnarBatchReader(capacity, requiredSchema, pushed.orNull, vecPredicateFilter, + orcFilterPushDown) // SPARK-23399 Register a task completion listener first to call `close()` in all cases. // There is a possibility that `initialize` and `initBatch` hit some errors (like OOM) // after opening a file. diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ OmniParquetOutputWriter.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ OmniParquetOutputWriter.scala new file mode 100644 index 0000000000000000000000000000000000000000..0f330b9f07543c52f074d568178781a49aa03567 --- /dev/null +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ OmniParquetOutputWriter.scala @@ -0,0 +1,75 @@ +/* + * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor.sparkTypeToOmniType +import com.huawei.boostkit.spark.jni.{OrcColumnarBatchWriter, ParquetColumnarBatchWriter} +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.security.UserGroupInformation +import org.apache.parquet.hadoop.ParquetOutputFormat +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.datasources.{OmniInternalRow, OutputWriter} +import org.apache.spark.sql.types.StructType + +import scala.Array.{emptyBooleanArray, emptyIntArray} + +// NOTE: This class is instantiated and used on executor side only, no need to be serializable. +class OmniParquetOutputWriter(path: String, dataSchema: StructType, + context: TaskAttemptContext) + extends OutputWriter { + + val writer = new ParquetColumnarBatchWriter() + var omniTypes: Array[Int] = emptyIntArray + var dataColumnsIds: Array[Boolean] = emptyBooleanArray + var allOmniTypes: Array[Int] = emptyIntArray + + def initialize(allColumns: Seq[Attribute], dataColumns: Seq[Attribute]): Unit = { + val filePath = new Path(path) + writer.initializeSchemaJava(dataSchema) + writer.initializeWriterJava(filePath) + omniTypes = dataSchema.fields + .map(field => sparkTypeToOmniType(field.dataType, field.metadata).getId.ordinal()) + .toArray + allOmniTypes = allColumns.toStructType.fields + .map(field => sparkTypeToOmniType(field.dataType, field.metadata).getId.ordinal()) + .toArray + dataColumnsIds = allColumns.map(x => dataColumns.contains(x)).toArray + } + + override def write(row: InternalRow): Unit = { + assert(row.isInstanceOf[OmniInternalRow]) + writer.write(omniTypes, dataColumnsIds, row.asInstanceOf[OmniInternalRow].batch) + } + + def spiltWrite(row: InternalRow, startPos: Long, endPos: Long): Unit = { + assert(row.isInstanceOf[OmniInternalRow]) + writer.splitWrite(omniTypes, allOmniTypes, dataColumnsIds, + row.asInstanceOf[OmniInternalRow].batch, startPos, endPos) + } + + override def close(): Unit = { + writer.close() + } + + override def path(): String = { + path + } +} diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala index ecef831d12c4bb473d4e9a43b4e842f35de5605c..b549cff662780c49e88668bebc80f05c26900cdc 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/OmniParquetFileFormat.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.apache.parquet.hadoop.util.ContextUtil import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.sql._ @@ -33,7 +34,12 @@ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration import org.apache.parquet.format.converter.ParquetMetadataConverter.SKIP_ROW_GROUPS + import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat} +import org.apache.parquet.hadoop.codec.CodecConfig +import org.apache.parquet.hadoop.ParquetOutputFormat.JobSummaryLevel +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.util.ShimUtil import java.net.URI @@ -49,28 +55,99 @@ class OmniParquetFileFormat extends FileFormat with DataSourceRegister with Logg override def equals(other: Any): Boolean = other.isInstanceOf[OmniParquetFileFormat] override def prepareWrite( - sparkSession: SparkSession, - job: Job, - options: Map[String, String], - dataSchema: StructType): OutputWriterFactory = { - throw new UnsupportedOperationException() + sparkSession: SparkSession, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + val parquetOptions = new ParquetOptions(options, sparkSession.sessionState.conf) + + val conf = ContextUtil.getConfiguration(job) + + val committerClass = + conf.getClass( + SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key, + classOf[ParquetOutputCommitter], + classOf[OutputCommitter]) + + if (conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) == null) { + logInfo("Using default output committer for Parquet: " + + classOf[ParquetOutputCommitter].getCanonicalName) + } else { + logInfo("Using user defined output committer for Parquet: " + committerClass.getCanonicalName) + } + + conf.setClass( + SQLConf.OUTPUT_COMMITTER_CLASS.key, + committerClass, + classOf[OutputCommitter]) + + // We're not really using `ParquetOutputFormat[Row]` for writing data here, because we override + // it in `ParquetOutputWriter` to support appending and dynamic partitioning. The reason why + // we set it here is to setup the output committer class to `ParquetOutputCommitter`, which is + // bundled with `ParquetOutputFormat[Row]`. + job.setOutputFormatClass(classOf[ParquetOutputFormat[Row]]) + + ParquetOutputFormat.setWriteSupportClass(job, classOf[ParquetWriteSupport]) + + // This metadata is useful for keeping UDTs like Vector/Matrix. + ParquetWriteSupport.setSchema(dataSchema, conf) + + // Sets flags for `ParquetWriteSupport`, which converts Catalyst schema to Parquet + // schema and writes actual rows to Parquet files. + conf.set( + SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key, + sparkSession.sessionState.conf.writeLegacyParquetFormat.toString) + + conf.set( + SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key, + sparkSession.sessionState.conf.parquetOutputTimestampType.toString) + + // Sets compression scheme + conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName) + + // SPARK-15719: Disables writing Parquet summary files by default. + if (conf.get(ParquetOutputFormat.JOB_SUMMARY_LEVEL) == null + && conf.get(ParquetOutputFormat.ENABLE_JOB_SUMMARY) == null) { + conf.setEnum(ParquetOutputFormat.JOB_SUMMARY_LEVEL, JobSummaryLevel.NONE) + } + + if (ParquetOutputFormat.getJobSummaryLevel(conf) != JobSummaryLevel.NONE + && !classOf[ParquetOutputCommitter].isAssignableFrom(committerClass)) { + // output summary is requested, but the class is not a Parquet Committer + logWarning(s"Committer $committerClass is not a ParquetOutputCommitter and cannot" + + s" create job summaries. " + + s"Set Parquet option ${ParquetOutputFormat.JOB_SUMMARY_LEVEL} to NONE.") + } + + new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new OmniParquetOutputWriter(path, dataSchema, context) + } + + override def getFileExtension(context: TaskAttemptContext): String = { + CodecConfig.from(context).getCodec.getExtension + ".parquet" + } + } } override def inferSchema( - sparkSession: SparkSession, - parameters: Map[String, String], - files: Seq[FileStatus]): Option[StructType] = { + sparkSession: SparkSession, + parameters: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { ParquetUtils.inferSchema(sparkSession, parameters, files) } override def buildReaderWithPartitionValues( - sparkSession: SparkSession, - dataSchema: StructType, - partitionSchema: StructType, - requiredSchema: StructType, - filters: Seq[Filter], - options: Map[String, String], - hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) @@ -89,7 +166,7 @@ class OmniParquetFileFormat extends FileFormat with DataSourceRegister with Logg (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) - val filePath = ShimUtil.getPartitionedFilePath(file) + val filePath = new Path(new URI(file.filePath.toString)) val split = new org.apache.parquet.hadoop.ParquetInputSplit( filePath, diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala index 511f33628a3020c5e3764ecaa165778a722b042f..cb9f2061f4c1ee61d0ea6aca6538ce5d6164ebeb 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala @@ -24,7 +24,7 @@ import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor.{checkOmniJsonWhiteList, isSimpleColumn, isSimpleColumnForAll} import com.huawei.boostkit.spark.util.OmniAdaptorUtil import com.huawei.boostkit.spark.util.OmniAdaptorUtil.{getExprIdForProjectList, getIndexArray, getProjectListIndex, pruneOutput, reorderOutputVecs, transColBatchToOmniVecs} -import nova.hetu.omniruntime.`type`.DataType +import nova.hetu.omniruntime.`type`.{DataType, BooleanDataType} import nova.hetu.omniruntime.operator.config.{OperatorConfig, OverflowConfig, SpillConfig} import nova.hetu.omniruntime.operator.join.{OmniHashBuilderWithExprOperatorFactory, OmniLookupJoinWithExprOperatorFactory, OmniLookupOuterJoinWithExprOperatorFactory} import nova.hetu.omniruntime.vector.VecBatch @@ -32,14 +32,15 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext -import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildSide} -import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType, LeftAnti, LeftExistence, LeftOuter, LeftSemi, RightOuter} +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType, LeftAnti, LeftExistence, LeftOuter, LeftSemi, RightOuter, ExistenceJoin} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{ExplainUtils, ShuffledHashJoinExecShim, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.util.SparkMemoryUtils import org.apache.spark.sql.execution.vectorized.OmniColumnVector import org.apache.spark.sql.vectorized.ColumnarBatch +import nova.hetu.omniruntime.constants.JoinType._ case class ColumnarShuffledHashJoinExec( leftKeys: Seq[Expression], @@ -114,7 +115,7 @@ case class ColumnarShuffledHashJoinExec( def buildCheck(): Unit = { joinType match { - case FullOuter | Inner | LeftAnti | LeftOuter | LeftSemi | RightOuter => + case FullOuter | Inner | LeftAnti | LeftOuter | LeftSemi | RightOuter | ExistenceJoin(_) => case _ => throw new UnsupportedOperationException(s"Join-type[${joinType}] is not supported " + s"in ${this.nodeName}") @@ -183,6 +184,8 @@ case class ColumnarShuffledHashJoinExec( val buildOutputCols: Array[Int] = joinType match { case Inner | FullOuter | LeftOuter | RightOuter => getIndexArray(buildOutput, projectExprIdList) + case ExistenceJoin(_) => + Array[Int](1) case LeftExistence(_) => Array[Int]() case x => @@ -195,9 +198,15 @@ case class ColumnarShuffledHashJoinExec( }.toArray val prunedBuildOutput = pruneOutput(buildOutput, projectExprIdList) - val buildOutputTypes = new Array[DataType](prunedBuildOutput.size) // {2,2}, buildOutput:col1#12,col2#13 - prunedBuildOutput.zipWithIndex.foreach { case (att, i) => - buildOutputTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(att.dataType, att.metadata) + var buildOutputTypes: Array[DataType] = null + val lookupJoinType = OmniExpressionAdaptor.toOmniJoinType(joinType) + if (lookupJoinType == OMNI_JOIN_TYPE_EXISTENCE) { + buildOutputTypes = Array[DataType](BooleanDataType.BOOLEAN) + } else { + buildOutputTypes = new Array[DataType](prunedBuildOutput.size) // {2,2}, buildOutput:col1#12,col2#13 + prunedBuildOutput.zipWithIndex.foreach { case (att, i) => + buildOutputTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(att.dataType, att.metadata) + } } val probeTypes = new Array[DataType](streamedOutput.size) @@ -223,7 +232,8 @@ case class ColumnarShuffledHashJoinExec( } val startBuildCodegen = System.nanoTime() val lookupJoinType = OmniExpressionAdaptor.toOmniJoinType(joinType) - val buildOpFactory = new OmniHashBuilderWithExprOperatorFactory(lookupJoinType, buildTypes, + val lookupBuildSide = OmniExpressionAdaptor.toOmniBuildSide(buildSide) + val buildOpFactory = new OmniHashBuilderWithExprOperatorFactory(lookupJoinType, lookupBuildSide, buildTypes, buildJoinColsExp, 1, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val buildOp = buildOpFactory.createOperator() @@ -337,7 +347,9 @@ case class ColumnarShuffledHashJoinExec( new ColumnarBatch(vecs.toArray, rowCnt) } } - if ("FULL OUTER" == joinType.sql) { + if (lookupJoinType == OMNI_JOIN_TYPE_FULL || + (lookupJoinType == OMNI_JOIN_TYPE_LEFT && buildSide.equals(BuildLeft)) || + (lookupJoinType == OMNI_JOIN_TYPE_RIGHT && buildSide.equals(BuildRight))) { val lookupOuterOpFactory = new OmniLookupOuterJoinWithExprOperatorFactory(probeTypes, probeOutputCols, probeHashColsExp, buildOutputCols, buildOutputTypes, buildOpFactory, diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/topnsort/UnsafeInMemoryTopNSorter.java b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/topnsort/UnsafeInMemoryTopNSorter.java deleted file mode 100644 index 7b14bb6694eec58c48cef5a96aa6626ff22ec431..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/topnsort/UnsafeInMemoryTopNSorter.java +++ /dev/null @@ -1,272 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.topnsort; - -import org.apache.spark.TaskContext; -import org.apache.spark.memory.MemoryConsumer; -import org.apache.spark.memory.TaskMemoryManager; -import org.apache.spark.sql.catalyst.expressions.UnsafeRow; -import org.apache.spark.unsafe.UnsafeAlignedOffset; -import org.apache.spark.unsafe.array.LongArray; -import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator; - -public final class UnsafeInMemoryTopNSorter { - - private final MemoryConsumer consumer; - private final TaskMemoryManager memoryManager; - private final UnsafePartitionedTopNSorter.TopNSortComparator sortComparator; - - /** - * Within this buffer, position {@code 2 * i} holds a pointer to the record at index {@code i}, - * while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. - * - * Only part of the array will be used to store the pointers, the rest part is preserved as temporary buffer for sorting. - */ - private LongArray array; - - /** - * The position in the sort buffer where new records can be inserted. - */ - private int nextEmptyPos = 0; - - // Top n. - private final int n; - private final boolean strictTopN; - - // The capacity of array. - private final int capacity; - private static final int MIN_ARRAY_CAPACITY = 64; - - public UnsafeInMemoryTopNSorter( - final int n, - final boolean strictTopN, - final MemoryConsumer consumer, - final TaskMemoryManager memoryManager, - final UnsafePartitionedTopNSorter.TopNSortComparator sortComparator) { - this.n = n; - this.strictTopN = strictTopN; - this.consumer = consumer; - this.memoryManager = memoryManager; - this.sortComparator = sortComparator; - this.capacity = Math.max(MIN_ARRAY_CAPACITY, Integer.highestOneBit(n) << 1); - // The size of Long array is equal to twice capacity because each item consists of a prefix and a pointer. - this.array = consumer.allocateArray(capacity << 1); - } - - /** - * Free the memory used by pointer array - */ - public void freeMemory() { - if (consumer != null) { - if (array != null) { - consumer.freeArray(array); - } - array = null; - } - nextEmptyPos = 0; - } - - public long getMemoryUsage() { - if (array == null) { - return 0L; - } - return array.size() * 8; - } - - public int insert(UnsafeRow row, long prefix) { - if (nextEmptyPos < n) { - return insertIntoArray(nextEmptyPos -1, row, prefix); - } else { - // reach n candidates - final int compareResult = nthRecordCompareTo(row, prefix); - if (compareResult < 0) { - // skip this record - return -1; - } - else if (compareResult == 0) { - if (strictTopN) { - // For rows that have duplicate values, skip it if this is strict TopN (e.g. RowNumber). - return -1; - } - // append record - checkForInsert(); - array.set((nextEmptyPos << 1) + 1, prefix); - return nextEmptyPos++; - } - else { - checkForInsert(); - // The record at position n -1 should be excluded, so we start comparing with record at position n - 2. - final int insertPosition = insertIntoArray(n - 2, row, prefix); - if (strictTopN || insertPosition == n - 1 || hasDistinctTopN()) { - nextEmptyPos = n; - } - // For other cases, 'nextEmptyPos' will move to the next empty position in 'insertIntoArray()'. - // e.g. given rank <= 4, and we already have 1, 2, 6, 6, so 'nextEmptyPos' is 4. - // If the new row is 3, then values in the array will be 1, 2, 3, 6, 6, and 'nextEmptyPos' will be 5. - return insertPosition; - } - } - } - - public void updateRecordPointer(int position, long pointer) { - array.set(position << 1, pointer); - } - - private int insertIntoArray(int position, UnsafeRow row, long prefix) { - // find insert position - while (position >= 0 && sortComparator.compare(array.get(position << 1), array.get((position << 1) + 1), row, prefix) > 0) { - --position; - } - final int insertPos = position + 1; - - // move records between 'insertPos' and 'nextEmptyPos' to next positions - for (int i = nextEmptyPos; i > insertPos; --i) { - int src = (i - 1) << 1; - int dst = i << 1; - array.set(dst, array.get(src)); - array.set(dst + 1, array.get(src + 1)); - } - - // Insert prefix of this row. Note that the address will be inserted by 'updateRecordPointer()' - // after we get its address from 'taskMemoryManager' - array.set((insertPos << 1) + 1, prefix); - ++nextEmptyPos; - return insertPos; - } - - private void checkForInsert() { - if (nextEmptyPos >= capacity) { - throw new IllegalStateException("No space for new record.\n" + - "For RANK expressions with TOP-N filter(e.g. rk <= 100), we maintain a fixed capacity " + - "array for TOP-N sorting for each partition, and if there are too many same rankings, " + - "the result that needs to be retained will exceed the capacity of the array.\n" + - "Please consider using ROW_NUMBER expression or disabling TOP-N sorting by setting " + - "saprk.sql.execution.topNPushDownFOrWindow.enabled to false."); - } - } - - private int nthRecordCompareTo(UnsafeRow row, long prefix) { - int nthPos = n - 1; - return sortComparator.compare(array.get(nthPos << 1), array.get((nthPos << 1) + 1), row, prefix); - } - - private boolean hasDistinctTopN() { - int nthPosition = (n - 1) << 1; - return sortComparator.compare(array.get(nthPosition), array.get(nthPosition + 1), // nth record - array.get(nthPosition + 2), array.get(nthPosition + 3)) // (n + 1)th record - != 0; // not eq - } - - /** - * This is copied from - * {@link org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.SortedIterator}. - * */ - public final class TopNSortedIterator extends UnsafeSorterIterator implements Cloneable { - private final int numRecords; - private int position; - private int offset; - private Object baseObject; - private long baseOffset; - private long keyPrefix; - private int recordLength; - private long currentPageNumber; - private final TaskContext taskContext = TaskContext.get(); - - private TopNSortedIterator(int numRecords, int offset) { - this.numRecords = numRecords; - this.position = 0; - this.offset = offset; - } - - public TopNSortedIterator clone() { - TopNSortedIterator iter = new TopNSortedIterator(numRecords, offset); - iter.position = position; - iter.baseObject = baseObject; - iter.baseOffset = baseOffset; - iter.keyPrefix = keyPrefix; - iter.recordLength = recordLength; - iter.currentPageNumber = currentPageNumber; - return iter; - } - - @Override - public int getNumRecords() { - return numRecords; - } - - @Override - public boolean hasNext() { - return position / 2 < numRecords; - } - - @Override - public void loadNext() { - // Kill the task in case it has been marked as killed. This logic is from - // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order - // to avoid performance overhead. This check is added here in 'loadNext()' instead of in - // 'hasNext()' because it's technically possible for the caller to be relying on - // 'getNumRecords()' instead of 'hasNext()' to know when to stop. - if (taskContext != null) { - taskContext.killTaskIfInterrupted(); - } - // This pointer points to a 4-byte record length, followed by the record's bytes - final long recordPointer = array.get(offset + position); - currentPageNumber = TaskMemoryManager.decodePageNumber(recordPointer); - int uaoSize = UnsafeAlignedOffset.getUaoSize(); - baseObject = memoryManager.getPage(recordPointer); - // Skip over record length - baseOffset = memoryManager.getOffsetInPage(recordPointer) + uaoSize; - recordLength = UnsafeAlignedOffset.getSize(baseObject, baseOffset - uaoSize); - keyPrefix = array.get(offset + position + 1); - position += 2; - } - - @Override - public Object getBaseObject() { - return baseObject; - } - - @Override - public long getBaseOffset() { - return baseOffset; - } - - @Override - public long getCurrentPageNumber() { - return currentPageNumber; - } - - @Override - public int getRecordLength() { - return recordLength; - } - - @Override - public long getKeyPrefix() { - return keyPrefix; - } - } - - /** - * Return an iterator over record pointers in sorted order. For efficiency, all calls to - * {@code next()} will return the same mutable object. - * */ - public UnsafeSorterIterator getSortedIterator() { - return new TopNSortedIterator(nextEmptyPos, 0); - } -} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/topnsort/UnsafePartitionedTopNSorter.java b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/topnsort/UnsafePartitionedTopNSorter.java deleted file mode 100644 index 57941aefb4fc8a3234c0ab22b6d45294ae09c639..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/topnsort/UnsafePartitionedTopNSorter.java +++ /dev/null @@ -1,263 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.topnsort; - -import java.util.*; -import java.util.function.Supplier; - -import com.google.common.annotations.VisibleForTesting; - -import org.apache.spark.TaskContext; -import org.apache.spark.memory.MemoryConsumer; -import org.apache.spark.memory.TaskMemoryManager; -import org.apache.spark.sql.catalyst.expressions.UnsafeRow; -import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.UnsafeAlignedOffset; -import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.util.collection.unsafe.sort.*; - -/** - * Partitioned top n sorter based on {@link org.apache.spark.sql.execution.topnsort.UnsafeInMemoryTopNSorter}. - * The implementation mostly refers to {@link UnsafeExternalSorter}. - * */ -public final class UnsafePartitionedTopNSorter extends MemoryConsumer { - private final TaskMemoryManager taskMemoryManager; - private TopNSortComparator sortComparator; - - /** - * Memory pages that hold the records being sorted. The pages in this list are freed when - * spilling, although in principle we could recycle these pages across spills (on the other hand, - * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager itself). - * */ - private final LinkedList allocatedPages = new LinkedList<>(); - private final Map partToSorters = new LinkedHashMap<>(); - - private final int n; - private final boolean strictTopN; - private MemoryBlock currentPage = null; - private long pageCursor = -1; - private long peakMemoryUsedBytes = 0; - - public static UnsafePartitionedTopNSorter create( - int n, - boolean strictTopN, - TaskMemoryManager taskMemoryManager, - TaskContext taskContext, - Supplier recordComparatorSupplier, - PrefixComparator prefixComparator, - long pageSizeBytes, - boolean canSortFullyWithPrefix) { - assert n > 0 : "Top n must be positive"; - assert recordComparatorSupplier != null; - return new UnsafePartitionedTopNSorter(n, strictTopN, taskMemoryManager, taskContext, - recordComparatorSupplier, prefixComparator, pageSizeBytes, canSortFullyWithPrefix); - } - - private UnsafePartitionedTopNSorter( - int n, - boolean strictTopN, - TaskMemoryManager taskMemoryManager, - TaskContext taskContext, - Supplier recordComparatorSupplier, - PrefixComparator prefixComparator, - long pageSizeBytes, - boolean canSortFullyWithPrefix) { - super(taskMemoryManager, pageSizeBytes, taskMemoryManager.getTungstenMemoryMode()); - this.n = n; - this.strictTopN = strictTopN; - this.taskMemoryManager = taskMemoryManager; - this.sortComparator = new TopNSortComparator(recordComparatorSupplier.get(), - prefixComparator, taskMemoryManager, canSortFullyWithPrefix); - - // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at - // the end of the task. This is necessary to avoid memory leaks in when the downstream operator - // does not fully consume the sorter's output (e.g. sort followed by limit). - taskContext.addTaskCompletionListener(context -> { - cleanupResources(); - }); - } - - @Override - public long spill(long size, MemoryConsumer trigger) { - throw new UnsupportedOperationException("Spill is unsupported operation in topN in-memory sorter"); - } - - /** - * Return the total memory usage of this sorter, including the data pages and the sorter's pointer array. - * */ - private long getMemoryUsage() { - long totalPageSize = 0; - for (MemoryBlock page : allocatedPages) { - totalPageSize += page.size(); - } - for (UnsafeInMemoryTopNSorter sorter : partToSorters.values()) { - totalPageSize += sorter.getMemoryUsage(); - } - return totalPageSize; - } - - private void updatePeakMemoryUsed() { - long mem = getMemoryUsage(); - if (mem > peakMemoryUsedBytes) { - peakMemoryUsedBytes = mem; - } - } - - /** - * Return the peak memory used so far, in bytes. - * */ - public long getPeakMemoryUsedBytes() { - updatePeakMemoryUsed(); - return peakMemoryUsedBytes; - } - - @VisibleForTesting - public int getNumberOfAllocatedPages() { - return allocatedPages.size(); - } - - /** - * Free this sorter's data pages. - * - * @return the number of bytes freed. - * */ - private long freeMemory() { - updatePeakMemoryUsed(); - long memoryFreed = 0; - for (MemoryBlock block : allocatedPages) { - memoryFreed += block.size(); - freePage(block); - } - allocatedPages.clear(); - currentPage = null; - pageCursor = 0; - for (UnsafeInMemoryTopNSorter sorter: partToSorters.values()) { - memoryFreed += sorter.getMemoryUsage(); - sorter.freeMemory(); - } - partToSorters.clear(); - sortComparator = null; - return memoryFreed; - } - - /** - * Frees this sorter's in-memory data structures and cleans up its spill files. - * */ - public void cleanupResources() { - synchronized (this) { - freeMemory(); - } - } - - /** - * Allocates an additional page in order to insert an additional record. This will request - * additional memory from the memory manager and spill if the requested memory can not be obtained. - * - * @param required the required space in the data page, in bytes, including space for storing the record size - * */ - private void acquireNewPageIfNecessary(int required) { - if (currentPage == null || - pageCursor + required > currentPage.getBaseOffset() + currentPage.size()) { - currentPage = allocatePage(required); - pageCursor = currentPage.getBaseOffset(); - allocatedPages.add(currentPage); - } - } - - public void insertRow(UnsafeRow partKey, UnsafeRow row, long prefix) { - UnsafeInMemoryTopNSorter sorter = - partToSorters.computeIfAbsent( - partKey, - k -> new UnsafeInMemoryTopNSorter(n, strictTopN, this, taskMemoryManager, sortComparator) - ); - final int position = sorter.insert(row, prefix); - if (position >= 0) { - final int uaoSize = UnsafeAlignedOffset.getUaoSize(); - // Need 4 or 8 bytes to store the record length. - final int length = row.getSizeInBytes(); - final int required = length + uaoSize; - acquireNewPageIfNecessary(required); - - final Object base = currentPage.getBaseObject(); - final long recordAddress = - taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor); - UnsafeAlignedOffset.putSize(base, pageCursor, length); - pageCursor += uaoSize; - Platform.copyMemory(row.getBaseObject(), row.getBaseOffset(), base, pageCursor, length); - pageCursor += length; - - sorter.updateRecordPointer(position, recordAddress); - } - } - - public Map getPartKeyToSorter() { - return partToSorters; - } - - static final class TopNSortComparator { - private final RecordComparator recordComparator; - private final PrefixComparator prefixComparator; - private final TaskMemoryManager memoryManager; - private final boolean needCompareFully; - - TopNSortComparator( - RecordComparator recordComparator, - PrefixComparator prefixComparator, - TaskMemoryManager memoryManager, - boolean canSortFullyWithPrefix) { - this.recordComparator = recordComparator; - this.prefixComparator = prefixComparator; - this.memoryManager = memoryManager; - this.needCompareFully = !canSortFullyWithPrefix; - } - - public int compare(long pointer1, long prefix1, long pointer2, long prefix2) { - final int prefixComparisonResult = prefixComparator.compare(prefix1, prefix2); - if (needCompareFully && prefixComparisonResult == 0) { - final int uaoSize = UnsafeAlignedOffset.getUaoSize(); - final Object baseObject1 = memoryManager.getPage(pointer1); - final long baseOffset1 = memoryManager.getOffsetInPage(pointer1) + uaoSize; - final int baseLength1 = UnsafeAlignedOffset.getSize(baseObject1, baseOffset1 - uaoSize); - final Object baseObject2 = memoryManager.getPage(pointer2); - final long baseOffset2 = memoryManager.getOffsetInPage(pointer2) + uaoSize; - final int baseLength2 = UnsafeAlignedOffset.getSize(baseObject2, baseOffset2 - uaoSize); - return recordComparator.compare(baseObject1, baseOffset1, baseLength1, baseObject2, - baseOffset2, baseLength2); - } else { - return prefixComparisonResult; - } - } - - public int compare(long pointer, long prefix1, UnsafeRow row, long prefix2) { - final int prefixComparisonResult = prefixComparator.compare(prefix1, prefix2); - if (needCompareFully && prefixComparisonResult == 0) { - final int uaoSize = UnsafeAlignedOffset.getUaoSize(); - final Object baseObject1 = memoryManager.getPage(pointer); - final long baseOffset1 = memoryManager.getOffsetInPage(pointer) + uaoSize; - final int baseLength1 = UnsafeAlignedOffset.getSize(baseObject1, baseOffset1 - uaoSize); - final Object baseObject2 = row.getBaseObject(); - final long baseOffset2 = row.getBaseOffset(); - final int baseLength2 = row.getSizeInBytes(); - return recordComparator.compare(baseObject1, baseOffset1, baseLength1, baseObject2, - baseOffset2, baseLength2); - } else { - return prefixComparisonResult; - } - } - } -} diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/util/ExecutorManager.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/util/ExecutorManager.scala new file mode 100644 index 0000000000000000000000000000000000000000..7ddb09f98f6a623088514af0348e4319a338f7d1 --- /dev/null +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/util/ExecutorManager.scala @@ -0,0 +1,58 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.util + +import com.huawei.boostkit.spark.{ColumnarPluginConfig, NumaBindingInfo} +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils +import org.apache.spark.{SparkContext, SparkEnv} + +import java.lang.management.ManagementFactory + +object ExecutorManager extends Logging { + var isTaskSet: Boolean = false + + def getExecutorIds(sc: SparkContext): Seq[String] = sc.getExecutorIds + + def tryTaskSet(numaInfo: NumaBindingInfo): Any = synchronized { + if (numaInfo.enableNumaBinding && !isTaskSet) { + val cmd_output = Utils.executeAndGetOutput(Seq("bash", "-c", "ps -ef | grep YarnCoarseGrainedExecutorBackend")) + val getExecutorId = """--executor-id (\d+)""".r + val executorIdOnLocalNode = { + val tmp = for (m <- getExecutorId.findAllMatchIn(cmd_output)) yield m.group(1) + tmp.toList.distinct + } + val executorId = SparkEnv.get.executorId + val coreRange = numaInfo.totalCoreRange + val shouldBindNumaIdx = executorIdOnLocalNode.indexOf(executorId) % coreRange.size + logInfo( + s"executorId is $executorId, executorIdOnLocalNode is $executorIdOnLocalNode") + val taskSetCmd = s"taskset -cpa ${coreRange(shouldBindNumaIdx)} ${getProcessId()}" + System.out.println(taskSetCmd) + + isTaskSet = true + Utils.executeCommand(Seq("bash", "-c", taskSetCmd)) + } + } + + def getProcessId(): Int = { + val runtimeMXBean = ManagementFactory.getRuntimeMXBean() + runtimeMXBean.getName().split("@")(0).toInt + } + +} diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/window/TopNPushDownForWindow.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/window/TopNPushDownForWindow.scala deleted file mode 100644 index d53c6e0286c21e026c5073335e96a5a00010a71a..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/window/TopNPushDownForWindow.scala +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.window; - -import com.huawei.boostkit.spark.ColumnarPluginConfig -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{FilterExec, SortExec, SparkPlan, TopNSortExec} - -object TopNPushDownForWindow extends Rule[SparkPlan] with PredicateHelper { - override def apply(plan: SparkPlan): SparkPlan = { - if (!ColumnarPluginConfig.getConf.topNPushDownForWindowEnable) { - return plan - } - - plan.transform { - case f @ FilterExec(condition, - w @ WindowExec(Seq(windowExpression), _, orderSpec, sort: SortExec)) - if orderSpec.nonEmpty && isTopNExpression(windowExpression) => - var topn = Int.MaxValue - val nonTopNConditions = splitConjunctivePredicates(condition).filter { - case LessThan(e: NamedExpression, IntegerLiteral(n)) - if e.exprId == windowExpression.exprId => - topn = Math.min(topn, n - 1) - false - case LessThanOrEqual(e: NamedExpression, IntegerLiteral(n)) - if e.exprId == windowExpression.exprId => - topn = Math.min(topn, n) - false - case GreaterThan(IntegerLiteral(n), e: NamedExpression) - if e.exprId == windowExpression.exprId => - topn = Math.min(topn, n - 1) - false - case GreaterThanOrEqual(IntegerLiteral(n), e: NamedExpression) - if e.exprId == windowExpression.exprId => - topn = Math.min(topn, n) - false - case EqualTo(e: NamedExpression, IntegerLiteral(n)) - if n == 1 && e.exprId == windowExpression.exprId => - topn = 1 - false - case EqualTo(IntegerLiteral(n), e: NamedExpression) - if n == 1 && e.exprId == windowExpression.exprId => - topn = 1 - false - case _ => true - } - - // topn <= SQLConf.get.topNPushDownForWindowThreshold 100. - if (topn> 0 && topn <= ColumnarPluginConfig.getConf.topNPushDownForWindowThreshold) { - val strictTopN = isStrictTopN(windowExpression) - val topNSortExec = TopNSortExec( - topn, strictTopN, w.partitionSpec, w.orderSpec, sort.global, sort.child) - val newCondition = if (nonTopNConditions.isEmpty) { - Literal.TrueLiteral - } else { - nonTopNConditions.reduce(And) - } - FilterExec(newCondition, w.copy(child = topNSortExec)) - } else { - f - } - } - } - - private def isTopNExpression(e: Expression): Boolean = e match { - case Alias(child, _) => isTopNExpression(child) - case WindowExpression(windowFunction, _) - if windowFunction.isInstanceOf[Rank] => true - case _ => false - } - - private def isStrictTopN(e: Expression): Boolean = e match { - case Alias(child, _) => isStrictTopN(child) - case WindowExpression(windowFunction, _) => windowFunction.isInstanceOf[RowNumber] - } -} diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java b/omnioperator/omniop-spark-extension/spark-extension-core/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java index 19f23db00d133ddb47c475c42003ace1ebecf675..9b28daf37b2529634ec23f14affccb8b2ae47adc 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java @@ -125,7 +125,7 @@ public class OrcColumnarBatchJniReaderTest extends TestCase { private void initRecordReaderJava() { orcColumnarBatchScanReader.recordReader = orcColumnarBatchScanReader. - initializeRecordReaderJava(offset, length, null, requiredSchema); + initializeRecordReaderJava(offset, length, null, requiredSchema, false, true); assertTrue(orcColumnarBatchScanReader.recordReader != 0); } diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderVecPredicateTest.java b/omnioperator/omniop-spark-extension/spark-extension-core/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderVecPredicateTest.java new file mode 100644 index 0000000000000000000000000000000000000000..ca5572ff7209fabbc884051407115db3f6d571a4 --- /dev/null +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderVecPredicateTest.java @@ -0,0 +1,119 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.jni; + +import com.huawei.boostkit.spark.predicate.AndPredicateCondition; +import com.huawei.boostkit.spark.predicate.LeafPredicateCondition; +import com.huawei.boostkit.spark.predicate.PredicateCondition; +import com.huawei.boostkit.spark.predicate.PredicateOperatorType; +import junit.framework.TestCase; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.vector.LongVec; +import nova.hetu.omniruntime.vector.VarcharVec; +import org.json.JSONObject; +import org.junit.After; +import org.junit.Before; +import org.junit.FixMethodOrder; +import org.junit.Test; +import org.junit.runners.MethodSorters; + +import java.io.File; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.ArrayList; + +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR; + +@FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) +public class OrcColumnarBatchJniReaderVecPredicateTest extends TestCase { + public OrcColumnarBatchScanReader orcColumnarBatchScanReader; + + @Before + public void setUp() throws Exception { + orcColumnarBatchScanReader = new OrcColumnarBatchScanReader(); + initReaderJava(); + initRecordReaderJava(); + initBatch(); + } + + @After + public void tearDown() throws Exception { + System.out.println("orcColumnarBatchScanReader test finished"); + } + + public void initReaderJava() { + File directory = new File("src/test/java/com/huawei/boostkit/spark/jni/orcsrc/000000_0"); + String absolutePath = directory.getAbsolutePath(); + System.out.println(absolutePath); + URI uri = null; + try { + uri = new URI(absolutePath); + } catch (URISyntaxException ignore) { + // if URISyntaxException thrown, next line assertNotNull will interrupt the test + } + assertNotNull(uri); + orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri); + assertTrue(orcColumnarBatchScanReader.reader != 0); + } + + public void initRecordReaderJava() { + JSONObject job = new JSONObject(); + job.put("include",""); + job.put("offset", 0); + job.put("length", 3345152); + + PredicateCondition leafLess = new LeafPredicateCondition(PredicateOperatorType.LESS_THAN_OR_EQUAL, 0, OMNI_LONG, "100"); + PredicateCondition leafIsNotNull = new LeafPredicateCondition(PredicateOperatorType.IS_NOT_NULL, 0, OMNI_LONG, "-1"); + PredicateCondition vecPredicateCondition = new AndPredicateCondition(leafLess, leafIsNotNull).reduce(); + job.put("vecPredicateCondition", vecPredicateCondition.toString()); + + ArrayList includedColumns = new ArrayList(); + includedColumns.add("i_item_sk"); + includedColumns.add("i_item_id"); + job.put("includedColumns", includedColumns.toArray()); + + orcColumnarBatchScanReader.recordReader = orcColumnarBatchScanReader.jniReader.initializeRecordReader(orcColumnarBatchScanReader.reader, job); + assertTrue(orcColumnarBatchScanReader.recordReader != 0); + } + + public void initBatch() { + orcColumnarBatchScanReader.batchReader = orcColumnarBatchScanReader.jniReader.initializeBatch(orcColumnarBatchScanReader.recordReader, 4096); + assertTrue(orcColumnarBatchScanReader.batchReader != 0); + } + + @Test + public void testNext() { + int[] typeId = new int[] {OMNI_LONG.ordinal(), OMNI_VARCHAR.ordinal()}; + long[] vecNativeId = new long[2]; + long rtn = orcColumnarBatchScanReader.jniReader.recordReaderNext(orcColumnarBatchScanReader.recordReader, orcColumnarBatchScanReader.batchReader, typeId, vecNativeId); + assertTrue(rtn == 100); + LongVec vec1 = new LongVec(vecNativeId[0]); + VarcharVec vec2 = new VarcharVec(vecNativeId[1]); + assertTrue(11 == vec1.get(10)); + assertTrue(21 == vec1.get(20)); + String tmp1 = new String(vec2.get(10)); + String tmp2 = new String(vec2.get(20)); + assertTrue(tmp1.equals("AAAAAAAAKAAAAAAA")); + assertTrue(tmp2.equals("AAAAAAAAEBAAAAAA")); + vec1.close(); + vec2.close(); + } + +} diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/test/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchJniReaderTest.java b/omnioperator/omniop-spark-extension/spark-extension-core/src/test/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchJniReaderTest.java index 3fb5b29284e197c664cc0090392e14b357a37a05..2c2f3a845df8db1274d30b0996f78977b264adda 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/test/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchJniReaderTest.java +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/test/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchJniReaderTest.java @@ -76,8 +76,7 @@ public class ParquetColumnarBatchJniReaderTest extends TestCase { .add("c8", createDecimalType(9, 8)) .add("c9", createDecimalType(18, 5)) .add("c10", BooleanType) - .add("c11", ShortType) - .add("c13", DateType); + .add("c11", ShortType); types = new ArrayList<>(); for (StructField f: schema.fields()) { diff --git a/omnioperator/omniop-spark-extension/spark-extension-shims/spark32-modify/src/main/scala/org/apache/spark/sql/util/ModifyUtil.scala b/omnioperator/omniop-spark-extension/spark-extension-shims/spark32-modify/src/main/scala/org/apache/spark/sql/util/ModifyUtil.scala index 3ef8a993ea76c62d9298d24f5c14d166f2b5ea59..9d642404370456478c040269cc69febc6c45e957 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-shims/spark32-modify/src/main/scala/org/apache/spark/sql/util/ModifyUtil.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-shims/spark32-modify/src/main/scala/org/apache/spark/sql/util/ModifyUtil.scala @@ -70,10 +70,13 @@ object ModifyUtil extends Logging { extensions.injectOptimizerRule(_ => CombineJoinedAggregates) } + private def addTransformableTagAdapter(plan: SparkPlan): Unit = {} + def registerFunc(): Unit = { ModifyUtilAdaptor.configRewriteJsonFunc(rewriteToOmniJsonExpressionAdapter) ModifyUtilAdaptor.configPreReplacePlanFunc(preReplaceSparkPlanAdapter) ModifyUtilAdaptor.configPostReplacePlanFunc(postReplaceSparkPlanAdapter) ModifyUtilAdaptor.configInjectRuleFunc(injectRuleAdapter) + ModifyUtilAdaptor.configAddTransformableFunc(addTransformableTagAdapter) } } diff --git a/omnioperator/omniop-spark-extension/spark-extension-shims/spark32-shim/src/main/scala/org/apache/spark/sql/execution/UnaryExecNodeShim.scala b/omnioperator/omniop-spark-extension/spark-extension-shims/spark32-shim/src/main/scala/org/apache/spark/sql/execution/UnaryExecNodeShim.scala new file mode 100644 index 0000000000000000000000000000000000000000..1052e70bf1f692d6b4a8e1440f593a468309f40d --- /dev/null +++ b/omnioperator/omniop-spark-extension/spark-extension-shims/spark32-shim/src/main/scala/org/apache/spark/sql/execution/UnaryExecNodeShim.scala @@ -0,0 +1,27 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.expressions.{NamedExpression, SortOrder} + +abstract class UnaryExecNodeShim(sortOrder: Seq[SortOrder], projectList: Seq[NamedExpression]) + extends UnaryExecNode { + + override def outputOrdering: Seq[SortOrder] = sortOrder +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/spark-extension-shims/spark32-shim/src/main/scala/org/apache/spark/sql/util/ShimUtil.scala b/omnioperator/omniop-spark-extension/spark-extension-shims/spark32-shim/src/main/scala/org/apache/spark/sql/util/ShimUtil.scala index 06bf24dc2ada8c38916c89d6ccae4d7b37141f86..d82e9ba1f0ea64350680110458f992f5864fdd38 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-shims/spark32-shim/src/main/scala/org/apache/spark/sql/util/ShimUtil.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-shims/spark32-shim/src/main/scala/org/apache/spark/sql/util/ShimUtil.scala @@ -29,7 +29,7 @@ import org.apache.spark.shuffle.sort.SortShuffleWriter import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriteMetricsReporter, ShuffleWriter} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Average, Count, Max, Min, Sum} import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl} -import org.apache.spark.sql.catalyst.expressions.{Attribute, BinaryOperator, CastBase, Expression} +import org.apache.spark.sql.catalyst.expressions.{Add, Attribute, BinaryOperator, CastBase, Divide, Expression, Multiply, Subtract} import org.apache.spark.sql.catalyst.optimizer.BuildSide import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, CTERelationDef, CTERelationRef, LogicalPlan, Statistics} @@ -39,11 +39,15 @@ import org.apache.spark.sql.types.{DataType, DateType, DecimalType, DoubleType, import java.net.URI import java.util.{Locale, Properties} +import nova.hetu.omniruntime.constants.FunctionType +import nova.hetu.omniruntime.constants.FunctionType.{OMNI_AGGREGATION_TYPE_SUM,OMNI_AGGREGATION_TYPE_AVG} object ShimUtil { def isSupportDataWriter: Boolean = false + def isNeedModifyBuildSide: Boolean = false + def createCTERelationRef(cteId: Long, resolved: Boolean, output: Seq[Attribute], isStreaming: Boolean, tatsOpt: Option[Statistics] = None): CTERelationRef = { CTERelationRef(cteId, resolved, output, tatsOpt) @@ -60,7 +64,23 @@ object ShimUtil { new TaskContextImpl(stageId, stageAttemptNumber, partitionId, taskAttemptId, attemptNumber, taskMemoryManager, localProperties, metricsSystem, taskMetrics, resources) } - def unsupportedEvalModeCheck(expr: Expression): Unit = {} + def transformExpressionByEvalMode(expr: Expression): String = { + expr match { + case add: Add => "ADD" + case sub: Subtract => "SUBTRACT" + case mult: Multiply => "MULTIPLY" + case divide: Divide => "DIVIDE" + case _ => throw new UnsupportedOperationException(s"Unsupported Operation for $expr") + } + } + + def transformFuncTypeByEvalMode(expr: Expression):FunctionType = { + expr match { + case sum: Sum => OMNI_AGGREGATION_TYPE_SUM + case avg: Average => OMNI_AGGREGATION_TYPE_AVG + case _ => throw new UnsupportedOperationException(s"Unsupported Operation for $expr") + } + } def binaryOperatorAdjust(expr: BinaryOperator, returnDataType: DataType): (Expression, Expression) = { (expr.left, expr.right) @@ -118,10 +138,6 @@ object ShimUtil { ) } - def buildBuildSide(buildSide: BuildSide, joinType: JoinType): BuildSide = { - buildSide - } - def createSortShuffleWriter[K, V, C](handle: BaseShuffleHandle[K, V, C], mapId: Long, context: TaskContext, diff --git a/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-modify/src/main/scala/org/apache/spark/sql/catalyst/catalog/CatalogTableCache.scala b/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-modify/src/main/scala/org/apache/spark/sql/catalyst/catalog/CatalogTableCache.scala new file mode 100644 index 0000000000000000000000000000000000000000..9b78a282698c293a296fd6fd2db19fd199882705 --- /dev/null +++ b/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-modify/src/main/scala/org/apache/spark/sql/catalyst/catalog/CatalogTableCache.scala @@ -0,0 +1,156 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import com.google.common.cache.{Cache, CacheBuilder} +import org.apache.spark.internal.Logging +import org.apache.spark.metrics.source.HiveCatalogMetrics +import org.apache.spark.sql.catalyst.QualifiedTableName +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, ExprId, Expression, LessThan, Literal} + +import java.util +import java.util.Objects +import java.util.concurrent.{Callable, TimeUnit} +import scala.collection.mutable + +case class CachedPrunedPartitionKey(qualifiedTableName: QualifiedTableName, predicates: Seq[Expression]) { + + private var hashComputed = false + private var hash = 0 + override def hashCode(): Int = { + if(hashComputed) { + return hash + } + var result = 1 + + for (element <- predicates) { + result = 31 * result + (if (element == null) 0 else element.hashCode) + } + + result = 31 * result + (if (qualifiedTableName == null) 0 else qualifiedTableName.hashCode) + + hash = result + hashComputed = true + hash + } + + override def equals(obj: Any): Boolean = { + if(obj == null) { + return false + } + if(!obj.isInstanceOf[CachedPrunedPartitionKey]) { + return false + } + val other = obj.asInstanceOf[CachedPrunedPartitionKey] + this.qualifiedTableName == other.qualifiedTableName && this.predicates.size == other.predicates.size && this.predicates.zip(other.predicates).forall((e1) => e1._1 == e1._2) + } +} + +class CatalogTableCache(cacheSize: Int, expire: Int) extends Logging { + + private var tableRelationCache: Cache[QualifiedTableName, CatalogTable] = _ + + private var prunedPartitionsCache: + Cache[CachedPrunedPartitionKey, + Seq[CatalogTablePartition]] = _ + + initCache(cacheSize, expire) + + private def initCache(cacheSize: Int, expire: Int): Unit = { + if(cacheSize <= 0) { + return + } + tableRelationCache = + CacheBuilder.newBuilder() + .maximumSize(cacheSize) + .build[QualifiedTableName, CatalogTable]() + prunedPartitionsCache = CacheBuilder.newBuilder() + .maximumSize(cacheSize) + .expireAfterAccess(expire, TimeUnit.SECONDS) + .build[CachedPrunedPartitionKey, Seq[CatalogTablePartition]]() + } + + def getCachedTable(db: String, table: String): CatalogTable = { + if (cacheSize <= 0) { + return null + } + tableRelationCache.getIfPresent(QualifiedTableName(db, table)) + } + + private val dummyExprId = ExprId(0, null) + + def cacheTable(db: String, table: String, value: CatalogTable): Unit = { + if (cacheSize <= 0) { + return + } + tableRelationCache.put(QualifiedTableName(db, table), value) + } + + def getCachedPartitions( + db: String, + table: String, + predicate: Seq[Expression], + callable: Callable[Seq[CatalogTablePartition]]): Seq[CatalogTablePartition] = { + if (cacheSize <= 0) { + return callable.call() + } + val newPredicates = predicate.map(e => e.transform { + case a: AttributeReference => + AttributeReference(a.name, a.dataType.asNullable)(exprId = dummyExprId) + }) + val key = CachedPrunedPartitionKey(QualifiedTableName(db, table), newPredicates) + prunedPartitionsCache.get(key, callable) + } + + def invalidateCachedTable(key: QualifiedTableName): Unit = { + if (cacheSize <= 0) { + return + } + tableRelationCache.invalidate(key) + } + + /** This method provides a way to invalidate all the cached CatalogTable. */ + def invalidateAllCachedTables(): Unit = { + if (cacheSize <= 0) { + return + } + tableRelationCache.invalidateAll() + } + + + /** This method provides a way to invalidate a cached plan. */ + def invalidateCachedPartition( + db: String, + table: String): Unit = { + if (cacheSize <= 0) { + return + } + prunedPartitionsCache.invalidate(QualifiedTableName(db, table)) + } + + + /** This method provides a way to invalidate all the cached plans. */ + def invalidateAllCachedPartition(): Unit = { + if (cacheSize <= 0) { + return + } + prunedPartitionsCache.invalidateAll() + } + +} diff --git a/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-modify/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-modify/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala new file mode 100644 index 0000000000000000000000000000000000000000..72e913a901bfb2eec7459d4e9c91f81117bd5b4d --- /dev/null +++ b/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-modify/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -0,0 +1,1906 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import java.net.URI +import java.util.Locale +import java.util.concurrent.Callable +import java.util.concurrent.TimeUnit +import javax.annotation.concurrent.GuardedBy +import scala.collection.mutable +import scala.util.{Failure, Success, Try} +import com.google.common.cache.{Cache, CacheBuilder} +import com.huawei.boostkit.spark.ColumnarPluginConfig +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, ExpressionInfo, UpCast} +import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserInterface} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, SubqueryAlias, View} +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} +import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, StringUtils} +import org.apache.spark.sql.connector.catalog.CatalogManager +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.StaticSQLConf.GLOBAL_TEMP_DATABASE +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.{CaseInsensitiveStringMap, PartitioningUtils} +import org.apache.spark.util.Utils + +object SessionCatalog { + val DEFAULT_DATABASE = "default" +} + +/** + * An internal catalog that is used by a Spark Session. This internal catalog serves as a + * proxy to the underlying metastore (e.g. Hive Metastore) and it also manages temporary + * views and functions of the Spark Session that it belongs to. + * + * This class must be thread-safe. + */ +class SessionCatalog( + externalCatalogBuilder: () => ExternalCatalog, + globalTempViewManagerBuilder: () => GlobalTempViewManager, + functionRegistry: FunctionRegistry, + tableFunctionRegistry: TableFunctionRegistry, + hadoopConf: Configuration, + parser: ParserInterface, + functionResourceLoader: FunctionResourceLoader, + functionExpressionBuilder: FunctionExpressionBuilder, + cacheSize: Int = SQLConf.get.tableRelationCacheSize, + cacheTTL: Long = SQLConf.get.metadataCacheTTL) extends SQLConfHelper with Logging { + import SessionCatalog._ + import CatalogTypes.TablePartitionSpec + + // For testing only. + def this( + externalCatalog: ExternalCatalog, + functionRegistry: FunctionRegistry, + tableFunctionRegistry: TableFunctionRegistry, + conf: SQLConf) = { + this( + () => externalCatalog, + () => new GlobalTempViewManager(conf.getConf(GLOBAL_TEMP_DATABASE)), + functionRegistry, + tableFunctionRegistry, + new Configuration(), + new CatalystSqlParser(), + DummyFunctionResourceLoader, + DummyFunctionExpressionBuilder, + conf.tableRelationCacheSize, + conf.metadataCacheTTL) + } + + // For testing only. + def this( + externalCatalog: ExternalCatalog, + functionRegistry: FunctionRegistry, + conf: SQLConf) = { + this(externalCatalog, functionRegistry, new SimpleTableFunctionRegistry, conf) + } + + // For testing only. + def this( + externalCatalog: ExternalCatalog, + functionRegistry: FunctionRegistry, + tableFunctionRegistry: TableFunctionRegistry) = { + this(externalCatalog, functionRegistry, tableFunctionRegistry, SQLConf.get) + } + + // For testing only. + def this(externalCatalog: ExternalCatalog, functionRegistry: FunctionRegistry) = { + this(externalCatalog, functionRegistry, SQLConf.get) + } + + // For testing only. + def this(externalCatalog: ExternalCatalog) = { + this(externalCatalog, new SimpleFunctionRegistry) + } + + lazy val externalCatalog = externalCatalogBuilder() + + val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf + + lazy val externalCatalogCache = new CatalogTableCache(columnarConf.catalogCacheSize, columnarConf.catalogCacheExpireTime) + + lazy val globalTempViewManager = globalTempViewManagerBuilder() + + /** List of temporary views, mapping from table name to their logical plan. */ + @GuardedBy("this") + protected val tempViews = new mutable.HashMap[String, TemporaryViewRelation] + + // Note: we track current database here because certain operations do not explicitly + // specify the database (e.g. DROP TABLE my_table). In these cases we must first + // check whether the temporary view or function exists, then, if not, operate on + // the corresponding item in the current database. + @GuardedBy("this") + protected var currentDb: String = formatDatabaseName(DEFAULT_DATABASE) + + private val validNameFormat = "([\\w_]+)".r + + /** + * Checks if the given name conforms the Hive standard ("[a-zA-Z_0-9]+"), + * i.e. if this name only contains characters, numbers, and _. + * + * This method is intended to have the same behavior of + * org.apache.hadoop.hive.metastore.MetaStoreUtils.validateName. + */ + private def validateName(name: String): Unit = { + if (!validNameFormat.pattern.matcher(name).matches()) { + throw QueryCompilationErrors.invalidNameForTableOrDatabaseError(name) + } + } + + /** + * Format table name, taking into account case sensitivity. + */ + protected[this] def formatTableName(name: String): String = { + if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT) + } + + /** + * Format database name, taking into account case sensitivity. + */ + protected[this] def formatDatabaseName(name: String): String = { + if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT) + } + + private val tableRelationCache: Cache[QualifiedTableName, LogicalPlan] = { + var builder = CacheBuilder.newBuilder() + .maximumSize(cacheSize) + + if (cacheTTL > 0) { + builder = builder.expireAfterWrite(cacheTTL, TimeUnit.SECONDS) + } + + builder.build[QualifiedTableName, LogicalPlan]() + } + + /** This method provides a way to get a cached plan. */ + def getCachedPlan(t: QualifiedTableName, c: Callable[LogicalPlan]): LogicalPlan = { + tableRelationCache.get(t, c) + } + + /** This method provides a way to get a cached plan if the key exists. */ + def getCachedTable(key: QualifiedTableName): LogicalPlan = { + tableRelationCache.getIfPresent(key) + } + + /** This method provides a way to cache a plan. */ + def cacheTable(t: QualifiedTableName, l: LogicalPlan): Unit = { + tableRelationCache.put(t, l) + } + + /** This method provides a way to invalidate a cached plan. */ + def invalidateCachedTable(key: QualifiedTableName): Unit = { + tableRelationCache.invalidate(key) + externalCatalogCache.invalidateCachedTable(key) + externalCatalogCache.invalidateCachedPartition(key.database, key.name) + } + + /** This method discards any cached table relation plans for the given table identifier. */ + def invalidateCachedTable(name: TableIdentifier): Unit = { + val dbName = formatDatabaseName(name.database.getOrElse(currentDb)) + val tableName = formatTableName(name.table) + invalidateCachedTable(QualifiedTableName(dbName, tableName)) + } + + /** This method provides a way to invalidate all the cached plans. */ + def invalidateAllCachedTables(): Unit = { + tableRelationCache.invalidateAll() + externalCatalogCache.invalidateAllCachedTables() + externalCatalogCache.invalidateAllCachedPartition() + } + + /** + * This method is used to make the given path qualified before we + * store this path in the underlying external catalog. So, when a path + * does not contain a scheme, this path will not be changed after the default + * FileSystem is changed. + */ + private def makeQualifiedPath(path: URI): URI = { + CatalogUtils.makeQualifiedPath(path, hadoopConf) + } + + private def requireDbExists(db: String): Unit = { + if (!databaseExists(db)) { + throw new NoSuchDatabaseException(db) + } + } + + private def requireTableExists(name: TableIdentifier): Unit = { + if (!tableExists(name)) { + val db = name.database.getOrElse(currentDb) + throw new NoSuchTableException(db = db, table = name.table) + } + } + + private def requireTableNotExists(name: TableIdentifier): Unit = { + if (tableExists(name)) { + val db = name.database.getOrElse(currentDb) + throw new TableAlreadyExistsException(db = db, table = name.table) + } + } + + // ---------------------------------------------------------------------------- + // Databases + // ---------------------------------------------------------------------------- + // All methods in this category interact directly with the underlying catalog. + // ---------------------------------------------------------------------------- + + def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = { + val dbName = formatDatabaseName(dbDefinition.name) + if (dbName == globalTempViewManager.database) { + throw QueryCompilationErrors.cannotCreateDatabaseWithSameNameAsPreservedDatabaseError( + globalTempViewManager.database) + } + validateName(dbName) + externalCatalog.createDatabase( + dbDefinition.copy(name = dbName, locationUri = makeQualifiedDBPath(dbDefinition.locationUri)), + ignoreIfExists) + } + + private def makeQualifiedDBPath(locationUri: URI): URI = { + CatalogUtils.makeQualifiedDBObjectPath(locationUri, conf.warehousePath, hadoopConf) + } + + def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = { + val dbName = formatDatabaseName(db) + if (dbName == DEFAULT_DATABASE) { + throw QueryCompilationErrors.cannotDropDefaultDatabaseError + } + if (!ignoreIfNotExists) { + requireDbExists(dbName) + } + if (cascade && databaseExists(dbName)) { + listTables(dbName).foreach { t => + invalidateCachedTable(QualifiedTableName(dbName, t.table)) + } + } + externalCatalog.dropDatabase(dbName, ignoreIfNotExists, cascade) + } + + def alterDatabase(dbDefinition: CatalogDatabase): Unit = { + val dbName = formatDatabaseName(dbDefinition.name) + requireDbExists(dbName) + externalCatalog.alterDatabase(dbDefinition.copy( + name = dbName, locationUri = makeQualifiedDBPath(dbDefinition.locationUri))) + } + + def getDatabaseMetadata(db: String): CatalogDatabase = { + val dbName = formatDatabaseName(db) + requireDbExists(dbName) + externalCatalog.getDatabase(dbName) + } + + def databaseExists(db: String): Boolean = { + val dbName = formatDatabaseName(db) + externalCatalog.databaseExists(dbName) + } + + def listDatabases(): Seq[String] = { + externalCatalog.listDatabases() + } + + def listDatabases(pattern: String): Seq[String] = { + externalCatalog.listDatabases(pattern) + } + + def getCurrentDatabase: String = synchronized { currentDb } + + def setCurrentDatabase(db: String): Unit = { + val dbName = formatDatabaseName(db) + if (dbName == globalTempViewManager.database) { + throw QueryCompilationErrors.cannotUsePreservedDatabaseAsCurrentDatabaseError( + globalTempViewManager.database) + } + requireDbExists(dbName) + synchronized { currentDb = dbName } + } + + /** + * Get the path for creating a non-default database when database location is not provided + * by users. + */ + def getDefaultDBPath(db: String): URI = { + CatalogUtils.stringToURI(formatDatabaseName(db) + ".db") + } + + // ---------------------------------------------------------------------------- + // Tables + // ---------------------------------------------------------------------------- + // There are two kinds of tables, temporary views and metastore tables. + // Temporary views are isolated across sessions and do not belong to any + // particular database. Metastore tables can be used across multiple + // sessions as their metadata is persisted in the underlying catalog. + // ---------------------------------------------------------------------------- + + // ---------------------------------------------------- + // | Methods that interact with metastore tables only | + // ---------------------------------------------------- + + /** + * Create a metastore table in the database specified in `tableDefinition`. + * If no such database is specified, create it in the current database. + */ + def createTable( + tableDefinition: CatalogTable, + ignoreIfExists: Boolean, + validateLocation: Boolean = true): Unit = { + val isExternal = tableDefinition.tableType == CatalogTableType.EXTERNAL + if (isExternal && tableDefinition.storage.locationUri.isEmpty) { + throw QueryCompilationErrors.createExternalTableWithoutLocationError + } + + val db = formatDatabaseName(tableDefinition.identifier.database.getOrElse(getCurrentDatabase)) + val table = formatTableName(tableDefinition.identifier.table) + val tableIdentifier = TableIdentifier(table, Some(db)) + validateName(table) + + val newTableDefinition = if (tableDefinition.storage.locationUri.isDefined + && !tableDefinition.storage.locationUri.get.isAbsolute) { + // make the location of the table qualified. + val qualifiedTableLocation = + makeQualifiedTablePath(tableDefinition.storage.locationUri.get, db) + tableDefinition.copy( + storage = tableDefinition.storage.copy(locationUri = Some(qualifiedTableLocation)), + identifier = tableIdentifier) + } else { + tableDefinition.copy(identifier = tableIdentifier) + } + + requireDbExists(db) + if (tableExists(newTableDefinition.identifier)) { + if (!ignoreIfExists) { + throw new TableAlreadyExistsException(db = db, table = table) + } + } else { + if (validateLocation) { + validateTableLocation(newTableDefinition) + } + externalCatalog.createTable(newTableDefinition, ignoreIfExists) + } + } + + def validateTableLocation(table: CatalogTable): Unit = { + // SPARK-19724: the default location of a managed table should be non-existent or empty. + if (table.tableType == CatalogTableType.MANAGED) { + val tableLocation = + new Path(table.storage.locationUri.getOrElse(defaultTablePath(table.identifier))) + val fs = tableLocation.getFileSystem(hadoopConf) + + if (fs.exists(tableLocation) && fs.listStatus(tableLocation).nonEmpty) { + throw QueryCompilationErrors.cannotOperateManagedTableWithExistingLocationError( + "create", table.identifier, tableLocation) + } + } + } + + private def makeQualifiedTablePath(locationUri: URI, database: String): URI = { + if (locationUri.isAbsolute) { + locationUri + } else if (new Path(locationUri).isAbsolute) { + makeQualifiedPath(locationUri) + } else { + val dbName = formatDatabaseName(database) + val dbLocation = makeQualifiedDBPath(getDatabaseMetadata(dbName).locationUri) + new Path(new Path(dbLocation), CatalogUtils.URIToString(locationUri)).toUri + } + } + + /** + * Alter the metadata of an existing metastore table identified by `tableDefinition`. + * + * If no database is specified in `tableDefinition`, assume the table is in the + * current database. + * + * Note: If the underlying implementation does not support altering a certain field, + * this becomes a no-op. + */ + def alterTable(tableDefinition: CatalogTable): Unit = { + val db = formatDatabaseName(tableDefinition.identifier.database.getOrElse(getCurrentDatabase)) + val table = formatTableName(tableDefinition.identifier.table) + val tableIdentifier = TableIdentifier(table, Some(db)) + requireDbExists(db) + requireTableExists(tableIdentifier) + val newTableDefinition = if (tableDefinition.storage.locationUri.isDefined + && !tableDefinition.storage.locationUri.get.isAbsolute) { + // make the location of the table qualified. + val qualifiedTableLocation = + makeQualifiedTablePath(tableDefinition.storage.locationUri.get, db) + tableDefinition.copy( + storage = tableDefinition.storage.copy(locationUri = Some(qualifiedTableLocation)), + identifier = tableIdentifier) + } else { + tableDefinition.copy(identifier = tableIdentifier) + } + + externalCatalog.alterTable(newTableDefinition) + } + + /** + * Alter the data schema of a table identified by the provided table identifier. The new data + * schema should not have conflict column names with the existing partition columns, and should + * still contain all the existing data columns. + * + * @param identifier TableIdentifier + * @param newDataSchema Updated data schema to be used for the table + */ + def alterTableDataSchema( + identifier: TableIdentifier, + newDataSchema: StructType): Unit = { + val db = formatDatabaseName(identifier.database.getOrElse(getCurrentDatabase)) + val table = formatTableName(identifier.table) + val tableIdentifier = TableIdentifier(table, Some(db)) + requireDbExists(db) + requireTableExists(tableIdentifier) + + val catalogTable = externalCatalog.getTable(db, table) + val oldDataSchema = catalogTable.dataSchema + // not supporting dropping columns yet + val nonExistentColumnNames = + oldDataSchema.map(_.name).filterNot(columnNameResolved(newDataSchema, _)) + if (nonExistentColumnNames.nonEmpty) { + throw QueryCompilationErrors.dropNonExistentColumnsNotSupportedError(nonExistentColumnNames) + } + + externalCatalog.alterTableDataSchema(db, table, newDataSchema) + } + + private def columnNameResolved(schema: StructType, colName: String): Boolean = { + schema.fields.map(_.name).exists(conf.resolver(_, colName)) + } + + /** + * Alter Spark's statistics of an existing metastore table identified by the provided table + * identifier. + */ + def alterTableStats(identifier: TableIdentifier, newStats: Option[CatalogStatistics]): Unit = { + val db = formatDatabaseName(identifier.database.getOrElse(getCurrentDatabase)) + val table = formatTableName(identifier.table) + val tableIdentifier = TableIdentifier(table, Some(db)) + requireDbExists(db) + requireTableExists(tableIdentifier) + externalCatalog.alterTableStats(db, table, newStats) + // Invalidate the table relation cache + refreshTable(identifier) + } + + /** + * Return whether a table/view with the specified name exists. If no database is specified, check + * with current database. + */ + def tableExists(name: TableIdentifier): Boolean = { + val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) + val table = formatTableName(name.table) + externalCatalog.tableExists(db, table) + } + + /** + * Retrieve the metadata of an existing permanent table/view. If no database is specified, + * assume the table/view is in the current database. + * We replace char/varchar with "annotated" string type in the table schema, as the query + * engine doesn't support char/varchar yet. + */ + @throws[NoSuchDatabaseException] + @throws[NoSuchTableException] + def getTableMetadata(name: TableIdentifier): CatalogTable = { + val t = getTableRawMetadata(name) + t.copy(schema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(t.schema)) + } + + /** + * Retrieve the metadata of an existing permanent table/view. If no database is specified, + * assume the table/view is in the current database. + */ + @throws[NoSuchDatabaseException] + @throws[NoSuchTableException] + def getTableRawMetadata(name: TableIdentifier): CatalogTable = { + val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) + val table = formatTableName(name.table) + requireDbExists(db) + requireTableExists(TableIdentifier(table, Some(db))) + externalCatalog.getTable(db, table) + } + + /** + * Retrieve all metadata of existing permanent tables/views. If no database is specified, + * assume the table/view is in the current database. + * Only the tables/views belong to the same database that can be retrieved are returned. + * For example, if none of the requested tables could be retrieved, an empty list is returned. + * There is no guarantee of ordering of the returned tables. + */ + @throws[NoSuchDatabaseException] + def getTablesByName(names: Seq[TableIdentifier]): Seq[CatalogTable] = { + if (names.nonEmpty) { + val dbs = names.map(_.database.getOrElse(getCurrentDatabase)) + if (dbs.distinct.size != 1) { + val tables = names.map(name => formatTableName(name.table)) + val qualifiedTableNames = dbs.zip(tables).map { case (d, t) => QualifiedTableName(d, t)} + throw QueryCompilationErrors.cannotRetrieveTableOrViewNotInSameDatabaseError( + qualifiedTableNames) + } + val db = formatDatabaseName(dbs.head) + requireDbExists(db) + val tables = names.map(name => formatTableName(name.table)) + externalCatalog.getTablesByName(db, tables) + } else { + Seq.empty + } + } + + /** + * Load files stored in given path into an existing metastore table. + * If no database is specified, assume the table is in the current database. + * If the specified table is not found in the database then a [[NoSuchTableException]] is thrown. + */ + def loadTable( + name: TableIdentifier, + loadPath: String, + isOverwrite: Boolean, + isSrcLocal: Boolean): Unit = { + val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) + val table = formatTableName(name.table) + requireDbExists(db) + requireTableExists(TableIdentifier(table, Some(db))) + externalCatalog.loadTable(db, table, loadPath, isOverwrite, isSrcLocal) + } + + /** + * Load files stored in given path into the partition of an existing metastore table. + * If no database is specified, assume the table is in the current database. + * If the specified table is not found in the database then a [[NoSuchTableException]] is thrown. + */ + def loadPartition( + name: TableIdentifier, + loadPath: String, + spec: TablePartitionSpec, + isOverwrite: Boolean, + inheritTableSpecs: Boolean, + isSrcLocal: Boolean): Unit = { + val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) + val table = formatTableName(name.table) + requireDbExists(db) + requireTableExists(TableIdentifier(table, Some(db))) + requireNonEmptyValueInPartitionSpec(Seq(spec)) + externalCatalog.loadPartition( + db, table, loadPath, spec, isOverwrite, inheritTableSpecs, isSrcLocal) + } + + def defaultTablePath(tableIdent: TableIdentifier): URI = { + val dbName = formatDatabaseName(tableIdent.database.getOrElse(getCurrentDatabase)) + val dbLocation = getDatabaseMetadata(dbName).locationUri + + new Path(new Path(dbLocation), formatTableName(tableIdent.table)).toUri + } + + // ---------------------------------------------- + // | Methods that interact with temp views only | + // ---------------------------------------------- + + /** + * Create a local temporary view. + */ + def createTempView( + name: String, + viewDefinition: TemporaryViewRelation, + overrideIfExists: Boolean): Unit = synchronized { + val table = formatTableName(name) + if (tempViews.contains(table) && !overrideIfExists) { + throw new TempTableAlreadyExistsException(name) + } + tempViews.put(table, viewDefinition) + } + + /** + * Create a global temporary view. + */ + def createGlobalTempView( + name: String, + viewDefinition: TemporaryViewRelation, + overrideIfExists: Boolean): Unit = { + globalTempViewManager.create(formatTableName(name), viewDefinition, overrideIfExists) + } + + /** + * Alter the definition of a local/global temp view matching the given name, returns true if a + * temp view is matched and altered, false otherwise. + */ + def alterTempViewDefinition( + name: TableIdentifier, + viewDefinition: TemporaryViewRelation): Boolean = synchronized { + val viewName = formatTableName(name.table) + if (name.database.isEmpty) { + if (tempViews.contains(viewName)) { + createTempView(viewName, viewDefinition, overrideIfExists = true) + true + } else { + false + } + } else if (formatDatabaseName(name.database.get) == globalTempViewManager.database) { + globalTempViewManager.update(viewName, viewDefinition) + } else { + false + } + } + + /** + * Return a local temporary view exactly as it was stored. + */ + def getRawTempView(name: String): Option[TemporaryViewRelation] = synchronized { + tempViews.get(formatTableName(name)) + } + + /** + * Generate a [[View]] operator from the temporary view stored. + */ + def getTempView(name: String): Option[View] = synchronized { + getRawTempView(name).map(getTempViewPlan) + } + + def getTempViewNames(): Seq[String] = synchronized { + tempViews.keySet.toSeq + } + + /** + * Return a global temporary view exactly as it was stored. + */ + def getRawGlobalTempView(name: String): Option[TemporaryViewRelation] = { + globalTempViewManager.get(formatTableName(name)) + } + + /** + * Generate a [[View]] operator from the global temporary view stored. + */ + def getGlobalTempView(name: String): Option[View] = { + getRawGlobalTempView(name).map(getTempViewPlan) + } + + /** + * Drop a local temporary view. + * + * Returns true if this view is dropped successfully, false otherwise. + */ + def dropTempView(name: String): Boolean = synchronized { + tempViews.remove(formatTableName(name)).isDefined + } + + /** + * Drop a global temporary view. + * + * Returns true if this view is dropped successfully, false otherwise. + */ + def dropGlobalTempView(name: String): Boolean = { + globalTempViewManager.remove(formatTableName(name)) + } + + // ------------------------------------------------------------- + // | Methods that interact with temporary and metastore tables | + // ------------------------------------------------------------- + + /** + * Retrieve the metadata of an existing temporary view or permanent table/view. + * + * If a database is specified in `name`, this will return the metadata of table/view in that + * database. + * If no database is specified, this will first attempt to get the metadata of a temporary view + * with the same name, then, if that does not exist, return the metadata of table/view in the + * current database. + */ + def getTempViewOrPermanentTableMetadata(name: TableIdentifier): CatalogTable = synchronized { + val table = formatTableName(name.table) + if (name.database.isEmpty) { + tempViews.get(table).map(_.tableMeta).getOrElse(getTableMetadata(name)) + } else if (formatDatabaseName(name.database.get) == globalTempViewManager.database) { + globalTempViewManager.get(table).map(_.tableMeta) + .getOrElse(throw new NoSuchTableException(globalTempViewManager.database, table)) + } else { + getTableMetadata(name) + } + } + + /** + * Rename a table. + * + * If a database is specified in `oldName`, this will rename the table in that database. + * If no database is specified, this will first attempt to rename a temporary view with + * the same name, then, if that does not exist, rename the table in the current database. + * + * This assumes the database specified in `newName` matches the one in `oldName`. + */ + def renameTable(oldName: TableIdentifier, newName: TableIdentifier): Unit = synchronized { + val db = formatDatabaseName(oldName.database.getOrElse(currentDb)) + newName.database.map(formatDatabaseName).foreach { newDb => + if (db != newDb) { + throw QueryCompilationErrors.renameTableSourceAndDestinationMismatchError(db, newDb) + } + } + + val oldTableName = formatTableName(oldName.table) + val newTableName = formatTableName(newName.table) + if (db == globalTempViewManager.database) { + globalTempViewManager.rename(oldTableName, newTableName) + } else { + requireDbExists(db) + if (oldName.database.isDefined || !tempViews.contains(oldTableName)) { + validateName(newTableName) + validateNewLocationOfRename( + TableIdentifier(oldTableName, Some(db)), TableIdentifier(newTableName, Some(db))) + externalCatalog.renameTable(db, oldTableName, newTableName) + } else { + if (newName.database.isDefined) { + throw QueryCompilationErrors.cannotRenameTempViewWithDatabaseSpecifiedError( + oldName, newName) + } + if (tempViews.contains(newTableName)) { + throw QueryCompilationErrors.cannotRenameTempViewToExistingTableError( + oldName, newName) + } + val table = tempViews(oldTableName) + tempViews.remove(oldTableName) + tempViews.put(newTableName, table) + } + } + } + + /** + * Drop a table. + * + * If a database is specified in `name`, this will drop the table from that database. + * If no database is specified, this will first attempt to drop a temporary view with + * the same name, then, if that does not exist, drop the table from the current database. + */ + def dropTable( + name: TableIdentifier, + ignoreIfNotExists: Boolean, + purge: Boolean): Unit = synchronized { + val db = formatDatabaseName(name.database.getOrElse(currentDb)) + val table = formatTableName(name.table) + if (db == globalTempViewManager.database) { + val viewExists = globalTempViewManager.remove(table) + if (!viewExists && !ignoreIfNotExists) { + throw new NoSuchTableException(globalTempViewManager.database, table) + } + } else { + if (name.database.isDefined || !tempViews.contains(table)) { + requireDbExists(db) + // When ignoreIfNotExists is false, no exception is issued when the table does not exist. + // Instead, log it as an error message. + if (tableExists(TableIdentifier(table, Option(db)))) { + externalCatalog.dropTable(db, table, ignoreIfNotExists = true, purge = purge) + } else if (!ignoreIfNotExists) { + throw new NoSuchTableException(db = db, table = table) + } + } else { + tempViews.remove(table) + } + } + } + + /** + * Return a [[LogicalPlan]] that represents the given table or view. + * + * If a database is specified in `name`, this will return the table/view from that database. + * If no database is specified, this will first attempt to return a temporary view with + * the same name, then, if that does not exist, return the table/view from the current database. + * + * Note that, the global temp view database is also valid here, this will return the global temp + * view matching the given name. + * + * If the relation is a view, we generate a [[View]] operator from the view description, and + * wrap the logical plan in a [[SubqueryAlias]] which will track the name of the view. + * [[SubqueryAlias]] will also keep track of the name and database(optional) of the table/view + * + * @param name The name of the table/view that we look up. + */ + def lookupRelation(name: TableIdentifier): LogicalPlan = { + synchronized { + val db = formatDatabaseName(name.database.getOrElse(currentDb)) + val table = formatTableName(name.table) + if (db == globalTempViewManager.database) { + globalTempViewManager.get(table).map { viewDef => + SubqueryAlias(table, db, getTempViewPlan(viewDef)) + }.getOrElse(throw new NoSuchTableException(db, table)) + } else if (name.database.isDefined || !tempViews.contains(table)) { + val cacheMetadata = externalCatalogCache.getCachedTable(db, table) + if(cacheMetadata == null) { + val metadata = externalCatalog.getTable(db, table) + externalCatalogCache.cacheTable(db, table, metadata) + getRelation(metadata) + } else { + getRelation(cacheMetadata) + } + } else { + SubqueryAlias(table, getTempViewPlan(tempViews(table))) + } + } + } + + def getRelation( + metadata: CatalogTable, + options: CaseInsensitiveStringMap = CaseInsensitiveStringMap.empty()): LogicalPlan = { + val name = metadata.identifier + val db = formatDatabaseName(name.database.getOrElse(currentDb)) + val table = formatTableName(name.table) + val multiParts = Seq(CatalogManager.SESSION_CATALOG_NAME, db, table) + + if (metadata.tableType == CatalogTableType.VIEW) { + // The relation is a view, so we wrap the relation by: + // 1. Add a [[View]] operator over the relation to keep track of the view desc; + // 2. Wrap the logical plan in a [[SubqueryAlias]] which tracks the name of the view. + SubqueryAlias(multiParts, fromCatalogTable(metadata, isTempView = false)) + } else { + SubqueryAlias(multiParts, UnresolvedCatalogRelation(metadata, options)) + } + } + + private def getTempViewPlan(viewInfo: TemporaryViewRelation): View = viewInfo.plan match { + case Some(p) => View(desc = viewInfo.tableMeta, isTempView = true, child = p) + case None => fromCatalogTable(viewInfo.tableMeta, isTempView = true) + } + + private def buildViewDDL(metadata: CatalogTable, isTempView: Boolean): Option[String] = { + if (isTempView) { + None + } else { + val viewName = metadata.identifier.unquotedString + val viewText = metadata.viewText.get + val userSpecifiedColumns = + if (metadata.schema.fieldNames.toSeq == metadata.viewQueryColumnNames) { + "" + } else { + s"(${metadata.schema.fieldNames.mkString(", ")})" + } + Some(s"CREATE OR REPLACE VIEW $viewName $userSpecifiedColumns AS $viewText") + } + } + + private def isHiveCreatedView(metadata: CatalogTable): Boolean = { + // For views created by hive without explicit column names, there will be auto-generated + // column names like "_c0", "_c1", "_c2"... + metadata.viewQueryColumnNames.isEmpty && + metadata.schema.fieldNames.exists(_.matches("_c[0-9]+")) + } + + private def fromCatalogTable(metadata: CatalogTable, isTempView: Boolean): View = { + val viewText = metadata.viewText.getOrElse { + throw new IllegalStateException("Invalid view without text.") + } + val viewConfigs = metadata.viewSQLConfigs + val origin = Origin( + objectType = Some("VIEW"), + objectName = Some(metadata.qualifiedName) + ) + val parsedPlan = SQLConf.withExistingConf(View.effectiveSQLConf(viewConfigs, isTempView)) { + try { + CurrentOrigin.withOrigin(origin) { + parser.parseQuery(viewText) + } + } catch { + case _: ParseException => + throw QueryCompilationErrors.invalidViewText(viewText, metadata.qualifiedName) + } + } + val projectList = if (!isHiveCreatedView(metadata)) { + val viewColumnNames = if (metadata.viewQueryColumnNames.isEmpty) { + // For view created before Spark 2.2.0, the view text is already fully qualified, the plan + // output is the same with the view output. + metadata.schema.fieldNames.toSeq + } else { + assert(metadata.viewQueryColumnNames.length == metadata.schema.length) + metadata.viewQueryColumnNames + } + + // For view queries like `SELECT * FROM t`, the schema of the referenced table/view may + // change after the view has been created. We need to add an extra SELECT to pick the columns + // according to the recorded column names (to get the correct view column ordering and omit + // the extra columns that we don't require), with UpCast (to make sure the type change is + // safe) and Alias (to respect user-specified view column names) according to the view schema + // in the catalog. + // Note that, the column names may have duplication, e.g. `CREATE VIEW v(x, y) AS + // SELECT 1 col, 2 col`. We need to make sure that the matching attributes have the same + // number of duplications, and pick the corresponding attribute by ordinal. + val viewConf = View.effectiveSQLConf(metadata.viewSQLConfigs, isTempView) + val normalizeColName: String => String = if (viewConf.caseSensitiveAnalysis) { + identity + } else { + _.toLowerCase(Locale.ROOT) + } + val nameToCounts = viewColumnNames.groupBy(normalizeColName).mapValues(_.length) + val nameToCurrentOrdinal = scala.collection.mutable.HashMap.empty[String, Int] + val viewDDL = buildViewDDL(metadata, isTempView) + + viewColumnNames.zip(metadata.schema).map { case (name, field) => + val normalizedName = normalizeColName(name) + val count = nameToCounts(normalizedName) + val ordinal = nameToCurrentOrdinal.getOrElse(normalizedName, 0) + nameToCurrentOrdinal(normalizedName) = ordinal + 1 + val col = GetViewColumnByNameAndOrdinal( + metadata.identifier.toString, name, ordinal, count, viewDDL) + Alias(UpCast(col, field.dataType), field.name)(explicitMetadata = Some(field.metadata)) + } + } else { + // For view created by hive, the parsed view plan may have different output columns with + // the schema stored in metadata. For example: `CREATE VIEW v AS SELECT 1 FROM t` + // the schema in metadata will be `_c0` while the parsed view plan has column named `1` + metadata.schema.zipWithIndex.map { case (field, index) => + val col = GetColumnByOrdinal(index, field.dataType) + Alias(UpCast(col, field.dataType), field.name)(explicitMetadata = Some(field.metadata)) + } + } + View(desc = metadata, isTempView = isTempView, child = Project(projectList, parsedPlan)) + } + + def lookupTempView(table: String): Option[SubqueryAlias] = { + val formattedTable = formatTableName(table) + getTempView(formattedTable).map { view => + SubqueryAlias(formattedTable, view) + } + } + + def lookupGlobalTempView(db: String, table: String): Option[SubqueryAlias] = { + val formattedDB = formatDatabaseName(db) + if (formattedDB == globalTempViewManager.database) { + val formattedTable = formatTableName(table) + getGlobalTempView(formattedTable).map { view => + SubqueryAlias(formattedTable, formattedDB, view) + } + } else { + None + } + } + + /** + * Return whether the given name parts belong to a temporary or global temporary view. + */ + def isTempView(nameParts: Seq[String]): Boolean = { + if (nameParts.length > 2) return false + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + isTempView(nameParts.asTableIdentifier) + } + + def lookupTempView(name: TableIdentifier): Option[View] = { + val tableName = formatTableName(name.table) + if (name.database.isEmpty) { + tempViews.get(tableName).map(getTempViewPlan) + } else if (formatDatabaseName(name.database.get) == globalTempViewManager.database) { + globalTempViewManager.get(tableName).map(getTempViewPlan) + } else { + None + } + } + + /** + * Return whether a table with the specified name is a temporary view. + * + * Note: The temporary view cache is checked only when database is not + * explicitly specified. + */ + def isTempView(name: TableIdentifier): Boolean = synchronized { + lookupTempView(name).isDefined + } + + def isView(nameParts: Seq[String]): Boolean = { + nameParts.length <= 2 && { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + val ident = nameParts.asTableIdentifier + try { + getTempViewOrPermanentTableMetadata(ident).tableType == CatalogTableType.VIEW + } catch { + case _: NoSuchTableException => false + case _: NoSuchDatabaseException => false + case _: NoSuchNamespaceException => false + } + } + } + + /** + * List all tables in the specified database, including local temporary views. + * + * Note that, if the specified database is global temporary view database, we will list global + * temporary views. + */ + def listTables(db: String): Seq[TableIdentifier] = listTables(db, "*") + + /** + * List all matching tables in the specified database, including local temporary views. + * + * Note that, if the specified database is global temporary view database, we will list global + * temporary views. + */ + def listTables(db: String, pattern: String): Seq[TableIdentifier] = listTables(db, pattern, true) + + /** + * List all matching tables in the specified database, including local temporary views + * if includeLocalTempViews is enabled. + * + * Note that, if the specified database is global temporary view database, we will list global + * temporary views. + */ + def listTables( + db: String, + pattern: String, + includeLocalTempViews: Boolean): Seq[TableIdentifier] = { + val dbName = formatDatabaseName(db) + val dbTables = if (dbName == globalTempViewManager.database) { + globalTempViewManager.listViewNames(pattern).map { name => + TableIdentifier(name, Some(globalTempViewManager.database)) + } + } else { + requireDbExists(dbName) + externalCatalog.listTables(dbName, pattern).map { name => + TableIdentifier(name, Some(dbName)) + } + } + + if (includeLocalTempViews) { + dbTables ++ listLocalTempViews(pattern) + } else { + dbTables + } + } + + /** + * List all matching views in the specified database, including local temporary views. + */ + def listViews(db: String, pattern: String): Seq[TableIdentifier] = { + val dbName = formatDatabaseName(db) + val dbViews = if (dbName == globalTempViewManager.database) { + globalTempViewManager.listViewNames(pattern).map { name => + TableIdentifier(name, Some(globalTempViewManager.database)) + } + } else { + requireDbExists(dbName) + externalCatalog.listViews(dbName, pattern).map { name => + TableIdentifier(name, Some(dbName)) + } + } + + dbViews ++ listLocalTempViews(pattern) + } + + /** + * List all matching local temporary views. + */ + def listLocalTempViews(pattern: String): Seq[TableIdentifier] = { + synchronized { + StringUtils.filterPattern(tempViews.keys.toSeq, pattern).map { name => + TableIdentifier(name) + } + } + } + + /** + * Refresh table entries in structures maintained by the session catalog such as: + * - The map of temporary or global temporary view names to their logical plans + * - The relation cache which maps table identifiers to their logical plans + * + * For temp views, it refreshes their logical plans, and as a consequence of that it can refresh + * the file indexes of the base relations (`HadoopFsRelation` for instance) used in the views. + * The method still keeps the views in the internal lists of session catalog. + * + * For tables/views, it removes their entries from the relation cache. + * + * The method is supposed to use in the following situations: + * 1. The logical plan of a table/view was changed, and cached table/view data is cleared + * explicitly. For example, like in `AlterTableRenameCommand` which re-caches the table + * itself. Otherwise if you need to refresh cached data, consider using of + * `CatalogImpl.refreshTable()`. + * 2. A table/view doesn't exist, and need to only remove its entry in the relation cache since + * the cached data is invalidated explicitly like in `DropTableCommand` which uncaches + * table/view data itself. + * 3. Meta-data (such as file indexes) of any relation used in a temporary view should be + * updated. + */ + def refreshTable(name: TableIdentifier): Unit = synchronized { + lookupTempView(name).map(_.refresh).getOrElse { + val dbName = formatDatabaseName(name.database.getOrElse(currentDb)) + val tableName = formatTableName(name.table) + val qualifiedTableName = QualifiedTableName(dbName, tableName) + tableRelationCache.invalidate(qualifiedTableName) + externalCatalogCache.invalidateCachedTable(qualifiedTableName) + externalCatalogCache.invalidateCachedPartition(dbName, tableName) + } + } + + /** + * Drop all existing temporary views. + * For testing only. + */ + def clearTempTables(): Unit = synchronized { + tempViews.clear() + } + + // ---------------------------------------------------------------------------- + // Partitions + // ---------------------------------------------------------------------------- + // All methods in this category interact directly with the underlying catalog. + // These methods are concerned with only metastore tables. + // ---------------------------------------------------------------------------- + + // TODO: We need to figure out how these methods interact with our data source + // tables. For such tables, we do not store values of partitioning columns in + // the metastore. For now, partition values of a data source table will be + // automatically discovered when we load the table. + + /** + * Create partitions in an existing table, assuming it exists. + * If no database is specified, assume the table is in the current database. + */ + def createPartitions( + tableName: TableIdentifier, + parts: Seq[CatalogTablePartition], + ignoreIfExists: Boolean): Unit = { + val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) + val table = formatTableName(tableName.table) + requireDbExists(db) + requireTableExists(TableIdentifier(table, Option(db))) + requireExactMatchedPartitionSpec(parts.map(_.spec), getTableMetadata(tableName)) + requireNonEmptyValueInPartitionSpec(parts.map(_.spec)) + externalCatalog.createPartitions( + db, table, partitionWithQualifiedPath(tableName, parts), ignoreIfExists) + } + + /** + * Drop partitions from a table, assuming they exist. + * If no database is specified, assume the table is in the current database. + */ + def dropPartitions( + tableName: TableIdentifier, + specs: Seq[TablePartitionSpec], + ignoreIfNotExists: Boolean, + purge: Boolean, + retainData: Boolean): Unit = { + val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) + val table = formatTableName(tableName.table) + requireDbExists(db) + requireTableExists(TableIdentifier(table, Option(db))) + requirePartialMatchedPartitionSpec(specs, getTableMetadata(tableName)) + requireNonEmptyValueInPartitionSpec(specs) + externalCatalog.dropPartitions(db, table, specs, ignoreIfNotExists, purge, retainData) + } + + /** + * Override the specs of one or many existing table partitions, assuming they exist. + * + * This assumes index i of `specs` corresponds to index i of `newSpecs`. + * If no database is specified, assume the table is in the current database. + */ + def renamePartitions( + tableName: TableIdentifier, + specs: Seq[TablePartitionSpec], + newSpecs: Seq[TablePartitionSpec]): Unit = { + val tableMetadata = getTableMetadata(tableName) + val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) + val table = formatTableName(tableName.table) + requireDbExists(db) + requireTableExists(TableIdentifier(table, Option(db))) + requireExactMatchedPartitionSpec(specs, tableMetadata) + requireExactMatchedPartitionSpec(newSpecs, tableMetadata) + requireNonEmptyValueInPartitionSpec(specs) + requireNonEmptyValueInPartitionSpec(newSpecs) + externalCatalog.renamePartitions(db, table, specs, newSpecs) + } + + /** + * Alter one or many table partitions whose specs that match those specified in `parts`, + * assuming the partitions exist. + * + * If no database is specified, assume the table is in the current database. + * + * Note: If the underlying implementation does not support altering a certain field, + * this becomes a no-op. + */ + def alterPartitions(tableName: TableIdentifier, parts: Seq[CatalogTablePartition]): Unit = { + val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) + val table = formatTableName(tableName.table) + requireDbExists(db) + requireTableExists(TableIdentifier(table, Option(db))) + requireExactMatchedPartitionSpec(parts.map(_.spec), getTableMetadata(tableName)) + requireNonEmptyValueInPartitionSpec(parts.map(_.spec)) + externalCatalog.alterPartitions(db, table, partitionWithQualifiedPath(tableName, parts)) + } + + /** + * Retrieve the metadata of a table partition, assuming it exists. + * If no database is specified, assume the table is in the current database. + */ + def getPartition(tableName: TableIdentifier, spec: TablePartitionSpec): CatalogTablePartition = { + val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) + val table = formatTableName(tableName.table) + requireDbExists(db) + requireTableExists(TableIdentifier(table, Option(db))) + requireExactMatchedPartitionSpec(Seq(spec), getTableMetadata(tableName)) + requireNonEmptyValueInPartitionSpec(Seq(spec)) + externalCatalog.getPartition(db, table, spec) + } + + /** + * List the names of all partitions that belong to the specified table, assuming it exists. + * + * A partial partition spec may optionally be provided to filter the partitions returned. + * For instance, if there exist partitions (a='1', b='2'), (a='1', b='3') and (a='2', b='4'), + * then a partial spec of (a='1') will return the first two only. + */ + def listPartitionNames( + tableName: TableIdentifier, + partialSpec: Option[TablePartitionSpec] = None): Seq[String] = { + val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) + val table = formatTableName(tableName.table) + requireDbExists(db) + requireTableExists(TableIdentifier(table, Option(db))) + partialSpec.foreach { spec => + requirePartialMatchedPartitionSpec(Seq(spec), getTableMetadata(tableName)) + requireNonEmptyValueInPartitionSpec(Seq(spec)) + } + externalCatalog.listPartitionNames(db, table, partialSpec) + } + + /** + * List the metadata of all partitions that belong to the specified table, assuming it exists. + * + * A partial partition spec may optionally be provided to filter the partitions returned. + * For instance, if there exist partitions (a='1', b='2'), (a='1', b='3') and (a='2', b='4'), + * then a partial spec of (a='1') will return the first two only. + */ + def listPartitions( + tableName: TableIdentifier, + partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = { + val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) + val table = formatTableName(tableName.table) + requireDbExists(db) + requireTableExists(TableIdentifier(table, Option(db))) + partialSpec.foreach { spec => + requirePartialMatchedPartitionSpec(Seq(spec), getTableMetadata(tableName)) + requireNonEmptyValueInPartitionSpec(Seq(spec)) + } + externalCatalog.listPartitions(db, table, partialSpec) + } + + /** + * List the metadata of partitions that belong to the specified table, assuming it exists, that + * satisfy the given partition-pruning predicate expressions. + */ + def listPartitionsByFilter( + tableName: TableIdentifier, + predicates: Seq[Expression]): Seq[CatalogTablePartition] = { + val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) + val table = formatTableName(tableName.table) + val cachedPartition = externalCatalogCache.getCachedPartitions(db, table, predicates, () => { + requireDbExists(db) + requireTableExists(TableIdentifier(table, Option(db))) + val partitions = externalCatalog.listPartitionsByFilter(db, table, predicates, conf.sessionLocalTimeZone) + partitions + }) + cachedPartition + } + + /** + * Verify if the input partition spec has any empty value. + */ + private def requireNonEmptyValueInPartitionSpec(specs: Seq[TablePartitionSpec]): Unit = { + specs.foreach { s => + if (s.values.exists(v => v != null && v.isEmpty)) { + val spec = s.map(p => p._1 + "=" + p._2).mkString("[", ", ", "]") + throw QueryCompilationErrors.invalidPartitionSpecError( + s"The spec ($spec) contains an empty partition column value") + } + } + } + + /** + * Verify if the input partition spec exactly matches the existing defined partition spec + * The columns must be the same but the orders could be different. + */ + private def requireExactMatchedPartitionSpec( + specs: Seq[TablePartitionSpec], + table: CatalogTable): Unit = { + specs.foreach { spec => + PartitioningUtils.requireExactMatchedPartitionSpec( + table.identifier.toString, + spec, + table.partitionColumnNames) + } + } + + /** + * Verify if the input partition spec partially matches the existing defined partition spec + * That is, the columns of partition spec should be part of the defined partition spec. + */ + private def requirePartialMatchedPartitionSpec( + specs: Seq[TablePartitionSpec], + table: CatalogTable): Unit = { + val defined = table.partitionColumnNames + specs.foreach { s => + if (!s.keys.forall(defined.contains)) { + throw QueryCompilationErrors.invalidPartitionSpecError( + s"The spec (${s.keys.mkString(", ")}) must be contained " + + s"within the partition spec (${table.partitionColumnNames.mkString(", ")}) defined " + + s"in table '${table.identifier}'") + } + } + } + + /** + * Make the partition path qualified. + * If the partition path is relative, e.g. 'paris', it will be qualified with + * parent path using table location, e.g. 'file:/warehouse/table/paris' + */ + private def partitionWithQualifiedPath( + tableIdentifier: TableIdentifier, + parts: Seq[CatalogTablePartition]): Seq[CatalogTablePartition] = { + lazy val tbl = getTableMetadata(tableIdentifier) + parts.map { part => + if (part.storage.locationUri.isDefined && !part.storage.locationUri.get.isAbsolute) { + val partPath = new Path(new Path(tbl.location), new Path(part.storage.locationUri.get)) + val qualifiedPartPath = makeQualifiedPath(CatalogUtils.stringToURI(partPath.toString)) + part.copy(storage = part.storage.copy(locationUri = Some(qualifiedPartPath))) + } else part + } + } + // ---------------------------------------------------------------------------- + // Functions + // ---------------------------------------------------------------------------- + // There are two kinds of functions, temporary functions and metastore + // functions (permanent UDFs). Temporary functions are isolated across + // sessions. Metastore functions can be used across multiple sessions as + // their metadata is persisted in the underlying catalog. + // ---------------------------------------------------------------------------- + + // ------------------------------------------------------- + // | Methods that interact with metastore functions only | + // ------------------------------------------------------- + + /** + * Create a function in the database specified in `funcDefinition`. + * If no such database is specified, create it in the current database. + * + * @param ignoreIfExists: When true, ignore if the function with the specified name exists + * in the specified database. + */ + def createFunction(funcDefinition: CatalogFunction, ignoreIfExists: Boolean): Unit = { + val db = formatDatabaseName(funcDefinition.identifier.database.getOrElse(getCurrentDatabase)) + requireDbExists(db) + val identifier = FunctionIdentifier(funcDefinition.identifier.funcName, Some(db)) + val newFuncDefinition = funcDefinition.copy(identifier = identifier) + if (!functionExists(identifier)) { + externalCatalog.createFunction(db, newFuncDefinition) + } else if (!ignoreIfExists) { + throw new FunctionAlreadyExistsException(db = db, func = identifier.toString) + } + } + + /** + * Drop a metastore function. + * If no database is specified, assume the function is in the current database. + */ + def dropFunction(name: FunctionIdentifier, ignoreIfNotExists: Boolean): Unit = { + val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) + requireDbExists(db) + val identifier = name.copy(database = Some(db)) + if (functionExists(identifier)) { + if (functionRegistry.functionExists(identifier)) { + // If we have loaded this function into the FunctionRegistry, + // also drop it from there. + // For a permanent function, because we loaded it to the FunctionRegistry + // when it's first used, we also need to drop it from the FunctionRegistry. + functionRegistry.dropFunction(identifier) + } + externalCatalog.dropFunction(db, name.funcName) + } else if (!ignoreIfNotExists) { + throw new NoSuchPermanentFunctionException(db = db, func = identifier.toString) + } + } + + /** + * overwrite a metastore function in the database specified in `funcDefinition`.. + * If no database is specified, assume the function is in the current database. + */ + def alterFunction(funcDefinition: CatalogFunction): Unit = { + val db = formatDatabaseName(funcDefinition.identifier.database.getOrElse(getCurrentDatabase)) + requireDbExists(db) + val identifier = FunctionIdentifier(funcDefinition.identifier.funcName, Some(db)) + val newFuncDefinition = funcDefinition.copy(identifier = identifier) + if (functionExists(identifier)) { + if (functionRegistry.functionExists(identifier)) { + // If we have loaded this function into the FunctionRegistry, + // also drop it from there. + // For a permanent function, because we loaded it to the FunctionRegistry + // when it's first used, we also need to drop it from the FunctionRegistry. + functionRegistry.dropFunction(identifier) + } + externalCatalog.alterFunction(db, newFuncDefinition) + } else { + throw new NoSuchPermanentFunctionException(db = db, func = identifier.toString) + } + } + + /** + * Retrieve the metadata of a metastore function. + * + * If a database is specified in `name`, this will return the function in that database. + * If no database is specified, this will return the function in the current database. + */ + def getFunctionMetadata(name: FunctionIdentifier): CatalogFunction = { + val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) + requireDbExists(db) + externalCatalog.getFunction(db, name.funcName) + } + + /** + * Check if the function with the specified name exists + */ + def functionExists(name: FunctionIdentifier): Boolean = { + functionRegistry.functionExists(name) || tableFunctionRegistry.functionExists(name) || { + val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) + requireDbExists(db) + externalCatalog.functionExists(db, name.funcName) + } + } + + // ---------------------------------------------------------------- + // | Methods that interact with temporary and metastore functions | + // ---------------------------------------------------------------- + + /** + * Constructs a [[FunctionBuilder]] based on the provided function metadata. + */ + private def makeFunctionBuilder(func: CatalogFunction): FunctionBuilder = { + val className = func.className + if (!Utils.classIsLoadable(className)) { + throw QueryCompilationErrors.cannotLoadClassWhenRegisteringFunctionError( + className, func.identifier) + } + val clazz = Utils.classForName(className) + val name = func.identifier.unquotedString + (input: Seq[Expression]) => functionExpressionBuilder.makeExpression(name, clazz, input) + } + + /** + * Loads resources such as JARs and Files for a function. Every resource is represented + * by a tuple (resource type, resource uri). + */ + def loadFunctionResources(resources: Seq[FunctionResource]): Unit = { + resources.foreach(functionResourceLoader.loadResource) + } + + /** + * Registers a temporary or permanent scalar function into a session-specific [[FunctionRegistry]] + */ + def registerFunction( + funcDefinition: CatalogFunction, + overrideIfExists: Boolean, + functionBuilder: Option[FunctionBuilder] = None): Unit = { + val builder = functionBuilder.getOrElse(makeFunctionBuilder(funcDefinition)) + registerFunction(funcDefinition, overrideIfExists, functionRegistry, builder) + } + + private def registerFunction[T]( + funcDefinition: CatalogFunction, + overrideIfExists: Boolean, + registry: FunctionRegistryBase[T], + functionBuilder: FunctionRegistryBase[T]#FunctionBuilder): Unit = { + val func = funcDefinition.identifier + if (registry.functionExists(func) && !overrideIfExists) { + throw QueryCompilationErrors.functionAlreadyExistsError(func) + } + val info = makeExprInfoForHiveFunction(funcDefinition) + registry.registerFunction(func, info, functionBuilder) + } + + private def makeExprInfoForHiveFunction(func: CatalogFunction): ExpressionInfo = { + new ExpressionInfo( + func.className, + func.identifier.database.orNull, + func.identifier.funcName, + null, + "", + "", + "", + "", + "", + "", + "hive") + } + + /** + * Unregister a temporary or permanent function from a session-specific [[FunctionRegistry]] + * Return true if function exists. + */ + def unregisterFunction(name: FunctionIdentifier): Boolean = { + functionRegistry.dropFunction(name) + } + + /** + * Drop a temporary function. + */ + def dropTempFunction(name: String, ignoreIfNotExists: Boolean): Unit = { + if (!functionRegistry.dropFunction(FunctionIdentifier(name)) && + !tableFunctionRegistry.dropFunction(FunctionIdentifier(name)) && + !ignoreIfNotExists) { + throw new NoSuchTempFunctionException(name) + } + } + + /** + * Returns whether it is a temporary function. If not existed, returns false. + */ + def isTemporaryFunction(name: FunctionIdentifier): Boolean = { + // A temporary function is a function that has been registered in functionRegistry + // without a database name, and is neither a built-in function nor a Hive function + name.database.isEmpty && isRegisteredFunction(name) && !isBuiltinFunction(name) + } + + /** + * Return whether this function has been registered in the function registry of the current + * session. If not existed, return false. + */ + def isRegisteredFunction(name: FunctionIdentifier): Boolean = { + functionRegistry.functionExists(name) || tableFunctionRegistry.functionExists(name) + } + + /** + * Returns whether it is a persistent function. If not existed, returns false. + */ + def isPersistentFunction(name: FunctionIdentifier): Boolean = { + val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) + databaseExists(db) && externalCatalog.functionExists(db, name.funcName) + } + + /** + * Returns whether it is a built-in function. + */ + def isBuiltinFunction(name: FunctionIdentifier): Boolean = { + FunctionRegistry.builtin.functionExists(name) || + TableFunctionRegistry.builtin.functionExists(name) + } + + protected[sql] def failFunctionLookup( + name: FunctionIdentifier, cause: Option[Throwable] = None): Nothing = { + throw new NoSuchFunctionException( + db = name.database.getOrElse(getCurrentDatabase), func = name.funcName, cause) + } + + /** + * Look up the `ExpressionInfo` of the given function by name if it's a built-in or temp function. + * This only supports scalar functions. + */ + def lookupBuiltinOrTempFunction(name: String): Option[ExpressionInfo] = { + FunctionRegistry.builtinOperators.get(name.toLowerCase(Locale.ROOT)).orElse { + synchronized(lookupTempFuncWithViewContext( + name, FunctionRegistry.builtin.functionExists, functionRegistry.lookupFunction)) + } + } + + /** + * Look up the `ExpressionInfo` of the given function by name if it's a built-in or + * temp table function. + */ + def lookupBuiltinOrTempTableFunction(name: String): Option[ExpressionInfo] = synchronized { + lookupTempFuncWithViewContext( + name, TableFunctionRegistry.builtin.functionExists, tableFunctionRegistry.lookupFunction) + } + + /** + * Look up a built-in or temp scalar function by name and resolves it to an Expression if such + * a function exists. + */ + def resolveBuiltinOrTempFunction(name: String, arguments: Seq[Expression]): Option[Expression] = { + resolveBuiltinOrTempFunctionInternal( + name, arguments, FunctionRegistry.builtin.functionExists, functionRegistry) + } + + /** + * Look up a built-in or temp table function by name and resolves it to a LogicalPlan if such + * a function exists. + */ + def resolveBuiltinOrTempTableFunction( + name: String, arguments: Seq[Expression]): Option[LogicalPlan] = { + resolveBuiltinOrTempFunctionInternal( + name, arguments, TableFunctionRegistry.builtin.functionExists, tableFunctionRegistry) + } + + private def resolveBuiltinOrTempFunctionInternal[T]( + name: String, + arguments: Seq[Expression], + isBuiltin: FunctionIdentifier => Boolean, + registry: FunctionRegistryBase[T]): Option[T] = synchronized { + val funcIdent = FunctionIdentifier(name) + if (!registry.functionExists(funcIdent)) { + None + } else { + lookupTempFuncWithViewContext( + name, isBuiltin, ident => Option(registry.lookupFunction(ident, arguments))) + } + } + + private def lookupTempFuncWithViewContext[T]( + name: String, + isBuiltin: FunctionIdentifier => Boolean, + lookupFunc: FunctionIdentifier => Option[T]): Option[T] = { + val funcIdent = FunctionIdentifier(name) + if (isBuiltin(funcIdent)) { + lookupFunc(funcIdent) + } else { + val isResolvingView = AnalysisContext.get.catalogAndNamespace.nonEmpty + val referredTempFunctionNames = AnalysisContext.get.referredTempFunctionNames + if (isResolvingView) { + // When resolving a view, only return a temp function if it's referred by this view. + if (referredTempFunctionNames.contains(name)) { + lookupFunc(funcIdent) + } else { + None + } + } else { + val result = lookupFunc(funcIdent) + if (result.isDefined) { + // We are not resolving a view and the function is a temp one, add it to + // `AnalysisContext`, so during the view creation, we can save all referred temp + // functions to view metadata. + AnalysisContext.get.referredTempFunctionNames.add(name) + } + result + } + } + } + + /** + * Look up the `ExpressionInfo` of the given function by name if it's a persistent function. + * This supports both scalar and table functions. + */ + def lookupPersistentFunction(name: FunctionIdentifier): ExpressionInfo = { + val database = name.database.orElse(Some(currentDb)).map(formatDatabaseName) + val qualifiedName = name.copy(database = database) + functionRegistry.lookupFunction(qualifiedName) + .orElse(tableFunctionRegistry.lookupFunction(qualifiedName)) + .getOrElse { + val db = qualifiedName.database.get + requireDbExists(db) + if (externalCatalog.functionExists(db, name.funcName)) { + val metadata = externalCatalog.getFunction(db, name.funcName) + makeExprInfoForHiveFunction(metadata.copy(identifier = qualifiedName)) + } else { + failFunctionLookup(name) + } + } + } + + /** + * Look up a persistent scalar function by name and resolves it to an Expression. + */ + def resolvePersistentFunction( + name: FunctionIdentifier, arguments: Seq[Expression]): Expression = { + resolvePersistentFunctionInternal(name, arguments, functionRegistry, makeFunctionBuilder) + } + + /** + * Look up a persistent table function by name and resolves it to a LogicalPlan. + */ + def resolvePersistentTableFunction( + name: FunctionIdentifier, + arguments: Seq[Expression]): LogicalPlan = { + // We don't support persistent table functions yet. + val builder = (func: CatalogFunction) => failFunctionLookup(name) + resolvePersistentFunctionInternal(name, arguments, tableFunctionRegistry, builder) + } + + private def resolvePersistentFunctionInternal[T]( + name: FunctionIdentifier, + arguments: Seq[Expression], + registry: FunctionRegistryBase[T], + createFunctionBuilder: CatalogFunction => FunctionRegistryBase[T]#FunctionBuilder): T = { + val database = formatDatabaseName(name.database.getOrElse(currentDb)) + val qualifiedName = name.copy(database = Some(database)) + if (registry.functionExists(qualifiedName)) { + // This function has been already loaded into the function registry. + registry.lookupFunction(qualifiedName, arguments) + } else { + // The function has not been loaded to the function registry, which means + // that the function is a persistent function (if it actually has been registered + // in the metastore). We need to first put the function in the function registry. + val catalogFunction = try { + externalCatalog.getFunction(database, qualifiedName.funcName) + } catch { + case _: AnalysisException => failFunctionLookup(qualifiedName) + } + loadFunctionResources(catalogFunction.resources) + // Please note that qualifiedName is provided by the user. However, + // catalogFunction.identifier.unquotedString is returned by the underlying + // catalog. So, it is possible that qualifiedName is not exactly the same as + // catalogFunction.identifier.unquotedString (difference is on case-sensitivity). + // At here, we preserve the input from the user. + val funcMetadata = catalogFunction.copy(identifier = qualifiedName) + registerFunction( + funcMetadata, + overrideIfExists = false, + registry = registry, + functionBuilder = createFunctionBuilder(funcMetadata)) + // Now, we need to create the Expression. + registry.lookupFunction(qualifiedName, arguments) + } + } + + /** + * Look up the [[ExpressionInfo]] associated with the specified function, assuming it exists. + */ + def lookupFunctionInfo(name: FunctionIdentifier): ExpressionInfo = synchronized { + if (name.database.isEmpty) { + lookupBuiltinOrTempFunction(name.funcName) + .orElse(lookupBuiltinOrTempTableFunction(name.funcName)) + .getOrElse(lookupPersistentFunction(name)) + } else { + lookupPersistentFunction(name) + } + } + + // The actual function lookup logic looks up temp/built-in function first, then persistent + // function from either v1 or v2 catalog. This method only look up v1 catalog. + def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = { + if (name.database.isEmpty) { + resolveBuiltinOrTempFunction(name.funcName, children) + .getOrElse(resolvePersistentFunction(name, children)) + } else { + resolvePersistentFunction(name, children) + } + } + + def lookupTableFunction(name: FunctionIdentifier, children: Seq[Expression]): LogicalPlan = { + if (name.database.isEmpty) { + resolveBuiltinOrTempTableFunction(name.funcName, children) + .getOrElse(resolvePersistentTableFunction(name, children)) + } else { + resolvePersistentTableFunction(name, children) + } + } + + /** + * List all registered functions in a database with the given pattern. + */ + private def listRegisteredFunctions(db: String, pattern: String): Seq[FunctionIdentifier] = { + val functions = (functionRegistry.listFunction() ++ tableFunctionRegistry.listFunction()) + .filter(_.database.forall(_ == db)) + StringUtils.filterPattern(functions.map(_.unquotedString), pattern).map { f => + // In functionRegistry, function names are stored as an unquoted format. + Try(parser.parseFunctionIdentifier(f)) match { + case Success(e) => e + case Failure(_) => + // The names of some built-in functions are not parsable by our parser, e.g., % + FunctionIdentifier(f) + } + } + } + + /** + * List all functions in the specified database, including temporary functions. This + * returns the function identifier and the scope in which it was defined (system or user + * defined). + */ + def listFunctions(db: String): Seq[(FunctionIdentifier, String)] = listFunctions(db, "*") + + /** + * List all matching functions in the specified database, including temporary functions. This + * returns the function identifier and the scope in which it was defined (system or user + * defined). + */ + def listFunctions(db: String, pattern: String): Seq[(FunctionIdentifier, String)] = { + val dbName = formatDatabaseName(db) + requireDbExists(dbName) + val dbFunctions = externalCatalog.listFunctions(dbName, pattern).map { f => + FunctionIdentifier(f, Some(dbName)) } + val loadedFunctions = listRegisteredFunctions(db, pattern) + val functions = dbFunctions ++ loadedFunctions + // The session catalog caches some persistent functions in the FunctionRegistry + // so there can be duplicates. + functions.map { + case f if FunctionRegistry.functionSet.contains(f) => (f, "SYSTEM") + case f if TableFunctionRegistry.functionSet.contains(f) => (f, "SYSTEM") + case f => (f, "USER") + }.distinct + } + + + // ----------------- + // | Other methods | + // ----------------- + + /** + * Drop all existing databases (except "default"), tables, partitions and functions, + * and set the current database to "default". + * + * This is mainly used for tests. + */ + def reset(): Unit = synchronized { + setCurrentDatabase(DEFAULT_DATABASE) + externalCatalog.setCurrentDatabase(DEFAULT_DATABASE) + listDatabases().filter(_ != DEFAULT_DATABASE).foreach { db => + dropDatabase(db, ignoreIfNotExists = false, cascade = true) + } + listTables(DEFAULT_DATABASE).foreach { table => + dropTable(table, ignoreIfNotExists = false, purge = false) + } + listFunctions(DEFAULT_DATABASE).map(_._1).foreach { func => + if (func.database.isDefined) { + dropFunction(func, ignoreIfNotExists = false) + } else { + dropTempFunction(func.funcName, ignoreIfNotExists = false) + } + } + clearTempTables() + globalTempViewManager.clear() + functionRegistry.clear() + tableFunctionRegistry.clear() + tableRelationCache.invalidateAll() + externalCatalogCache.invalidateAllCachedTables() + // restore built-in functions + FunctionRegistry.builtin.listFunction().foreach { f => + val expressionInfo = FunctionRegistry.builtin.lookupFunction(f) + val functionBuilder = FunctionRegistry.builtin.lookupFunctionBuilder(f) + require(expressionInfo.isDefined, s"built-in function '$f' is missing expression info") + require(functionBuilder.isDefined, s"built-in function '$f' is missing function builder") + functionRegistry.registerFunction(f, expressionInfo.get, functionBuilder.get) + } + // restore built-in table functions + TableFunctionRegistry.builtin.listFunction().foreach { f => + val expressionInfo = TableFunctionRegistry.builtin.lookupFunction(f) + val functionBuilder = TableFunctionRegistry.builtin.lookupFunctionBuilder(f) + require(expressionInfo.isDefined, s"built-in function '$f' is missing expression info") + require(functionBuilder.isDefined, s"built-in function '$f' is missing function builder") + tableFunctionRegistry.registerFunction(f, expressionInfo.get, functionBuilder.get) + } + } + + /** + * Copy the current state of the catalog to another catalog. + * + * This function is synchronized on this [[SessionCatalog]] (the source) to make sure the copied + * state is consistent. The target [[SessionCatalog]] is not synchronized, and should not be + * because the target [[SessionCatalog]] should not be published at this point. The caller must + * synchronize on the target if this assumption does not hold. + */ + private[sql] def copyStateTo(target: SessionCatalog): Unit = synchronized { + target.currentDb = currentDb + // copy over temporary views + tempViews.foreach(kv => target.tempViews.put(kv._1, kv._2)) + } + + /** + * Validate the new location before renaming a managed table, which should be non-existent. + */ + private def validateNewLocationOfRename( + oldName: TableIdentifier, + newName: TableIdentifier): Unit = { + requireTableExists(oldName) + requireTableNotExists(newName) + val oldTable = getTableMetadata(oldName) + if (oldTable.tableType == CatalogTableType.MANAGED) { + assert(oldName.database.nonEmpty) + val databaseLocation = + externalCatalog.getDatabase(oldName.database.get).locationUri + val newTableLocation = new Path(new Path(databaseLocation), formatTableName(newName.table)) + val fs = newTableLocation.getFileSystem(hadoopConf) + if (fs.exists(newTableLocation)) { + throw QueryCompilationErrors.cannotOperateManagedTableWithExistingLocationError( + "rename", oldName, newTableLocation) + } + } + } +} diff --git a/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-modify/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatDataWriter.scala b/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-modify/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatDataWriter.scala index f3fe865e0b85e88277ba10776bde67092415a245..4f35fc45cb55cf93b31a5344a1d4423f853e7fc3 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-modify/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatDataWriter.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-modify/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatDataWriter.scala @@ -26,12 +26,15 @@ import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils import org.apache.spark.sql.catalyst.expressions.{Cast, Concat, Expression, Literal, ScalaUDF, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.connector.write.DataWriter import org.apache.spark.sql.execution.datasources.orc.OmniOrcOutputWriter +import org.apache.spark.sql.execution.datasources.parquet.OmniParquetOutputWriter import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric} import org.apache.spark.sql.types.StringType import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.Utils import scala.collection.mutable +import scala.math.max +import scala.math.abs /** Writes data to a single directory (used for non-dynamic-partition writes). */ @@ -62,6 +65,9 @@ class OmniSingleDirectoryDataWriter( context = taskAttemptContext) currentWriter match { + case _: OmniParquetOutputWriter => + currentWriter.asInstanceOf[OmniParquetOutputWriter] + .initialize(description.allColumns, description.dataColumns) case _: OmniOrcOutputWriter => currentWriter.asInstanceOf[OmniOrcOutputWriter] .initialize(description.allColumns, description.dataColumns) @@ -235,8 +241,17 @@ abstract class OmniBaseDynamicPartitionDataWriter( path = currentPath, dataSchema = description.dataColumns.toStructType, context = taskAttemptContext) - currentWriter.asInstanceOf[OmniOrcOutputWriter] - .initialize(description.allColumns, description.dataColumns) + currentWriter match { + case _: OmniParquetOutputWriter => + currentWriter.asInstanceOf[OmniParquetOutputWriter] + .initialize(description.allColumns, description.dataColumns) + case _: OmniOrcOutputWriter => + currentWriter.asInstanceOf[OmniOrcOutputWriter] + .initialize(description.allColumns, description.dataColumns) + case _ => + throw new UnsupportedOperationException + (s"Unsupported ${currentWriter.getClass} Output writer!") + } statsTrackers.foreach(_.newFile(currentPath)) } @@ -266,8 +281,17 @@ abstract class OmniBaseDynamicPartitionDataWriter( protected def writeRecord(record: InternalRow, startPos: Long, endPos: Long): Unit = { // TODO After add OmniParquetOutPutWriter need extract // a abstract interface named OmniOutPutWriter - assert(currentWriter.isInstanceOf[OmniOrcOutputWriter]) - currentWriter.asInstanceOf[OmniOrcOutputWriter].spiltWrite(record, startPos, endPos) + currentWriter match { + case _: OmniParquetOutputWriter => + assert(currentWriter.isInstanceOf[OmniParquetOutputWriter]) + currentWriter.asInstanceOf[OmniParquetOutputWriter].spiltWrite(record, startPos, endPos) + case _: OmniOrcOutputWriter => + assert(currentWriter.isInstanceOf[OmniOrcOutputWriter]) + currentWriter.asInstanceOf[OmniOrcOutputWriter].spiltWrite(record, startPos, endPos) + case _ => + throw new UnsupportedOperationException + (s"writeRecord Unsupported ${currentWriter.getClass} Output writer!") + } statsTrackers.foreach(_.newRow(currentWriter.path, record)) recordsInFile += record.asInstanceOf[OmniInternalRow].batch.numRows() @@ -287,56 +311,83 @@ class OmniDynamicPartitionDataSingleWriter( extends OmniBaseDynamicPartitionDataWriter(description, taskAttemptContext, committer, customMetrics) { - private var currentPartitionValues: Option[UnsafeRow] = None - private var currentBucketId: Option[Int] = None + private var firstPartitionValues: Option[UnsafeRow] = None + private var firstBucketId: Option[Int] = None + private var firstFilePath = "" override def write(record: InternalRow): Unit = { assert(record.isInstanceOf[OmniInternalRow]) splitWrite(record) } + private def judgePartitionBucketSame(batch: ColumnarBatch, firstIndex: Int, lastIndex: Int): Boolean = { + val firstRecord = batch.getRow(firstIndex) + val lastRecord = batch.getRow(lastIndex) + var lastPartitionValues: Option[UnsafeRow] = None + var lastBucketId: Option[Int] = None + if (isPartitioned) { + firstPartitionValues = Some(getPartitionValues(batch.getRow(firstIndex)).copy()) + lastPartitionValues = Some(getPartitionValues(batch.getRow(lastIndex)).copy()) + } + if (isBucketed) { + firstBucketId = Some(getBucketId(firstRecord)) + lastBucketId = Some(getBucketId(lastRecord)) + } + return (firstPartitionValues == lastPartitionValues && firstBucketId == lastBucketId) + } + private def splitWrite(omniInternalRow: InternalRow): Unit = { val batch = omniInternalRow.asInstanceOf[OmniInternalRow].batch val numRows = batch.numRows() - var lastIndex = 0 - for (i <- 0 until numRows) { - val record = batch.getRow(i) - val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(record)) else None - val nextBucketId = if (isBucketed) Some(getBucketId(record)) else None - - if (currentPartitionValues != nextPartitionValues || currentBucketId != nextBucketId) { - val isFilePathSame = getPartitionPath(currentPartitionValues, - currentBucketId) == getPartitionPath(nextPartitionValues, nextBucketId) - if (!isFilePathSame) { - // See a new partition or bucket - write to a new partition dir (or a new bucket file). - if (isPartitioned && currentPartitionValues != nextPartitionValues) { - currentPartitionValues = Some(nextPartitionValues.get.copy()) - statsTrackers.foreach(_.newPartition(currentPartitionValues.get)) - } - if (isBucketed) { - currentBucketId = nextBucketId + var firstIndex = 0 + var lastIndex = numRows - 1 + // when the partition information of the first row is equal to the last row of the batch,write them by on writer + if (judgePartitionBucketSame(batch, firstIndex, lastIndex)) { + if (!(currentWriter.isInstanceOf[OmniOrcOutputWriter] || currentWriter.isInstanceOf[OmniParquetOutputWriter])) { + renewCurrentWriter(firstPartitionValues, firstBucketId, closeCurrentWriter = true) + firstFilePath = getPartitionPath(firstPartitionValues, firstBucketId) + } + writeRecord(omniInternalRow, firstIndex, numRows) + } + // try to find the dividing line between partitions by using the dichotomy + else { + while (!judgePartitionBucketSame(batch, firstIndex, lastIndex)) { + var tmpIndex = lastIndex + lastIndex = firstIndex + (lastIndex - firstIndex) / 2 + /* + the dividing line show satisfy two demands: + 1.the partition of it should be equal to the first row of the batch + 2.it should not be equal to the next one + */ + while (!judgePartitionBucketSame(batch, firstIndex, lastIndex) || judgePartitionBucketSame(batch, lastIndex, lastIndex + 1)) { + if (!judgePartitionBucketSame(batch, firstIndex, lastIndex)) { + var tmp1Index = lastIndex + lastIndex = lastIndex - max(abs(tmpIndex - lastIndex) / 2, 1) + tmpIndex = tmp1Index } - - fileCounter = 0 - if (i != 0) { - writeRecord(omniInternalRow, lastIndex, i) - lastIndex = i + else if (judgePartitionBucketSame(batch, lastIndex, lastIndex + 1)) { + var tmp1Index = lastIndex + lastIndex = lastIndex + max(abs(tmpIndex - lastIndex) / 2, 1) + tmpIndex = tmp1Index } - renewCurrentWriter(currentPartitionValues, currentBucketId, closeCurrentWriter = true) } - } else if ( - description.maxRecordsPerFile > 0 && - recordsInFile >= description.maxRecordsPerFile - ) { - if (i != 0) { - writeRecord(omniInternalRow, lastIndex, i) - lastIndex = i + if (getPartitionPath(firstPartitionValues, firstBucketId) != firstFilePath) { + renewCurrentWriter(firstPartitionValues, firstBucketId, closeCurrentWriter = true) + firstFilePath = getPartitionPath(firstPartitionValues, firstBucketId) } - renewCurrentWriterIfTooManyRecords(currentPartitionValues, currentBucketId) + writeRecord(omniInternalRow, firstIndex, lastIndex + 1) + firstIndex = lastIndex + 1; + lastIndex = numRows - 1 } - } - if (lastIndex < batch.numRows()) { - writeRecord(omniInternalRow, lastIndex, numRows) + if (description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile) { + renewCurrentWriterIfTooManyRecords(firstPartitionValues, firstBucketId) + } + if (getPartitionPath(firstPartitionValues, firstBucketId) != firstFilePath) { + renewCurrentWriter(firstPartitionValues, firstBucketId, closeCurrentWriter = true) + firstFilePath = getPartitionPath(firstPartitionValues, firstBucketId) + } + // write the remaining data in the partition + writeRecord(omniInternalRow, firstIndex, numRows) } } } diff --git a/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-modify/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatWriter.scala b/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-modify/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatWriter.scala index 465123d8730d39656d242bd4785e22680312ce05..217eac6f856ce07e302140065a9d459933fc79bf 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-modify/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatWriter.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-modify/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatWriter.scala @@ -24,6 +24,7 @@ import org.apache.hadoop.fs.{FileAlreadyExistsException, Path} import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.apache.parquet.hadoop.util.ContextUtil import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils} @@ -39,7 +40,7 @@ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.FileFormatWriter.ConcurrentOutputWriterSpec -import org.apache.spark.sql.execution.{ColumnarProjectExec, ColumnarSortExec, OmniColumnarToRowExec, ProjectExec, SQLExecution, SortExec, SparkPlan, UnsafeExternalRowSorter} +import org.apache.spark.sql.execution.{ColumnarProjectExec, ColumnarSortExec, OmniColumnarToRowExec, ProjectExec, SQLExecution, SortExec, SparkPlan} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StringType import org.apache.spark.sql.vectorized.ColumnarBatch @@ -156,6 +157,11 @@ object OmniFileFormatWriter extends Logging { val dataSchema = dataColumns.toStructType DataSourceUtils.verifySchema(fileFormat, dataSchema) // Note: prepareWrite has side effect. It sets "job". + + val conf = ContextUtil.getConfiguration(job) + conf.set( + SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.key, + sparkSession.sessionState.conf.parquetFieldIdWriteEnabled.toString) val outputWriterFactory = fileFormat.prepareWrite(sparkSession, job, caseInsensitiveOptions, dataSchema) diff --git a/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-modify/src/main/scala/org/apache/spark/sql/util/ModifyUtil.scala b/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-modify/src/main/scala/org/apache/spark/sql/util/ModifyUtil.scala index baa8279d1cac30b4120c5b283174c176826044d3..1e31047933b7d8daf3569c5cdba32055201ac574 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-modify/src/main/scala/org/apache/spark/sql/util/ModifyUtil.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-modify/src/main/scala/org/apache/spark/sql/util/ModifyUtil.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.execution.command.{DataWritingCommand, DataWritingCo import org.apache.spark.sql.execution.datasources.{FileFormat, InsertIntoHadoopFsRelationCommand, OmniInsertIntoHadoopFsRelationCommand} import org.apache.spark.sql.execution.datasources.OmniFileFormatWriter.Empty2Null import org.apache.spark.sql.execution.datasources.orc.{OmniOrcFileFormat, OrcFileFormat} +import org.apache.spark.sql.execution.datasources.parquet.{OmniParquetFileFormat, ParquetFileFormat} import org.apache.spark.sql.expression.ColumnarExpressionConverter import org.apache.spark.sql.types.DataType import com.google.gson.{JsonArray, JsonObject} @@ -84,6 +85,7 @@ object ModifyUtil extends Logging { logInfo(s"Columnar Processing for ${cmd.getClass} is currently supported.") val fileFormat: FileFormat = cmd.fileFormat match { case _: OrcFileFormat => new OmniOrcFileFormat() + case _: ParquetFileFormat => new OmniParquetFileFormat() case format => logInfo(s"Unsupported ${format.getClass} file " + s"format for columnar data write command.") @@ -145,10 +147,13 @@ object ModifyUtil extends Logging { extensions.injectOptimizerRule(_ => CombineJoinedAggregates) } + private def addTransformableTagAdapter(plan: SparkPlan): Unit = {} + def registerFunc(): Unit = { ModifyUtilAdaptor.configRewriteJsonFunc(rewriteToOmniJsonExpressionAdapter) ModifyUtilAdaptor.configPreReplacePlanFunc(preReplaceSparkPlanAdapter) ModifyUtilAdaptor.configPostReplacePlanFunc(postReplaceSparkPlanAdapter) ModifyUtilAdaptor.configInjectRuleFunc(injectRuleAdapter) + ModifyUtilAdaptor.configAddTransformableFunc(addTransformableTagAdapter) } } diff --git a/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-modify/src/main/scala/org/apache/spark/util/CatalogFileIndex.scala b/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-modify/src/main/scala/org/apache/spark/util/CatalogFileIndex.scala new file mode 100644 index 0000000000000000000000000000000000000000..a5d36a3a61f9cce6547c0ba32f8cea5108080a58 --- /dev/null +++ b/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-modify/src/main/scala/org/apache/spark/util/CatalogFileIndex.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.net.URI + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, ExternalCatalogUtils} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.StructType + + +/** + * A [[FileIndex]] for a metastore catalog table. + * + * @param sparkSession a [[SparkSession]] + * @param table the metadata of the table + * @param sizeInBytes the table's data size in bytes + */ +class CatalogFileIndex( + sparkSession: SparkSession, + val table: CatalogTable, + override val sizeInBytes: Long) extends FileIndex { + + protected val hadoopConf: Configuration = sparkSession.sessionState.newHadoopConf() + + /** Globally shared (not exclusive to this table) cache for file statuses to speed up listing. */ + private val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) + + assert(table.identifier.database.isDefined, + "The table identifier must be qualified in CatalogFileIndex") + + private val baseLocation: Option[URI] = table.storage.locationUri + + override def partitionSchema: StructType = table.partitionSchema + + override def rootPaths: Seq[Path] = baseLocation.map(new Path(_)).toSeq + + override def listFiles( + partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): Seq[PartitionDirectory] = { + filterPartitions(partitionFilters).listFiles(Nil, dataFilters) + } + + override def refresh(): Unit = fileStatusCache.invalidateAll() + + /** + * Returns a [[InMemoryFileIndex]] for this table restricted to the subset of partitions + * specified by the given partition-pruning filters. + * + * @param filters partition-pruning filters + */ + def filterPartitions(filters: Seq[Expression]): InMemoryFileIndex = { + if (table.partitionColumnNames.nonEmpty) { + val startTime = System.nanoTime() + val selectedPartitions = ExternalCatalogUtils.listPartitionsByFilter( + sparkSession.sessionState.conf, sparkSession.sessionState.catalog, table, filters) + val partitions = selectedPartitions.map { p => + var path = new Path(p.location) + val pathUri = path.toUri + val scheme = Option(pathUri.getScheme) + val authority = Option(pathUri.getAuthority) + path = (path.isAbsolute, scheme, authority) match { + case (true, Some(_), Some(_)) => path + case _ => + val fs = path.getFileSystem(hadoopConf) + path.makeQualified(fs.getUri, fs.getWorkingDirectory) + } + PartitionPath( + p.toRow(partitionSchema, sparkSession.sessionState.conf.sessionLocalTimeZone), + path) + } + val partitionSpec = PartitionSpec(partitionSchema, partitions) + val timeNs = System.nanoTime() - startTime + new InMemoryFileIndex(sparkSession, + rootPathsSpecified = partitionSpec.partitions.map(_.path), + parameters = Map.empty, + userSpecifiedSchema = Some(partitionSpec.partitionColumns), + fileStatusCache = fileStatusCache, + userSpecifiedPartitionSpec = Some(partitionSpec), + metadataOpsTimeNs = Some(timeNs)) + } else { + new InMemoryFileIndex(sparkSession, rootPaths, parameters = table.storage.properties, + userSpecifiedSchema = None, fileStatusCache = fileStatusCache) + } + } + + override def inputFiles: Array[String] = filterPartitions(Nil).inputFiles + + // `CatalogFileIndex` may be a member of `HadoopFsRelation`, `HadoopFsRelation` may be a member + // of `LogicalRelation`, and `LogicalRelation` may be used as the cache key. So we need to + // implement `equals` and `hashCode` here, to make it work with cache lookup. + override def equals(o: Any): Boolean = o match { + case other: CatalogFileIndex => this.table.identifier == other.table.identifier + case _ => false + } + + override def hashCode(): Int = table.identifier.hashCode() +} diff --git a/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-modify/src/main/scala/org/apache/spark/util/FileStreamSink.scala b/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-modify/src/main/scala/org/apache/spark/util/FileStreamSink.scala new file mode 100644 index 0000000000000000000000000000000000000000..0685cae8056beca2819403ea17478029ec719fde --- /dev/null +++ b/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-modify/src/main/scala/org/apache/spark/util/FileStreamSink.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import scala.util.control.NonFatal + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.SparkException +import org.apache.spark.internal.Logging +import org.apache.spark.internal.io.FileCommitProtocol +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, FileFormat, FileFormatWriter} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.{SerializableConfiguration, Utils} + +object FileStreamSink extends Logging { + // The name of the subdirectory that is used to store metadata about which files are valid. + val metadataDir = "_spark_metadata" + + /** + * Returns true if there is a single path that has a metadata log indicating which files should + * be read. + */ + def hasMetadata(path: Seq[String], hadoopConf: Configuration, sqlConf: SQLConf): Boolean = { + // User explicitly configs to ignore sink metadata. + if (sqlConf.fileStreamSinkMetadataIgnored) { + return false + } + + path match { + case Seq(singlePath) => + val hdfsPath = new Path(singlePath) + try { + val fs = hdfsPath.getFileSystem(hadoopConf) + if (fs.isDirectory(hdfsPath)) { + val metadataPath = getMetadataLogPath(fs, hdfsPath, sqlConf) + fs.exists(metadataPath) + } else { + false + } + } catch { + case e: SparkException => throw e + case NonFatal(e) => + logWarning(s"Assume no metadata directory. Error while looking for " + + s"metadata directory in the path: $singlePath.", e) + false + } + case _ => false + } + } + + def getMetadataLogPath(fs: FileSystem, path: Path, sqlConf: SQLConf): Path = { + val metadataDir = new Path(path, FileStreamSink.metadataDir) + FileStreamSink.checkEscapedMetadataPath(fs, metadataDir, sqlConf) + metadataDir + } + + def checkEscapedMetadataPath(fs: FileSystem, metadataPath: Path, sqlConf: SQLConf): Unit = { + if (sqlConf.getConf(SQLConf.STREAMING_CHECKPOINT_ESCAPED_PATH_CHECK_ENABLED) + && StreamExecution.containsSpecialCharsInPath(metadataPath)) { + val legacyMetadataPath = new Path(metadataPath.toUri.toString) + val legacyMetadataPathExists = + try { + fs.exists(legacyMetadataPath) + } catch { + case NonFatal(e) => + // We may not have access to this directory. Don't fail the query if that happens. + logWarning(e.getMessage, e) + false + } + if (legacyMetadataPathExists) { + throw QueryExecutionErrors.legacyMetadataPathExistsError(metadataPath, legacyMetadataPath) + } + } + } + + /** + * Returns true if the path is the metadata dir or its ancestor is the metadata dir. + * E.g.: + * - ancestorIsMetadataDirectory(/.../_spark_metadata) => true + * - ancestorIsMetadataDirectory(/.../_spark_metadata/0) => true + * - ancestorIsMetadataDirectory(/a/b/c) => false + */ + def ancestorIsMetadataDirectory(path: Path, hadoopConf: Configuration): Boolean = { + val pathUri = path.toUri + val scheme = Option(pathUri.getScheme) + val authority = Option(pathUri.getAuthority) + var currentPath = (path.isAbsolute, scheme, authority) match { + case (true, Some(_), Some(_)) => path + case _ => + val fs = path.getFileSystem(hadoopConf) + path.makeQualified(fs.getUri, fs.getWorkingDirectory) + } + + while (currentPath != null) { + if (currentPath.getName == FileStreamSink.metadataDir) { + return true + } else { + currentPath = currentPath.getParent + } + } + return false + } +} + +/** + * A sink that writes out results to parquet files. Each batch is written out to a unique + * directory. After all of the files in a batch have been successfully written, the list of + * file paths is appended to the log atomically. In the case of partial failures, some duplicate + * data may be present in the target directory, but only one copy of each file will be present + * in the log. + */ +class FileStreamSink( + sparkSession: SparkSession, + path: String, + fileFormat: FileFormat, + partitionColumnNames: Seq[String], + options: Map[String, String]) extends Sink with Logging { + + import FileStreamSink._ + + private val hadoopConf = sparkSession.sessionState.newHadoopConf() + private val basePath = new Path(path) + private val logPath = getMetadataLogPath(basePath.getFileSystem(hadoopConf), basePath, + sparkSession.sessionState.conf) + private val retention = options.get("retention").map(Utils.timeStringAsMs) + private val fileLog = new FileStreamSinkLog(FileStreamSinkLog.VERSION, sparkSession, + logPath.toString, retention) + + private def basicWriteJobStatsTracker: BasicWriteJobStatsTracker = { + val serializableHadoopConf = new SerializableConfiguration(hadoopConf) + new BasicWriteJobStatsTracker(serializableHadoopConf, BasicWriteJobStatsTracker.metrics) + } + + override def addBatch(batchId: Long, data: DataFrame): Unit = { + if (batchId <= fileLog.getLatestBatchId().getOrElse(-1L)) { + logInfo(s"Skipping already committed batch $batchId") + } else { + val committer = FileCommitProtocol.instantiate( + className = sparkSession.sessionState.conf.streamingFileCommitProtocolClass, + jobId = batchId.toString, + outputPath = path) + + committer match { + case manifestCommitter: ManifestFileCommitProtocol => + manifestCommitter.setupManifestOptions(fileLog, batchId) + case _ => // Do nothing + } + + // Get the actual partition columns as attributes after matching them by name with + // the given columns names. + val partitionColumns: Seq[Attribute] = partitionColumnNames.map { col => + val nameEquality = data.sparkSession.sessionState.conf.resolver + data.logicalPlan.output.find(f => nameEquality(f.name, col)).getOrElse { + throw QueryExecutionErrors.partitionColumnNotFoundInSchemaError(col, data.schema) + } + } + val qe = data.queryExecution + + FileFormatWriter.write( + sparkSession = sparkSession, + plan = qe.executedPlan, + fileFormat = fileFormat, + committer = committer, + outputSpec = FileFormatWriter.OutputSpec(path, Map.empty, qe.analyzed.output), + hadoopConf = hadoopConf, + partitionColumns = partitionColumns, + bucketSpec = None, + statsTrackers = Seq(basicWriteJobStatsTracker), + options = options) + } + } + + override def toString: String = s"FileSink[$path]" +} diff --git a/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-shim/src/main/scala/org/apache/spark/sql/execution/UnaryExecNodeShim.scala b/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-shim/src/main/scala/org/apache/spark/sql/execution/UnaryExecNodeShim.scala new file mode 100644 index 0000000000000000000000000000000000000000..1052e70bf1f692d6b4a8e1440f593a468309f40d --- /dev/null +++ b/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-shim/src/main/scala/org/apache/spark/sql/execution/UnaryExecNodeShim.scala @@ -0,0 +1,27 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.expressions.{NamedExpression, SortOrder} + +abstract class UnaryExecNodeShim(sortOrder: Seq[SortOrder], projectList: Seq[NamedExpression]) + extends UnaryExecNode { + + override def outputOrdering: Seq[SortOrder] = sortOrder +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-shim/src/main/scala/org/apache/spark/sql/util/ShimUtil.scala b/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-shim/src/main/scala/org/apache/spark/sql/util/ShimUtil.scala index c1d20df058d30ed785c271b973f6aa0aa1f3ec0f..13df736039782b30aa21ca9a6c90a0d6c7b3cfe6 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-shim/src/main/scala/org/apache/spark/sql/util/ShimUtil.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-shims/spark33-shim/src/main/scala/org/apache/spark/sql/util/ShimUtil.scala @@ -29,7 +29,7 @@ import org.apache.spark.shuffle.sort.SortShuffleWriter import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriteMetricsReporter, ShuffleWriter} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Average, Count, Max, Min, Sum} import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl} -import org.apache.spark.sql.catalyst.expressions.{Attribute, BinaryOperator, CastBase, Expression} +import org.apache.spark.sql.catalyst.expressions.{Add, Attribute, BinaryOperator, CastBase, Divide, Expression, Multiply, Subtract} import org.apache.spark.sql.catalyst.optimizer.BuildSide import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, CTERelationDef, CTERelationRef, LogicalPlan, Statistics} @@ -38,11 +38,15 @@ import org.apache.spark.sql.types.{DataType, DateType, DecimalType, DoubleType, import java.net.URI import java.util.{Locale, Properties} +import nova.hetu.omniruntime.constants.FunctionType +import nova.hetu.omniruntime.constants.FunctionType.{OMNI_AGGREGATION_TYPE_SUM,OMNI_AGGREGATION_TYPE_AVG} object ShimUtil { def isSupportDataWriter: Boolean = true + def isNeedModifyBuildSide: Boolean = false + def createCTERelationRef(cteId: Long, resolved: Boolean, output: Seq[Attribute], isStreaming: Boolean, tatsOpt: Option[Statistics] = None): CTERelationRef = { CTERelationRef(cteId, resolved, output, tatsOpt) @@ -59,7 +63,23 @@ object ShimUtil { new TaskContextImpl(stageId, stageAttemptNumber, partitionId, taskAttemptId, attemptNumber, taskMemoryManager, localProperties, metricsSystem, taskMetrics, cpus, resources) } - def unsupportedEvalModeCheck(expr: Expression): Unit = {} + def transformExpressionByEvalMode(expr: Expression): String = { + expr match { + case add: Add => "ADD" + case sub: Subtract => "SUBTRACT" + case mult: Multiply => "MULTIPLY" + case divide: Divide => "DIVIDE" + case _ => throw new UnsupportedOperationException(s"Unsupported Operation for $expr") + } + } + + def transformFuncTypeByEvalMode(expr: Expression):FunctionType = { + expr match { + case sum: Sum => OMNI_AGGREGATION_TYPE_SUM + case avg: Average => OMNI_AGGREGATION_TYPE_AVG + case _ => throw new UnsupportedOperationException(s"Unsupported Operation for $expr") + } + } def binaryOperatorAdjust(expr: BinaryOperator, returnDataType: DataType): (Expression, Expression) = { (expr.left, expr.right) @@ -117,10 +137,6 @@ object ShimUtil { ) } - def buildBuildSide(buildSide: BuildSide, joinType: JoinType): BuildSide = { - buildSide - } - def createSortShuffleWriter[K, V, C](handle: BaseShuffleHandle[K, V, C], mapId: Long, context: TaskContext, diff --git a/omnioperator/omniop-spark-extension/spark-extension-shims/spark34-modify/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatDataWriter.scala b/omnioperator/omniop-spark-extension/spark-extension-shims/spark34-modify/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatDataWriter.scala index f3fe865e0b85e88277ba10776bde67092415a245..ccb312da1a9380718f0dac74f7b8fc35d0e9e60c 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-shims/spark34-modify/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatDataWriter.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-shims/spark34-modify/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatDataWriter.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils import org.apache.spark.sql.catalyst.expressions.{Cast, Concat, Expression, Literal, ScalaUDF, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.connector.write.DataWriter import org.apache.spark.sql.execution.datasources.orc.OmniOrcOutputWriter +import org.apache.spark.sql.execution.datasources.parquet.OmniParquetOutputWriter import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric} import org.apache.spark.sql.types.StringType import org.apache.spark.sql.vectorized.ColumnarBatch @@ -62,6 +63,9 @@ class OmniSingleDirectoryDataWriter( context = taskAttemptContext) currentWriter match { + case _: OmniParquetOutputWriter => + currentWriter.asInstanceOf[OmniParquetOutputWriter] + .initialize(description.allColumns, description.dataColumns) case _: OmniOrcOutputWriter => currentWriter.asInstanceOf[OmniOrcOutputWriter] .initialize(description.allColumns, description.dataColumns) @@ -153,9 +157,6 @@ abstract class OmniBaseDynamicPartitionDataWriter( row => proj(row).getInt(0) } - /** Returns the data columns to be written given an input row */ - protected val getOutputRow = - UnsafeProjection.create(description.dataColumns, description.allColumns) protected def getPartitionPath(partitionValues: Option[InternalRow], bucketId: Option[Int]): String = { @@ -186,6 +187,10 @@ abstract class OmniBaseDynamicPartitionDataWriter( currentPath } + /** Returns the data columns to be written given an input row */ + protected val getOutputRow = + UnsafeProjection.create(description.dataColumns, description.allColumns) + /** * Opens a new OutputWriter given a partition key and/or a bucket id. * If bucket id is specified, we will append it to the end of the file name, but before the @@ -235,8 +240,18 @@ abstract class OmniBaseDynamicPartitionDataWriter( path = currentPath, dataSchema = description.dataColumns.toStructType, context = taskAttemptContext) - currentWriter.asInstanceOf[OmniOrcOutputWriter] - .initialize(description.allColumns, description.dataColumns) + + currentWriter match { + case _: OmniParquetOutputWriter => + currentWriter.asInstanceOf[OmniParquetOutputWriter] + .initialize(description.allColumns, description.dataColumns) + case _: OmniOrcOutputWriter => + currentWriter.asInstanceOf[OmniOrcOutputWriter] + .initialize(description.allColumns, description.dataColumns) + case _ => + throw new UnsupportedOperationException + (s"Unsupported ${currentWriter.getClass} Output writer!") + } statsTrackers.foreach(_.newFile(currentPath)) } @@ -266,8 +281,17 @@ abstract class OmniBaseDynamicPartitionDataWriter( protected def writeRecord(record: InternalRow, startPos: Long, endPos: Long): Unit = { // TODO After add OmniParquetOutPutWriter need extract // a abstract interface named OmniOutPutWriter - assert(currentWriter.isInstanceOf[OmniOrcOutputWriter]) - currentWriter.asInstanceOf[OmniOrcOutputWriter].spiltWrite(record, startPos, endPos) + currentWriter match { + case _: OmniParquetOutputWriter => + assert(currentWriter.isInstanceOf[OmniParquetOutputWriter]) + currentWriter.asInstanceOf[OmniParquetOutputWriter].spiltWrite(record, startPos, endPos) + case _: OmniOrcOutputWriter => + assert(currentWriter.isInstanceOf[OmniOrcOutputWriter]) + currentWriter.asInstanceOf[OmniOrcOutputWriter].spiltWrite(record, startPos, endPos) + case _ => + throw new UnsupportedOperationException + (s"writeRecord Unsupported ${currentWriter.getClass} Output writer!") + } statsTrackers.foreach(_.newRow(currentWriter.path, record)) recordsInFile += record.asInstanceOf[OmniInternalRow].batch.numRows() diff --git a/omnioperator/omniop-spark-extension/spark-extension-shims/spark34-modify/src/main/scala/org/apache/spark/sql/util/ModifyUtil.scala b/omnioperator/omniop-spark-extension/spark-extension-shims/spark34-modify/src/main/scala/org/apache/spark/sql/util/ModifyUtil.scala index ba1cbda834e56b25637f036547649923778d3f60..90875c2c2902bb78e52980f3d4f9e0e84b514508 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-shims/spark34-modify/src/main/scala/org/apache/spark/sql/util/ModifyUtil.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-shims/spark34-modify/src/main/scala/org/apache/spark/sql/util/ModifyUtil.scala @@ -22,7 +22,7 @@ import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor.{JsonObjectExt import com.huawei.boostkit.spark.util.ModifyUtilAdaptor import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSessionExtensions -import org.apache.spark.sql.catalyst.expressions.{BloomFilterMightContain, ExprId, Expression} +import org.apache.spark.sql.catalyst.expressions.{BloomFilterMightContain, ExprId, Expression, TryEval} import org.apache.spark.sql.catalyst.optimizer.{CombineJoinedAggregates, MergeSubqueryFilters} import org.apache.spark.sql.execution.adaptive.QueryStageExec import org.apache.spark.sql.execution.{BroadcastExchangeExecProxy, ColumnarBloomFilterSubquery, ColumnarToRowExec, OmniColumnarToRowExec, SparkPlan} @@ -60,7 +60,8 @@ object ModifyUtil extends Logging { .put("isNull", bfAddress == 0L) .put("dataType", 2) .put("value", bfAddress) - + case tryEval: TryEval => + func(tryEval.child, exprsIndexMap, returnDatatype) case _ => null } @@ -102,10 +103,13 @@ object ModifyUtil extends Logging { extensions.injectOptimizerRule(_ => CombineJoinedAggregates) } + private def addTransformableTagAdapter(plan: SparkPlan): Unit = {} + def registerFunc(): Unit = { ModifyUtilAdaptor.configRewriteJsonFunc(rewriteToOmniJsonExpressionAdapter) ModifyUtilAdaptor.configPreReplacePlanFunc(preReplaceSparkPlanAdapter) ModifyUtilAdaptor.configPostReplacePlanFunc(postReplaceSparkPlanAdapter) ModifyUtilAdaptor.configInjectRuleFunc(injectRuleAdapter) + ModifyUtilAdaptor.configAddTransformableFunc(addTransformableTagAdapter) } } diff --git a/omnioperator/omniop-spark-extension/spark-extension-shims/spark34-shim/src/main/scala/org/apache/spark/sql/execution/UnaryExecNodeShim.scala b/omnioperator/omniop-spark-extension/spark-extension-shims/spark34-shim/src/main/scala/org/apache/spark/sql/execution/UnaryExecNodeShim.scala new file mode 100644 index 0000000000000000000000000000000000000000..8e49f1d388f776a969f917f4d0a6493b1242a320 --- /dev/null +++ b/omnioperator/omniop-spark-extension/spark-extension-shims/spark34-shim/src/main/scala/org/apache/spark/sql/execution/UnaryExecNodeShim.scala @@ -0,0 +1,29 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.expressions.{NamedExpression, SortOrder} + +abstract class UnaryExecNodeShim(sortOrder: Seq[SortOrder], projectList: Seq[NamedExpression]) + extends OrderPreservingUnaryExecNode { + + override def outputExpressions: Seq[NamedExpression] = projectList + + override def orderingExpressions: Seq[SortOrder] = sortOrder +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/spark-extension-shims/spark34-shim/src/main/scala/org/apache/spark/sql/util/ShimUtil.scala b/omnioperator/omniop-spark-extension/spark-extension-shims/spark34-shim/src/main/scala/org/apache/spark/sql/util/ShimUtil.scala index c55e240d6fec7e05daff7fda9c8baed6e1381fa7..e703a7ffa594d3b17b51471da2e618ee917ce369 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-shims/spark34-shim/src/main/scala/org/apache/spark/sql/util/ShimUtil.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-shims/spark34-shim/src/main/scala/org/apache/spark/sql/util/ShimUtil.scala @@ -37,6 +37,8 @@ import org.apache.spark.sql.execution.datasources.PartitionedFile import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.DecimalType.{MAX_PRECISION, MAX_SCALE} import org.apache.spark.sql.types.{DataType, DateType, DecimalType, DoubleType, IntegerType, LongType, NullType, StringType} +import nova.hetu.omniruntime.constants.FunctionType +import nova.hetu.omniruntime.constants.FunctionType.{OMNI_AGGREGATION_TYPE_SUM,OMNI_AGGREGATION_TYPE_AVG,OMNI_AGGREGATION_TYPE_TRY_SUM,OMNI_AGGREGATION_TYPE_TRY_AVG} import java.lang.reflect.Constructor import java.util.{Locale, Properties} @@ -45,6 +47,8 @@ object ShimUtil { def isSupportDataWriter: Boolean = false + def isNeedModifyBuildSide: Boolean = false + private val LOW_SPARK_VERSION_SET: Set[String] = Set("3.4.0", "3.4.1") private val cteRelationDefConstructor: Constructor[CTERelationRef] = { @@ -76,25 +80,23 @@ object ShimUtil { new TaskContextImpl(stageId, stageAttemptNumber, partitionId, taskAttemptId, attemptNumber, numPartitions, taskMemoryManager, localProperties, metricsSystem, taskMetrics, cpus, resources) } - def unsupportedEvalModeCheck(expr: Expression): Unit = { + def transformExpressionByEvalMode(expr: Expression): String = { expr match { - case add: Add if add.evalMode == EvalMode.TRY => - throw new UnsupportedOperationException(s"Unsupported EvalMode TRY for $add") - case sub: Subtract if sub.evalMode == EvalMode.TRY => - throw new UnsupportedOperationException(s"Unsupported EvalMode TRY for $sub") - case mult: Multiply if mult.evalMode == EvalMode.TRY => - throw new UnsupportedOperationException(s"Unsupported EvalMode TRY for $mult") - case divide: Divide if divide.evalMode == EvalMode.TRY => - throw new UnsupportedOperationException(s"Unsupported EvalMode TRY for $divide") - case mod: Remainder if mod.evalMode == EvalMode.TRY => - throw new UnsupportedOperationException(s"Unsupported EvalMode TRY for $mod") - case sum: Sum if sum.evalMode == EvalMode.TRY => - throw new UnsupportedOperationException(s"Unsupported EvalMode TRY for $sum") - case avg: Average if avg.evalMode == EvalMode.TRY => - throw new UnsupportedOperationException(s"Unsupported EvalMode TRY for $avg") - case _ => + case add: Add => if (add.evalMode == EvalMode.TRY) "TRY_ADD" else "ADD" + case sub: Subtract => if (sub.evalMode == EvalMode.TRY) "TRY_SUBTRACT" else "SUBTRACT" + case mult: Multiply => if (mult.evalMode == EvalMode.TRY) "TRY_MULTIPLY" else "MULTIPLY" + case divide: Divide => if (divide.evalMode == EvalMode.TRY) "TRY_DIVIDE" else "DIVIDE" + case _ => throw new UnsupportedOperationException(s"Unsupported Operation for $expr") } - } + } + + def transformFuncTypeByEvalMode(expr: Expression):FunctionType = { + expr match { + case sum: Sum => if (sum.evalMode == EvalMode.TRY) OMNI_AGGREGATION_TYPE_TRY_SUM else OMNI_AGGREGATION_TYPE_SUM + case avg: Average => if (avg.evalMode == EvalMode.TRY) OMNI_AGGREGATION_TYPE_TRY_AVG else OMNI_AGGREGATION_TYPE_AVG + case _ => throw new UnsupportedOperationException(s"Unsupported Operation for $expr") + } + } def binaryOperatorAdjust(expr: BinaryOperator, returnDataType: DataType): (Expression, Expression) = { import scala.math.{max, min} @@ -199,10 +201,6 @@ object ShimUtil { ) } - def buildBuildSide(buildSide: BuildSide, joinType: JoinType): BuildSide = { - buildSide - } - def createSortShuffleWriter[K, V, C](handle: BaseShuffleHandle[K, V, C], mapId: Long, context: TaskContext, diff --git a/omnioperator/omniop-spark-extension/spark-extension-shims/spark35-modify/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowGroupLimitExec.scala b/omnioperator/omniop-spark-extension/spark-extension-shims/spark35-modify/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowGroupLimitExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..ea3855d95b7134af022d4bfe61ccd340d00db6f4 --- /dev/null +++ b/omnioperator/omniop-spark-extension/spark-extension-shims/spark35-modify/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowGroupLimitExec.scala @@ -0,0 +1,144 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.concurrent.TimeUnit.NANOSECONDS +import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor._ +import com.huawei.boostkit.spark.util.OmniAdaptorUtil +import com.huawei.boostkit.spark.util.OmniAdaptorUtil.{addAllAndGetIterator, genSortParam} +import nova.hetu.omniruntime.operator.config.{OperatorConfig, OverflowConfig, SpillConfig} +import nova.hetu.omniruntime.operator.window.OmniWindowGroupLimitWithExprOperatorFactory +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder, Rank, RowNumber, DenseRank} +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.util.SparkMemoryUtils +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.sql.execution.window.{WindowGroupLimitMode, Partial, Final} + +/** + * This operator is designed to filter out unnecessary rows before WindowExec + * for top-k computation. + * @param partitionSpec Should be the same as [[WindowExec#partitionSpec]]. + * @param orderSpec Should be the same as [[WindowExec#orderSpec]]. + * @param rankLikeFunction The function to compute row rank, should be RowNumber/Rank/DenseRank. + * @param limit The limit for rank value. + * @param mode The mode describes [[WindowGroupLimitExec]] before or after shuffle. + * @param child The child spark plan. + */ +case class ColumnarWindowGroupLimitExec( + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + rankLikeFunction: Expression, + limit: Int, + mode: WindowGroupLimitMode, + child: SparkPlan) extends UnaryExecNode { + + override def supportsColumnar: Boolean = true + + override def nodeName: String = "OmniColumnarWindowGroupLimit" + + override def outputOrdering: Seq[SortOrder] = orderSpec + + override def output: Seq[Attribute] = child.output + + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarWindowGroupLimitExec = + copy(child = newChild) + + override def requiredChildDistribution: Seq[Distribution] = mode match { + case Partial => super.requiredChildDistribution + case Final => + if (partitionSpec.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(partitionSpec) :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec) + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override lazy val metrics = Map( + "addInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni addInput"), + "numInputVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of input vecBatches"), + "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"), + "omniCodegenTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni codegen"), + "getOutputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni getOutput"), + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "outputDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "output data size"), + "numOutputVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatches")) + + override protected def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException(s"This operator doesn't support doExecute().") + } + + def buildCheck(): Unit = { + rankLikeFunction match { + case _: RowNumber => + case _: Rank => + case _: DenseRank => + throw new UnsupportedOperationException(s"This operator doesn't support DenseRank() doExecuteColumnar().") + } + val omniAttrExpsIdMap = getExprIdMap(child.output) + val omniPartitionChanels: Array[AnyRef] = partitionSpec.map( + exp => rewriteToOmniJsonExpressionLiteral(exp, omniAttrExpsIdMap)).toArray + checkOmniJsonWhiteList("", omniPartitionChanels) + genSortParam(child.output, orderSpec) + } + + + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val omniCodegenTime = longMetric("omniCodegenTime") + val omniAttrExpsIdMap = getExprIdMap(child.output) + val omniPartitionChanels = partitionSpec.map( + exp => rewriteToOmniJsonExpressionLiteral(exp, omniAttrExpsIdMap)).toArray + val (sourceTypes, ascending, nullFirsts, sortColsExp) = genSortParam(child.output, orderSpec) + + child.executeColumnar().mapPartitionsWithIndexInternal { (_, iter) => + val startCodegen = System.nanoTime() + val windowGroupLimitOperatorFactory = rankLikeFunction match { + case _: RowNumber => + new OmniWindowGroupLimitWithExprOperatorFactory(sourceTypes, limit, + "row_number", omniPartitionChanels, sortColsExp, ascending, nullFirsts, + new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) + case _: Rank => + new OmniWindowGroupLimitWithExprOperatorFactory(sourceTypes, limit, + "rank", omniPartitionChanels, sortColsExp, ascending, nullFirsts, + new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) + case _: DenseRank => + throw new UnsupportedOperationException(s"This operator doesn't support DenseRank() doExecuteColumnar().") + } + val windowGroupLimitOperator = windowGroupLimitOperatorFactory.createOperator + omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + windowGroupLimitOperator.close() + windowGroupLimitOperatorFactory.close() + }) + addAllAndGetIterator(windowGroupLimitOperator, iter, this.schema, + longMetric("addInputTime"), longMetric("numInputVecBatches"), longMetric("numInputRows"), + longMetric("getOutputTime"), longMetric("numOutputVecBatches"), longMetric("numOutputRows"), + longMetric("outputDataSize")) + } + } + +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/spark-extension-shims/spark35-modify/src/main/scala/org/apache/spark/sql/util/ModifyUtil.scala b/omnioperator/omniop-spark-extension/spark-extension-shims/spark35-modify/src/main/scala/org/apache/spark/sql/util/ModifyUtil.scala index ba1cbda834e56b25637f036547649923778d3f60..3ba2d82773a84ed93b95fdad5dcaff9751a9541b 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-shims/spark35-modify/src/main/scala/org/apache/spark/sql/util/ModifyUtil.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-shims/spark35-modify/src/main/scala/org/apache/spark/sql/util/ModifyUtil.scala @@ -22,7 +22,7 @@ import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor.{JsonObjectExt import com.huawei.boostkit.spark.util.ModifyUtilAdaptor import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSessionExtensions -import org.apache.spark.sql.catalyst.expressions.{BloomFilterMightContain, ExprId, Expression} +import org.apache.spark.sql.catalyst.expressions.{BloomFilterMightContain, ExprId, Expression, TryEval} import org.apache.spark.sql.catalyst.optimizer.{CombineJoinedAggregates, MergeSubqueryFilters} import org.apache.spark.sql.execution.adaptive.QueryStageExec import org.apache.spark.sql.execution.{BroadcastExchangeExecProxy, ColumnarBloomFilterSubquery, ColumnarToRowExec, OmniColumnarToRowExec, SparkPlan} @@ -30,6 +30,9 @@ import org.apache.spark.sql.execution.datasources.OmniFileFormatWriter.Empty2Nul import org.apache.spark.sql.expression.ColumnarExpressionConverter import org.apache.spark.sql.types.DataType import com.google.gson.{JsonArray, JsonObject} +import com.huawei.boostkit.spark.{TransformHints, ColumnarPreOverrides, ColumnarPluginConfig} +import org.apache.spark.sql.execution.ColumnarWindowGroupLimitExec +import org.apache.spark.sql.execution.window.WindowGroupLimitExec object ModifyUtil extends Logging { @@ -60,7 +63,8 @@ object ModifyUtil extends Logging { .put("isNull", bfAddress == 0L) .put("dataType", 2) .put("value", bfAddress) - + case tryEval: TryEval => + func(tryEval.child, exprsIndexMap, returnDatatype) case _ => null } @@ -68,6 +72,10 @@ object ModifyUtil extends Logging { private def preReplaceSparkPlanAdapter(plan: SparkPlan, func: (SparkPlan => SparkPlan)): SparkPlan = { plan match { + case plan: WindowGroupLimitExec => + val child = func(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarWindowGroupLimitExec(plan.partitionSpec, plan.orderSpec, plan.rankLikeFunction, plan.limit, plan.mode, child) case _ => null } } @@ -102,10 +110,24 @@ object ModifyUtil extends Logging { extensions.injectOptimizerRule(_ => CombineJoinedAggregates) } + private def addTransformableTagAdapter(plan: SparkPlan): Unit = + plan match { + case plan: WindowGroupLimitExec => + if (!ColumnarPluginConfig.getSessionConf.enableColumnarWindowGroupLimit) { + TransformHints.tagNotTransformable( + plan, "columnar Project is not enabled in WindowGroupLimitExec") + return + } + ColumnarWindowGroupLimitExec(plan.partitionSpec, plan.orderSpec, + plan.rankLikeFunction, plan.limit, plan.mode, plan.child).buildCheck() + case _ => + } + def registerFunc(): Unit = { ModifyUtilAdaptor.configRewriteJsonFunc(rewriteToOmniJsonExpressionAdapter) ModifyUtilAdaptor.configPreReplacePlanFunc(preReplaceSparkPlanAdapter) ModifyUtilAdaptor.configPostReplacePlanFunc(postReplaceSparkPlanAdapter) ModifyUtilAdaptor.configInjectRuleFunc(injectRuleAdapter) + ModifyUtilAdaptor.configAddTransformableFunc(addTransformableTagAdapter) } } diff --git a/omnioperator/omniop-spark-extension/spark-extension-shims/spark35-shim/src/main/scala/org/apache/spark/sql/execution/UnaryExecNodeShim.scala b/omnioperator/omniop-spark-extension/spark-extension-shims/spark35-shim/src/main/scala/org/apache/spark/sql/execution/UnaryExecNodeShim.scala new file mode 100644 index 0000000000000000000000000000000000000000..8e49f1d388f776a969f917f4d0a6493b1242a320 --- /dev/null +++ b/omnioperator/omniop-spark-extension/spark-extension-shims/spark35-shim/src/main/scala/org/apache/spark/sql/execution/UnaryExecNodeShim.scala @@ -0,0 +1,29 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.expressions.{NamedExpression, SortOrder} + +abstract class UnaryExecNodeShim(sortOrder: Seq[SortOrder], projectList: Seq[NamedExpression]) + extends OrderPreservingUnaryExecNode { + + override def outputExpressions: Seq[NamedExpression] = projectList + + override def orderingExpressions: Seq[SortOrder] = sortOrder +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/spark-extension-shims/spark35-shim/src/main/scala/org/apache/spark/sql/util/ShimUtil.scala b/omnioperator/omniop-spark-extension/spark-extension-shims/spark35-shim/src/main/scala/org/apache/spark/sql/util/ShimUtil.scala index 61d0206bd472954c02b602a5dc4a8b3ab6a60383..13a878093cf6e14eaf8cd65990560f4c5d1965cc 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-shims/spark35-shim/src/main/scala/org/apache/spark/sql/util/ShimUtil.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-shims/spark35-shim/src/main/scala/org/apache/spark/sql/util/ShimUtil.scala @@ -37,6 +37,8 @@ import org.apache.spark.sql.execution.datasources.PartitionedFile import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.DecimalType.{MAX_PRECISION, MAX_SCALE} import org.apache.spark.sql.types.{DataType, DateType, DecimalType, DoubleType, IntegerType, LongType, NullType, StringType} +import nova.hetu.omniruntime.constants.FunctionType +import nova.hetu.omniruntime.constants.FunctionType.{OMNI_AGGREGATION_TYPE_SUM,OMNI_AGGREGATION_TYPE_AVG,OMNI_AGGREGATION_TYPE_TRY_SUM,OMNI_AGGREGATION_TYPE_TRY_AVG} import java.util.{Locale, Properties} @@ -44,6 +46,8 @@ object ShimUtil { def isSupportDataWriter: Boolean = false + def isNeedModifyBuildSide: Boolean = true + def createCTERelationRef(cteId: Long, resolved: Boolean, output: Seq[Attribute], isStreaming: Boolean, tatsOpt: Option[Statistics] = None): CTERelationRef = { CTERelationRef(cteId, resolved, output, isStreaming, tatsOpt) @@ -60,25 +64,23 @@ object ShimUtil { new TaskContextImpl(stageId, stageAttemptNumber, partitionId, taskAttemptId, attemptNumber, numPartitions, taskMemoryManager, localProperties, metricsSystem, taskMetrics, cpus, resources) } - def unsupportedEvalModeCheck(expr: Expression): Unit = { + def transformExpressionByEvalMode(expr: Expression): String = { expr match { - case add: Add if add.evalMode == EvalMode.TRY => - throw new UnsupportedOperationException(s"Unsupported EvalMode TRY for $add") - case sub: Subtract if sub.evalMode == EvalMode.TRY => - throw new UnsupportedOperationException(s"Unsupported EvalMode TRY for $sub") - case mult: Multiply if mult.evalMode == EvalMode.TRY => - throw new UnsupportedOperationException(s"Unsupported EvalMode TRY for $mult") - case divide: Divide if divide.evalMode == EvalMode.TRY => - throw new UnsupportedOperationException(s"Unsupported EvalMode TRY for $divide") - case mod: Remainder if mod.evalMode == EvalMode.TRY => - throw new UnsupportedOperationException(s"Unsupported EvalMode TRY for $mod") - case sum: Sum if sum.evalMode == EvalMode.TRY => - throw new UnsupportedOperationException(s"Unsupported EvalMode TRY for $sum") - case avg: Average if avg.evalMode == EvalMode.TRY => - throw new UnsupportedOperationException(s"Unsupported EvalMode TRY for $avg") - case _ => + case add: Add => if (add.evalMode == EvalMode.TRY) "TRY_ADD" else "ADD" + case sub: Subtract => if (sub.evalMode == EvalMode.TRY) "TRY_SUBTRACT" else "SUBTRACT" + case mult: Multiply => if (mult.evalMode == EvalMode.TRY) "TRY_MULTIPLY" else "MULTIPLY" + case divide: Divide => if (divide.evalMode == EvalMode.TRY) "TRY_DIVIDE" else "DIVIDE" + case _ => throw new UnsupportedOperationException(s"Unsupported Operation for $expr") } - } + } + + def transformFuncTypeByEvalMode(expr: Expression):FunctionType = { + expr match { + case sum: Sum => if (sum.evalMode == EvalMode.TRY) OMNI_AGGREGATION_TYPE_TRY_SUM else OMNI_AGGREGATION_TYPE_SUM + case avg: Average => if (avg.evalMode == EvalMode.TRY) OMNI_AGGREGATION_TYPE_TRY_AVG else OMNI_AGGREGATION_TYPE_AVG + case _ => throw new UnsupportedOperationException(s"Unsupported Operation for $expr") + } + } def binaryOperatorAdjust(expr: BinaryOperator, returnDataType: DataType): (Expression, Expression) = { import scala.math.{max, min} @@ -183,16 +185,6 @@ object ShimUtil { ) } - def buildBuildSide(buildSide: BuildSide, joinType: JoinType): BuildSide = { - if (buildSide == BuildLeft && joinType == LeftOuter) { - BuildRight - } else if (buildSide == BuildRight && joinType == RightOuter) { - BuildLeft - } else { - buildSide - } - } - def createSortShuffleWriter[K, V, C](handle: BaseShuffleHandle[K, V, C], mapId: Long, context: TaskContext, diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/com/huawei/boostkit/spark/TableWriteBasicFunctionSuite.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/com/huawei/boostkit/spark/TableWriteBasicFunctionSuite.scala index b32c3983d8331a4fa94cc5351f21f7d78da6727b..c8888db59230c9a3f68b419344cf78971086cd27 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/com/huawei/boostkit/spark/TableWriteBasicFunctionSuite.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/com/huawei/boostkit/spark/TableWriteBasicFunctionSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ColumnarBroa import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, ColumnarFilterExec, ColumnarProjectExec, ColumnarTakeOrderedAndProjectExec, CommandResultExec, LeafExecNode, OmniColumnarToRowExec, ProjectExec, RowToOmniColumnarExec, SparkPlan, TakeOrderedAndProjectExec, UnaryExecNode} import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.execution.ColumnarDataWritingCommandExec import scala.concurrent.Future @@ -63,6 +64,24 @@ class TableWriteBasicFunctionSuite extends QueryTest with SharedSparkSession { val runRows = select.collect() val expectedRows = Seq(Row("Lisa", "Sales", 10000, 35), Row("Maggie", "Sales", 1, 2)) assert(QueryTest.sameRows(runRows, expectedRows).isEmpty, "the run value is error") + + } + + test("Insert basic data parquet (Non-partitioned table)") { + val dropParquet = spark.sql("drop table if exists employees_for_parquet_table_write_ut_test") + dropParquet.collect() + val employeesParquet = Seq[(String, String, Int, Int)]( + ("Lisa", "Sales", 10000, 35), + ).toDF("name", "dept", "salary", "age") + employeesParquet.write.format("parquet").saveAsTable("employees_for_parquet_table_write_ut_test") + + val insertParquet = spark.sql("insert into " + + "employees_for_parquet_table_write_ut_test values('Maggie', 'Sales', 1, 2)") + insertParquet.collect() + val selectParquet = spark.sql("select * from employees_for_parquet_table_write_ut_test") + val runRowsParquet = selectParquet.collect() + val expectedRowsParquet = Seq(Row("Lisa", "Sales", 10000, 35), Row("Maggie", "Sales", 1, 2)) + assert(QueryTest.sameRows(runRowsParquet, expectedRowsParquet).isEmpty, "the run value is error") } test("Insert Basic data (Partitioned table)") { @@ -81,6 +100,22 @@ class TableWriteBasicFunctionSuite extends QueryTest with SharedSparkSession { assert(QueryTest.sameRows(runRows, expectedRows).isEmpty, "the run value is error") } + test("Insert Basic data parquet (Partitioned table)") { + val drop = spark.sql("drop table if exists employees_for_parquet_table_write_ut_partition_test") + drop.collect() + val employees = Seq(("Lisa", "Sales", 10000, 35)).toDF("name", "dept", "salary", "age") + employees.write.format("parquet").partitionBy("age") + .saveAsTable("employees_for_parquet_table_write_ut_partition_test") + val insert = spark.sql("insert into employees_for_parquet_table_write_ut_partition_test " + + "values('Maggie','Sales',200,30),('Bob','Sales',2000,30),('Tom','Sales',5000,20)") + insert.collect() + val select = spark.sql("select * from employees_for_parquet_table_write_ut_partition_test") + val runRows = select.collect() + val expectedRows = Seq(Row("Lisa", "Sales", 10000, 35), Row("Maggie", "Sales", 200, 30), + Row("Bob", "Sales", 2000, 30), Row("Tom", "Sales", 5000, 20)) + assert(QueryTest.sameRows(runRows, expectedRows).isEmpty, "the run value is error") + } + test("Unsupported Scenarios") { val data = Seq[(Int, Int)]( (10000, 35), @@ -91,10 +126,9 @@ class TableWriteBasicFunctionSuite extends QueryTest with SharedSparkSession { insert.collect() var columnarDataWrite = insert.queryExecution.executedPlan.asInstanceOf[CommandResultExec] .commandPhysicalPlan.find({ - case _: DataWritingCommandExec => true + case _: ColumnarDataWritingCommandExec => true case _ => false - } - ) + }) assert(columnarDataWrite.isDefined, "use columnar data writing command") val createTable = spark.sql("create table table_write_ut_map_test" + @@ -128,6 +162,22 @@ class TableWriteBasicFunctionSuite extends QueryTest with SharedSparkSession { "529314109398732268.884038357697864858", "the run value is error") } + test("Insert of parquet decimal 128") { + val drop = spark.sql("drop table if exists table_parquet_for_decimal_128") + drop.collect() + val createTable = spark.sql("create table table_parquet_for_decimal_128 " + + "(amount DECIMAL(38,18)) using parquet") + createTable.collect() + + val insert = spark.sql("insert into table_parquet_for_decimal_128 " + + "values(529314109398732268.884038357697864858)") + insert.collect() + val select = spark.sql("select * from table_parquet_for_decimal_128") + val runRows = select.collect() + assert(runRows(0).getDecimal(0).toString == + "529314109398732268.884038357697864858", "the run value is error") + } + test("replace child plan to columnar") { val drop = spark.sql("drop table if exists test_parquet_int") drop.collect() @@ -148,8 +198,8 @@ class TableWriteBasicFunctionSuite extends QueryTest with SharedSparkSession { val columnarFilter = insertNew.queryExecution.executedPlan.asInstanceOf[CommandResultExec] .commandPhysicalPlan.find({ - case _: ColumnarFilterExec => true - case _ => false + case _: ColumnarFilterExec => false + case _ => true } ) assert(columnarFilter.isDefined, "use columnar data writing command") @@ -169,6 +219,20 @@ class TableWriteBasicFunctionSuite extends QueryTest with SharedSparkSession { "1001-01-04", "the run value is error") } + test("rebase parquet date to julian") { + val drop = spark.sql("drop table if exists test_parquet_date") + drop.collect() + val createTable = spark.sql("create table test_parquet_date(date_col date) using parquet") + createTable.collect() + val insert = spark.sql("insert into table test_parquet_date values(cast('1001-01-04' as date))") + insert.collect() + + val select = spark.sql("select * from test_parquet_date") + val runRows = select.collect() + assert(runRows(0).getDate(0).toString == + "1001-01-04", "the run value is error") + } + test("empty string partition") { val drop = spark.sql("drop table if exists table_insert_varchar") drop.collect() @@ -201,4 +265,37 @@ class TableWriteBasicFunctionSuite extends QueryTest with SharedSparkSession { Row(13, "6884578", 6, null, null)) assert(QueryTest.sameRows(runRowsNP, expectedRowsNP).isEmpty, "the run value is error") } + + test("empty parquet string partition") { + val drop = spark.sql("drop table if exists table_parquet_insert_varchar") + drop.collect() + val createTable = spark.sql("create table table_parquet_insert_varchar" + + "(id int, c_varchar varchar(40)) using parquet partitioned by (p_varchar varchar(40))") + createTable.collect() + val insert = spark.sql("insert into table table_parquet_insert_varchar values" + + "(5,'',''), (13,'6884578', null), (6,'72135', '666')") + insert.collect() + + val select = spark.sql("select * from table_parquet_insert_varchar order by id, c_varchar, p_varchar") + val runRows = select.collect() + val expectedRows = Seq(Row(5, "", null), Row(6, "72135", "666"), Row(13, "6884578", null)) + assert(QueryTest.sameRows(runRows, expectedRows).isEmpty, "the run value is error") + + val dropNP = spark.sql("drop table if exists table_parquet_insert_varchar_np") + dropNP.collect() + val createTableNP = spark.sql("create table table_parquet_insert_varchar_np" + + "(id int, c_varchar varchar(40)) using parquet partitioned by " + + "(p_varchar1 int, p_varchar2 varchar(40), p_varchar3 varchar(40))") + createTableNP.collect() + val insertNP = spark.sql("insert into table table_parquet_insert_varchar_np values" + + "(5,'',1,'',''), (13,'6884578',6, null, null), (1,'abc',1,'',''), " + + "(3,'abcde',6,null,null), (4,'qqqqq', 8, 'a', 'b'), (6,'ooooo', 8, 'a', 'b')") + val selectNP = spark.sql("select * from table_parquet_insert_varchar_np " + + "order by id, c_varchar, p_varchar1") + val runRowsNP = selectNP.collect() + val expectedRowsNP = Seq(Row(1, "abc", 1, null, null), Row(3, "abcde", 6, null, null), + Row(4, "qqqqq", 8, "a", "b"), Row(5, "", 1, null, null), Row(6, "ooooo", 8, "a", "b"), + Row(13, "6884578", 6, null, null)) + assert(QueryTest.sameRows(runRowsNP, expectedRowsNP).isEmpty, "the run value is error") + } } diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala index ec03b275363ec70940db54d610ea30b5972906c6..403a1baed983d35df6cb21813ba4ff1201443307 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala @@ -55,108 +55,15 @@ class HiveResourceSuite extends SparkFunSuite { } test("queryBySparkSql-HiveDataSource") { - runner.runQuery("q1", 1) - runner.runQuery("q2", 1) - runner.runQuery("q3", 1) - runner.runQuery("q4", 1) - runner.runQuery("q5", 1) - runner.runQuery("q6", 1) - runner.runQuery("q7", 1) - runner.runQuery("q8", 1) - runner.runQuery("q9", 1) - runner.runQuery("q10", 1) - runner.runQuery("q11", 1) - runner.runQuery("q12", 1) - runner.runQuery("q13", 1) - runner.runQuery("q14a", 1) - runner.runQuery("q14b", 1) - runner.runQuery("q15", 1) - runner.runQuery("q16", 1) - runner.runQuery("q17", 1) - runner.runQuery("q18", 1) runner.runQuery("q19", 1) - runner.runQuery("q20", 1) - runner.runQuery("q21", 1) - runner.runQuery("q22", 1) - runner.runQuery("q23a", 1) - runner.runQuery("q23b", 1) - runner.runQuery("q24a", 1) - runner.runQuery("q24b", 1) - runner.runQuery("q25", 1) - runner.runQuery("q26", 1) - runner.runQuery("q27", 1) - runner.runQuery("q28", 1) - runner.runQuery("q29", 1) - runner.runQuery("q30", 1) - runner.runQuery("q31", 1) - runner.runQuery("q32", 1) - runner.runQuery("q33", 1) - runner.runQuery("q34", 1) - runner.runQuery("q35", 1) - runner.runQuery("q36", 1) - runner.runQuery("q37", 1) - runner.runQuery("q38", 1) - runner.runQuery("q39a", 1) - runner.runQuery("q39b", 1) - runner.runQuery("q40", 1) - runner.runQuery("q41", 1) - runner.runQuery("q42", 1) - runner.runQuery("q43", 1) - runner.runQuery("q44", 1) - runner.runQuery("q45", 1) - runner.runQuery("q46", 1) runner.runQuery("q47", 1) - runner.runQuery("q48", 1) - runner.runQuery("q49", 1) - runner.runQuery("q50", 1) - runner.runQuery("q51", 1) - runner.runQuery("q52", 1) runner.runQuery("q53", 1) - runner.runQuery("q54", 1) - runner.runQuery("q55", 1) - runner.runQuery("q56", 1) - runner.runQuery("q57", 1) - runner.runQuery("q58", 1) - runner.runQuery("q59", 1) - runner.runQuery("q60", 1) - runner.runQuery("q61", 1) - runner.runQuery("q62", 1) runner.runQuery("q63", 1) - runner.runQuery("q64", 1) - runner.runQuery("q65", 1) - runner.runQuery("q66", 1) - runner.runQuery("q67", 1) - runner.runQuery("q68", 1) - runner.runQuery("q69", 1) - runner.runQuery("q70", 1) runner.runQuery("q71", 1) - runner.runQuery("q72", 1) - runner.runQuery("q73", 1) - runner.runQuery("q74", 1) - runner.runQuery("q75", 1) - runner.runQuery("q76", 1) - runner.runQuery("q77", 1) - runner.runQuery("q78", 1) runner.runQuery("q79", 1) - runner.runQuery("q80", 1) - runner.runQuery("q81", 1) runner.runQuery("q82", 1) - runner.runQuery("q83", 1) runner.runQuery("q84", 1) - runner.runQuery("q85", 1) - runner.runQuery("q86", 1) - runner.runQuery("q87", 1) - runner.runQuery("q88", 1) runner.runQuery("q89", 1) - runner.runQuery("q90", 1) - runner.runQuery("q91", 1) - runner.runQuery("q92", 1) - runner.runQuery("q93", 1) - runner.runQuery("q94", 1) - runner.runQuery("q95", 1) - runner.runQuery("q96", 1) - runner.runQuery("q97", 1) - runner.runQuery("q98", 1) runner.runQuery("q99", 1) } diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderPlanTestBase.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderPlanTestBase.scala deleted file mode 100644 index d8d7d0bd97807cb6bdcaad024e54232687b43937..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderPlanTestBase.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Project} -import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.util.sideBySide - -trait HeuristicJoinReorderPlanTestBase extends PlanTest { - - def outputsOf(plans: LogicalPlan*): Seq[Attribute] = { - plans.map(_.output).reduce(_ ++ _) - } - - def assertEqualJoinPlans( - optimizer: RuleExecutor[LogicalPlan], - originalPlan: LogicalPlan, - groundTruthBestPlan: LogicalPlan): Unit = { - val analyzed = originalPlan.analyze - val optimized = optimizer.execute(analyzed) - val expected = EliminateResolvedHint.apply(groundTruthBestPlan.analyze) - - assert(equivalentOutput(analyzed, expected)) - assert(equivalentOutput(analyzed, optimized)) - - compareJoinOrder(optimized, expected) - } - - protected def equivalentOutput(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { - normalizeExprIds(plan1).output == normalizeExprIds(plan2).output - } - - protected def compareJoinOrder(plan1: LogicalPlan, plan2: LogicalPlan): Unit = { - val normalized1 = normalizePlan(normalizeExprIds(plan1)) - val normalized2 = normalizePlan(normalizeExprIds(plan2)) - if (!sameJoinPlan(normalized1, normalized2)) { - fail( - s""" - |== FAIL: Plans do not match === - |${sideBySide( - rewriteNameFromAttrNullability(normalized1).treeString, - rewriteNameFromAttrNullability(normalized2).treeString).mkString("\n")} - """.stripMargin) - } - } - - private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { - (plan1, plan2) match { - case (j1: Join, j2: Join) => - (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right) - && j1.hint.leftHint == j2.hint.leftHint && j1.hint.rightHint == j2.hint.rightHint) || - (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left) - && j1.hint.leftHint == j2.hint.rightHint && j1.hint.rightHint == j2.hint.leftHint) - case (p1: Project, p2: Project) => - p1.projectList == p2.projectList && sameJoinPlan(p1.child, p2.child) - case _ => - plan1 == plan2 - } - } -} diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderSuite.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/joinReorder/ReorderJoinEnhancesSuite.scala similarity index 70% rename from omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderSuite.scala rename to omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/joinReorder/ReorderJoinEnhancesSuite.scala index c7ea9bd95ad953b136ff9f5f196a0e0bace24027..6ce18a507b01e57974784ecb035bfac05958047c 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderSuite.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/joinReorder/ReorderJoinEnhancesSuite.scala @@ -1,4 +1,5 @@ /* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. @@ -15,27 +16,40 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.optimizer +package org.apache.spark.sql.catalyst.optimizer.joinReorder import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} +import org.apache.spark.sql.catalyst.optimizer._ import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.plans.logical.ColumnStat +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} +import org.apache.spark.sql.catalyst.optimizer.ReorderJoinEnhances -class HeuristicJoinReorderSuite - extends HeuristicJoinReorderPlanTestBase with StatsEstimationTestBase { +class ReorderJoinEnhancesSuite + extends JoinReorderPlanTestBase with StatsEstimationTestBase { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Resolve Hints", Once, + EliminateResolvedHint) :: + Batch("Operator Optimizations", FixedPoint(100), + CombineFilters, + PushPredicateThroughNonJoin, + ReorderJoin, + ReorderJoinEnhances, + PushPredicateThroughJoin, + ColumnPruning, + CollapseProject) :: Nil + } private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( attr("t1.k-1-2") -> rangeColumnStat(2, 0), attr("t1.v-1-10") -> rangeColumnStat(10, 0), attr("t2.k-1-5") -> rangeColumnStat(5, 0), - attr("t3.v-1-100") -> rangeColumnStat(100, 0), - attr("t4.k-1-2") -> rangeColumnStat(2, 0), - attr("t4.v-1-10") -> rangeColumnStat(10, 0), - attr("t5.k-1-5") -> rangeColumnStat(5, 0), - attr("t5.v-1-5") -> rangeColumnStat(5, 0) + attr("t3.v-1-100") -> rangeColumnStat(100, 0) )) private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) @@ -65,17 +79,13 @@ class HeuristicJoinReorderSuite t1.join(t2).join(t3) .where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .select(outputsOf(t1, t2, t3): _*) - val analyzed = originalPlan.analyze - val optimized = HeuristicJoinReorder.apply(analyzed).select(outputsOf(t1, t2, t3): _*) val expected = t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) .join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) .select(outputsOf(t1, t2, t3): _*) - assert(equivalentOutput(analyzed, expected)) - assert(equivalentOutput(analyzed, optimized)) - - compareJoinOrder(optimized, expected) + assertEqualJoinPlans(Optimize, originalPlan, expected) } } diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarJoinExecSuite.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarJoinExecSuite.scala index a3eee279a30287ddd48bf624bbc9035d82fab6e5..469aa53c8a2687cebbb0daab17d827c0b99fc44d 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarJoinExecSuite.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarJoinExecSuite.scala @@ -262,7 +262,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row(" yeah ", "yeah", 10, 8.0, "", "Hello", 2, 2.0), Row(" yeah ", "yeah", 10, 8.0, " add", null, 1, null) - ), false) + ), true) } test("validate columnar shuffledHashJoin full outer join happened") { @@ -299,7 +299,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row(" yeah ", "yeah", 10, 8.0, "abc", "", 4, 1.0), Row(" yeah ", "yeah", 10, 8.0, "", "Hello", 2, 2.0), Row(" yeah ", "yeah", 10, 8.0, " add", null, 1, null) - ), false) + ), true) } test("validate columnar shuffledHashJoin left semi join happened") { @@ -388,7 +388,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row("Adams", 24562), Row("Adams", 22456), Row("Bush", null) - ), false) + ), true) } test("BroadcastHashJoin and project funsion test for duplicate column") { @@ -403,7 +403,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row("Adams", 24562, 1), Row("Adams", 22456, 1), Row("Bush", null, null) - ), false) + ), true) } test("BroadcastHashJoin and project funsion test for reorder columns") { @@ -418,7 +418,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row(24562, "Adams", 1), Row(22456, "Adams", 1), Row(null, "Bush", null) - ), false) + ), true) } test("BroadcastHashJoin and project are not funsed test") { @@ -433,7 +433,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row(24563, "Adams"), Row(22457, "Adams"), Row(null, "Bush") - ), false) + ), true) } test("BroadcastHashJoin and project funsion test for alias") { @@ -448,7 +448,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row("Adams", 24562), Row("Adams", 22456), Row("Bush", null) - ), false) + ), true) } test("validate columnar shuffledHashJoin left anti join happened") { @@ -484,7 +484,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row(null, null, null, null, "", "Hello", 2, 2.0), Row("", "Hello", 1, 1.0, " add", "World", 1, 3.0), Row(null, null, null, null, " yeah ", "yeah", 0, 4.0) - ), false) + ), true) } test("columnar ShuffledHashJoin right outer join is equal to native with null") { @@ -505,7 +505,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row(null, null, null, null, "", "Hello", 2, 2.0), Row("", "Hello", 1, 1.0, " add", "World", 1, 3.0), Row(null, null, null, null, " yeah ", "yeah", 0, 4.0) - ), false) + ), true) } test("columnar BroadcastHashJoin right outer join is equal to native with null") { diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNSortExecSuite.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNSortExecSuite.scala index fa8c1390ebe2af2a3441234868c65321d2055847..42a13edb179e4db4931bc7aa45d791cef5bda0e8 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNSortExecSuite.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNSortExecSuite.scala @@ -62,8 +62,6 @@ class ColumnarTopNSortExecSuite extends ColumnarSparkPlanTest { hasTopNSortExec: Boolean = false): Unit = { // run ColumnarTopNSortExec config spark.conf.set("spark.omni.sql.columnar.topNSort", true) - spark.conf.set("spark.sql.execution.topNPushDownForWindow.enabled", true) - spark.conf.set("spark.sql.execution.topNPushDownForWindow.threshold", 100) spark.conf.set("spark.sql.adaptive.enabled", true) val omniResult = spark.sql(sql) omniResult.collect() @@ -80,15 +78,10 @@ class ColumnarTopNSortExecSuite extends ColumnarSparkPlanTest { val sparkPlan = sparkResult.queryExecution.executedPlan.toString() assert(!sparkPlan.contains("ColumnarTopNSort"), s"SQL:${sql}\n@SparkEnv have ColumnarTopNSortExec, sparkPlan:${sparkPlan}") - if (hasTopNSortExec) { - assert(sparkPlan.contains("TopNSort"), - s"SQL:${sql}\n@SparkEnv no TopNSortExec, sparkPlan:${sparkPlan}") - } // DataFrame do not support comparing with equals method, use DataFrame.except instead // DataFrame.except can do equal for rows misorder(with and without order by are same) assert(omniResult.except(sparkResult).isEmpty, s"SQL:${sql}\nomniResult:${omniResult.show()}\nsparkResult:${sparkResult.show()}\n") spark.conf.set("spark.omni.sql.columnar.topNSort", true) - spark.conf.set("spark.sql.execution.topNPushDownForWindow.enabled", false) } } diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/com/huawei/boostkit/spark/TableWriteBasicFunctionSuite.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/com/huawei/boostkit/spark/TableWriteBasicFunctionSuite.scala index b32c3983d8331a4fa94cc5351f21f7d78da6727b..03b019e6cc33770743219e7cced38c5ea28b1b5d 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/com/huawei/boostkit/spark/TableWriteBasicFunctionSuite.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/com/huawei/boostkit/spark/TableWriteBasicFunctionSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ColumnarBroa import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, ColumnarFilterExec, ColumnarProjectExec, ColumnarTakeOrderedAndProjectExec, CommandResultExec, LeafExecNode, OmniColumnarToRowExec, ProjectExec, RowToOmniColumnarExec, SparkPlan, TakeOrderedAndProjectExec, UnaryExecNode} import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.execution.ColumnarDataWritingCommandExec import scala.concurrent.Future @@ -42,6 +43,8 @@ class TableWriteBasicFunctionSuite extends QueryTest with SharedSparkSession { .setAppName("test tableWriteBasicFunctionSuit") .set(StaticSQLConf.SPARK_SESSION_EXTENSIONS.key, "com.huawei.boostkit.spark.ColumnarPlugin") .set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "false") + .set(SQLConf.PARQUET_REBASE_MODE_IN_WRITE.key, "LEGACY") + .set(SQLConf.PARQUET_REBASE_MODE_IN_READ.key, "LEGACY") .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.OmniColumnarShuffleManager") override def beforeAll(): Unit = { @@ -63,6 +66,24 @@ class TableWriteBasicFunctionSuite extends QueryTest with SharedSparkSession { val runRows = select.collect() val expectedRows = Seq(Row("Lisa", "Sales", 10000, 35), Row("Maggie", "Sales", 1, 2)) assert(QueryTest.sameRows(runRows, expectedRows).isEmpty, "the run value is error") + + } + + test("Insert basic data parquet (Non-partitioned table)") { + val dropParquet = spark.sql("drop table if exists employees_for_parquet_table_write_ut_test") + dropParquet.collect() + val employeesParquet = Seq[(String, String, Int, Int)]( + ("Lisa", "Sales", 10000, 35), + ).toDF("name", "dept", "salary", "age") + employeesParquet.write.format("parquet").saveAsTable("employees_for_parquet_table_write_ut_test") + + val insertParquet = spark.sql("insert into " + + "employees_for_parquet_table_write_ut_test values('Maggie', 'Sales', 1, 2)") + insertParquet.collect() + val selectParquet = spark.sql("select * from employees_for_parquet_table_write_ut_test") + val runRowsParquet = selectParquet.collect() + val expectedRowsParquet = Seq(Row("Lisa", "Sales", 10000, 35), Row("Maggie", "Sales", 1, 2)) + assert(QueryTest.sameRows(runRowsParquet, expectedRowsParquet).isEmpty, "the run value is error") } test("Insert Basic data (Partitioned table)") { @@ -81,6 +102,22 @@ class TableWriteBasicFunctionSuite extends QueryTest with SharedSparkSession { assert(QueryTest.sameRows(runRows, expectedRows).isEmpty, "the run value is error") } + test("Insert Basic data parquet (Partitioned table)") { + val drop = spark.sql("drop table if exists employees_for_parquet_table_write_ut_partition_test") + drop.collect() + val employees = Seq(("Lisa", "Sales", 10000, 35)).toDF("name", "dept", "salary", "age") + employees.write.format("parquet").partitionBy("age") + .saveAsTable("employees_for_parquet_table_write_ut_partition_test") + val insert = spark.sql("insert into employees_for_parquet_table_write_ut_partition_test " + + "values('Maggie','Sales',200,30),('Bob','Sales',2000,30),('Tom','Sales',5000,20)") + insert.collect() + val select = spark.sql("select * from employees_for_parquet_table_write_ut_partition_test") + val runRows = select.collect() + val expectedRows = Seq(Row("Lisa", "Sales", 10000, 35), Row("Maggie", "Sales", 200, 30), + Row("Bob", "Sales", 2000, 30), Row("Tom", "Sales", 5000, 20)) + assert(QueryTest.sameRows(runRows, expectedRows).isEmpty, "the run value is error") + } + test("Unsupported Scenarios") { val data = Seq[(Int, Int)]( (10000, 35), @@ -91,10 +128,9 @@ class TableWriteBasicFunctionSuite extends QueryTest with SharedSparkSession { insert.collect() var columnarDataWrite = insert.queryExecution.executedPlan.asInstanceOf[CommandResultExec] .commandPhysicalPlan.find({ - case _: DataWritingCommandExec => true + case _: ColumnarDataWritingCommandExec => true case _ => false - } - ) + }) assert(columnarDataWrite.isDefined, "use columnar data writing command") val createTable = spark.sql("create table table_write_ut_map_test" + @@ -128,6 +164,23 @@ class TableWriteBasicFunctionSuite extends QueryTest with SharedSparkSession { "529314109398732268.884038357697864858", "the run value is error") } + test("Insert of parquet decimal 128") { + val drop = spark.sql("drop table if exists table_parquet_for_decimal_128") + drop.collect() + val createTable = spark.sql("create table table_parquet_for_decimal_128 " + + "(amount DECIMAL(38,18)) using parquet") + createTable.collect() + + val insert = spark.sql("insert into table_parquet_for_decimal_128 " + + "values(529314109398732268.884038357697864858)") + insert.collect() + val select = spark.sql("select * from table_parquet_for_decimal_128") + val runRows = select.collect() + assert(runRows(0).getDecimal(0).toString == + "529314109398732268.884038357697864858", "the run value is error") + } + print("line") + test("replace child plan to columnar") { val drop = spark.sql("drop table if exists test_parquet_int") drop.collect() @@ -148,8 +201,8 @@ class TableWriteBasicFunctionSuite extends QueryTest with SharedSparkSession { val columnarFilter = insertNew.queryExecution.executedPlan.asInstanceOf[CommandResultExec] .commandPhysicalPlan.find({ - case _: ColumnarFilterExec => true - case _ => false + case _: ColumnarFilterExec => false + case _ => true } ) assert(columnarFilter.isDefined, "use columnar data writing command") @@ -169,6 +222,20 @@ class TableWriteBasicFunctionSuite extends QueryTest with SharedSparkSession { "1001-01-04", "the run value is error") } + test("rebase parquet date to julian") { + val drop = spark.sql("drop table if exists test_parquet_date") + drop.collect() + val createTable = spark.sql("create table test_parquet_date(date_col date) using parquet") + createTable.collect() + val insert = spark.sql("insert into table test_parquet_date values(cast('1001-01-04' as date))") + insert.collect() + + val select = spark.sql("select * from test_parquet_date") + val runRows = select.collect() + assert(runRows(0).getDate(0).toString == + "1001-01-04", "the run value is error") + } + test("empty string partition") { val drop = spark.sql("drop table if exists table_insert_varchar") drop.collect() @@ -201,4 +268,37 @@ class TableWriteBasicFunctionSuite extends QueryTest with SharedSparkSession { Row(13, "6884578", 6, null, null)) assert(QueryTest.sameRows(runRowsNP, expectedRowsNP).isEmpty, "the run value is error") } + + test("empty parquet string partition") { + val drop = spark.sql("drop table if exists table_parquet_insert_varchar") + drop.collect() + val createTable = spark.sql("create table table_parquet_insert_varchar" + + "(id int, c_varchar varchar(40)) using parquet partitioned by (p_varchar varchar(40))") + createTable.collect() + val insert = spark.sql("insert into table table_parquet_insert_varchar values" + + "(5,'',''), (13,'6884578', null), (6,'72135', '666')") + insert.collect() + + val select = spark.sql("select * from table_parquet_insert_varchar order by id, c_varchar, p_varchar") + val runRows = select.collect() + val expectedRows = Seq(Row(5, "", null), Row(6, "72135", "666"), Row(13, "6884578", null)) + assert(QueryTest.sameRows(runRows, expectedRows).isEmpty, "the run value is error") + + val dropNP = spark.sql("drop table if exists table_parquet_insert_varchar_np") + dropNP.collect() + val createTableNP = spark.sql("create table table_parquet_insert_varchar_np" + + "(id int, c_varchar varchar(40)) using parquet partitioned by " + + "(p_varchar1 int, p_varchar2 varchar(40), p_varchar3 varchar(40))") + createTableNP.collect() + val insertNP = spark.sql("insert into table table_parquet_insert_varchar_np values" + + "(5,'',1,'',''), (13,'6884578',6, null, null), (1,'abc',1,'',''), " + + "(3,'abcde',6,null,null), (4,'qqqqq', 8, 'a', 'b'), (6,'ooooo', 8, 'a', 'b')") + val selectNP = spark.sql("select * from table_parquet_insert_varchar_np " + + "order by id, c_varchar, p_varchar1") + val runRowsNP = selectNP.collect() + val expectedRowsNP = Seq(Row(1, "abc", 1, null, null), Row(3, "abcde", 6, null, null), + Row(4, "qqqqq", 8, "a", "b"), Row(5, "", 1, null, null), Row(6, "ooooo", 8, "a", "b"), + Row(13, "6884578", 6, null, null)) + assert(QueryTest.sameRows(runRowsNP, expectedRowsNP).isEmpty, "the run value is error") + } } diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala index ec03b275363ec70940db54d610ea30b5972906c6..403a1baed983d35df6cb21813ba4ff1201443307 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala @@ -55,108 +55,15 @@ class HiveResourceSuite extends SparkFunSuite { } test("queryBySparkSql-HiveDataSource") { - runner.runQuery("q1", 1) - runner.runQuery("q2", 1) - runner.runQuery("q3", 1) - runner.runQuery("q4", 1) - runner.runQuery("q5", 1) - runner.runQuery("q6", 1) - runner.runQuery("q7", 1) - runner.runQuery("q8", 1) - runner.runQuery("q9", 1) - runner.runQuery("q10", 1) - runner.runQuery("q11", 1) - runner.runQuery("q12", 1) - runner.runQuery("q13", 1) - runner.runQuery("q14a", 1) - runner.runQuery("q14b", 1) - runner.runQuery("q15", 1) - runner.runQuery("q16", 1) - runner.runQuery("q17", 1) - runner.runQuery("q18", 1) runner.runQuery("q19", 1) - runner.runQuery("q20", 1) - runner.runQuery("q21", 1) - runner.runQuery("q22", 1) - runner.runQuery("q23a", 1) - runner.runQuery("q23b", 1) - runner.runQuery("q24a", 1) - runner.runQuery("q24b", 1) - runner.runQuery("q25", 1) - runner.runQuery("q26", 1) - runner.runQuery("q27", 1) - runner.runQuery("q28", 1) - runner.runQuery("q29", 1) - runner.runQuery("q30", 1) - runner.runQuery("q31", 1) - runner.runQuery("q32", 1) - runner.runQuery("q33", 1) - runner.runQuery("q34", 1) - runner.runQuery("q35", 1) - runner.runQuery("q36", 1) - runner.runQuery("q37", 1) - runner.runQuery("q38", 1) - runner.runQuery("q39a", 1) - runner.runQuery("q39b", 1) - runner.runQuery("q40", 1) - runner.runQuery("q41", 1) - runner.runQuery("q42", 1) - runner.runQuery("q43", 1) - runner.runQuery("q44", 1) - runner.runQuery("q45", 1) - runner.runQuery("q46", 1) runner.runQuery("q47", 1) - runner.runQuery("q48", 1) - runner.runQuery("q49", 1) - runner.runQuery("q50", 1) - runner.runQuery("q51", 1) - runner.runQuery("q52", 1) runner.runQuery("q53", 1) - runner.runQuery("q54", 1) - runner.runQuery("q55", 1) - runner.runQuery("q56", 1) - runner.runQuery("q57", 1) - runner.runQuery("q58", 1) - runner.runQuery("q59", 1) - runner.runQuery("q60", 1) - runner.runQuery("q61", 1) - runner.runQuery("q62", 1) runner.runQuery("q63", 1) - runner.runQuery("q64", 1) - runner.runQuery("q65", 1) - runner.runQuery("q66", 1) - runner.runQuery("q67", 1) - runner.runQuery("q68", 1) - runner.runQuery("q69", 1) - runner.runQuery("q70", 1) runner.runQuery("q71", 1) - runner.runQuery("q72", 1) - runner.runQuery("q73", 1) - runner.runQuery("q74", 1) - runner.runQuery("q75", 1) - runner.runQuery("q76", 1) - runner.runQuery("q77", 1) - runner.runQuery("q78", 1) runner.runQuery("q79", 1) - runner.runQuery("q80", 1) - runner.runQuery("q81", 1) runner.runQuery("q82", 1) - runner.runQuery("q83", 1) runner.runQuery("q84", 1) - runner.runQuery("q85", 1) - runner.runQuery("q86", 1) - runner.runQuery("q87", 1) - runner.runQuery("q88", 1) runner.runQuery("q89", 1) - runner.runQuery("q90", 1) - runner.runQuery("q91", 1) - runner.runQuery("q92", 1) - runner.runQuery("q93", 1) - runner.runQuery("q94", 1) - runner.runQuery("q95", 1) - runner.runQuery("q96", 1) - runner.runQuery("q97", 1) - runner.runQuery("q98", 1) runner.runQuery("q99", 1) } diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderPlanTestBase.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderPlanTestBase.scala deleted file mode 100644 index d8d7d0bd97807cb6bdcaad024e54232687b43937..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderPlanTestBase.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Project} -import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.util.sideBySide - -trait HeuristicJoinReorderPlanTestBase extends PlanTest { - - def outputsOf(plans: LogicalPlan*): Seq[Attribute] = { - plans.map(_.output).reduce(_ ++ _) - } - - def assertEqualJoinPlans( - optimizer: RuleExecutor[LogicalPlan], - originalPlan: LogicalPlan, - groundTruthBestPlan: LogicalPlan): Unit = { - val analyzed = originalPlan.analyze - val optimized = optimizer.execute(analyzed) - val expected = EliminateResolvedHint.apply(groundTruthBestPlan.analyze) - - assert(equivalentOutput(analyzed, expected)) - assert(equivalentOutput(analyzed, optimized)) - - compareJoinOrder(optimized, expected) - } - - protected def equivalentOutput(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { - normalizeExprIds(plan1).output == normalizeExprIds(plan2).output - } - - protected def compareJoinOrder(plan1: LogicalPlan, plan2: LogicalPlan): Unit = { - val normalized1 = normalizePlan(normalizeExprIds(plan1)) - val normalized2 = normalizePlan(normalizeExprIds(plan2)) - if (!sameJoinPlan(normalized1, normalized2)) { - fail( - s""" - |== FAIL: Plans do not match === - |${sideBySide( - rewriteNameFromAttrNullability(normalized1).treeString, - rewriteNameFromAttrNullability(normalized2).treeString).mkString("\n")} - """.stripMargin) - } - } - - private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { - (plan1, plan2) match { - case (j1: Join, j2: Join) => - (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right) - && j1.hint.leftHint == j2.hint.leftHint && j1.hint.rightHint == j2.hint.rightHint) || - (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left) - && j1.hint.leftHint == j2.hint.rightHint && j1.hint.rightHint == j2.hint.leftHint) - case (p1: Project, p2: Project) => - p1.projectList == p2.projectList && sameJoinPlan(p1.child, p2.child) - case _ => - plan1 == plan2 - } - } -} diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderSuite.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/joinreorder/ReorderJoinEnhancesSuite.scala similarity index 70% rename from omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderSuite.scala rename to omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/joinreorder/ReorderJoinEnhancesSuite.scala index c7ea9bd95ad953b136ff9f5f196a0e0bace24027..6ce18a507b01e57974784ecb035bfac05958047c 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderSuite.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/joinreorder/ReorderJoinEnhancesSuite.scala @@ -1,4 +1,5 @@ /* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. @@ -15,27 +16,40 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.optimizer +package org.apache.spark.sql.catalyst.optimizer.joinReorder import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} +import org.apache.spark.sql.catalyst.optimizer._ import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.plans.logical.ColumnStat +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} +import org.apache.spark.sql.catalyst.optimizer.ReorderJoinEnhances -class HeuristicJoinReorderSuite - extends HeuristicJoinReorderPlanTestBase with StatsEstimationTestBase { +class ReorderJoinEnhancesSuite + extends JoinReorderPlanTestBase with StatsEstimationTestBase { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Resolve Hints", Once, + EliminateResolvedHint) :: + Batch("Operator Optimizations", FixedPoint(100), + CombineFilters, + PushPredicateThroughNonJoin, + ReorderJoin, + ReorderJoinEnhances, + PushPredicateThroughJoin, + ColumnPruning, + CollapseProject) :: Nil + } private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( attr("t1.k-1-2") -> rangeColumnStat(2, 0), attr("t1.v-1-10") -> rangeColumnStat(10, 0), attr("t2.k-1-5") -> rangeColumnStat(5, 0), - attr("t3.v-1-100") -> rangeColumnStat(100, 0), - attr("t4.k-1-2") -> rangeColumnStat(2, 0), - attr("t4.v-1-10") -> rangeColumnStat(10, 0), - attr("t5.k-1-5") -> rangeColumnStat(5, 0), - attr("t5.v-1-5") -> rangeColumnStat(5, 0) + attr("t3.v-1-100") -> rangeColumnStat(100, 0) )) private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) @@ -65,17 +79,13 @@ class HeuristicJoinReorderSuite t1.join(t2).join(t3) .where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .select(outputsOf(t1, t2, t3): _*) - val analyzed = originalPlan.analyze - val optimized = HeuristicJoinReorder.apply(analyzed).select(outputsOf(t1, t2, t3): _*) val expected = t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) .join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) .select(outputsOf(t1, t2, t3): _*) - assert(equivalentOutput(analyzed, expected)) - assert(equivalentOutput(analyzed, optimized)) - - compareJoinOrder(optimized, expected) + assertEqualJoinPlans(Optimize, originalPlan, expected) } } diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarJoinExecSuite.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarJoinExecSuite.scala index a3eee279a30287ddd48bf624bbc9035d82fab6e5..469aa53c8a2687cebbb0daab17d827c0b99fc44d 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarJoinExecSuite.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarJoinExecSuite.scala @@ -262,7 +262,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row(" yeah ", "yeah", 10, 8.0, "", "Hello", 2, 2.0), Row(" yeah ", "yeah", 10, 8.0, " add", null, 1, null) - ), false) + ), true) } test("validate columnar shuffledHashJoin full outer join happened") { @@ -299,7 +299,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row(" yeah ", "yeah", 10, 8.0, "abc", "", 4, 1.0), Row(" yeah ", "yeah", 10, 8.0, "", "Hello", 2, 2.0), Row(" yeah ", "yeah", 10, 8.0, " add", null, 1, null) - ), false) + ), true) } test("validate columnar shuffledHashJoin left semi join happened") { @@ -388,7 +388,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row("Adams", 24562), Row("Adams", 22456), Row("Bush", null) - ), false) + ), true) } test("BroadcastHashJoin and project funsion test for duplicate column") { @@ -403,7 +403,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row("Adams", 24562, 1), Row("Adams", 22456, 1), Row("Bush", null, null) - ), false) + ), true) } test("BroadcastHashJoin and project funsion test for reorder columns") { @@ -418,7 +418,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row(24562, "Adams", 1), Row(22456, "Adams", 1), Row(null, "Bush", null) - ), false) + ), true) } test("BroadcastHashJoin and project are not funsed test") { @@ -433,7 +433,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row(24563, "Adams"), Row(22457, "Adams"), Row(null, "Bush") - ), false) + ), true) } test("BroadcastHashJoin and project funsion test for alias") { @@ -448,7 +448,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row("Adams", 24562), Row("Adams", 22456), Row("Bush", null) - ), false) + ), true) } test("validate columnar shuffledHashJoin left anti join happened") { @@ -484,7 +484,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row(null, null, null, null, "", "Hello", 2, 2.0), Row("", "Hello", 1, 1.0, " add", "World", 1, 3.0), Row(null, null, null, null, " yeah ", "yeah", 0, 4.0) - ), false) + ), true) } test("columnar ShuffledHashJoin right outer join is equal to native with null") { @@ -505,7 +505,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row(null, null, null, null, "", "Hello", 2, 2.0), Row("", "Hello", 1, 1.0, " add", "World", 1, 3.0), Row(null, null, null, null, " yeah ", "yeah", 0, 4.0) - ), false) + ), true) } test("columnar BroadcastHashJoin right outer join is equal to native with null") { diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNSortExecSuite.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNSortExecSuite.scala index fa8c1390ebe2af2a3441234868c65321d2055847..42a13edb179e4db4931bc7aa45d791cef5bda0e8 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNSortExecSuite.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNSortExecSuite.scala @@ -62,8 +62,6 @@ class ColumnarTopNSortExecSuite extends ColumnarSparkPlanTest { hasTopNSortExec: Boolean = false): Unit = { // run ColumnarTopNSortExec config spark.conf.set("spark.omni.sql.columnar.topNSort", true) - spark.conf.set("spark.sql.execution.topNPushDownForWindow.enabled", true) - spark.conf.set("spark.sql.execution.topNPushDownForWindow.threshold", 100) spark.conf.set("spark.sql.adaptive.enabled", true) val omniResult = spark.sql(sql) omniResult.collect() @@ -80,15 +78,10 @@ class ColumnarTopNSortExecSuite extends ColumnarSparkPlanTest { val sparkPlan = sparkResult.queryExecution.executedPlan.toString() assert(!sparkPlan.contains("ColumnarTopNSort"), s"SQL:${sql}\n@SparkEnv have ColumnarTopNSortExec, sparkPlan:${sparkPlan}") - if (hasTopNSortExec) { - assert(sparkPlan.contains("TopNSort"), - s"SQL:${sql}\n@SparkEnv no TopNSortExec, sparkPlan:${sparkPlan}") - } // DataFrame do not support comparing with equals method, use DataFrame.except instead // DataFrame.except can do equal for rows misorder(with and without order by are same) assert(omniResult.except(sparkResult).isEmpty, s"SQL:${sql}\nomniResult:${omniResult.show()}\nsparkResult:${sparkResult.show()}\n") spark.conf.set("spark.omni.sql.columnar.topNSort", true) - spark.conf.set("spark.sql.execution.topNPushDownForWindow.enabled", false) } } diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarDateTypeSqlSuite.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarDateTypeSqlSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..b61279bfa259fb4e9c5ba2a5f4ff9c297c86706c --- /dev/null +++ b/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarDateTypeSqlSuite.scala @@ -0,0 +1,81 @@ +/* + * Copyright (C) 2024-2025. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.forsql + +import org.apache.spark.SparkConf +import org.apache.spark.sql.Row +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.internal.SQLConf + +import java.sql.Date + + +class ColumnarDateTypeSqlSuite extends ColumnarSparkPlanTest { + + import testImplicits._ + + override def sparkConf: SparkConf = super.sparkConf + .set(SQLConf.PARQUET_REBASE_MODE_IN_WRITE.key, "LEGACY") + + protected override def beforeAll(): Unit = { + super.beforeAll() + Seq[(Int, Date, Date)]( + (1, Date.valueOf("2024-10-10"), Date.valueOf("2024-10-10")), + (2, Date.valueOf("1024-02-29"), Date.valueOf("1024-03-01")), + (3, null, Date.valueOf("1024-10-10")) + ).toDF("int_c", "date_c1", "date_c2").write.format("orc").saveAsTable("date_test_orc") + + Seq[(Int, Date, Date)]( + (1, Date.valueOf("2024-10-10"), Date.valueOf("2024-10-10")), + (2, Date.valueOf("1024-02-29"), Date.valueOf("1024-03-01")), + (3, null, Date.valueOf("1024-10-10")) + ).toDF("int_c", "date_c1", "date_c2").write.format("parquet").saveAsTable("date_test_parquet") + } + + protected override def afterAll(): Unit = { + spark.sql("drop table if exists date_test_orc") + spark.sql("drop table if exists date_test_parquet") + super.afterAll() + } + + test("Test rebaseJulianToGregorianDays") { + val orcRes = spark.sql("select * from date_test_orc") + var executedPlan = orcRes.queryExecution.executedPlan + assert(executedPlan.find(_.isInstanceOf[ColumnarFileSourceScanExec]).isDefined, s"ColumnarFileSourceScanExec not happened, executedPlan as follows: \n$executedPlan") + checkAnswer( + orcRes, + Seq( + Row(1, Date.valueOf("2024-10-10"), Date.valueOf("2024-10-10")), + Row(2, Date.valueOf("1024-02-29"), Date.valueOf("1024-03-01")), + Row(3, null, Date.valueOf("1024-10-10"))) + ) + + val parquetRes = spark.sql("select * from date_test_parquet") + executedPlan = parquetRes.queryExecution.executedPlan + assert(executedPlan.find(_.isInstanceOf[ColumnarFileSourceScanExec]).isDefined, s"ColumnarFileSourceScanExec not happened, executedPlan as follows: \n$executedPlan") + + checkAnswer( + parquetRes, + Seq( + Row(1, Date.valueOf("2024-10-10"), Date.valueOf("2024-10-10")), + Row(2, Date.valueOf("1024-02-29"), Date.valueOf("1024-03-01")), + Row(3, null, Date.valueOf("1024-10-10"))) + ) + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala index ec03b275363ec70940db54d610ea30b5972906c6..403a1baed983d35df6cb21813ba4ff1201443307 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala @@ -55,108 +55,15 @@ class HiveResourceSuite extends SparkFunSuite { } test("queryBySparkSql-HiveDataSource") { - runner.runQuery("q1", 1) - runner.runQuery("q2", 1) - runner.runQuery("q3", 1) - runner.runQuery("q4", 1) - runner.runQuery("q5", 1) - runner.runQuery("q6", 1) - runner.runQuery("q7", 1) - runner.runQuery("q8", 1) - runner.runQuery("q9", 1) - runner.runQuery("q10", 1) - runner.runQuery("q11", 1) - runner.runQuery("q12", 1) - runner.runQuery("q13", 1) - runner.runQuery("q14a", 1) - runner.runQuery("q14b", 1) - runner.runQuery("q15", 1) - runner.runQuery("q16", 1) - runner.runQuery("q17", 1) - runner.runQuery("q18", 1) runner.runQuery("q19", 1) - runner.runQuery("q20", 1) - runner.runQuery("q21", 1) - runner.runQuery("q22", 1) - runner.runQuery("q23a", 1) - runner.runQuery("q23b", 1) - runner.runQuery("q24a", 1) - runner.runQuery("q24b", 1) - runner.runQuery("q25", 1) - runner.runQuery("q26", 1) - runner.runQuery("q27", 1) - runner.runQuery("q28", 1) - runner.runQuery("q29", 1) - runner.runQuery("q30", 1) - runner.runQuery("q31", 1) - runner.runQuery("q32", 1) - runner.runQuery("q33", 1) - runner.runQuery("q34", 1) - runner.runQuery("q35", 1) - runner.runQuery("q36", 1) - runner.runQuery("q37", 1) - runner.runQuery("q38", 1) - runner.runQuery("q39a", 1) - runner.runQuery("q39b", 1) - runner.runQuery("q40", 1) - runner.runQuery("q41", 1) - runner.runQuery("q42", 1) - runner.runQuery("q43", 1) - runner.runQuery("q44", 1) - runner.runQuery("q45", 1) - runner.runQuery("q46", 1) runner.runQuery("q47", 1) - runner.runQuery("q48", 1) - runner.runQuery("q49", 1) - runner.runQuery("q50", 1) - runner.runQuery("q51", 1) - runner.runQuery("q52", 1) runner.runQuery("q53", 1) - runner.runQuery("q54", 1) - runner.runQuery("q55", 1) - runner.runQuery("q56", 1) - runner.runQuery("q57", 1) - runner.runQuery("q58", 1) - runner.runQuery("q59", 1) - runner.runQuery("q60", 1) - runner.runQuery("q61", 1) - runner.runQuery("q62", 1) runner.runQuery("q63", 1) - runner.runQuery("q64", 1) - runner.runQuery("q65", 1) - runner.runQuery("q66", 1) - runner.runQuery("q67", 1) - runner.runQuery("q68", 1) - runner.runQuery("q69", 1) - runner.runQuery("q70", 1) runner.runQuery("q71", 1) - runner.runQuery("q72", 1) - runner.runQuery("q73", 1) - runner.runQuery("q74", 1) - runner.runQuery("q75", 1) - runner.runQuery("q76", 1) - runner.runQuery("q77", 1) - runner.runQuery("q78", 1) runner.runQuery("q79", 1) - runner.runQuery("q80", 1) - runner.runQuery("q81", 1) runner.runQuery("q82", 1) - runner.runQuery("q83", 1) runner.runQuery("q84", 1) - runner.runQuery("q85", 1) - runner.runQuery("q86", 1) - runner.runQuery("q87", 1) - runner.runQuery("q88", 1) runner.runQuery("q89", 1) - runner.runQuery("q90", 1) - runner.runQuery("q91", 1) - runner.runQuery("q92", 1) - runner.runQuery("q93", 1) - runner.runQuery("q94", 1) - runner.runQuery("q95", 1) - runner.runQuery("q96", 1) - runner.runQuery("q97", 1) - runner.runQuery("q98", 1) runner.runQuery("q99", 1) } diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderPlanTestBase.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderPlanTestBase.scala deleted file mode 100644 index d8d7d0bd97807cb6bdcaad024e54232687b43937..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderPlanTestBase.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Project} -import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.util.sideBySide - -trait HeuristicJoinReorderPlanTestBase extends PlanTest { - - def outputsOf(plans: LogicalPlan*): Seq[Attribute] = { - plans.map(_.output).reduce(_ ++ _) - } - - def assertEqualJoinPlans( - optimizer: RuleExecutor[LogicalPlan], - originalPlan: LogicalPlan, - groundTruthBestPlan: LogicalPlan): Unit = { - val analyzed = originalPlan.analyze - val optimized = optimizer.execute(analyzed) - val expected = EliminateResolvedHint.apply(groundTruthBestPlan.analyze) - - assert(equivalentOutput(analyzed, expected)) - assert(equivalentOutput(analyzed, optimized)) - - compareJoinOrder(optimized, expected) - } - - protected def equivalentOutput(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { - normalizeExprIds(plan1).output == normalizeExprIds(plan2).output - } - - protected def compareJoinOrder(plan1: LogicalPlan, plan2: LogicalPlan): Unit = { - val normalized1 = normalizePlan(normalizeExprIds(plan1)) - val normalized2 = normalizePlan(normalizeExprIds(plan2)) - if (!sameJoinPlan(normalized1, normalized2)) { - fail( - s""" - |== FAIL: Plans do not match === - |${sideBySide( - rewriteNameFromAttrNullability(normalized1).treeString, - rewriteNameFromAttrNullability(normalized2).treeString).mkString("\n")} - """.stripMargin) - } - } - - private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { - (plan1, plan2) match { - case (j1: Join, j2: Join) => - (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right) - && j1.hint.leftHint == j2.hint.leftHint && j1.hint.rightHint == j2.hint.rightHint) || - (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left) - && j1.hint.leftHint == j2.hint.rightHint && j1.hint.rightHint == j2.hint.leftHint) - case (p1: Project, p2: Project) => - p1.projectList == p2.projectList && sameJoinPlan(p1.child, p2.child) - case _ => - plan1 == plan2 - } - } -} diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderSuite.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/joinReorder/ReorderJoinEnhancesSuite.scala similarity index 70% rename from omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderSuite.scala rename to omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/joinReorder/ReorderJoinEnhancesSuite.scala index c7ea9bd95ad953b136ff9f5f196a0e0bace24027..6ce18a507b01e57974784ecb035bfac05958047c 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark33-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderSuite.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/joinReorder/ReorderJoinEnhancesSuite.scala @@ -1,4 +1,5 @@ /* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. @@ -15,27 +16,40 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.optimizer +package org.apache.spark.sql.catalyst.optimizer.joinReorder import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} +import org.apache.spark.sql.catalyst.optimizer._ import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.plans.logical.ColumnStat +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} +import org.apache.spark.sql.catalyst.optimizer.ReorderJoinEnhances -class HeuristicJoinReorderSuite - extends HeuristicJoinReorderPlanTestBase with StatsEstimationTestBase { +class ReorderJoinEnhancesSuite + extends JoinReorderPlanTestBase with StatsEstimationTestBase { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Resolve Hints", Once, + EliminateResolvedHint) :: + Batch("Operator Optimizations", FixedPoint(100), + CombineFilters, + PushPredicateThroughNonJoin, + ReorderJoin, + ReorderJoinEnhances, + PushPredicateThroughJoin, + ColumnPruning, + CollapseProject) :: Nil + } private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( attr("t1.k-1-2") -> rangeColumnStat(2, 0), attr("t1.v-1-10") -> rangeColumnStat(10, 0), attr("t2.k-1-5") -> rangeColumnStat(5, 0), - attr("t3.v-1-100") -> rangeColumnStat(100, 0), - attr("t4.k-1-2") -> rangeColumnStat(2, 0), - attr("t4.v-1-10") -> rangeColumnStat(10, 0), - attr("t5.k-1-5") -> rangeColumnStat(5, 0), - attr("t5.v-1-5") -> rangeColumnStat(5, 0) + attr("t3.v-1-100") -> rangeColumnStat(100, 0) )) private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) @@ -65,17 +79,13 @@ class HeuristicJoinReorderSuite t1.join(t2).join(t3) .where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .select(outputsOf(t1, t2, t3): _*) - val analyzed = originalPlan.analyze - val optimized = HeuristicJoinReorder.apply(analyzed).select(outputsOf(t1, t2, t3): _*) val expected = t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) .join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) .select(outputsOf(t1, t2, t3): _*) - assert(equivalentOutput(analyzed, expected)) - assert(equivalentOutput(analyzed, optimized)) - - compareJoinOrder(optimized, expected) + assertEqualJoinPlans(Optimize, originalPlan, expected) } } diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarJoinExecSuite.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarJoinExecSuite.scala index a3eee279a30287ddd48bf624bbc9035d82fab6e5..469aa53c8a2687cebbb0daab17d827c0b99fc44d 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarJoinExecSuite.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarJoinExecSuite.scala @@ -262,7 +262,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row(" yeah ", "yeah", 10, 8.0, "", "Hello", 2, 2.0), Row(" yeah ", "yeah", 10, 8.0, " add", null, 1, null) - ), false) + ), true) } test("validate columnar shuffledHashJoin full outer join happened") { @@ -299,7 +299,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row(" yeah ", "yeah", 10, 8.0, "abc", "", 4, 1.0), Row(" yeah ", "yeah", 10, 8.0, "", "Hello", 2, 2.0), Row(" yeah ", "yeah", 10, 8.0, " add", null, 1, null) - ), false) + ), true) } test("validate columnar shuffledHashJoin left semi join happened") { @@ -388,7 +388,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row("Adams", 24562), Row("Adams", 22456), Row("Bush", null) - ), false) + ), true) } test("BroadcastHashJoin and project funsion test for duplicate column") { @@ -403,7 +403,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row("Adams", 24562, 1), Row("Adams", 22456, 1), Row("Bush", null, null) - ), false) + ), true) } test("BroadcastHashJoin and project funsion test for reorder columns") { @@ -418,7 +418,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row(24562, "Adams", 1), Row(22456, "Adams", 1), Row(null, "Bush", null) - ), false) + ), true) } test("BroadcastHashJoin and project are not funsed test") { @@ -433,7 +433,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row(24563, "Adams"), Row(22457, "Adams"), Row(null, "Bush") - ), false) + ), true) } test("BroadcastHashJoin and project funsion test for alias") { @@ -448,7 +448,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row("Adams", 24562), Row("Adams", 22456), Row("Bush", null) - ), false) + ), true) } test("validate columnar shuffledHashJoin left anti join happened") { @@ -484,7 +484,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row(null, null, null, null, "", "Hello", 2, 2.0), Row("", "Hello", 1, 1.0, " add", "World", 1, 3.0), Row(null, null, null, null, " yeah ", "yeah", 0, 4.0) - ), false) + ), true) } test("columnar ShuffledHashJoin right outer join is equal to native with null") { @@ -505,7 +505,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row(null, null, null, null, "", "Hello", 2, 2.0), Row("", "Hello", 1, 1.0, " add", "World", 1, 3.0), Row(null, null, null, null, " yeah ", "yeah", 0, 4.0) - ), false) + ), true) } test("columnar BroadcastHashJoin right outer join is equal to native with null") { diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarLimitExecSuit.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarLimitExecSuit.scala index 5944618785ba67a869cba691d5ef223b7c13c045..aa35668839af8685b615f304b84b11a80220441b 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarLimitExecSuit.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarLimitExecSuit.scala @@ -66,6 +66,21 @@ class ColumnarLimitExecSuit extends ColumnarSparkPlanTest { assert(result.count() == 0) } + test("limit with offset and global limit columnar exec") { + val result = spark.sql("SELECT y FROM right WHERE x in " + + "(SELECT a FROM left WHERE a = 4 LIMIT 1 OFFSET 1)") + val plan = result.queryExecution.executedPlan + assert(plan.find(_.isInstanceOf[ColumnarLocalLimitExec]).isEmpty, + s"not match ColumnarLocalLimitExec, real plan: ${plan}") + assert(plan.find(_.isInstanceOf[LocalLimitExec]).isEmpty, + s"real plan: ${plan}") + assert(plan.find(_.isInstanceOf[ColumnarGlobalLimitExec]).isDefined, + s"not match ColumnarGlobalLimitExec, real plan: ${plan}") + assert(plan.find(_.isInstanceOf[GlobalLimitExec]).isEmpty, + s"real plan: ${plan}") + assert(result.count() == 0) + } + test("limit with rollback global limit to row-based exec") { spark.conf.set("spark.omni.sql.columnar.globalLimit", false) val result = spark.sql("SELECT a FROM left WHERE a in " + diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNSortExecSuite.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNSortExecSuite.scala index fa8c1390ebe2af2a3441234868c65321d2055847..42a13edb179e4db4931bc7aa45d791cef5bda0e8 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNSortExecSuite.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-ut/spark34-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNSortExecSuite.scala @@ -62,8 +62,6 @@ class ColumnarTopNSortExecSuite extends ColumnarSparkPlanTest { hasTopNSortExec: Boolean = false): Unit = { // run ColumnarTopNSortExec config spark.conf.set("spark.omni.sql.columnar.topNSort", true) - spark.conf.set("spark.sql.execution.topNPushDownForWindow.enabled", true) - spark.conf.set("spark.sql.execution.topNPushDownForWindow.threshold", 100) spark.conf.set("spark.sql.adaptive.enabled", true) val omniResult = spark.sql(sql) omniResult.collect() @@ -80,15 +78,10 @@ class ColumnarTopNSortExecSuite extends ColumnarSparkPlanTest { val sparkPlan = sparkResult.queryExecution.executedPlan.toString() assert(!sparkPlan.contains("ColumnarTopNSort"), s"SQL:${sql}\n@SparkEnv have ColumnarTopNSortExec, sparkPlan:${sparkPlan}") - if (hasTopNSortExec) { - assert(sparkPlan.contains("TopNSort"), - s"SQL:${sql}\n@SparkEnv no TopNSortExec, sparkPlan:${sparkPlan}") - } // DataFrame do not support comparing with equals method, use DataFrame.except instead // DataFrame.except can do equal for rows misorder(with and without order by are same) assert(omniResult.except(sparkResult).isEmpty, s"SQL:${sql}\nomniResult:${omniResult.show()}\nsparkResult:${sparkResult.show()}\n") spark.conf.set("spark.omni.sql.columnar.topNSort", true) - spark.conf.set("spark.sql.execution.topNPushDownForWindow.enabled", false) } } diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala index 076a3959df1943881989e1e84d80e19ea25f94bc..b7370cd5dcec7e6df7fcf4fef8a16eff6ef41873 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/com/huawei/boostkit/spark/hive/HiveResourceSuite.scala @@ -56,108 +56,15 @@ class HiveResourceSuite extends SparkFunSuite { } test("queryBySparkSql-HiveDataSource") { - runner.runQuery("q1", 1) - runner.runQuery("q2", 1) - runner.runQuery("q3", 1) - runner.runQuery("q4", 1) - runner.runQuery("q5", 1) - runner.runQuery("q6", 1) - runner.runQuery("q7", 1) - runner.runQuery("q8", 1) - runner.runQuery("q9", 1) - runner.runQuery("q10", 1) - runner.runQuery("q11", 1) - runner.runQuery("q12", 1) - runner.runQuery("q13", 1) - runner.runQuery("q14a", 1) - runner.runQuery("q14b", 1) - runner.runQuery("q15", 1) - runner.runQuery("q16", 1) - runner.runQuery("q17", 1) - runner.runQuery("q18", 1) runner.runQuery("q19", 1) - runner.runQuery("q20", 1) - runner.runQuery("q21", 1) - runner.runQuery("q22", 1) - runner.runQuery("q23a", 1) - runner.runQuery("q23b", 1) - runner.runQuery("q24a", 1) - runner.runQuery("q24b", 1) - runner.runQuery("q25", 1) - runner.runQuery("q26", 1) - runner.runQuery("q27", 1) - runner.runQuery("q28", 1) - runner.runQuery("q29", 1) - runner.runQuery("q30", 1) - runner.runQuery("q31", 1) - runner.runQuery("q32", 1) - runner.runQuery("q33", 1) - runner.runQuery("q34", 1) - runner.runQuery("q35", 1) - runner.runQuery("q36", 1) - runner.runQuery("q37", 1) - runner.runQuery("q38", 1) - runner.runQuery("q39a", 1) - runner.runQuery("q39b", 1) - runner.runQuery("q40", 1) - runner.runQuery("q41", 1) - runner.runQuery("q42", 1) - runner.runQuery("q43", 1) - runner.runQuery("q44", 1) - runner.runQuery("q45", 1) - runner.runQuery("q46", 1) runner.runQuery("q47", 1) - runner.runQuery("q48", 1) - runner.runQuery("q49", 1) - runner.runQuery("q50", 1) - runner.runQuery("q51", 1) - runner.runQuery("q52", 1) runner.runQuery("q53", 1) - runner.runQuery("q54", 1) - runner.runQuery("q55", 1) - runner.runQuery("q56", 1) - runner.runQuery("q57", 1) - runner.runQuery("q58", 1) - runner.runQuery("q59", 1) - runner.runQuery("q60", 1) - runner.runQuery("q61", 1) - runner.runQuery("q62", 1) runner.runQuery("q63", 1) - runner.runQuery("q64", 1) - runner.runQuery("q65", 1) - runner.runQuery("q66", 1) - runner.runQuery("q67", 1) - runner.runQuery("q68", 1) - runner.runQuery("q69", 1) - runner.runQuery("q70", 1) runner.runQuery("q71", 1) - runner.runQuery("q72", 1) - runner.runQuery("q73", 1) - runner.runQuery("q74", 1) - runner.runQuery("q75", 1) - runner.runQuery("q76", 1) - runner.runQuery("q77", 1) - runner.runQuery("q78", 1) runner.runQuery("q79", 1) - runner.runQuery("q80", 1) - runner.runQuery("q81", 1) runner.runQuery("q82", 1) - runner.runQuery("q83", 1) runner.runQuery("q84", 1) - runner.runQuery("q85", 1) - runner.runQuery("q86", 1) - runner.runQuery("q87", 1) - runner.runQuery("q88", 1) runner.runQuery("q89", 1) - runner.runQuery("q90", 1) - runner.runQuery("q91", 1) - runner.runQuery("q92", 1) - runner.runQuery("q93", 1) - runner.runQuery("q94", 1) - runner.runQuery("q95", 1) - runner.runQuery("q96", 1) - runner.runQuery("q97", 1) - runner.runQuery("q98", 1) runner.runQuery("q99", 1) } diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderPlanTestBase.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderPlanTestBase.scala deleted file mode 100644 index d8d7d0bd97807cb6bdcaad024e54232687b43937..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderPlanTestBase.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Project} -import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.util.sideBySide - -trait HeuristicJoinReorderPlanTestBase extends PlanTest { - - def outputsOf(plans: LogicalPlan*): Seq[Attribute] = { - plans.map(_.output).reduce(_ ++ _) - } - - def assertEqualJoinPlans( - optimizer: RuleExecutor[LogicalPlan], - originalPlan: LogicalPlan, - groundTruthBestPlan: LogicalPlan): Unit = { - val analyzed = originalPlan.analyze - val optimized = optimizer.execute(analyzed) - val expected = EliminateResolvedHint.apply(groundTruthBestPlan.analyze) - - assert(equivalentOutput(analyzed, expected)) - assert(equivalentOutput(analyzed, optimized)) - - compareJoinOrder(optimized, expected) - } - - protected def equivalentOutput(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { - normalizeExprIds(plan1).output == normalizeExprIds(plan2).output - } - - protected def compareJoinOrder(plan1: LogicalPlan, plan2: LogicalPlan): Unit = { - val normalized1 = normalizePlan(normalizeExprIds(plan1)) - val normalized2 = normalizePlan(normalizeExprIds(plan2)) - if (!sameJoinPlan(normalized1, normalized2)) { - fail( - s""" - |== FAIL: Plans do not match === - |${sideBySide( - rewriteNameFromAttrNullability(normalized1).treeString, - rewriteNameFromAttrNullability(normalized2).treeString).mkString("\n")} - """.stripMargin) - } - } - - private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { - (plan1, plan2) match { - case (j1: Join, j2: Join) => - (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right) - && j1.hint.leftHint == j2.hint.leftHint && j1.hint.rightHint == j2.hint.rightHint) || - (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left) - && j1.hint.leftHint == j2.hint.rightHint && j1.hint.rightHint == j2.hint.leftHint) - case (p1: Project, p2: Project) => - p1.projectList == p2.projectList && sameJoinPlan(p1.child, p2.child) - case _ => - plan1 == plan2 - } - } -} diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderSuite.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/joinReorder/ReorderJoinEnhancesSuite.scala similarity index 70% rename from omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderSuite.scala rename to omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/joinReorder/ReorderJoinEnhancesSuite.scala index c7ea9bd95ad953b136ff9f5f196a0e0bace24027..6ce18a507b01e57974784ecb035bfac05958047c 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark32-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderSuite.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/org/apache/spark/sql/catalyst/optimizer/joinReorder/ReorderJoinEnhancesSuite.scala @@ -1,4 +1,5 @@ /* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. @@ -15,27 +16,40 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.optimizer +package org.apache.spark.sql.catalyst.optimizer.joinReorder import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} +import org.apache.spark.sql.catalyst.optimizer._ import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.plans.logical.ColumnStat +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} +import org.apache.spark.sql.catalyst.optimizer.ReorderJoinEnhances -class HeuristicJoinReorderSuite - extends HeuristicJoinReorderPlanTestBase with StatsEstimationTestBase { +class ReorderJoinEnhancesSuite + extends JoinReorderPlanTestBase with StatsEstimationTestBase { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Resolve Hints", Once, + EliminateResolvedHint) :: + Batch("Operator Optimizations", FixedPoint(100), + CombineFilters, + PushPredicateThroughNonJoin, + ReorderJoin, + ReorderJoinEnhances, + PushPredicateThroughJoin, + ColumnPruning, + CollapseProject) :: Nil + } private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( attr("t1.k-1-2") -> rangeColumnStat(2, 0), attr("t1.v-1-10") -> rangeColumnStat(10, 0), attr("t2.k-1-5") -> rangeColumnStat(5, 0), - attr("t3.v-1-100") -> rangeColumnStat(100, 0), - attr("t4.k-1-2") -> rangeColumnStat(2, 0), - attr("t4.v-1-10") -> rangeColumnStat(10, 0), - attr("t5.k-1-5") -> rangeColumnStat(5, 0), - attr("t5.v-1-5") -> rangeColumnStat(5, 0) + attr("t3.v-1-100") -> rangeColumnStat(100, 0) )) private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) @@ -65,17 +79,13 @@ class HeuristicJoinReorderSuite t1.join(t2).join(t3) .where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .select(outputsOf(t1, t2, t3): _*) - val analyzed = originalPlan.analyze - val optimized = HeuristicJoinReorder.apply(analyzed).select(outputsOf(t1, t2, t3): _*) val expected = t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) .join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) .select(outputsOf(t1, t2, t3): _*) - assert(equivalentOutput(analyzed, expected)) - assert(equivalentOutput(analyzed, optimized)) - - compareJoinOrder(optimized, expected) + assertEqualJoinPlans(Optimize, originalPlan, expected) } } diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarJoinExecSuite.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarJoinExecSuite.scala index 1a8549cb93b8b1dad076a6ced88b6b60fb50ed7d..39637065f10132791354e606ff4ee1a82bf64858 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarJoinExecSuite.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarJoinExecSuite.scala @@ -262,7 +262,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row(" yeah ", "yeah", 10, 8.0, "", "Hello", 2, 2.0), Row(" yeah ", "yeah", 10, 8.0, " add", null, 1, null) - ), false) + ), true) } test("validate columnar shuffledHashJoin full outer join happened") { @@ -299,7 +299,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row(" yeah ", "yeah", 10, 8.0, "abc", "", 4, 1.0), Row(" yeah ", "yeah", 10, 8.0, "", "Hello", 2, 2.0), Row(" yeah ", "yeah", 10, 8.0, " add", null, 1, null) - ), false) + ), true) } test("validate columnar shuffledHashJoin left semi join happened") { @@ -334,25 +334,46 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { s" executedPlan as follows: \n${res.queryExecution.executedPlan}") } - test("columnar shuffledHashJoin left outer join is equal to native") { - val df = left.join(right.hint("SHUFFLE_HASH"), col("q") === col("c"), "leftouter") + test("columnar shuffledHashJoin left outer join and build left is equal to native") { + val df = left.hint("SHUFFLE_HASH").join(right, col("q") === col("c"), "leftouter") checkAnswer(df, _ => df.queryExecution.executedPlan, Seq( - Row("abc", "", 4, 2.0, "abc", "", 4, 1.0), + Row("", "Hello", 1, 1.0, " add", "World", 1, 3.0), + Row(" add", "World", 8, 3.0, null, null, null, null), Row(" yeah ", "yeah", 10, 8.0, null, null, null, null), + Row("abc", "", 4, 2.0, "abc", "", 4, 1.0) + ), true) + } + + test("columnar shuffledHashJoin left outer join and build left is equal to native with null") { + val df = leftWithNull.hint("SHUFFLE_HASH").join(rightWithNull, + col("q") === col("c"), "leftouter") + checkAnswer(df, _ => df.queryExecution.executedPlan, Seq( + Row("", "Hello", null, 1.0, null, null, null, null), + Row(" add", "World", 8, 3.0, null, null, null, null), + Row(" yeah ", "yeah", 10, 8.0, null, null, null, null), + Row("abc", null, 4, 2.0, "abc", "", 4, 1.0) + ), true) + } + + test("columnar shuffledHashJoin left outer join and build right is equal to native") { + val df = left.join(right.hint("SHUFFLE_HASH"), col("q") === col("c"), "leftouter") + checkAnswer(df, _ => df.queryExecution.executedPlan, Seq( Row("", "Hello", 1, 1.0, " add", "World", 1, 3.0), - Row(" add", "World", 8, 3.0, null, null, null, null) - ), false) + Row(" add", "World", 8, 3.0, null, null, null, null), + Row(" yeah ", "yeah", 10, 8.0, null, null, null, null), + Row("abc", "", 4, 2.0, "abc", "", 4, 1.0) + ), true) } - test("columnar shuffledHashJoin left outer join is equal to native with null") { + test("columnar shuffledHashJoin left outer join and build right is equal to native with null") { val df = leftWithNull.join(rightWithNull.hint("SHUFFLE_HASH"), col("q") === col("c"), "leftouter") checkAnswer(df, _ => df.queryExecution.executedPlan, Seq( - Row("abc", null, 4, 2.0, "abc", "", 4, 1.0), Row("", "Hello", null, 1.0, null, null, null, null), + Row(" add", "World", 8, 3.0, null, null, null, null), Row(" yeah ", "yeah", 10, 8.0, null, null, null, null), - Row(" add", "World", 8, 3.0, null, null, null, null) - ), false) + Row("abc", null, 4, 2.0, "abc", "", 4, 1.0) + ), true) } test("ColumnarBroadcastHashJoin is not rolled back with not_equal filter expr") { @@ -388,7 +409,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row("Adams", 24562), Row("Adams", 22456), Row("Bush", null) - ), false) + ), true) } test("BroadcastHashJoin and project funsion test for duplicate column") { @@ -403,7 +424,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row("Adams", 24562, 1), Row("Adams", 22456, 1), Row("Bush", null, null) - ), false) + ), true) } test("BroadcastHashJoin and project funsion test for reorder columns") { @@ -418,7 +439,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row(24562, "Adams", 1), Row(22456, "Adams", 1), Row(null, "Bush", null) - ), false) + ), true) } test("BroadcastHashJoin and project are not funsed test") { @@ -433,7 +454,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row(24563, "Adams"), Row(22457, "Adams"), Row(null, "Bush") - ), false) + ), true) } test("BroadcastHashJoin and project funsion test for alias") { @@ -448,7 +469,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row("Adams", 24562), Row("Adams", 22456), Row("Bush", null) - ), false) + ), true) } test("validate columnar shuffledHashJoin left anti join happened") { @@ -477,17 +498,38 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { ), false) } - test("columnar ShuffledHashJoin right outer join is equal to native") { + test("columnar ShuffledHashJoin right outer join and build left is equal to native") { + val df = left.hint("SHUFFLE_HASH").join(right, col("q") === col("c"), "rightouter") + checkAnswer(df, _ => df.queryExecution.executedPlan, Seq( + Row(null, null, null, null, " yeah ", "yeah", 0, 4.0), + Row("abc", "", 4, 2.0, "abc", "", 4, 1.0), + Row("", "Hello", 1, 1.0, " add", "World", 1, 3.0), + Row(null, null, null, null, "", "Hello", 2, 2.0) + ), true) + } + + test("columnar ShuffledHashJoin right outer join and build left is equal to native with null") { + val df = leftWithNull.hint("SHUFFLE_HASH").join(rightWithNull, + col("q") === col("c"), "rightouter") + checkAnswer(df, _ => df.queryExecution.executedPlan, Seq( + Row("abc", null, 4, 2.0, "abc", "", 4, 1.0), + Row(null, null, null, null, " yeah ", null, null, 4.0), + Row(null, null, null, null, " add", null, 1, null), + Row(null, null, null, null, "", "Hello", 2, 2.0) + ), true) + } + + test("columnar ShuffledHashJoin right outer join and build right is equal to native") { val df = left.join(right.hint("SHUFFLE_HASH"), col("q") === col("c"), "rightouter") checkAnswer(df, _ => df.queryExecution.executedPlan, Seq( Row(null, null, null, null, " yeah ", "yeah", 0, 4.0), Row("abc", "", 4, 2.0, "abc", "", 4, 1.0), Row("", "Hello", 1, 1.0, " add", "World", 1, 3.0), Row(null, null, null, null, "", "Hello", 2, 2.0) - ), false) + ), true) } - test("columnar ShuffledHashJoin right outer join is equal to native with null") { + test("columnar ShuffledHashJoin right outer join and build right is equal to native with null") { val df = leftWithNull.join(rightWithNull.hint("SHUFFLE_HASH"), col("q") === col("c"), "rightouter") checkAnswer(df, _ => df.queryExecution.executedPlan, Seq( @@ -495,7 +537,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row(null, null, null, null, " yeah ", null, null, 4.0), Row(null, null, null, null, " add", null, 1, null), Row(null, null, null, null, "", "Hello", 2, 2.0) - ), false) + ), true) } test("columnar BroadcastHashJoin right outer join is equal to native") { @@ -505,7 +547,7 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { Row(null, null, null, null, "", "Hello", 2, 2.0), Row("", "Hello", 1, 1.0, " add", "World", 1, 3.0), Row(null, null, null, null, " yeah ", "yeah", 0, 4.0) - ), false) + ), true) } test("columnar BroadcastHashJoin right outer join is equal to native with null") { diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarLimitExecSuit.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarLimitExecSuit.scala index 5944618785ba67a869cba691d5ef223b7c13c045..aa35668839af8685b615f304b84b11a80220441b 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarLimitExecSuit.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarLimitExecSuit.scala @@ -66,6 +66,21 @@ class ColumnarLimitExecSuit extends ColumnarSparkPlanTest { assert(result.count() == 0) } + test("limit with offset and global limit columnar exec") { + val result = spark.sql("SELECT y FROM right WHERE x in " + + "(SELECT a FROM left WHERE a = 4 LIMIT 1 OFFSET 1)") + val plan = result.queryExecution.executedPlan + assert(plan.find(_.isInstanceOf[ColumnarLocalLimitExec]).isEmpty, + s"not match ColumnarLocalLimitExec, real plan: ${plan}") + assert(plan.find(_.isInstanceOf[LocalLimitExec]).isEmpty, + s"real plan: ${plan}") + assert(plan.find(_.isInstanceOf[ColumnarGlobalLimitExec]).isDefined, + s"not match ColumnarGlobalLimitExec, real plan: ${plan}") + assert(plan.find(_.isInstanceOf[GlobalLimitExec]).isEmpty, + s"real plan: ${plan}") + assert(result.count() == 0) + } + test("limit with rollback global limit to row-based exec") { spark.conf.set("spark.omni.sql.columnar.globalLimit", false) val result = spark.sql("SELECT a FROM left WHERE a in " + diff --git a/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarWindowGroupLimitExecSuite.scala b/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarWindowGroupLimitExecSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..20c9382077c7290660a55452087baf446761388a --- /dev/null +++ b/omnioperator/omniop-spark-extension/spark-extension-ut/spark35-ut/src/test/scala/org/apache/spark/sql/execution/ColumnarWindowGroupLimitExecSuite.scala @@ -0,0 +1,105 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.types._ + +class ColumnarWindowGroupLimitExecSuite extends ColumnarSparkPlanTest { + + private var dealer: DataFrame = _ + + protected override def beforeAll(): Unit = { + super.beforeAll() + + dealer = spark.createDataFrame( + sparkContext.parallelize(Seq( + Row(1,"shanghai",10), + Row(2, "chengdu", 1), + Row(3,"guangzhou", 7), + Row(4, "beijing", 20), + Row(5, "hangzhou", 4), + Row(6, "tianjing", 3), + Row(7, "shenzhen", 5), + Row(8, "changsha", 5), + Row(9,"nanjing", 5), + Row(10, "wuhan", 6), + Row(11, "wuhan", 9), + Row(12, "wuhan", 9), + Row(12, "wuhan", 8) + )),new StructType() + .add("id", IntegerType) + .add("city", StringType) + .add("sales", IntegerType)) + dealer.createOrReplaceTempView("dealer") + } + + test("Test windowGroupLimit") { + val sql1 = "select * from (SELECT city, rank() OVER (ORDER BY sales) AS rk FROM dealer) where rk < 4 order by rk;" + assertColumnarWindowGroupLimitExecAndSparkResultEqual(sql1, true, true) + + val sql2 = "select * from (SELECT city, rank() OVER (PARTITION BY city ORDER BY sales) AS rk FROM dealer) where rk < 4 order by rk;" + assertColumnarWindowGroupLimitExecAndSparkResultEqual(sql2, true, true) + + val sql3 = "select * from (SELECT city, sales, row_number() OVER (PARTITION BY city ORDER BY sales) AS rn FROM dealer) where rn < 4 order by rn;" + assertColumnarWindowGroupLimitExecAndSparkResultEqual(sql3, true, true) + } + + private def assertColumnarWindowGroupLimitExecAndSparkResultEqual(sql: String, hasColumnarWindowGroupLimitExec: Boolean = true, + hasWindowGroupLimitExec: Boolean = false): Unit = { + // run WindowGroupLimitExec config + spark.conf.set("spark.omni.sql.columnar.windowGroupLimit", true) + spark.conf.set("spark.sql.adaptive.enabled", true) + val omniResult = spark.sql(sql) + omniResult.collect() + val omniPlan = omniResult.queryExecution.executedPlan.toString() + if (hasColumnarWindowGroupLimitExec) { + assert(omniPlan.contains("ColumnarWindowGroupLimit"), + s"SQL:${sql}\n@OmniEnv no ColumnarWindowGroupLimitExec, omniPlan:${omniPlan}") + } + + spark.conf.set("spark.sql.adaptive.enabled", false) + val omniResult_withoutAQE = spark.sql(sql) + omniResult_withoutAQE.collect() + val omniPlan_withoutAQE = omniResult_withoutAQE.queryExecution.executedPlan.toString() + if (hasColumnarWindowGroupLimitExec) { + assert(omniPlan_withoutAQE.contains("ColumnarWindowGroupLimit"), + s"SQL:${sql}\n@OmniEnv no ColumnarWindowGroupLimitExec, omniPlan_withoutAQE:${omniPlan_withoutAQE}") + } + assert(omniResult.except(omniResult_withoutAQE).isEmpty, + s"SQL:${sql}\nomniResult:${omniResult.show()}\nomniResult_withoutAQE:${omniResult_withoutAQE.show()}\n") + + // run WindowGroupLimitExec config + spark.conf.set("spark.omni.sql.columnar.windowGroupLimit", false) + spark.conf.set("spark.sql.adaptive.enabled", true) + val sparkResult = spark.sql(sql) + sparkResult.collect() + val sparkPlan = sparkResult.queryExecution.executedPlan.toString() + assert(!sparkPlan.contains("ColumnarWindowGroupLimit"), + s"SQL:${sql}\n@SparkEnv have ColumnarWindowGroupLimitExec, sparkPlan:${sparkPlan}") + if (hasWindowGroupLimitExec) { + assert(sparkPlan.contains("WindowGroupLimit"), + s"SQL:${sql}\n@SparkEnv no WindowGroupLimitExec, sparkPlan:${sparkPlan}") + } + // DataFrame do not support comparing with equals method, use DataFrame.except instead + // DataFrame.except can do equal for rows misorder(with and without order by are same) + assert(omniResult.except(sparkResult).isEmpty, + s"SQL:${sql}\nomniResult:${omniResult.show()}\nsparkResult:${sparkResult.show()}\n") + } +}