From a595cdb204b1f729ef90c9b20e3404cf674f7e0e Mon Sep 17 00:00:00 2001 From: muhuiqin <3508807397@qq.com> Date: Thu, 21 Aug 2025 09:46:26 +0800 Subject: [PATCH 1/4] =?UTF-8?q?br=5Fnoncom=5Fspark=5Fomniop=5F25.1.T7=5F25?= =?UTF-8?q?0730=E5=88=86=E6=94=AF=E4=BB=A3=E7=A0=81=E8=BF=81=E7=A7=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 63 + BUILDING.MD | 403 + CMakeLists.txt | 133 + README.md | 58 +- bindings/java/pom.xml | 272 + bindings/java/src/main/cpp/CMakeLists.txt | 19 + .../java/src/main/cpp/src/jni_common_def.cpp | 44 + .../java/src/main/cpp/src/jni_common_def.h | 111 + .../java/src/main/cpp/src/jni_constants.cpp | 12 + .../java/src/main/cpp/src/jni_constants.h | 24 + bindings/java/src/main/cpp/src/jni_helper.cpp | 25 + bindings/java/src/main/cpp/src/jni_helper.h | 27 + .../java/src/main/cpp/src/jni_operator.cpp | 377 + bindings/java/src/main/cpp/src/jni_operator.h | 101 + .../src/main/cpp/src/jni_operator_factory.cpp | 1637 ++++ .../src/main/cpp/src/jni_operator_factory.h | 360 + bindings/java/src/main/cpp/src/jni_vector.cpp | 299 + bindings/java/src/main/cpp/src/jni_vector.h | 208 + .../java/nova/hetu/omniruntime/OmniLibs.java | 79 + .../hetu/omniruntime/constants/BuildSide.java | 38 + .../hetu/omniruntime/constants/Constant.java | 58 + .../omniruntime/constants/ConstantHelper.java | 27 + .../omniruntime/constants/FunctionType.java | 98 + .../hetu/omniruntime/constants/JoinType.java | 58 + .../constants/OmniWindowFrameBoundType.java | 48 + .../constants/OmniWindowFrameType.java | 33 + .../omniruntime/constants/OperatorType.java | 108 + .../hetu/omniruntime/constants/Status.java | 38 + .../omniruntime/memory/MemoryManager.java | 80 + .../omniruntime/operator/OmniExprVerify.java | 37 + .../omniruntime/operator/OmniOperator.java | 189 + .../operator/OmniOperatorFactory.java | 95 + .../operator/OmniOperatorFactoryContext.java | 55 + .../omniruntime/operator/OmniResults.java | 58 + .../omniruntime/operator/OmniRowResults.java | 57 + .../OmniAggregationOperatorFactory.java | 146 + ...mniAggregationWithExprOperatorFactory.java | 224 + .../OmniHashAggregationOperatorFactory.java | 231 + ...ashAggregationWithExprOperatorFactory.java | 225 + .../operator/config/OperatorConfig.java | 278 + .../operator/config/OverflowConfig.java | 62 + .../operator/config/SparkSpillConfig.java | 127 + .../operator/config/SpillConfig.java | 223 + .../OmniBloomFilterOperatorFactory.java | 72 + .../OmniFilterAndProjectOperatorFactory.java | 169 + .../join/OmniHashBuilderOperatorFactory.java | 118 + ...mniHashBuilderWithExprOperatorFactory.java | 224 + .../join/OmniLookupJoinOperatorFactory.java | 245 + ...OmniLookupJoinWithExprOperatorFactory.java | 243 + .../OmniLookupOuterJoinOperatorFactory.java | 139 + ...ookupOuterJoinWithExprOperatorFactory.java | 143 + ...mniNestedLoopJoinBuildOperatorFactory.java | 153 + ...niNestedLoopJoinLookupOperatorFactory.java | 121 + ...jBufferedTableWithExprOperatorFactory.java | 127 + ...ufferedTableWithExprOperatorFactoryV3.java | 128 + ...jStreamedTableWithExprOperatorFactory.java | 127 + ...treamedTableWithExprOperatorFactoryV3.java | 127 + .../OmniDistinctLimitOperatorFactory.java | 114 + .../limit/OmniLimitOperatorFactory.java | 81 + .../OmniPartitionedOutPutOperatorFactory.java | 162 + .../project/OmniProjectOperatorFactory.java | 151 + .../sort/OmniSortOperatorFactory.java | 123 + .../sort/OmniSortWithExprOperatorFactory.java | 123 + .../topn/OmniTopNOperatorFactory.java | 156 + .../topn/OmniTopNWithExprOperatorFactory.java | 160 + .../OmniTopNSortWithExprOperatorFactory.java | 141 + .../union/OmniUnionOperatorFactory.java | 93 + ...ndowGroupLimitWithExprOperatorFactory.java | 140 + .../window/OmniWindowOperatorFactory.java | 243 + .../OmniWindowWithExprOperatorFactory.java | 244 + .../omniruntime/type/BooleanDataType.java | 26 + .../hetu/omniruntime/type/ByteDataType.java | 26 + .../hetu/omniruntime/type/CharDataType.java | 54 + .../omniruntime/type/ContainerDataType.java | 56 + .../nova/hetu/omniruntime/type/DataType.java | 154 + .../omniruntime/type/DataTypeSerializer.java | 184 + .../hetu/omniruntime/type/Date32DataType.java | 57 + .../hetu/omniruntime/type/Date64DataType.java | 57 + .../omniruntime/type/Decimal128DataType.java | 44 + .../omniruntime/type/Decimal64DataType.java | 41 + .../omniruntime/type/DecimalDataType.java | 64 + .../hetu/omniruntime/type/DoubleDataType.java | 26 + .../hetu/omniruntime/type/IntDataType.java | 26 + .../omniruntime/type/InvalidDataType.java | 24 + .../hetu/omniruntime/type/LongDataType.java | 26 + .../hetu/omniruntime/type/NoneDataType.java | 24 + .../hetu/omniruntime/type/ShortDataType.java | 26 + .../omniruntime/type/TimestampDataType.java | 26 + .../omniruntime/type/VarcharDataType.java | 76 + .../hetu/omniruntime/utils/JsonUtils.java | 83 + .../hetu/omniruntime/utils/NativeLog.java | 50 + .../omniruntime/utils/NullsBufHelper.java | 170 + .../hetu/omniruntime/utils/OmniErrorType.java | 60 + .../utils/OmniRuntimeException.java | 62 + .../hetu/omniruntime/utils/ParseUtil.java | 88 + .../omniruntime/utils/ShuffleHashHelper.java | 22 + .../hetu/omniruntime/utils/TraceUtil.java | 32 + .../hetu/omniruntime/vector/BooleanVec.java | 114 + .../nova/hetu/omniruntime/vector/ByteVec.java | 101 + .../hetu/omniruntime/vector/ContainerVec.java | 201 + .../omniruntime/vector/Decimal128Vec.java | 189 + .../hetu/omniruntime/vector/DecimalVec.java | 158 + .../omniruntime/vector/DictionaryVec.java | 351 + .../hetu/omniruntime/vector/DoubleVec.java | 102 + .../omniruntime/vector/FixedWidthVec.java | 39 + .../nova/hetu/omniruntime/vector/IntVec.java | 101 + .../hetu/omniruntime/vector/JvmUtils.java | 120 + .../nova/hetu/omniruntime/vector/LongVec.java | 115 + .../hetu/omniruntime/vector/OmniBuffer.java | 215 + .../omniruntime/vector/OmniBufferFactory.java | 26 + .../vector/OmniBufferUnsafeV8.java | 145 + .../nova/hetu/omniruntime/vector/Row.java | 79 + .../hetu/omniruntime/vector/RowBatch.java | 91 + .../hetu/omniruntime/vector/ShortVec.java | 101 + .../hetu/omniruntime/vector/VarcharVec.java | 196 + .../omniruntime/vector/VariableWidthVec.java | 218 + .../nova/hetu/omniruntime/vector/Vec.java | 587 ++ .../hetu/omniruntime/vector/VecBatch.java | 174 + .../hetu/omniruntime/vector/VecEncoding.java | 18 + .../hetu/omniruntime/vector/VecFactory.java | 166 + .../vector/serialize/OmniRowDeserializer.java | 58 + .../serialize/ProtoVecBatchSerializer.java | 308 + .../vector/serialize/VecBatchSerializer.java | 30 + .../serialize/VecBatchSerializerFactory.java | 21 + .../java/src/main/proto/vec_batch_serde.proto | 67 + bindings/java/src/main/scripts/build_cpp.sh | 36 + .../constant/ConstantLoadTest.java | 61 + .../omniruntime/memory/TestMemoryManager.java | 64 + .../operator/OmniAggregationOperatorTest.java | 486 ++ .../operator/OmniBloomFilterOperatorTest.java | 43 + .../OmniDistinctLimitOperatorTest.java | 151 + .../operator/OmniExprVerifyTest.java | 89 + .../OmniFilterAndProjectOperatorTest.java | 1492 ++++ .../OmniHashAggregationOperatorTest.java | 402 + ...niHashAggregationWithExprOperatorTest.java | 410 + .../operator/OmniHashJoinOperatorsTest.java | 1196 +++ .../OmniHashJoinWithExprOperatorsTest.java | 940 +++ .../operator/OmniLimitOperatorTest.java | 279 + .../OmniNestedLoopJoinOperatorTest.java | 377 + .../operator/OmniOperatorConfigTest.java | 43 + .../operator/OmniOperatorFactoryTest.java | 141 + .../operator/OmniPartionOutOperatorTest.java | 159 + .../operator/OmniProjectOperatorTest.java | 339 + ...mniSortMergeJoinWithExprOperatorsTest.java | 447 ++ ...iSortMergeJoinWithExprOperatorsV3Test.java | 259 + .../operator/OmniSortOperatorTest.java | 635 ++ .../OmniSortWithExprOperatorTest.java | 614 ++ .../operator/OmniTopNOperatorTest.java | 345 + .../OmniTopNSortWithExprOperatorTest.java | 136 + .../OmniTopNWithExprOperatorTest.java | 301 + .../operator/OmniUnionOperatorTest.java | 170 + ...iWindowGroupLimitWithExprOperatorTest.java | 178 + .../operator/OmniWindowOperatorTest.java | 315 + .../OmniWindowWithExprOperatorTest.java | 333 + .../Sql10ForOmniFilterOperatorTest.java | 240 + .../tensql/Sql1ForOmniFilterOperatorTest.java | 314 + .../tensql/Sql2ForOmniFilterOperatorTest.java | 242 + .../tensql/Sql3ForOmniFilterOperatorTest.java | 303 + .../tensql/Sql4ForOmniFilterOperatorTest.java | 267 + .../tensql/Sql5ForOmniFilterOperatorTest.java | 253 + .../tensql/Sql6ForOmniFilterOperatorTest.java | 437 ++ .../tensql/Sql7ForOmniFilterOperatorTest.java | 167 + .../tensql/Sql8ForOmniFilterOperatorTest.java | 351 + .../tensql/Sql9ForOmniFilterOperatorTest.java | 303 + .../type/BenchmarkDataTypeSerializer.java | 67 + .../hetu/omniruntime/type/TestDataType.java | 65 + .../type/TestDataTypeSerializer.java | 35 + .../hetu/omniruntime/util/TestJsonUtils.java | 66 + .../nova/hetu/omniruntime/util/TestUtils.java | 1593 ++++ .../vector/BenchmarkDecimal128Vec.java | 352 + .../vector/BenchmarkDoubleVec.java | 359 + .../omniruntime/vector/BenchmarkIntVec.java | 359 + .../omniruntime/vector/BenchmarkLongVec.java | 359 + .../omniruntime/vector/BenchmarkShortVec.java | 364 + .../vector/BenchmarkVarcharVec.java | 371 + .../omniruntime/vector/MergeVectorsTest.java | 146 + .../omniruntime/vector/TestBooleanVec.java | 179 + .../omniruntime/vector/TestContainerVec.java | 176 + .../omniruntime/vector/TestDecimal128Vec.java | 252 + .../omniruntime/vector/TestDictionaryVec.java | 259 + .../omniruntime/vector/TestDoubleVec.java | 183 + .../hetu/omniruntime/vector/TestIntVec.java | 277 + .../hetu/omniruntime/vector/TestLongVec.java | 184 + .../hetu/omniruntime/vector/TestOmniRow.java | 162 + .../hetu/omniruntime/vector/TestShortVec.java | 277 + .../omniruntime/vector/TestVarcharVec.java | 464 ++ .../hetu/omniruntime/vector/TestVecBatch.java | 56 + .../nova/hetu/omniruntime/vector/VecUtil.java | 28 + .../serialize/VecBatchSerializerTest.java | 640 ++ build.sh | 137 + build_scripts/build.sh | 102 + build_scripts/env_check.sh | 97 + core/CMakeLists.txt | 15 + core/config.h.in | 8 + core/secDTFuzz/Dockerfile_build | 2 + core/secDTFuzz/Dockerfile_run | 1 + core/secDTFuzz/SecDTFuzz.yaml | 15 + core/secDTFuzz/build.sh | 31 + core/src/CMakeLists.txt | 21 + core/src/README.MD | 110 + core/src/codegen/CMakeLists.txt | 14 + core/src/codegen/batch_codegen_context.h | 52 + core/src/codegen/batch_expression_codegen.cpp | 1730 +++++ core/src/codegen/batch_expression_codegen.h | 163 + core/src/codegen/batch_filter_codegen.cpp | 87 + core/src/codegen/batch_filter_codegen.h | 31 + .../codegen/batch_func_registry_datetime.cpp | 26 + .../codegen/batch_func_registry_datetime.h | 17 + .../codegen/batch_func_registry_decimal.cpp | 435 ++ .../src/codegen/batch_func_registry_decimal.h | 38 + .../batch_func_registry_dictionary.cpp | 48 + .../codegen/batch_func_registry_dictionary.h | 17 + core/src/codegen/batch_func_registry_hash.cpp | 39 + core/src/codegen/batch_func_registry_hash.h | 17 + core/src/codegen/batch_func_registry_math.cpp | 156 + core/src/codegen/batch_func_registry_math.h | 28 + .../codegen/batch_func_registry_string.cpp | 352 + core/src/codegen/batch_func_registry_string.h | 68 + core/src/codegen/batch_func_registry_util.cpp | 133 + core/src/codegen/batch_func_registry_util.h | 17 + .../batch_func_registry_varchar_vector.cpp | 24 + .../batch_func_registry_varchar_vector.h | 17 + .../batch_datetime_functions.cpp | 96 + .../batch_datetime_functions.h | 34 + .../batch_decimal_arithmetic_functions.cpp | 1842 +++++ .../batch_decimal_arithmetic_functions.h | 467 ++ .../batch_decimal_cast_functions.cpp | 500 ++ .../batch_decimal_cast_functions.h | 134 + .../batch_dictionaryfunctions.cpp | 126 + .../batch_dictionaryfunctions.h | 54 + .../batch_functions/batch_mathfunctions.cpp | 404 + .../batch_functions/batch_mathfunctions.h | 164 + .../batch_functions/batch_murmur3_hash.cpp | 97 + .../batch_functions/batch_murmur3_hash.h | 43 + .../batch_functions/batch_stringfunctions.cpp | 1313 ++++ .../batch_functions/batch_stringfunctions.h | 430 ++ .../batch_functions/batch_utilfunctions.cpp | 299 + .../batch_functions/batch_utilfunctions.h | 172 + .../batch_varcharVectorfunctions.cpp | 40 + .../batch_varcharVectorfunctions.h | 25 + core/src/codegen/batch_projection_codegen.cpp | 131 + core/src/codegen/batch_projection_codegen.h | 32 + core/src/codegen/bloom_filter.cpp | 145 + core/src/codegen/bloom_filter.h | 73 + core/src/codegen/codegen_base.cpp | 101 + core/src/codegen/codegen_base.h | 57 + core/src/codegen/codegen_context.h | 54 + core/src/codegen/codegen_value.h | 82 + core/src/codegen/common_util.h | 153 + core/src/codegen/context_helper.cpp | 87 + core/src/codegen/context_helper.h | 139 + core/src/codegen/expr_evaluator.cpp | 579 ++ core/src/codegen/expr_evaluator.h | 238 + core/src/codegen/expr_function.cpp | 185 + core/src/codegen/expr_function.h | 120 + core/src/codegen/expr_info_extractor.cpp | 80 + core/src/codegen/expr_info_extractor.h | 46 + core/src/codegen/expression_codegen.cpp | 2295 ++++++ core/src/codegen/expression_codegen.h | 219 + core/src/codegen/filter_codegen.cpp | 175 + core/src/codegen/filter_codegen.h | 43 + core/src/codegen/func_registry.cpp | 267 + core/src/codegen/func_registry.h | 84 + core/src/codegen/func_registry_base.h | 18 + core/src/codegen/func_registry_context.cpp | 17 + core/src/codegen/func_registry_context.h | 17 + core/src/codegen/func_registry_datetime.cpp | 33 + core/src/codegen/func_registry_datetime.h | 19 + core/src/codegen/func_registry_decimal.cpp | 702 ++ core/src/codegen/func_registry_decimal.h | 62 + core/src/codegen/func_registry_dictionary.cpp | 32 + core/src/codegen/func_registry_dictionary.h | 27 + core/src/codegen/func_registry_hash.cpp | 65 + core/src/codegen/func_registry_hash.h | 17 + core/src/codegen/func_registry_hive_udf.cpp | 68 + core/src/codegen/func_registry_hive_udf.h | 19 + core/src/codegen/func_registry_math.cpp | 330 + core/src/codegen/func_registry_math.h | 27 + .../codegen/func_registry_might_contain.cpp | 23 + .../src/codegen/func_registry_might_contain.h | 17 + core/src/codegen/func_registry_string.cpp | 478 ++ core/src/codegen/func_registry_string.h | 72 + .../codegen/func_registry_varchar_vector.cpp | 21 + .../codegen/func_registry_varchar_vector.h | 20 + core/src/codegen/func_signature.cpp | 106 + core/src/codegen/func_signature.h | 40 + core/src/codegen/function.cpp | 62 + core/src/codegen/function.h | 66 + core/src/codegen/functions/README.md | 117 + .../codegen/functions/datetime_functions.cpp | 128 + .../codegen/functions/datetime_functions.h | 40 + .../decimal_arithmetic_functions.cpp | 1415 ++++ .../functions/decimal_arithmetic_functions.h | 426 ++ .../functions/decimal_cast_functions.cpp | 728 ++ .../functions/decimal_cast_functions.h | 194 + .../codegen/functions/dictionaryfunctions.cpp | 76 + .../codegen/functions/dictionaryfunctions.h | 37 + core/src/codegen/functions/dtoa.cpp | 1218 +++ core/src/codegen/functions/dtoa.h | 410 + core/src/codegen/functions/mathfunctions.cpp | 584 ++ core/src/codegen/functions/mathfunctions.h | 270 + core/src/codegen/functions/md5.cpp | 229 + core/src/codegen/functions/md5.h | 39 + core/src/codegen/functions/mightcontain.cpp | 21 + core/src/codegen/functions/mightcontain.h | 21 + core/src/codegen/functions/murmur3_hash.cpp | 269 + core/src/codegen/functions/murmur3_hash.h | 43 + .../src/codegen/functions/stringfunctions.cpp | 1203 +++ core/src/codegen/functions/stringfunctions.h | 364 + core/src/codegen/functions/udffunctions.cpp | 54 + core/src/codegen/functions/udffunctions.h | 26 + .../functions/varcharVectorfunctions.cpp | 36 + .../functions/varcharVectorfunctions.h | 24 + core/src/codegen/functions/xxhash64_hash.cpp | 121 + core/src/codegen/functions/xxhash64_hash.h | 35 + core/src/codegen/llvm_engine.cpp | 331 + core/src/codegen/llvm_engine.h | 131 + core/src/codegen/llvm_types.cpp | 236 + core/src/codegen/llvm_types.h | 86 + core/src/codegen/projection_codegen.cpp | 236 + core/src/codegen/projection_codegen.h | 49 + core/src/codegen/simple_filter_codegen.cpp | 169 + core/src/codegen/simple_filter_codegen.h | 44 + core/src/codegen/string_util.h | 258 + core/src/codegen/time_util.h | 234 + core/src/compute/CMakeLists.txt | 4 + core/src/compute/ColumnarBatchIterator.h | 19 + core/src/compute/ResultIterator.cpp | 9 + core/src/compute/ResultIterator.h | 80 + core/src/compute/cpuWall_timer.cpp | 22 + core/src/compute/cpuWall_timer.h | 102 + core/src/compute/driver.cpp | 249 + core/src/compute/driver.h | 176 + core/src/compute/local_planner.cpp | 215 + core/src/compute/local_planner.h | 34 + core/src/compute/operator_stats.h | 207 + core/src/compute/plannode_stats.cpp | 206 + core/src/compute/plannode_stats.h | 142 + core/src/compute/process_base.cpp | 24 + core/src/compute/process_base.h | 17 + core/src/compute/reason.h | 118 + core/src/compute/task.cpp | 80 + core/src/compute/task.h | 61 + core/src/compute/task_stats.h | 108 + core/src/cpu_checker/CMakeLists.txt | 9 + .../cpu_checker/omniruntime_cpu_checker.cpp | 105 + .../src/cpu_checker/omniruntime_cpu_checker.h | 39 + core/src/expression/CMakeLists.txt | 8 + core/src/expression/README.md | 32 + core/src/expression/expr_printer.cpp | 469 ++ core/src/expression/expr_printer.h | 32 + core/src/expression/expr_verifier.cpp | 308 + core/src/expression/expr_verifier.h | 33 + core/src/expression/expr_visitor.cpp | 62 + core/src/expression/expr_visitor.h | 26 + core/src/expression/expressions.cpp | 379 + core/src/expression/expressions.h | 256 + core/src/expression/jsonparser/jsonparser.cpp | 462 ++ core/src/expression/jsonparser/jsonparser.h | 37 + core/src/expression/parser/parser.cpp | 381 + core/src/expression/parser/parser.h | 44 + core/src/expression/parserhelper.cpp | 110 + core/src/expression/parserhelper.h | 21 + core/src/memory/CMakeLists.txt | 8 + core/src/memory/aligned_buffer.h | 87 + core/src/memory/allocator.h | 47 + core/src/memory/chunk.cpp | 34 + core/src/memory/chunk.h | 41 + core/src/memory/memory_manager.cpp | 174 + core/src/memory/memory_manager.h | 129 + core/src/memory/memory_manager_allocator.h | 81 + core/src/memory/memory_pool.cpp | 161 + core/src/memory/memory_pool.h | 29 + core/src/memory/memory_trace.cpp | 71 + core/src/memory/memory_trace.h | 52 + core/src/memory/simple_arena_allocator.h | 225 + core/src/memory/thread_memory_manager.cpp | 62 + core/src/memory/thread_memory_manager.h | 117 + core/src/memory/thread_memory_trace.cpp | 205 + core/src/memory/thread_memory_trace.h | 65 + core/src/metrics/metrics.h | 122 + core/src/metrics/metrics_config.h | 12 + core/src/metrics/metrics_memory_info.h | 51 + core/src/metrics/metrics_row_counter.h | 48 + core/src/metrics/metrics_spill_info.h | 41 + core/src/metrics/omni_metrics.h | 115 + core/src/operator/CMakeLists.txt | 27 + core/src/operator/aggregation/GROUPBY.MD | 14 + core/src/operator/aggregation/agg_util.h | 112 + core/src/operator/aggregation/aggregation.cpp | 92 + core/src/operator/aggregation/aggregation.h | 77 + .../aggregation/aggregator/aggregator.h | 333 + .../aggregator/aggregator_factory.cpp | 384 + .../aggregator/aggregator_factory.h | 276 + .../aggregator/aggregator_util.cpp | 48 + .../aggregation/aggregator/aggregator_util.h | 34 + .../aggregation/aggregator/all_aggregators.h | 25 + .../aggregator/average_aggregator.cpp | 498 ++ .../aggregator/average_aggregator.h | 153 + .../aggregator/average_flatim_aggregator.h | 351 + .../average_spark_decimal_aggregator.h | 571 ++ .../aggregator/count_all_aggregator.h | 69 + .../aggregator/count_column_aggregator.cpp | 246 + .../aggregator/count_column_aggregator.h | 161 + .../aggregation/aggregator/first_aggregator.h | 436 ++ .../mask_column_assistant_aggregator.h | 126 + .../aggregation/aggregator/max_aggregator.cpp | 356 + .../aggregation/aggregator/max_aggregator.h | 157 + .../aggregator/max_varchar_aggregator.cpp | 208 + .../aggregator/max_varchar_aggregator.h | 130 + .../aggregation/aggregator/min_aggregator.cpp | 356 + .../aggregation/aggregator/min_aggregator.h | 157 + .../aggregator/min_varchar_aggregator.cpp | 202 + .../aggregator/min_varchar_aggregator.h | 368 + .../aggregator/only_aggregator_factory.h | 34 + .../aggregator/operations_aggregator.h | 326 + .../aggregator/operations_hash_aggregator.h | 400 + .../aggregator/state_flag_operation.h | 142 + .../aggregator/stddev_samp_aggregator.h | 387 + .../aggregation/aggregator/sum_aggregator.cpp | 482 ++ .../aggregation/aggregator/sum_aggregator.h | 295 + .../aggregator/sum_flatim_aggregator.h | 305 + .../aggregator/sum_spark_decimal_aggregator.h | 597 ++ .../aggregator/try_sum_flatim_aggregator.h | 293 + .../aggregator/typed_aggregator.cpp | 66 + .../aggregation/aggregator/typed_aggregator.h | 576 ++ .../typed_mask_column_assistant_aggregator.h | 411 + .../operator/aggregation/container_vector.h | 199 + core/src/operator/aggregation/definitions.h | 17 + .../aggregation/group_aggregation.cpp | 1541 ++++ .../operator/aggregation/group_aggregation.h | 303 + .../aggregation/group_aggregation_expr.cpp | 308 + .../aggregation/group_aggregation_expr.h | 86 + .../aggregation/group_aggregation_sort.cpp | 38 + .../aggregation/group_aggregation_sort.h | 88 + .../aggregation/non_group_aggregation.cpp | 158 + .../aggregation/non_group_aggregation.h | 109 + .../non_group_aggregation_expr.cpp | 208 + .../aggregation/non_group_aggregation_expr.h | 70 + .../operator/aggregation/one_row_adaptor.h | 147 + core/src/operator/aggregation/vector_getter.h | 117 + core/src/operator/config/operator_config.cpp | 193 + core/src/operator/config/operator_config.h | 299 + core/src/operator/execution_context.h | 51 + core/src/operator/expand/expand.cpp | 88 + core/src/operator/expand/expand.h | 72 + .../operator/filter/filter_and_project.cpp | 241 + core/src/operator/filter/filter_and_project.h | 139 + core/src/operator/grouping/grouping.cpp | 176 + core/src/operator/grouping/grouping.h | 73 + core/src/operator/hash_util.cpp | 84 + core/src/operator/hash_util.h | 300 + core/src/operator/hashmap/array_map.h | 187 + core/src/operator/hashmap/base_hash_map.h | 843 ++ core/src/operator/hashmap/column_marshaller.h | 290 + core/src/operator/hashmap/crc32c.h | 122 + core/src/operator/hashmap/crc_hasher.h | 96 + core/src/operator/hashmap/group_hasher.h | 52 + core/src/operator/hashmap/vector_analyzer.cpp | 151 + core/src/operator/hashmap/vector_analyzer.h | 78 + .../operator/hashmap/vector_marshaller.cpp | 647 ++ core/src/operator/hashmap/vector_marshaller.h | 52 + core/src/operator/join/common_join.h | 28 + core/src/operator/join/hash_builder.cpp | 215 + core/src/operator/join/hash_builder.h | 96 + core/src/operator/join/hash_builder_expr.cpp | 127 + core/src/operator/join/hash_builder_expr.h | 89 + .../join/join_hash_table_variants.cpp | 964 +++ .../operator/join/join_hash_table_variants.h | 293 + core/src/operator/join/lookup_join.cpp | 2007 +++++ core/src/operator/join/lookup_join.h | 305 + core/src/operator/join/lookup_join_expr.cpp | 236 + core/src/operator/join/lookup_join_expr.h | 85 + .../src/operator/join/lookup_join_wrapper.cpp | 98 + core/src/operator/join/lookup_join_wrapper.h | 64 + core/src/operator/join/lookup_outer_join.cpp | 374 + core/src/operator/join/lookup_outer_join.h | 104 + .../operator/join/lookup_outer_join_expr.cpp | 124 + .../operator/join/lookup_outer_join_expr.h | 72 + .../operator/join/nest_loop_join_builder.cpp | 136 + .../operator/join/nest_loop_join_builder.h | 65 + .../operator/join/nest_loop_join_lookup.cpp | 423 ++ .../src/operator/join/nest_loop_join_lookup.h | 101 + .../join/nest_loop_join_lookup_wrapper.cpp | 99 + .../join/nest_loop_join_lookup_wrapper.h | 54 + core/src/operator/join/row_ref.h | 203 + .../sortmergejoin/dynamic_pages_index.cpp | 104 + .../join/sortmergejoin/dynamic_pages_index.h | 111 + .../join/sortmergejoin/sort_merge_join.cpp | 276 + .../join/sortmergejoin/sort_merge_join.h | 123 + .../sortmergejoin/sort_merge_join_expr.cpp | 228 + .../join/sortmergejoin/sort_merge_join_expr.h | 149 + .../sortmergejoin/sort_merge_join_expr_v2.cpp | 220 + .../sortmergejoin/sort_merge_join_expr_v2.h | 221 + .../sortmergejoin/sort_merge_join_expr_v3.cpp | 141 + .../sortmergejoin/sort_merge_join_expr_v3.h | 120 + .../sort_merge_join_resultBuilder.cpp | 749 ++ .../sort_merge_join_resultBuilder.h | 234 + .../sortmergejoin/sort_merge_join_scanner.cpp | 852 +++ .../sortmergejoin/sort_merge_join_scanner.h | 320 + .../join/sortmergejoin/sort_merge_join_v3.cpp | 1257 +++ .../join/sortmergejoin/sort_merge_join_v3.h | 305 + core/src/operator/limit/distinct_limit.cpp | 505 ++ core/src/operator/limit/distinct_limit.h | 112 + core/src/operator/limit/distinct_state_func.h | 231 + core/src/operator/limit/limit.cpp | 102 + core/src/operator/limit/limit.h | 53 + core/src/operator/memory_builder.h | 20 + .../src/operator/omni_id_type_vector_traits.h | 101 + core/src/operator/operator.h | 267 + core/src/operator/operator_factory.h | 51 + core/src/operator/pages_hash_strategy.cpp | 170 + core/src/operator/pages_hash_strategy.h | 77 + core/src/operator/pages_index.cpp | 1475 ++++ core/src/operator/pages_index.h | 264 + core/src/operator/projection/projection.cpp | 66 + core/src/operator/projection/projection.h | 65 + core/src/operator/radix_sort.h | 160 + core/src/operator/sort/sort.cpp | 455 ++ core/src/operator/sort/sort.h | 154 + core/src/operator/sort/sort_expr.cpp | 130 + core/src/operator/sort/sort_expr.h | 77 + core/src/operator/spill/loser_tree.h | 247 + core/src/operator/spill/spill_merger.cpp | 260 + core/src/operator/spill/spill_merger.h | 273 + core/src/operator/spill/spill_tracker.cpp | 47 + core/src/operator/spill/spill_tracker.h | 79 + core/src/operator/spill/spiller.cpp | 486 ++ core/src/operator/spill/spiller.h | 147 + core/src/operator/status.h | 15 + core/src/operator/topn/topn.cpp | 396 + core/src/operator/topn/topn.h | 140 + core/src/operator/topn/topn_expr.cpp | 88 + core/src/operator/topn/topn_expr.h | 63 + core/src/operator/topnsort/topn_sort.cpp | 807 ++ core/src/operator/topnsort/topn_sort.h | 225 + core/src/operator/topnsort/topn_sort_expr.cpp | 93 + core/src/operator/topnsort/topn_sort_expr.h | 68 + core/src/operator/union/union.cpp | 109 + core/src/operator/union/union.h | 97 + core/src/operator/util/function_type.h | 36 + core/src/operator/util/mm3_util.h | 423 ++ core/src/operator/util/operator_util.h | 595 ++ core/src/operator/window/window.cpp | 800 ++ core/src/operator/window/window.h | 309 + core/src/operator/window/window_expr.cpp | 196 + core/src/operator/window/window_expr.h | 91 + core/src/operator/window/window_frame.h | 82 + core/src/operator/window/window_function.cpp | 184 + core/src/operator/window/window_function.h | 142 + .../operator/window/window_group_limit.cpp | 872 +++ core/src/operator/window/window_group_limit.h | 232 + .../window/window_group_limit_expr.cpp | 79 + .../operator/window/window_group_limit_expr.h | 54 + core/src/operator/window/window_partition.cpp | 254 + core/src/operator/window/window_partition.h | 97 + core/src/plannode/CMakeLists.txt | 5 + core/src/plannode/RowVectorStream.h | 142 + core/src/plannode/planFragment.cpp | 5 + core/src/plannode/planFragment.h | 50 + core/src/plannode/planNode.cpp | 12 + core/src/plannode/planNode.h | 926 +++ core/src/simd/CMakeLists.txt | 1 + core/src/simd/base.h | 1527 ++++ core/src/simd/func/match.h | 58 + core/src/simd/func/quick_sort_simd.cpp | 383 + core/src/simd/func/quick_sort_simd.h | 19 + core/src/simd/func/reduce.h | 304 + core/src/simd/func/small_case_sort.h | 1383 ++++ core/src/simd/func/traits-inl.h | 312 + core/src/simd/instruction/arm_neon-inl.h | 6746 +++++++++++++++++ core/src/simd/instruction/arm_sve-inl.h | 5153 +++++++++++++ core/src/simd/instruction/generic_ops-inl.h | 4657 ++++++++++++ core/src/simd/instruction/inside-inl.h | 618 ++ core/src/simd/instruction/set_macros-inl.h | 152 + core/src/simd/instruction/shared-inl.h | 545 ++ core/src/simd/simd.h | 16 + core/src/simd/targets.h | 24 + core/src/type/CMakeLists.txt | 8 + core/src/type/base_operations.h | 62 + core/src/type/big_integer.cpp | 644 ++ core/src/type/big_integer.h | 177 + core/src/type/data_operations.h | 289 + core/src/type/data_type.h | 664 ++ core/src/type/data_type_serializer.cpp | 90 + core/src/type/data_type_serializer.h | 25 + core/src/type/data_types.h | 97 + core/src/type/data_utils.h | 69 + core/src/type/date32.cpp | 562 ++ core/src/type/date32.h | 172 + core/src/type/date_base.h | 18 + core/src/type/date_time_utils.h | 141 + core/src/type/decimal128.cpp | 51 + core/src/type/decimal128.h | 155 + core/src/type/decimal128_utils.h | 97 + core/src/type/decimal_base.h | 25 + core/src/type/decimal_operations.h | 1770 +++++ core/src/type/double_utils.h | 810 ++ core/src/type/integer256.cpp | 261 + core/src/type/integer256.h | 385 + core/src/type/string_Impl.h | 127 + core/src/type/string_ref.h | 92 + core/src/type/width_integer.h | 136 + core/src/udf/cplusplus/CMakeLists.txt | 10 + core/src/udf/cplusplus/java_udf_functions.cpp | 179 + core/src/udf/cplusplus/java_udf_functions.h | 25 + core/src/udf/cplusplus/jni_util.cpp | 280 + core/src/udf/cplusplus/jni_util.h | 71 + core/src/udf/java/pom.xml | 166 + .../java/omniruntime/udf/HiveUdfExecutor.java | 595 ++ .../main/java/omniruntime/udf/UdfUtil.java | 98 + .../test/java/omniruntime/udf/AddIntUDF.java | 30 + .../test/java/omniruntime/udf/AddLongUDF.java | 30 + .../java/omniruntime/udf/AddShortUDF.java | 35 + .../java/omniruntime/udf/AndBooleanUDF.java | 28 + .../java/omniruntime/udf/ConcatStringUDF.java | 29 + .../omniruntime/udf/HiveUdfExecutorTest.java | 504 ++ .../java/omniruntime/udf/HiveUdfTestUtil.java | 424 ++ .../java/omniruntime/udf/MaxDoubleUDF.java | 30 + core/src/util/CMakeLists.txt | 6 + core/src/util/bit_array.h | 56 + core/src/util/bit_map.h | 68 + core/src/util/bit_util.h | 221 + core/src/util/bits_selectivity_vector.h | 287 + core/src/util/compiler_util.h | 26 + core/src/util/config/Config.h | 95 + core/src/util/config/ConfigBase.h | 199 + core/src/util/config/QueryConfig.h | 366 + core/src/util/config_util.cpp | 385 + core/src/util/config_util.h | 108 + core/src/util/debug.h | 117 + core/src/util/error_code.cpp | 38 + core/src/util/error_code.h | 36 + core/src/util/format.h | 55 + core/src/util/global_log.h | 44 + core/src/util/native_log.cpp | 110 + core/src/util/native_log.h | 27 + core/src/util/omni_exception.h | 58 + core/src/util/perf_util.cpp | 66 + core/src/util/perf_util.h | 25 + core/src/util/policy.h | 312 + core/src/util/property.h | 53 + core/src/util/trace_util.h | 34 + core/src/util/type_util.cpp | 248 + core/src/util/type_util.h | 124 + core/src/util/utf8_util.h | 142 + core/src/vector/CMakeLists.txt | 21 + core/src/vector/dictionary_container.h | 103 + core/src/vector/large_string_container.cpp | 123 + core/src/vector/large_string_container.h | 54 + core/src/vector/nulls_buffer.h | 167 + core/src/vector/omni_row.h | 617 ++ core/src/vector/omni_row_value.h | 283 + core/src/vector/selectivity_vector.cpp | 187 + core/src/vector/selectivity_vector.h | 202 + core/src/vector/string_utils.h | 61 + core/src/vector/type_utils.h | 61 + core/src/vector/unsafe_dictionary_container.h | 42 + core/src/vector/unsafe_string_container.h | 53 + core/src/vector/unsafe_vector.h | 125 + core/src/vector/vector.h | 620 ++ core/src/vector/vector_batch.cpp | 108 + core/src/vector/vector_batch.h | 66 + core/src/vector/vector_common.h | 19 + core/src/vector/vector_helper.h | 851 +++ core/test/CMakeLists.txt | 62 + core/test/benchmark/CMakeLists.txt | 52 + core/test/benchmark/codegen/CMakeLists.txt | 42 + core/test/benchmark/codegen/decimal.cpp | 81 + core/test/benchmark/omni_benchmark.cpp | 7 + core/test/benchmark/operator/AggLarge.cpp | 302 + core/test/benchmark/operator/CMakeLists.txt | 29 + core/test/benchmark/operator/agg.cpp | 531 ++ .../benchmark/operator/agg_in_hash_agg.cpp | 566 ++ .../test/benchmark/operator/common/common.cpp | 253 + core/test/benchmark/operator/common/common.h | 195 + .../benchmark/operator/common/vector_util.cpp | 202 + .../benchmark/operator/common/vector_util.h | 30 + .../benchmark/operator/distinct_limit.cpp | 64 + .../benchmark/operator/filter_and_project.cpp | 375 + .../benchmark/operator/groupby_hashmap.cpp | 488 ++ core/test/benchmark/operator/hash_agg.cpp | 238 + .../operator/hash_agg_large_group.cpp | 107 + core/test/benchmark/operator/hash_join.cpp | 282 + .../operator/scripts/extract_output.sh | 38 + .../benchmark/operator/scripts/plot_agg.py | 164 + .../operator/scripts/plot_hash_agg.py | 53 + .../scripts/plot_hashagg_vs_ck_benchmark.py | 84 + .../benchmark/operator/scripts/plot_util.py | 129 + core/test/benchmark/operator/sort.cpp | 89 + .../benchmark/operator/sort_merge_join.cpp | 140 + core/test/benchmark/operator/topn.cpp | 80 + .../test/benchmark/operator/tpcds/HowToUse.md | 4 + .../benchmark/operator/tpcds/tpcds_common.cpp | 205 + .../benchmark/operator/tpcds/tpcds_common.h | 64 + .../operator/tpcds/tpcds_hash_agg.cpp | 69 + .../operator/tpcds/tpcds_hash_join.cpp | 99 + .../benchmark/operator/tpcds/tpcds_sort.cpp | 43 + .../operator/tpcds/tpcds_sort_merge_join.cpp | 101 + core/test/benchmark/operator/window.cpp | 159 + core/test/codegen/CMakeLists.txt | 8 + .../codegen/batch_codegen_binary_test.cpp | 989 +++ .../codegen/batch_codegen_func_other_test.cpp | 766 ++ .../codegen/batch_codegen_if_switch_test.cpp | 648 ++ ...batch_codegen_in_between_coalesce_test.cpp | 791 ++ core/test/codegen/batch_function_test.cpp | 2835 +++++++ core/test/codegen/codegen_test.cpp | 2470 ++++++ core/test/codegen/codegen_util.cpp | 138 + core/test/codegen/codegen_util.h | 27 + core/test/codegen/decimal_function_test.cpp | 1637 ++++ core/test/codegen/expression_test.cpp | 1875 +++++ core/test/codegen/function_test.cpp | 3394 +++++++++ core/test/codegen/jni_mock.cpp | 177 + core/test/codegen/jni_mock.h | 68 + core/test/compute/CMakeLists.txt | 6 + core/test/compute/LimitTest.cpp | 316 + core/test/compute/OrderByTest.cpp | 188 + core/test/compute/PlanNodeStatsTest.cpp | 142 + core/test/compute/ProccessBaseTest.cpp | 22 + core/test/compute/UnionTest.cpp | 143 + core/test/dt/CMakeLists.txt | 2 + core/test/dt/README.md | 4 + core/test/dt/sourcetree/CMakeLists.txt | 9 + core/test/dt/sourcetree/fuzz_wrapper.cpp | 545 ++ core/test/dt/sourcetree/fuzz_wrapper.h | 23 + core/test/dt/testtree/CMakeLists.txt | 50 + core/test/dt/testtree/cases/dtframe.json | 47 + core/test/dt/testtree/dtframe.cfg | 62 + core/test/function/CMakeLists.txt | 4 + .../function/compare_function_bench_test.cpp | 225 + .../function/equals_function_bench_test.cpp | 183 + .../function/hash_function_bench_test.cpp | 134 + core/test/memory/CMakeLists.txt | 8 + core/test/memory/aligned_buffer_test.cpp | 60 + core/test/memory/memory_allocator_test.cpp | 154 + core/test/memory/memory_benchmark.cpp | 128 + .../memory/memory_manager_allocator_test.cpp | 56 + core/test/memory/memory_manager_test.cpp | 422 ++ core/test/memory/memory_trace_test.cpp | 46 + .../memory/simple_arena_allocator_test.cpp | 130 + .../memory/thread_memory_manager_test.cpp | 83 + core/test/memory/thread_memory_trace_test.cpp | 206 + core/test/omtest.cpp | 12 + core/test/operator/CMakeLists.txt | 14 + .../operator/adaptive_partialagg_test.cpp | 641 ++ core/test/operator/aggregation_test.cpp | 5570 ++++++++++++++ .../operator/aggregation_with_expr_test.cpp | 2364 ++++++ .../aggregator/aggregator_cast_test.cpp | 1243 +++ .../aggregator_multi_stage_complete_test.cpp | 543 ++ .../aggregator_multi_stage_no_groupby.h | 1030 +++ .../aggregator_multi_stage_with_groupby.h | 551 ++ .../aggregator_single_stage_complete_test.cpp | 586 ++ .../aggregator/container_vector_test.cpp | 127 + .../neon_agg_calculate_func_test.cpp | 319 + core/test/operator/array_map_test.cpp | 55 + core/test/operator/base_hashmap_test.cpp | 749 ++ core/test/operator/bloom_filter_test.cpp | 81 + core/test/operator/config_test.cpp | 154 + core/test/operator/distinctlimit_test.cpp | 380 + .../operator/dt_fuzz_factory_create_test.cpp | 180 + core/test/operator/execution_context_test.cpp | 38 + core/test/operator/expand_test.cpp | 87 + core/test/operator/filter_test.cpp | 3518 +++++++++ core/test/operator/hash_util_test.cpp | 300 + core/test/operator/join_test.cpp | 4775 ++++++++++++ core/test/operator/join_with_expr_test.cpp | 568 ++ .../test/operator/join_with_plannode_test.cpp | 197 + core/test/operator/limit_test.cpp | 278 + core/test/operator/nested_loop_join_test.cpp | 1008 +++ core/test/operator/project_test.cpp | 3520 +++++++++ core/test/operator/random_join_test.cpp | 1645 ++++ core/test/operator/sort_merge_join_test.cpp | 1804 +++++ .../sort_merge_join_with_expr_test.cpp | 4585 +++++++++++ .../sort_merge_join_with_expr_v2_test.cpp | 167 + .../sort_merge_join_with_expr_v3_test.cpp | 3679 +++++++++ core/test/operator/sort_test.cpp | 3018 ++++++++ core/test/operator/sort_with_expr_test.cpp | 480 ++ core/test/operator/spill_test.cpp | 91 + .../operator/topn_sort_with_expr_test.cpp | 621 ++ core/test/operator/topn_test.cpp | 1823 +++++ core/test/operator/topn_with_expr_test.cpp | 216 + core/test/operator/union_test.cpp | 65 + core/test/operator/vector_analyzer_test.cpp | 76 + .../window_group_limit_with_expr_test.cpp | 575 ++ core/test/operator/window_test.cpp | 4867 ++++++++++++ core/test/operator/window_with_expr_test.cpp | 1744 +++++ core/test/parser/CMakeLists.txt | 8 + core/test/parser/jsonparser_test.cpp | 936 +++ core/test/parser/parser_test.cpp | 754 ++ core/test/plannode/CMakeLists.txt | 6 + core/test/plannode/PlanNodeTest.cpp | 42 + core/test/type/CMakeLists.txt | 7 + core/test/type/big_integer_test.cpp | 183 + core/test/type/date32_test.cpp | 131 + core/test/type/decimal128_test.cpp | 1075 +++ core/test/type/decimal_operations_test.cpp | 200 + core/test/type/dtoa_test.cpp | 63 + core/test/type/type_serialization_test.cpp | 112 + core/test/util/CMakeLists.txt | 3 + .../test/util/dt_fuzz_factory_create_util.cpp | 336 + core/test/util/dt_fuzz_factory_create_util.h | 49 + core/test/util/test_native_log.cpp | 67 + core/test/util/test_util.cpp | 879 +++ core/test/util/test_util.h | 272 + core/test/vector/CMakeLists.txt | 10 + core/test/vector/basic_test.cpp | 713 ++ core/test/vector/benchmark.cpp | 317 + core/test/vector/container_benchmark.cpp | 76 + core/test/vector/container_vec_test.cpp | 51 + core/test/vector/decimal128_vector_test.cpp | 194 + core/test/vector/dict_container_test.cpp | 63 + core/test/vector/dictionary_vector_test.cpp | 446 ++ core/test/vector/omni_row_test.cpp | 333 + core/test/vector/selectivity_vector_test.cpp | 851 +++ core/test/vector/slice_test.cpp | 121 + core/test/vector/string_container_test.cpp | 329 + core/test/vector/unsafe_vector_test.cpp | 332 + core/test/vector/vector_batch_test.cpp | 81 + core/test/vector/vector_helper_test.cpp | 68 + core/test/vector/vector_nulls_test.cpp | 91 + core/test/vector/vector_test_util.h | 111 + env_check.sh | 101 + examples/README.MD | 0 .../externalfunctions/externalfunctions.cpp | 38 + .../externalfunctions/externalfunctions.h | 29 + .../externalregistration.conf | 13 + .../externalfunctions/extstringfunctions.cpp | 31 + .../externalfunctions/extstringfunctions.h | 29 + 828 files changed, 242157 insertions(+), 25 deletions(-) create mode 100644 .gitignore create mode 100644 BUILDING.MD create mode 100644 CMakeLists.txt create mode 100644 bindings/java/pom.xml create mode 100644 bindings/java/src/main/cpp/CMakeLists.txt create mode 100644 bindings/java/src/main/cpp/src/jni_common_def.cpp create mode 100644 bindings/java/src/main/cpp/src/jni_common_def.h create mode 100644 bindings/java/src/main/cpp/src/jni_constants.cpp create mode 100644 bindings/java/src/main/cpp/src/jni_constants.h create mode 100644 bindings/java/src/main/cpp/src/jni_helper.cpp create mode 100644 bindings/java/src/main/cpp/src/jni_helper.h create mode 100644 bindings/java/src/main/cpp/src/jni_operator.cpp create mode 100644 bindings/java/src/main/cpp/src/jni_operator.h create mode 100644 bindings/java/src/main/cpp/src/jni_operator_factory.cpp create mode 100644 bindings/java/src/main/cpp/src/jni_operator_factory.h create mode 100644 bindings/java/src/main/cpp/src/jni_vector.cpp create mode 100644 bindings/java/src/main/cpp/src/jni_vector.h create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/OmniLibs.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/constants/BuildSide.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/constants/Constant.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/constants/ConstantHelper.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/constants/FunctionType.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/constants/JoinType.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/constants/OmniWindowFrameBoundType.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/constants/OmniWindowFrameType.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/constants/OperatorType.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/constants/Status.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/memory/MemoryManager.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/OmniExprVerify.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/OmniOperator.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/OmniOperatorFactory.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/OmniOperatorFactoryContext.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/OmniResults.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/OmniRowResults.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/aggregator/OmniAggregationOperatorFactory.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/aggregator/OmniAggregationWithExprOperatorFactory.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/aggregator/OmniHashAggregationOperatorFactory.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/aggregator/OmniHashAggregationWithExprOperatorFactory.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/config/OperatorConfig.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/config/OverflowConfig.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/config/SparkSpillConfig.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/config/SpillConfig.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/filter/OmniBloomFilterOperatorFactory.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/filter/OmniFilterAndProjectOperatorFactory.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniHashBuilderOperatorFactory.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniHashBuilderWithExprOperatorFactory.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniLookupJoinOperatorFactory.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniLookupJoinWithExprOperatorFactory.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniLookupOuterJoinOperatorFactory.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniLookupOuterJoinWithExprOperatorFactory.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniNestedLoopJoinBuildOperatorFactory.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniNestedLoopJoinLookupOperatorFactory.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniSmjBufferedTableWithExprOperatorFactory.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniSmjBufferedTableWithExprOperatorFactoryV3.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniSmjStreamedTableWithExprOperatorFactory.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniSmjStreamedTableWithExprOperatorFactoryV3.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/limit/OmniDistinctLimitOperatorFactory.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/limit/OmniLimitOperatorFactory.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/partitionedoutput/OmniPartitionedOutPutOperatorFactory.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/project/OmniProjectOperatorFactory.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/sort/OmniSortOperatorFactory.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/sort/OmniSortWithExprOperatorFactory.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/topn/OmniTopNOperatorFactory.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/topn/OmniTopNWithExprOperatorFactory.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/topnsort/OmniTopNSortWithExprOperatorFactory.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/union/OmniUnionOperatorFactory.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/window/OmniWindowGroupLimitWithExprOperatorFactory.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/window/OmniWindowOperatorFactory.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/operator/window/OmniWindowWithExprOperatorFactory.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/type/BooleanDataType.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/type/ByteDataType.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/type/CharDataType.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/type/ContainerDataType.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/type/DataType.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/type/DataTypeSerializer.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/type/Date32DataType.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/type/Date64DataType.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/type/Decimal128DataType.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/type/Decimal64DataType.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/type/DecimalDataType.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/type/DoubleDataType.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/type/IntDataType.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/type/InvalidDataType.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/type/LongDataType.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/type/NoneDataType.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/type/ShortDataType.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/type/TimestampDataType.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/type/VarcharDataType.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/utils/JsonUtils.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/utils/NativeLog.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/utils/NullsBufHelper.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/utils/OmniErrorType.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/utils/OmniRuntimeException.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/utils/ParseUtil.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/utils/ShuffleHashHelper.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/utils/TraceUtil.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/vector/BooleanVec.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/vector/ByteVec.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/vector/ContainerVec.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/vector/Decimal128Vec.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/vector/DecimalVec.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/vector/DictionaryVec.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/vector/DoubleVec.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/vector/FixedWidthVec.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/vector/IntVec.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/vector/JvmUtils.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/vector/LongVec.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/vector/OmniBuffer.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/vector/OmniBufferFactory.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/vector/OmniBufferUnsafeV8.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/vector/Row.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/vector/RowBatch.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/vector/ShortVec.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/vector/VarcharVec.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/vector/VariableWidthVec.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/vector/Vec.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/vector/VecBatch.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/vector/VecEncoding.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/vector/VecFactory.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/vector/serialize/OmniRowDeserializer.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/vector/serialize/ProtoVecBatchSerializer.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/vector/serialize/VecBatchSerializer.java create mode 100644 bindings/java/src/main/java/nova/hetu/omniruntime/vector/serialize/VecBatchSerializerFactory.java create mode 100644 bindings/java/src/main/proto/vec_batch_serde.proto create mode 100644 bindings/java/src/main/scripts/build_cpp.sh create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/constant/ConstantLoadTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/memory/TestMemoryManager.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniAggregationOperatorTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniBloomFilterOperatorTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniDistinctLimitOperatorTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniExprVerifyTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniFilterAndProjectOperatorTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniHashAggregationOperatorTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniHashAggregationWithExprOperatorTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniHashJoinOperatorsTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniHashJoinWithExprOperatorsTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniLimitOperatorTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniNestedLoopJoinOperatorTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniOperatorConfigTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniOperatorFactoryTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniPartionOutOperatorTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniProjectOperatorTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniSortMergeJoinWithExprOperatorsTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniSortMergeJoinWithExprOperatorsV3Test.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniSortOperatorTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniSortWithExprOperatorTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniTopNOperatorTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniTopNSortWithExprOperatorTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniTopNWithExprOperatorTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniUnionOperatorTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniWindowGroupLimitWithExprOperatorTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniWindowOperatorTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniWindowWithExprOperatorTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql10ForOmniFilterOperatorTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql1ForOmniFilterOperatorTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql2ForOmniFilterOperatorTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql3ForOmniFilterOperatorTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql4ForOmniFilterOperatorTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql5ForOmniFilterOperatorTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql6ForOmniFilterOperatorTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql7ForOmniFilterOperatorTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql8ForOmniFilterOperatorTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql9ForOmniFilterOperatorTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/type/BenchmarkDataTypeSerializer.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/type/TestDataType.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/type/TestDataTypeSerializer.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/util/TestJsonUtils.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/util/TestUtils.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/vector/BenchmarkDecimal128Vec.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/vector/BenchmarkDoubleVec.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/vector/BenchmarkIntVec.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/vector/BenchmarkLongVec.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/vector/BenchmarkShortVec.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/vector/BenchmarkVarcharVec.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/vector/MergeVectorsTest.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestBooleanVec.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestContainerVec.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestDecimal128Vec.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestDictionaryVec.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestDoubleVec.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestIntVec.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestLongVec.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestOmniRow.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestShortVec.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestVarcharVec.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestVecBatch.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/vector/VecUtil.java create mode 100644 bindings/java/src/test/java/nova/hetu/omniruntime/vector/serialize/VecBatchSerializerTest.java create mode 100644 build.sh create mode 100644 build_scripts/build.sh create mode 100644 build_scripts/env_check.sh create mode 100644 core/CMakeLists.txt create mode 100644 core/config.h.in create mode 100644 core/secDTFuzz/Dockerfile_build create mode 100644 core/secDTFuzz/Dockerfile_run create mode 100644 core/secDTFuzz/SecDTFuzz.yaml create mode 100644 core/secDTFuzz/build.sh create mode 100644 core/src/CMakeLists.txt create mode 100644 core/src/README.MD create mode 100644 core/src/codegen/CMakeLists.txt create mode 100644 core/src/codegen/batch_codegen_context.h create mode 100644 core/src/codegen/batch_expression_codegen.cpp create mode 100644 core/src/codegen/batch_expression_codegen.h create mode 100644 core/src/codegen/batch_filter_codegen.cpp create mode 100644 core/src/codegen/batch_filter_codegen.h create mode 100644 core/src/codegen/batch_func_registry_datetime.cpp create mode 100644 core/src/codegen/batch_func_registry_datetime.h create mode 100644 core/src/codegen/batch_func_registry_decimal.cpp create mode 100644 core/src/codegen/batch_func_registry_decimal.h create mode 100644 core/src/codegen/batch_func_registry_dictionary.cpp create mode 100644 core/src/codegen/batch_func_registry_dictionary.h create mode 100644 core/src/codegen/batch_func_registry_hash.cpp create mode 100644 core/src/codegen/batch_func_registry_hash.h create mode 100644 core/src/codegen/batch_func_registry_math.cpp create mode 100644 core/src/codegen/batch_func_registry_math.h create mode 100644 core/src/codegen/batch_func_registry_string.cpp create mode 100644 core/src/codegen/batch_func_registry_string.h create mode 100644 core/src/codegen/batch_func_registry_util.cpp create mode 100644 core/src/codegen/batch_func_registry_util.h create mode 100644 core/src/codegen/batch_func_registry_varchar_vector.cpp create mode 100644 core/src/codegen/batch_func_registry_varchar_vector.h create mode 100644 core/src/codegen/batch_functions/batch_datetime_functions.cpp create mode 100644 core/src/codegen/batch_functions/batch_datetime_functions.h create mode 100644 core/src/codegen/batch_functions/batch_decimal_arithmetic_functions.cpp create mode 100644 core/src/codegen/batch_functions/batch_decimal_arithmetic_functions.h create mode 100644 core/src/codegen/batch_functions/batch_decimal_cast_functions.cpp create mode 100644 core/src/codegen/batch_functions/batch_decimal_cast_functions.h create mode 100644 core/src/codegen/batch_functions/batch_dictionaryfunctions.cpp create mode 100644 core/src/codegen/batch_functions/batch_dictionaryfunctions.h create mode 100644 core/src/codegen/batch_functions/batch_mathfunctions.cpp create mode 100644 core/src/codegen/batch_functions/batch_mathfunctions.h create mode 100644 core/src/codegen/batch_functions/batch_murmur3_hash.cpp create mode 100644 core/src/codegen/batch_functions/batch_murmur3_hash.h create mode 100644 core/src/codegen/batch_functions/batch_stringfunctions.cpp create mode 100644 core/src/codegen/batch_functions/batch_stringfunctions.h create mode 100644 core/src/codegen/batch_functions/batch_utilfunctions.cpp create mode 100644 core/src/codegen/batch_functions/batch_utilfunctions.h create mode 100644 core/src/codegen/batch_functions/batch_varcharVectorfunctions.cpp create mode 100644 core/src/codegen/batch_functions/batch_varcharVectorfunctions.h create mode 100644 core/src/codegen/batch_projection_codegen.cpp create mode 100644 core/src/codegen/batch_projection_codegen.h create mode 100644 core/src/codegen/bloom_filter.cpp create mode 100644 core/src/codegen/bloom_filter.h create mode 100644 core/src/codegen/codegen_base.cpp create mode 100644 core/src/codegen/codegen_base.h create mode 100644 core/src/codegen/codegen_context.h create mode 100644 core/src/codegen/codegen_value.h create mode 100644 core/src/codegen/common_util.h create mode 100644 core/src/codegen/context_helper.cpp create mode 100644 core/src/codegen/context_helper.h create mode 100644 core/src/codegen/expr_evaluator.cpp create mode 100644 core/src/codegen/expr_evaluator.h create mode 100644 core/src/codegen/expr_function.cpp create mode 100644 core/src/codegen/expr_function.h create mode 100644 core/src/codegen/expr_info_extractor.cpp create mode 100644 core/src/codegen/expr_info_extractor.h create mode 100644 core/src/codegen/expression_codegen.cpp create mode 100644 core/src/codegen/expression_codegen.h create mode 100644 core/src/codegen/filter_codegen.cpp create mode 100644 core/src/codegen/filter_codegen.h create mode 100644 core/src/codegen/func_registry.cpp create mode 100644 core/src/codegen/func_registry.h create mode 100644 core/src/codegen/func_registry_base.h create mode 100644 core/src/codegen/func_registry_context.cpp create mode 100644 core/src/codegen/func_registry_context.h create mode 100644 core/src/codegen/func_registry_datetime.cpp create mode 100644 core/src/codegen/func_registry_datetime.h create mode 100644 core/src/codegen/func_registry_decimal.cpp create mode 100644 core/src/codegen/func_registry_decimal.h create mode 100644 core/src/codegen/func_registry_dictionary.cpp create mode 100644 core/src/codegen/func_registry_dictionary.h create mode 100644 core/src/codegen/func_registry_hash.cpp create mode 100644 core/src/codegen/func_registry_hash.h create mode 100644 core/src/codegen/func_registry_hive_udf.cpp create mode 100644 core/src/codegen/func_registry_hive_udf.h create mode 100644 core/src/codegen/func_registry_math.cpp create mode 100644 core/src/codegen/func_registry_math.h create mode 100644 core/src/codegen/func_registry_might_contain.cpp create mode 100644 core/src/codegen/func_registry_might_contain.h create mode 100644 core/src/codegen/func_registry_string.cpp create mode 100644 core/src/codegen/func_registry_string.h create mode 100644 core/src/codegen/func_registry_varchar_vector.cpp create mode 100644 core/src/codegen/func_registry_varchar_vector.h create mode 100644 core/src/codegen/func_signature.cpp create mode 100644 core/src/codegen/func_signature.h create mode 100644 core/src/codegen/function.cpp create mode 100644 core/src/codegen/function.h create mode 100644 core/src/codegen/functions/README.md create mode 100644 core/src/codegen/functions/datetime_functions.cpp create mode 100644 core/src/codegen/functions/datetime_functions.h create mode 100644 core/src/codegen/functions/decimal_arithmetic_functions.cpp create mode 100644 core/src/codegen/functions/decimal_arithmetic_functions.h create mode 100644 core/src/codegen/functions/decimal_cast_functions.cpp create mode 100644 core/src/codegen/functions/decimal_cast_functions.h create mode 100644 core/src/codegen/functions/dictionaryfunctions.cpp create mode 100644 core/src/codegen/functions/dictionaryfunctions.h create mode 100644 core/src/codegen/functions/dtoa.cpp create mode 100644 core/src/codegen/functions/dtoa.h create mode 100644 core/src/codegen/functions/mathfunctions.cpp create mode 100644 core/src/codegen/functions/mathfunctions.h create mode 100644 core/src/codegen/functions/md5.cpp create mode 100644 core/src/codegen/functions/md5.h create mode 100644 core/src/codegen/functions/mightcontain.cpp create mode 100644 core/src/codegen/functions/mightcontain.h create mode 100644 core/src/codegen/functions/murmur3_hash.cpp create mode 100644 core/src/codegen/functions/murmur3_hash.h create mode 100644 core/src/codegen/functions/stringfunctions.cpp create mode 100644 core/src/codegen/functions/stringfunctions.h create mode 100644 core/src/codegen/functions/udffunctions.cpp create mode 100644 core/src/codegen/functions/udffunctions.h create mode 100644 core/src/codegen/functions/varcharVectorfunctions.cpp create mode 100644 core/src/codegen/functions/varcharVectorfunctions.h create mode 100644 core/src/codegen/functions/xxhash64_hash.cpp create mode 100644 core/src/codegen/functions/xxhash64_hash.h create mode 100644 core/src/codegen/llvm_engine.cpp create mode 100644 core/src/codegen/llvm_engine.h create mode 100644 core/src/codegen/llvm_types.cpp create mode 100644 core/src/codegen/llvm_types.h create mode 100644 core/src/codegen/projection_codegen.cpp create mode 100644 core/src/codegen/projection_codegen.h create mode 100644 core/src/codegen/simple_filter_codegen.cpp create mode 100644 core/src/codegen/simple_filter_codegen.h create mode 100644 core/src/codegen/string_util.h create mode 100644 core/src/codegen/time_util.h create mode 100644 core/src/compute/CMakeLists.txt create mode 100644 core/src/compute/ColumnarBatchIterator.h create mode 100644 core/src/compute/ResultIterator.cpp create mode 100644 core/src/compute/ResultIterator.h create mode 100644 core/src/compute/cpuWall_timer.cpp create mode 100644 core/src/compute/cpuWall_timer.h create mode 100644 core/src/compute/driver.cpp create mode 100644 core/src/compute/driver.h create mode 100644 core/src/compute/local_planner.cpp create mode 100644 core/src/compute/local_planner.h create mode 100644 core/src/compute/operator_stats.h create mode 100644 core/src/compute/plannode_stats.cpp create mode 100644 core/src/compute/plannode_stats.h create mode 100644 core/src/compute/process_base.cpp create mode 100644 core/src/compute/process_base.h create mode 100644 core/src/compute/reason.h create mode 100644 core/src/compute/task.cpp create mode 100644 core/src/compute/task.h create mode 100644 core/src/compute/task_stats.h create mode 100644 core/src/cpu_checker/CMakeLists.txt create mode 100644 core/src/cpu_checker/omniruntime_cpu_checker.cpp create mode 100644 core/src/cpu_checker/omniruntime_cpu_checker.h create mode 100644 core/src/expression/CMakeLists.txt create mode 100644 core/src/expression/README.md create mode 100644 core/src/expression/expr_printer.cpp create mode 100644 core/src/expression/expr_printer.h create mode 100644 core/src/expression/expr_verifier.cpp create mode 100644 core/src/expression/expr_verifier.h create mode 100644 core/src/expression/expr_visitor.cpp create mode 100644 core/src/expression/expr_visitor.h create mode 100644 core/src/expression/expressions.cpp create mode 100644 core/src/expression/expressions.h create mode 100644 core/src/expression/jsonparser/jsonparser.cpp create mode 100644 core/src/expression/jsonparser/jsonparser.h create mode 100644 core/src/expression/parser/parser.cpp create mode 100644 core/src/expression/parser/parser.h create mode 100644 core/src/expression/parserhelper.cpp create mode 100644 core/src/expression/parserhelper.h create mode 100644 core/src/memory/CMakeLists.txt create mode 100644 core/src/memory/aligned_buffer.h create mode 100644 core/src/memory/allocator.h create mode 100644 core/src/memory/chunk.cpp create mode 100644 core/src/memory/chunk.h create mode 100644 core/src/memory/memory_manager.cpp create mode 100644 core/src/memory/memory_manager.h create mode 100644 core/src/memory/memory_manager_allocator.h create mode 100644 core/src/memory/memory_pool.cpp create mode 100644 core/src/memory/memory_pool.h create mode 100644 core/src/memory/memory_trace.cpp create mode 100644 core/src/memory/memory_trace.h create mode 100644 core/src/memory/simple_arena_allocator.h create mode 100644 core/src/memory/thread_memory_manager.cpp create mode 100644 core/src/memory/thread_memory_manager.h create mode 100644 core/src/memory/thread_memory_trace.cpp create mode 100644 core/src/memory/thread_memory_trace.h create mode 100644 core/src/metrics/metrics.h create mode 100644 core/src/metrics/metrics_config.h create mode 100644 core/src/metrics/metrics_memory_info.h create mode 100644 core/src/metrics/metrics_row_counter.h create mode 100644 core/src/metrics/metrics_spill_info.h create mode 100644 core/src/metrics/omni_metrics.h create mode 100644 core/src/operator/CMakeLists.txt create mode 100644 core/src/operator/aggregation/GROUPBY.MD create mode 100644 core/src/operator/aggregation/agg_util.h create mode 100644 core/src/operator/aggregation/aggregation.cpp create mode 100644 core/src/operator/aggregation/aggregation.h create mode 100644 core/src/operator/aggregation/aggregator/aggregator.h create mode 100644 core/src/operator/aggregation/aggregator/aggregator_factory.cpp create mode 100644 core/src/operator/aggregation/aggregator/aggregator_factory.h create mode 100644 core/src/operator/aggregation/aggregator/aggregator_util.cpp create mode 100644 core/src/operator/aggregation/aggregator/aggregator_util.h create mode 100644 core/src/operator/aggregation/aggregator/all_aggregators.h create mode 100644 core/src/operator/aggregation/aggregator/average_aggregator.cpp create mode 100644 core/src/operator/aggregation/aggregator/average_aggregator.h create mode 100644 core/src/operator/aggregation/aggregator/average_flatim_aggregator.h create mode 100644 core/src/operator/aggregation/aggregator/average_spark_decimal_aggregator.h create mode 100644 core/src/operator/aggregation/aggregator/count_all_aggregator.h create mode 100644 core/src/operator/aggregation/aggregator/count_column_aggregator.cpp create mode 100644 core/src/operator/aggregation/aggregator/count_column_aggregator.h create mode 100644 core/src/operator/aggregation/aggregator/first_aggregator.h create mode 100644 core/src/operator/aggregation/aggregator/mask_column_assistant_aggregator.h create mode 100644 core/src/operator/aggregation/aggregator/max_aggregator.cpp create mode 100644 core/src/operator/aggregation/aggregator/max_aggregator.h create mode 100644 core/src/operator/aggregation/aggregator/max_varchar_aggregator.cpp create mode 100644 core/src/operator/aggregation/aggregator/max_varchar_aggregator.h create mode 100644 core/src/operator/aggregation/aggregator/min_aggregator.cpp create mode 100644 core/src/operator/aggregation/aggregator/min_aggregator.h create mode 100644 core/src/operator/aggregation/aggregator/min_varchar_aggregator.cpp create mode 100644 core/src/operator/aggregation/aggregator/min_varchar_aggregator.h create mode 100644 core/src/operator/aggregation/aggregator/only_aggregator_factory.h create mode 100644 core/src/operator/aggregation/aggregator/operations_aggregator.h create mode 100644 core/src/operator/aggregation/aggregator/operations_hash_aggregator.h create mode 100644 core/src/operator/aggregation/aggregator/state_flag_operation.h create mode 100644 core/src/operator/aggregation/aggregator/stddev_samp_aggregator.h create mode 100644 core/src/operator/aggregation/aggregator/sum_aggregator.cpp create mode 100644 core/src/operator/aggregation/aggregator/sum_aggregator.h create mode 100644 core/src/operator/aggregation/aggregator/sum_flatim_aggregator.h create mode 100644 core/src/operator/aggregation/aggregator/sum_spark_decimal_aggregator.h create mode 100644 core/src/operator/aggregation/aggregator/try_sum_flatim_aggregator.h create mode 100644 core/src/operator/aggregation/aggregator/typed_aggregator.cpp create mode 100644 core/src/operator/aggregation/aggregator/typed_aggregator.h create mode 100644 core/src/operator/aggregation/aggregator/typed_mask_column_assistant_aggregator.h create mode 100644 core/src/operator/aggregation/container_vector.h create mode 100644 core/src/operator/aggregation/definitions.h create mode 100644 core/src/operator/aggregation/group_aggregation.cpp create mode 100644 core/src/operator/aggregation/group_aggregation.h create mode 100644 core/src/operator/aggregation/group_aggregation_expr.cpp create mode 100644 core/src/operator/aggregation/group_aggregation_expr.h create mode 100644 core/src/operator/aggregation/group_aggregation_sort.cpp create mode 100644 core/src/operator/aggregation/group_aggregation_sort.h create mode 100644 core/src/operator/aggregation/non_group_aggregation.cpp create mode 100644 core/src/operator/aggregation/non_group_aggregation.h create mode 100644 core/src/operator/aggregation/non_group_aggregation_expr.cpp create mode 100644 core/src/operator/aggregation/non_group_aggregation_expr.h create mode 100644 core/src/operator/aggregation/one_row_adaptor.h create mode 100644 core/src/operator/aggregation/vector_getter.h create mode 100644 core/src/operator/config/operator_config.cpp create mode 100644 core/src/operator/config/operator_config.h create mode 100644 core/src/operator/execution_context.h create mode 100644 core/src/operator/expand/expand.cpp create mode 100644 core/src/operator/expand/expand.h create mode 100644 core/src/operator/filter/filter_and_project.cpp create mode 100644 core/src/operator/filter/filter_and_project.h create mode 100644 core/src/operator/grouping/grouping.cpp create mode 100644 core/src/operator/grouping/grouping.h create mode 100644 core/src/operator/hash_util.cpp create mode 100644 core/src/operator/hash_util.h create mode 100644 core/src/operator/hashmap/array_map.h create mode 100644 core/src/operator/hashmap/base_hash_map.h create mode 100644 core/src/operator/hashmap/column_marshaller.h create mode 100644 core/src/operator/hashmap/crc32c.h create mode 100644 core/src/operator/hashmap/crc_hasher.h create mode 100644 core/src/operator/hashmap/group_hasher.h create mode 100644 core/src/operator/hashmap/vector_analyzer.cpp create mode 100644 core/src/operator/hashmap/vector_analyzer.h create mode 100644 core/src/operator/hashmap/vector_marshaller.cpp create mode 100644 core/src/operator/hashmap/vector_marshaller.h create mode 100644 core/src/operator/join/common_join.h create mode 100644 core/src/operator/join/hash_builder.cpp create mode 100644 core/src/operator/join/hash_builder.h create mode 100644 core/src/operator/join/hash_builder_expr.cpp create mode 100644 core/src/operator/join/hash_builder_expr.h create mode 100644 core/src/operator/join/join_hash_table_variants.cpp create mode 100644 core/src/operator/join/join_hash_table_variants.h create mode 100644 core/src/operator/join/lookup_join.cpp create mode 100644 core/src/operator/join/lookup_join.h create mode 100644 core/src/operator/join/lookup_join_expr.cpp create mode 100644 core/src/operator/join/lookup_join_expr.h create mode 100644 core/src/operator/join/lookup_join_wrapper.cpp create mode 100644 core/src/operator/join/lookup_join_wrapper.h create mode 100644 core/src/operator/join/lookup_outer_join.cpp create mode 100644 core/src/operator/join/lookup_outer_join.h create mode 100644 core/src/operator/join/lookup_outer_join_expr.cpp create mode 100644 core/src/operator/join/lookup_outer_join_expr.h create mode 100644 core/src/operator/join/nest_loop_join_builder.cpp create mode 100644 core/src/operator/join/nest_loop_join_builder.h create mode 100644 core/src/operator/join/nest_loop_join_lookup.cpp create mode 100644 core/src/operator/join/nest_loop_join_lookup.h create mode 100644 core/src/operator/join/nest_loop_join_lookup_wrapper.cpp create mode 100644 core/src/operator/join/nest_loop_join_lookup_wrapper.h create mode 100644 core/src/operator/join/row_ref.h create mode 100644 core/src/operator/join/sortmergejoin/dynamic_pages_index.cpp create mode 100644 core/src/operator/join/sortmergejoin/dynamic_pages_index.h create mode 100644 core/src/operator/join/sortmergejoin/sort_merge_join.cpp create mode 100644 core/src/operator/join/sortmergejoin/sort_merge_join.h create mode 100644 core/src/operator/join/sortmergejoin/sort_merge_join_expr.cpp create mode 100644 core/src/operator/join/sortmergejoin/sort_merge_join_expr.h create mode 100644 core/src/operator/join/sortmergejoin/sort_merge_join_expr_v2.cpp create mode 100644 core/src/operator/join/sortmergejoin/sort_merge_join_expr_v2.h create mode 100644 core/src/operator/join/sortmergejoin/sort_merge_join_expr_v3.cpp create mode 100644 core/src/operator/join/sortmergejoin/sort_merge_join_expr_v3.h create mode 100644 core/src/operator/join/sortmergejoin/sort_merge_join_resultBuilder.cpp create mode 100644 core/src/operator/join/sortmergejoin/sort_merge_join_resultBuilder.h create mode 100644 core/src/operator/join/sortmergejoin/sort_merge_join_scanner.cpp create mode 100644 core/src/operator/join/sortmergejoin/sort_merge_join_scanner.h create mode 100644 core/src/operator/join/sortmergejoin/sort_merge_join_v3.cpp create mode 100644 core/src/operator/join/sortmergejoin/sort_merge_join_v3.h create mode 100644 core/src/operator/limit/distinct_limit.cpp create mode 100644 core/src/operator/limit/distinct_limit.h create mode 100644 core/src/operator/limit/distinct_state_func.h create mode 100644 core/src/operator/limit/limit.cpp create mode 100644 core/src/operator/limit/limit.h create mode 100644 core/src/operator/memory_builder.h create mode 100644 core/src/operator/omni_id_type_vector_traits.h create mode 100644 core/src/operator/operator.h create mode 100644 core/src/operator/operator_factory.h create mode 100644 core/src/operator/pages_hash_strategy.cpp create mode 100644 core/src/operator/pages_hash_strategy.h create mode 100644 core/src/operator/pages_index.cpp create mode 100644 core/src/operator/pages_index.h create mode 100644 core/src/operator/projection/projection.cpp create mode 100644 core/src/operator/projection/projection.h create mode 100644 core/src/operator/radix_sort.h create mode 100644 core/src/operator/sort/sort.cpp create mode 100644 core/src/operator/sort/sort.h create mode 100644 core/src/operator/sort/sort_expr.cpp create mode 100644 core/src/operator/sort/sort_expr.h create mode 100644 core/src/operator/spill/loser_tree.h create mode 100644 core/src/operator/spill/spill_merger.cpp create mode 100644 core/src/operator/spill/spill_merger.h create mode 100644 core/src/operator/spill/spill_tracker.cpp create mode 100644 core/src/operator/spill/spill_tracker.h create mode 100644 core/src/operator/spill/spiller.cpp create mode 100644 core/src/operator/spill/spiller.h create mode 100644 core/src/operator/status.h create mode 100644 core/src/operator/topn/topn.cpp create mode 100644 core/src/operator/topn/topn.h create mode 100644 core/src/operator/topn/topn_expr.cpp create mode 100644 core/src/operator/topn/topn_expr.h create mode 100644 core/src/operator/topnsort/topn_sort.cpp create mode 100644 core/src/operator/topnsort/topn_sort.h create mode 100644 core/src/operator/topnsort/topn_sort_expr.cpp create mode 100644 core/src/operator/topnsort/topn_sort_expr.h create mode 100644 core/src/operator/union/union.cpp create mode 100644 core/src/operator/union/union.h create mode 100644 core/src/operator/util/function_type.h create mode 100644 core/src/operator/util/mm3_util.h create mode 100644 core/src/operator/util/operator_util.h create mode 100644 core/src/operator/window/window.cpp create mode 100644 core/src/operator/window/window.h create mode 100644 core/src/operator/window/window_expr.cpp create mode 100644 core/src/operator/window/window_expr.h create mode 100644 core/src/operator/window/window_frame.h create mode 100644 core/src/operator/window/window_function.cpp create mode 100644 core/src/operator/window/window_function.h create mode 100644 core/src/operator/window/window_group_limit.cpp create mode 100644 core/src/operator/window/window_group_limit.h create mode 100644 core/src/operator/window/window_group_limit_expr.cpp create mode 100644 core/src/operator/window/window_group_limit_expr.h create mode 100644 core/src/operator/window/window_partition.cpp create mode 100644 core/src/operator/window/window_partition.h create mode 100644 core/src/plannode/CMakeLists.txt create mode 100644 core/src/plannode/RowVectorStream.h create mode 100644 core/src/plannode/planFragment.cpp create mode 100644 core/src/plannode/planFragment.h create mode 100644 core/src/plannode/planNode.cpp create mode 100644 core/src/plannode/planNode.h create mode 100644 core/src/simd/CMakeLists.txt create mode 100644 core/src/simd/base.h create mode 100644 core/src/simd/func/match.h create mode 100644 core/src/simd/func/quick_sort_simd.cpp create mode 100644 core/src/simd/func/quick_sort_simd.h create mode 100644 core/src/simd/func/reduce.h create mode 100644 core/src/simd/func/small_case_sort.h create mode 100644 core/src/simd/func/traits-inl.h create mode 100644 core/src/simd/instruction/arm_neon-inl.h create mode 100644 core/src/simd/instruction/arm_sve-inl.h create mode 100644 core/src/simd/instruction/generic_ops-inl.h create mode 100644 core/src/simd/instruction/inside-inl.h create mode 100644 core/src/simd/instruction/set_macros-inl.h create mode 100644 core/src/simd/instruction/shared-inl.h create mode 100644 core/src/simd/simd.h create mode 100644 core/src/simd/targets.h create mode 100644 core/src/type/CMakeLists.txt create mode 100644 core/src/type/base_operations.h create mode 100644 core/src/type/big_integer.cpp create mode 100644 core/src/type/big_integer.h create mode 100644 core/src/type/data_operations.h create mode 100644 core/src/type/data_type.h create mode 100644 core/src/type/data_type_serializer.cpp create mode 100644 core/src/type/data_type_serializer.h create mode 100644 core/src/type/data_types.h create mode 100644 core/src/type/data_utils.h create mode 100644 core/src/type/date32.cpp create mode 100644 core/src/type/date32.h create mode 100644 core/src/type/date_base.h create mode 100644 core/src/type/date_time_utils.h create mode 100644 core/src/type/decimal128.cpp create mode 100644 core/src/type/decimal128.h create mode 100644 core/src/type/decimal128_utils.h create mode 100644 core/src/type/decimal_base.h create mode 100644 core/src/type/decimal_operations.h create mode 100644 core/src/type/double_utils.h create mode 100644 core/src/type/integer256.cpp create mode 100644 core/src/type/integer256.h create mode 100644 core/src/type/string_Impl.h create mode 100644 core/src/type/string_ref.h create mode 100644 core/src/type/width_integer.h create mode 100644 core/src/udf/cplusplus/CMakeLists.txt create mode 100644 core/src/udf/cplusplus/java_udf_functions.cpp create mode 100644 core/src/udf/cplusplus/java_udf_functions.h create mode 100644 core/src/udf/cplusplus/jni_util.cpp create mode 100644 core/src/udf/cplusplus/jni_util.h create mode 100644 core/src/udf/java/pom.xml create mode 100644 core/src/udf/java/src/main/java/omniruntime/udf/HiveUdfExecutor.java create mode 100644 core/src/udf/java/src/main/java/omniruntime/udf/UdfUtil.java create mode 100644 core/src/udf/java/src/test/java/omniruntime/udf/AddIntUDF.java create mode 100644 core/src/udf/java/src/test/java/omniruntime/udf/AddLongUDF.java create mode 100644 core/src/udf/java/src/test/java/omniruntime/udf/AddShortUDF.java create mode 100644 core/src/udf/java/src/test/java/omniruntime/udf/AndBooleanUDF.java create mode 100644 core/src/udf/java/src/test/java/omniruntime/udf/ConcatStringUDF.java create mode 100644 core/src/udf/java/src/test/java/omniruntime/udf/HiveUdfExecutorTest.java create mode 100644 core/src/udf/java/src/test/java/omniruntime/udf/HiveUdfTestUtil.java create mode 100644 core/src/udf/java/src/test/java/omniruntime/udf/MaxDoubleUDF.java create mode 100644 core/src/util/CMakeLists.txt create mode 100644 core/src/util/bit_array.h create mode 100644 core/src/util/bit_map.h create mode 100644 core/src/util/bit_util.h create mode 100644 core/src/util/bits_selectivity_vector.h create mode 100644 core/src/util/compiler_util.h create mode 100644 core/src/util/config/Config.h create mode 100644 core/src/util/config/ConfigBase.h create mode 100644 core/src/util/config/QueryConfig.h create mode 100644 core/src/util/config_util.cpp create mode 100644 core/src/util/config_util.h create mode 100644 core/src/util/debug.h create mode 100644 core/src/util/error_code.cpp create mode 100644 core/src/util/error_code.h create mode 100644 core/src/util/format.h create mode 100644 core/src/util/global_log.h create mode 100644 core/src/util/native_log.cpp create mode 100644 core/src/util/native_log.h create mode 100644 core/src/util/omni_exception.h create mode 100644 core/src/util/perf_util.cpp create mode 100644 core/src/util/perf_util.h create mode 100644 core/src/util/policy.h create mode 100644 core/src/util/property.h create mode 100644 core/src/util/trace_util.h create mode 100644 core/src/util/type_util.cpp create mode 100644 core/src/util/type_util.h create mode 100644 core/src/util/utf8_util.h create mode 100644 core/src/vector/CMakeLists.txt create mode 100644 core/src/vector/dictionary_container.h create mode 100644 core/src/vector/large_string_container.cpp create mode 100644 core/src/vector/large_string_container.h create mode 100644 core/src/vector/nulls_buffer.h create mode 100644 core/src/vector/omni_row.h create mode 100644 core/src/vector/omni_row_value.h create mode 100644 core/src/vector/selectivity_vector.cpp create mode 100644 core/src/vector/selectivity_vector.h create mode 100644 core/src/vector/string_utils.h create mode 100644 core/src/vector/type_utils.h create mode 100644 core/src/vector/unsafe_dictionary_container.h create mode 100644 core/src/vector/unsafe_string_container.h create mode 100644 core/src/vector/unsafe_vector.h create mode 100644 core/src/vector/vector.h create mode 100644 core/src/vector/vector_batch.cpp create mode 100644 core/src/vector/vector_batch.h create mode 100644 core/src/vector/vector_common.h create mode 100644 core/src/vector/vector_helper.h create mode 100644 core/test/CMakeLists.txt create mode 100644 core/test/benchmark/CMakeLists.txt create mode 100644 core/test/benchmark/codegen/CMakeLists.txt create mode 100644 core/test/benchmark/codegen/decimal.cpp create mode 100644 core/test/benchmark/omni_benchmark.cpp create mode 100644 core/test/benchmark/operator/AggLarge.cpp create mode 100644 core/test/benchmark/operator/CMakeLists.txt create mode 100644 core/test/benchmark/operator/agg.cpp create mode 100644 core/test/benchmark/operator/agg_in_hash_agg.cpp create mode 100644 core/test/benchmark/operator/common/common.cpp create mode 100644 core/test/benchmark/operator/common/common.h create mode 100644 core/test/benchmark/operator/common/vector_util.cpp create mode 100644 core/test/benchmark/operator/common/vector_util.h create mode 100644 core/test/benchmark/operator/distinct_limit.cpp create mode 100644 core/test/benchmark/operator/filter_and_project.cpp create mode 100644 core/test/benchmark/operator/groupby_hashmap.cpp create mode 100644 core/test/benchmark/operator/hash_agg.cpp create mode 100644 core/test/benchmark/operator/hash_agg_large_group.cpp create mode 100644 core/test/benchmark/operator/hash_join.cpp create mode 100644 core/test/benchmark/operator/scripts/extract_output.sh create mode 100644 core/test/benchmark/operator/scripts/plot_agg.py create mode 100644 core/test/benchmark/operator/scripts/plot_hash_agg.py create mode 100644 core/test/benchmark/operator/scripts/plot_hashagg_vs_ck_benchmark.py create mode 100644 core/test/benchmark/operator/scripts/plot_util.py create mode 100644 core/test/benchmark/operator/sort.cpp create mode 100644 core/test/benchmark/operator/sort_merge_join.cpp create mode 100644 core/test/benchmark/operator/topn.cpp create mode 100644 core/test/benchmark/operator/tpcds/HowToUse.md create mode 100644 core/test/benchmark/operator/tpcds/tpcds_common.cpp create mode 100644 core/test/benchmark/operator/tpcds/tpcds_common.h create mode 100644 core/test/benchmark/operator/tpcds/tpcds_hash_agg.cpp create mode 100644 core/test/benchmark/operator/tpcds/tpcds_hash_join.cpp create mode 100644 core/test/benchmark/operator/tpcds/tpcds_sort.cpp create mode 100644 core/test/benchmark/operator/tpcds/tpcds_sort_merge_join.cpp create mode 100644 core/test/benchmark/operator/window.cpp create mode 100644 core/test/codegen/CMakeLists.txt create mode 100644 core/test/codegen/batch_codegen_binary_test.cpp create mode 100644 core/test/codegen/batch_codegen_func_other_test.cpp create mode 100644 core/test/codegen/batch_codegen_if_switch_test.cpp create mode 100644 core/test/codegen/batch_codegen_in_between_coalesce_test.cpp create mode 100644 core/test/codegen/batch_function_test.cpp create mode 100644 core/test/codegen/codegen_test.cpp create mode 100644 core/test/codegen/codegen_util.cpp create mode 100644 core/test/codegen/codegen_util.h create mode 100644 core/test/codegen/decimal_function_test.cpp create mode 100644 core/test/codegen/expression_test.cpp create mode 100644 core/test/codegen/function_test.cpp create mode 100644 core/test/codegen/jni_mock.cpp create mode 100644 core/test/codegen/jni_mock.h create mode 100644 core/test/compute/CMakeLists.txt create mode 100644 core/test/compute/LimitTest.cpp create mode 100644 core/test/compute/OrderByTest.cpp create mode 100644 core/test/compute/PlanNodeStatsTest.cpp create mode 100644 core/test/compute/ProccessBaseTest.cpp create mode 100644 core/test/compute/UnionTest.cpp create mode 100644 core/test/dt/CMakeLists.txt create mode 100644 core/test/dt/README.md create mode 100644 core/test/dt/sourcetree/CMakeLists.txt create mode 100644 core/test/dt/sourcetree/fuzz_wrapper.cpp create mode 100644 core/test/dt/sourcetree/fuzz_wrapper.h create mode 100644 core/test/dt/testtree/CMakeLists.txt create mode 100644 core/test/dt/testtree/cases/dtframe.json create mode 100644 core/test/dt/testtree/dtframe.cfg create mode 100644 core/test/function/CMakeLists.txt create mode 100644 core/test/function/compare_function_bench_test.cpp create mode 100644 core/test/function/equals_function_bench_test.cpp create mode 100644 core/test/function/hash_function_bench_test.cpp create mode 100644 core/test/memory/CMakeLists.txt create mode 100644 core/test/memory/aligned_buffer_test.cpp create mode 100644 core/test/memory/memory_allocator_test.cpp create mode 100644 core/test/memory/memory_benchmark.cpp create mode 100644 core/test/memory/memory_manager_allocator_test.cpp create mode 100644 core/test/memory/memory_manager_test.cpp create mode 100644 core/test/memory/memory_trace_test.cpp create mode 100644 core/test/memory/simple_arena_allocator_test.cpp create mode 100644 core/test/memory/thread_memory_manager_test.cpp create mode 100644 core/test/memory/thread_memory_trace_test.cpp create mode 100644 core/test/omtest.cpp create mode 100644 core/test/operator/CMakeLists.txt create mode 100644 core/test/operator/adaptive_partialagg_test.cpp create mode 100644 core/test/operator/aggregation_test.cpp create mode 100644 core/test/operator/aggregation_with_expr_test.cpp create mode 100644 core/test/operator/aggregator/aggregator_cast_test.cpp create mode 100644 core/test/operator/aggregator/aggregator_multi_stage_complete_test.cpp create mode 100644 core/test/operator/aggregator/aggregator_multi_stage_no_groupby.h create mode 100644 core/test/operator/aggregator/aggregator_multi_stage_with_groupby.h create mode 100644 core/test/operator/aggregator/aggregator_single_stage_complete_test.cpp create mode 100644 core/test/operator/aggregator/container_vector_test.cpp create mode 100644 core/test/operator/aggregator/neon_agg_calculate_func_test.cpp create mode 100644 core/test/operator/array_map_test.cpp create mode 100644 core/test/operator/base_hashmap_test.cpp create mode 100644 core/test/operator/bloom_filter_test.cpp create mode 100644 core/test/operator/config_test.cpp create mode 100644 core/test/operator/distinctlimit_test.cpp create mode 100644 core/test/operator/dt_fuzz_factory_create_test.cpp create mode 100644 core/test/operator/execution_context_test.cpp create mode 100644 core/test/operator/expand_test.cpp create mode 100644 core/test/operator/filter_test.cpp create mode 100644 core/test/operator/hash_util_test.cpp create mode 100644 core/test/operator/join_test.cpp create mode 100644 core/test/operator/join_with_expr_test.cpp create mode 100644 core/test/operator/join_with_plannode_test.cpp create mode 100644 core/test/operator/limit_test.cpp create mode 100644 core/test/operator/nested_loop_join_test.cpp create mode 100644 core/test/operator/project_test.cpp create mode 100644 core/test/operator/random_join_test.cpp create mode 100644 core/test/operator/sort_merge_join_test.cpp create mode 100644 core/test/operator/sort_merge_join_with_expr_test.cpp create mode 100644 core/test/operator/sort_merge_join_with_expr_v2_test.cpp create mode 100644 core/test/operator/sort_merge_join_with_expr_v3_test.cpp create mode 100644 core/test/operator/sort_test.cpp create mode 100644 core/test/operator/sort_with_expr_test.cpp create mode 100644 core/test/operator/spill_test.cpp create mode 100644 core/test/operator/topn_sort_with_expr_test.cpp create mode 100644 core/test/operator/topn_test.cpp create mode 100644 core/test/operator/topn_with_expr_test.cpp create mode 100644 core/test/operator/union_test.cpp create mode 100644 core/test/operator/vector_analyzer_test.cpp create mode 100644 core/test/operator/window_group_limit_with_expr_test.cpp create mode 100644 core/test/operator/window_test.cpp create mode 100644 core/test/operator/window_with_expr_test.cpp create mode 100644 core/test/parser/CMakeLists.txt create mode 100644 core/test/parser/jsonparser_test.cpp create mode 100644 core/test/parser/parser_test.cpp create mode 100644 core/test/plannode/CMakeLists.txt create mode 100644 core/test/plannode/PlanNodeTest.cpp create mode 100644 core/test/type/CMakeLists.txt create mode 100644 core/test/type/big_integer_test.cpp create mode 100644 core/test/type/date32_test.cpp create mode 100644 core/test/type/decimal128_test.cpp create mode 100644 core/test/type/decimal_operations_test.cpp create mode 100644 core/test/type/dtoa_test.cpp create mode 100644 core/test/type/type_serialization_test.cpp create mode 100644 core/test/util/CMakeLists.txt create mode 100644 core/test/util/dt_fuzz_factory_create_util.cpp create mode 100644 core/test/util/dt_fuzz_factory_create_util.h create mode 100644 core/test/util/test_native_log.cpp create mode 100644 core/test/util/test_util.cpp create mode 100644 core/test/util/test_util.h create mode 100644 core/test/vector/CMakeLists.txt create mode 100644 core/test/vector/basic_test.cpp create mode 100644 core/test/vector/benchmark.cpp create mode 100644 core/test/vector/container_benchmark.cpp create mode 100644 core/test/vector/container_vec_test.cpp create mode 100644 core/test/vector/decimal128_vector_test.cpp create mode 100644 core/test/vector/dict_container_test.cpp create mode 100644 core/test/vector/dictionary_vector_test.cpp create mode 100644 core/test/vector/omni_row_test.cpp create mode 100644 core/test/vector/selectivity_vector_test.cpp create mode 100644 core/test/vector/slice_test.cpp create mode 100644 core/test/vector/string_container_test.cpp create mode 100644 core/test/vector/unsafe_vector_test.cpp create mode 100644 core/test/vector/vector_batch_test.cpp create mode 100644 core/test/vector/vector_helper_test.cpp create mode 100644 core/test/vector/vector_nulls_test.cpp create mode 100644 core/test/vector/vector_test_util.h create mode 100644 env_check.sh create mode 100644 examples/README.MD create mode 100644 examples/externalfunctions/externalfunctions.cpp create mode 100644 examples/externalfunctions/externalfunctions.h create mode 100644 examples/externalfunctions/externalregistration.conf create mode 100644 examples/externalfunctions/extstringfunctions.cpp create mode 100644 examples/externalfunctions/extstringfunctions.h diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c01dd41 --- /dev/null +++ b/.gitignore @@ -0,0 +1,63 @@ +.idea +.vscode +.DS_Store +.classpath +.settings +.project +.externalToolBuilders +.checkstyle +.mvn/timing.properties +.omnitools + +*.pyc +*.class +*.iml +*.ipr +*.iws +*.so +*.a +*.gch +*.cmake +*.ll +*.cbp +*.lock +*.log + +*/dependency-reduced-pom.xml +*/.idea +*/.gitignore + +*/*.iml + +target +benchmark_outputs +node_modules +libs +CMakeCache.txt +CMakeFiles +Makefile +cmake-build-debug +cmake-build-debug-coverage +cmake-build-release +cmake-build-release-coverage + +core/omni-runtime.cbp +core/config.h +bindings/java/*.log +omtest +install_manifest.txt + +core/build/test_coverage +test.info +test_detail.xml +*.tar.gz +*.zip +boostkit-omniop-operator-* + +installtree +registration +dist +dtframe_ +log +test_report.json +coverage_lcov* \ No newline at end of file diff --git a/BUILDING.MD b/BUILDING.MD new file mode 100644 index 0000000..7b8f315 --- /dev/null +++ b/BUILDING.MD @@ -0,0 +1,403 @@ +# Preface +This document describes how to compile and deploy the openLooKeng (OLK) and OmniRuntime package for function and performance debugging in ***yellow zone***. The operating system uses ***ubuntu 18.04***. + +# Preparing the compilation environment + +## Configure HTTP proxy + +### Edit file /etc/profile + +Add the following information to the file `/etc/profile` + +``` +export http_proxy=http://username:password@proxy.huawei.com:8080 +export https_proxy=http://username:password@proxy.huawei.com:8080 +export no_proxy=127.0.0.1,*.huawei.com,localhost,local,*.local,10.* +``` + +Notices: + +- Enter your own domain username and password for the username and password. + +- Special characters in the username and password must be escaped. For details, see the following links: + + Common escape character : + + All escape character: + +### Check whether the external network can be pinged + +Type the command `curl www.baidu.com ` on the command line. + +Normally, the following information is displayed, indicating that the external network can be pinged successfully. + +![1661582746423](http://image.huawei.com/tiny-lts/v1/images/5498ed56051b5e62988d9e2e182b60a3_865x593.png@900-0-90-f.png) + +## Configure Ubuntu mirror source + +The default mirror source configuration file is `/etc/apt/sources.list`,which can configure Ubuntu mirror source.However, it is recommended that add the .list file in the directory `/etc/apt/sources.list.d`.More specifically,create file 00ubuntu.list in the directory `/etc/apt/sources.list.d` and add the following content to file 00ubuntu.list. + +``` +deb http://mirrors.aliyun.com/ubuntu/ bionic main restricted universe multiverse +deb http://mirrors.aliyun.com/ubuntu/ bionic-security main restricted universe multiverse +deb http://mirrors.aliyun.com/ubuntu/ bionic-updates main restricted universe multiverse +deb http://mirrors.aliyun.com/ubuntu/ bionic-proposed main restricted universe multiverse +deb http://mirrors.aliyun.com/ubuntu/ bionic-backports main restricted universe multiverse +deb-src http://mirrors.aliyun.com/ubuntu/ bionic main restricted universe multiverse +deb-src http://mirrors.aliyun.com/ubuntu/ bionic-security main restricted universe multiverse +deb-src http://mirrors.aliyun.com/ubuntu/ bionic-updates main restricted universe multiverse +deb-src http://mirrors.aliyun.com/ubuntu/ bionic-proposed main restricted universe multiverse +deb-src http://mirrors.aliyun.com/ubuntu/ bionic-backports main restricted universe multiverse +``` + +Another solution is to configure the HUAWEI CLOUD source. However, the following error will be reported during subsequent use, which may be caused by the server side. + +![1661585348159](http://image.huawei.com/tiny-lts/v1/images/ddba80d5d937f8ac532fec970b66d604_865x399.png@900-0-90-f.png) + +After the Ubuntu mirror source configuration, run the following command: + +``` +apt update +apt upgrade +``` + +For details about the above configuration, refer to the reference link . + +## Install OmniOperatorJit Dependencies + +The components required for OmniOperatorJit compilation include cmake, LLVM, googletest, jemalloc, json, huawei securec, JDK, and protobuf. The involved software is as follows: + +| Software | Version | +| -------------- | -------- | +| python | Python3 | +| gcc/g++ | 7.3.0 | +| zlib | 1.2.8 | +| autoconf | 2.69 | +| cmake | 3.13.4 | +| LLVM | 12.0.1 | +| gtest | 1.10.0 | +| jemalloc | 5.2.1 | +| json | 3.7.3 | +| huawei securec | 无 | +| JDK | 1.8 | +| protobuf | 无 | + +### Confirm the Python Version 3 + +![1661586155228](http://image.huawei.com/tiny-lts/v1/images/538785cfce46f6cac67e0d4600956db0_544x89.png) + +Python3 is installed by default. + +### Install gcc and g++ + +#### Install gcc + +Run the command `apt-get install gcc ` + +![1661586482227](http://image.huawei.com/tiny-lts/v1/images/16619fd9c801e744c233b363b18def61_831x131.png) + +#### Install g++ + +Run the command `apt-get install gcc-c++ ` + +![1661586482227](http://image.huawei.com/tiny-lts/v1/images/4c2cc402beb2f009462560a2f56c1ea7_845x131.png@900-0-90-f.png) + +### Install zlib + +Run the command `cat /usr/lib64/pkgconfig/zlib.pc ` to check whether zlib has been installed.To install zlib, the installation options are as follows: + +- Install zlib via mirror source which means that run the command `apt-get install zlib ` + +- Install zlib via source code which means that download the source code from http://www.zlib.net/fossils/zlib-1.2.8.tar.gz and run the following command to install zlib + + ``` + tar -xvzf zlib-1.2.8.tar.gz + cd zlib-1.2.8.tar.gz + ./configure + make + make install + ``` + +### Install autoconf + +Run the command `apt-get install autoconf ` + +![1661587943041](http://image.huawei.com/tiny-lts/v1/images/65ede1529b53f9d8c20dd062e1696aec_848x186.png@900-0-90-f.png) + +### Install cmake + +- Download source code from https://github.com/Kitware/CMake/archive/refs/tags/v3.13.4.tar.gz + +- Run the following command to install cmake via the source code + + ``` + tar zxvf CMake-3.13.4.tar.gz + cd CMake-3.13.4 + ./bootstrap + gmake + gmake install + ``` + + +### Install LLVM + +- Download source code from + +- Run the following command to install LLVM via the source code + + ``` + tar zxvf llvm-project-llvmorg-12.0.1.tar.gz + cd llvm-project-llvmorg-12.0.1 + mkdir build + cd build + cmake -DCMAKE_INSTALL_PREFIX=/opt/omni-operator/llvm12 -DCMAKE_BUILD_TYPE=Release -DLLVM_BUILD_LLVM_DYLIB=true -DLLVM_ENABLE_PROJECTS="clang;libcxx;libcxxabi" -G "Unix Makefiles" ../llvm + make -j4 + make install + ``` + + Notice:the path `/opt/omni-operator/llvm12 `can be customized. + +### Install gtest + +- Download source code from   + +- Run the following command to install gtest via the source code + + ``` + tar zxvf googletest-release-1.10.0.tar.gz + cd googletest-release-1.10.0 + cmake CMakeLists.txt + make + make install + ``` + +### Install jemalloc + +- Download source code from + +- Run the following command to install jemalloc via the source code + + ``` + tar zxvf jemalloc-5.2.1.tar.gz + cd jemalloc-5.2.1 + ./autogen.sh --disable-initial-exec-tls + make -j2 + sudo make install + ``` + +### Install json + +The json version of OmniOperator requires at least 3.7.3. If the GCC version is 10.., choose a later version of json such as 3.9.1. Otherwise, it will result in compilation failure.If the GCC version is 7.3.0,the json version 3.7.3 is recommended. + +- Download source code from   + +- Run the following command to install jemalloc via the source code + + ``` + tar zxvf json-3.9.1.tar.gz + cd json-3.9.1 + mkdir build + cd build/ + cmake .. + make + make install + ``` + +Notice:The installation procedures for versions 3.9.1 and 3.7.3 are the same. + +### Install huawei securec + +- Download source code from + +- Run the following command to install jemalloc via the source code + + ``` + tar zxvf huawei_secure_c-tag_Huawei_Secure_C_V100R001C01SPC011B003_00001.tar.gz + mv huawei_secure_c-tag_Huawei_Secure_C_V100R001C01SPC011B003_00001 huawei_secure_c + cd huawei_secure_c/src + make + ``` + +### Install JDK + +Run the command `apt-get install java-8-openjdk ` + +### Install protobuf + +To install protobuf, the installation options are as follows: + +- Install protobuf via mirror source which means that run the command `apt-get install protobuf-compiler` + +- Install protobuf via source code which means that run the command `git clone [git@github.com](mailto:git@github.com):protocolbuffers/protobuf.git ` to download the source code and run the following command to install protobuf + + ``` + apt-get install autoconf automake libtool curl make g++ unzip + ./autogen.sh + ./configure + make + make check + sudo make install + sudo ldconfig + ``` + +If the version number is displayed after run the command `protoc --version` , the installation of protobuf is successful. + +### Configure environment variables + +Run the command `vi /etc/profile `,and add the following content to the file: + +- Configure LLVM + + To compile OmniOperatorJit, the LLVM configuration is as follows: + + ``` + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib:/opt/omni-operator/llvm12/lib + export LIBRARY_PATH=$LIBRARY_PATH:/usr/local/lib:/opt/omni-operator/llvm12/lib + export PATH=${PATH}:/opt/omni-operator/llvm12/bin + export C_INCLUDE_PATH=$C_INCLUDE_PATH:/opt/omni-operator/llvm12/include + export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/opt/omni-operator/llvm12/include + ``` + + Adjust the path that the variables point to based on the path of the LLVM compiled previously. + +- Configure huawei securec + + To compile OmniOperatorJit, the huawei securec configuration is as follows: + + ``` + export LD_LIBRARY_PATH(add the path) ":/opt/omni-operator/software/huawei_secure_c/lib" + export LIBRARY_PATH(add the path) ":/opt/omni-operator/software/huawei_secure_c/lib" + export C_INCLUDE_PATH(add the path) ":/opt/omni-operator/software" + export CPLUS_INCLUDE_PATH(add the path) ":/opt/omni-operator/software" + ``` + + huawei_secure_c is stored in the directory `/opt/omni-operator/software` . It is acceptable to place huawei_secure_c in a proper directory and specify the defined directory in the configuration file. + +- Configure JDK + + Run the command `find / -name "jni.h" ` to find the JDK installation location and configure it in `/etc/profile` + + ![1661592130232](http://image.huawei.com/tiny-lts/v1/images/260d2c72c5a7446a2adc40818f93281c_628x47.png) + + ![1661592155915](http://image.huawei.com/tiny-lts/v1/images/079e4d3db80efc911d9c6caf1eb9a7dd_865x130.png) + + Finally,run the command `source /etc/profile` to update environment variables. + + In the section 2.3.6 LLVM Installation, the path of installation of LLVM is directory `/opt/omni-operator/llvm12` + + ``` + cd /opt/omni-operator/llvm12 + ln –s clang++ clang++-12 + ``` + +## Compile OmniOperatorJit + +code repository + +- openLooKeng Extension code repository: +- OmniOperatorJIT code repository: +- openLooKeng code repository: + +Clone code from the preceding repositories to the local host. The compilation sequence is OmniOperator Jit Core, OmniOperator Jit bindings, omniop-openlookeng-extension, and openLooKeng. + + Notice: the openLooKeng branch must be set to 1.6 or later. + +### Compile OmniOperatorJit Core + +- Run command `cd /OmniOperatorJIT/core/build/ `to go to the directory + + - Run the different compilation commands based on different options, as follows: + + ``` + RELEASE: + sh build.sh release [--disable-cpuchecker] [--enable-dt] + + DEBUG: + sh build.sh debug [op/vec/llvm/all] [--disable-cpuchecker] [--enable-dt] + + TRACE: + sh build.sh trace [op/vec/llvm/all] [--disable-cpuchecker] [--enable-dt] + + COVERAGE: + sh build.sh coverage [op/vec/llvm/all] [--disable-cpuchecker] [--enable-dt] + ``` + + Notice: run the following command to download the jar package before compilation: + + ``` + wget https://szxy1.artifactory.cd-cloud-artifact.tools.huawei.com/artifactory/sz-maven-public/com/huawei/devtest/devtestcov-maven-plugin/2.1.1/devtestcov-maven-plugin-2.1.1.jar --proxy=off --no-check-certificate + + wget https://szxy1.artifactory.cd-cloud-artifact.tools.huawei.com/artifactory/sz-maven-public/com/huawei/devtest/devtestcov-maven-plugin/2.1.1/devtestcov-maven-plugin-2.1.1.pom --proxy=off --no-check-certificate + + mvn install:install-file -Dfile=devtestcov-maven-plugin-2.1.1.jar -DpomFile=devtestcov-maven-plugin-2.1.1.pom + ``` + + ``` + wget https://cmc.centralrepo.rnd.huawei.com/artifactory/product_maven/com/huawei/bepcloud/rebuild-maven-plugin/2.1.T1.SPC2/rebuild-maven-plugin-2.1.T1.SPC2.jar --proxy=off --no-check-certificate + + wget https://cmc.centralrepo.rnd.huawei.com/artifactory/product_maven/com/huawei/bepcloud/rebuild-maven-plugin/2.1.T1.SPC2/rebuild-maven-plugin-2.1.T1.SPC2.pom --proxy=off --no-check-certificate + + mvn install:install-file -Dfile=rebuild-maven-plugin-2.1.T1.SPC2.jar -DpomFile=rebuild-maven-plugin-2.1.T1.SPC2.pom + ``` + +### Compile OmniOperatorJit bindings + +- Run command `cd /OmniOperatorJIT/bindings/java `to go to the directory +- Run command `mvn clean install -DskipTests -T 1C ` to compile code,and the compiled target jar package is stored in the directory `/OmniOperatorJIT/bindings/java/target/ ` + +### Compile omni-openlookeng-extension + +- Run command `cd /hetu-core ` to go to the directory and choose the branch-1.6 +- Run command `mvn clean install -DskipTests -T 1C ` to compile code and the compiled target jar package is stored in the directory `/hetu-core/hetu-server/target/ `.In addition, it is acceptable to download it from the official website of openLooKeng. + +## Install openLooKeng with OmniOperatorJit + +### Prepare documents + +In this example,` /opt/omni-operator` is used as the root directory for installing openLooKeng. + +Obtain the following packages and files from the compiled target files: + +hetu-server-*.*.*-SNAPSHOT.tar.gz + +omni-openLooKeng-adapter-*.*.*-SNAPSHOT.jar + +boostkit-omniop-bindings-1.0.0-aarch64.jar + +*. *. * indicates the version number, for example, 1.6.1. + +### Deploy the OpenLooKeng Engine + +- Deploy the openLooKeng engine on the node. For details, see the openLooKeng deployment guide:https://docs.openlookeng.io/zh/docs/docs/installation/deployment.html. + +- After the openLooKeng is deployed, locate the deployment directory of the openLooKeng engine (for example, `/opt/hetu-server-1.6.1`) and place the file boostkit-omniop-openlookeng-1.6.1-1.0.0-aarch64.jar to the directory `/opt/omni-operator` . + +- Configure openLooKengExtension + + Add the following content to file hetu-server-1.6.1/etc/config.properties + + ``` + extension_execution_planner_enabled=false + extension_execution_planner_jar_path=file:///opt/omni-operator/boostkit-omniop-openlookeng-1.6.1-1.0.0-aarch64.jar + extension_execution_planner_class_path=nova.hetu.olk.OmniLocalExecutionPlanner + ``` + +- Copy boostkit-omniop-bindings-1.0.0-aarch64.jar to the lib directory of hetu-server-1.6.1. + +- Start hetu-server + + ``` + export LD_LIBRARY_PATH=/opt/omni-operator/lib + export OMNI_HOME=/opt/omni-operator + /opt/hetu-server-1.6.1/bin/launcher start + ``` + +# Common issues +1. ***Error***: fatal, unable to access 'https://github.com/jemalloc/jemalloc.git/': server certificate verification failed. CAfile: /etc/ssl/certs/ca-certificates.crt CRLfile: none.
+ ***Solution***: `export GIT_SSL_NO_VERIFY=1` +2. ***Error***: fatal, unable to access 'https://codehub-dg-y.huawei.com/data-app-lab/omni-runtime.git/': Received HTTP code 504 from proxy after CONNECT.
+ ***Solution***: `unset http_proxy; unset https_proxy` +3. ***Tip***: If your username or password has some special characters, please add `'\'` before them when configuring proxy. + + + diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..7a4d75f --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,133 @@ +# required cmake version +cmake_minimum_required(VERSION 3.14.1) + +#project name +project(omni-runtime) + +message(STATUS "Now processor is ${CMAKE_HOST_SYSTEM_PROCESSOR}") + +# omniruntime version +set(OMNI_RUNTIME_VERSION 1.9.0) +set(OS_ARCHITECTURES aarch64) +set(OMNI_CODEGEN_SO boostkit-omniop-codegen-${OMNI_RUNTIME_VERSION}-${OS_ARCHITECTURES}) +set(OMNI_OPERATOR_SO boostkit-omniop-operator-${OMNI_RUNTIME_VERSION}-${OS_ARCHITECTURES}) +set(OMNI_RUNTIME_SO boostkit-omniop-runtime-${OMNI_RUNTIME_VERSION}-${OS_ARCHITECTURES}) +set(OMNI_VECTOR_SO boostkit-omniop-vector-${OMNI_RUNTIME_VERSION}-${OS_ARCHITECTURES}) + +# configure cmake +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_COMPILER "g++") + +# this is used to generate LLVM IR for jitting +set(IR_COMPILER "clang++-15") +message(STATUS "IR_COMPILER : ${IR_COMPILER}") + +if (DEFINED COVERAGE) + if (${COVERAGE} STREQUAL "ON") + set(DEBUG_COMPILE_LEVEL "-O2 -fsanitize=address -fsanitize-recover=address,all -static-libasan") + if (DEFINED ENABLE_DT AND ENABLE_DT STREQUAL "ON") + set(DEBUG_COMPILE_LEVEL "-fsanitize=address -fsanitize-coverage=trace-pc") + endif () + else () + set(DEBUG_COMPILE_LEVEL "-O0") + endif () +elseif(DEFINED TRACE) + set(DEBUG_COMPILE_LEVEL "-O3 -s") +else () + set(DEBUG_COMPILE_LEVEL "-O0") +endif () + +find_program(CCACHE_FOUND ccache) +if (CCACHE_FOUND) + set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ccache) + set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK ccache) +endif (CCACHE_FOUND) + +option(ENABLE_SVE "Enable SVE" OFF) +message(STATUS "Option Enable SVE: ${ENABLE_SVE}") + +if (ENABLE_SVE) + add_definitions(-DENABLE_SVE) + set(SIMD_COMPILE_LEVEL "-march=armv8-a+crc+sve -fpermissive") +else () + set(SIMD_COMPILE_LEVEL "-march=armv8-a+crc -fpermissive") +endif () + +#-ffast-math +#-fopt-info-vec-optimized +set(CMAKE_CXX_FLAGS_DEBUG "${DEBUG_COMPILE_LEVEL} ${SIMD_COMPILE_LEVEL} -fno-var-tracking-assignments -pipe -g -Wall -fPIC -fno-omit-frame-pointer -fno-common -fno-stack-protector -ftest-coverage -fprofile-arcs") +set(CMAKE_CXX_FLAGS_RELEASE "-O3 ${SIMD_COMPILE_LEVEL} -fno-var-tracking-assignments -pipe -Wall -Wtrampolines -D_FORTIFY_SOURCE=2 -fPIC -finline-functions -fstack-protector-strong -s -Wl,-z,noexecstack -Wl,-z,relro,-z,now") +set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -finline-limit=6000 --param inline-unit-growth=300") + +if(ENABLE_COMPILE_TIME_REPORT) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ftime-report") +endif() + +set(CMAKE_VERBOSE_MAKEFILE OFF) + +if (EXISTS $ENV{OMNI_HOME}) + set(CMAKE_INSTALL_PREFIX $ENV{OMNI_HOME}/lib) +else () + set(CMAKE_INSTALL_PREFIX /opt/lib) +endif () + +message(STATUS "Build by ${CMAKE_BUILD_TYPE}") +message(STATUS "OMNI_HOME is $ENV{OMNI_HOME}") +message(STATUS "JAVA_HOME is $ENV{JAVA_HOME}") + +# clean all target +set(target_generated ${CMAKE_INSTALL_PREFIX}/lib${OMNI_RUNTIME_SO}.so + ${CMAKE_INSTALL_PREFIX}/lib${OMNI_VECTOR_SO}.so + ${CMAKE_INSTALL_PREFIX}/lib${OMNI_CODEGEN_SO}.so + ${CMAKE_INSTALL_PREFIX}/ir + ${CMAKE_INSTALL_PREFIX}/lib${OMNI_OPERATOR_SO}.so + ) +foreach (file ${target_generated}) + if (EXISTS ${file}) + file(REMOVE_RECURSE ${file}) + message(STATUS "Delete file or directory:${file}") + endif () +endforeach (file) + +aux_source_directory(${CMAKE_CURRENT_LIST_DIR} ROOT_SRCS) +# for header searching +include_directories(SYSTEM core/src) +include_directories(SYSTEM core) +include_directories(${CMAKE_INSTALL_PREFIX}/lib/APL/include) + +# compile library +add_subdirectory(core) + +add_subdirectory(bindings/java/src/main/cpp EXCLUDE_FROM_ALL) + + +# options +option(DEBUG "Debug" OFF) +message(STATUS "Option DEBUG: ${DEBUG}") + +option(DEBUG "Trace" OFF) +message(STATUS "Option TRACE: ${TRACE}") + +option(DEBUG_OPERATOR "Using for debug into native operator code" OFF) +message(STATUS "Option DEBUG_OPERATOR: ${DEBUG_OPERATOR}") + +option(DEBUG_VECTOR "Using for debug into native vector code" OFF) +message(STATUS "Option DEBUG_VECTOR: ${DEBUG_VECTOR}") + +option(DEBUG_LLVM "Using for debug into JIT and codegen code" OFF) +message(STATUS "Option DEBUG_LLVM: ${DEBUG_LLVM}") + +option(DISABLE_JIT "Disable OmniJit to optimize operator" OFF) +message(STATUS "Option DISABLE_JIT: ${DISABLE_JIT}") + +option(COVERAGE "Enable coverage to optimize operator" OFF) +message(STATUS "Option Enable COVERAGE: ${COVERAGE}") + +option(COVERAGE "Disable CPU checker to optimize operator" OFF) +message(STATUS "Option Disable CPU checker: ${DISABLE_CPU_CHECKER}") + +option(COVERAGE "Enable DT" OFF) +message(STATUS "Option Enable DT: ${ENABLE_DT}") + +option(ENABLE_BENCHMARK "Enable Benchmark" OFF) +message(STATUS "Option Enable Benchmark: ${ENABLE_BENCHMARK}") diff --git a/README.md b/README.md index 4e6f011..ae5502a 100644 --- a/README.md +++ b/README.md @@ -1,37 +1,45 @@ -# OmniOperator +# OmniOperatorJIT +OmniRuntime技术项目转产品化项目 -#### 介绍 -OmniOperator operator acceleration is implemented using native code (C/C++) to optimize big data SQL operators. +##Welcome to OmniRuntime! +### Overview +The OmniRuntime is designed to -#### 软件架构 -软件架构说明 +* Fast OmniRuntime data exchange, potential via TCP/UPD/RDMA to directly and efficiently transfer vectors into the memory +* Support inter-language zero-copy +* SIMD optimization via Weld/LLVM +The OmniRuntime is library used to implement higher level data analytics logics -#### 安装教程 +#### OMVector - +The OmniRuntime provides c/c++ interface similar to `vector` with Java binding using JNI. The `OMVector` also provides SIMD enabled +operations called 'in-situ operation', which normally takes another `vector` as parameter. For scalar parameters, it will be +a `vector` with single value. -1. xxxx -2. xxxx -3. xxxx +We try to allocate the memory needed for the vector in continuous space as much as possible. Not requiring all +elements stored in continuous memory space allows us to expand the `vectors` when needed. Separated allocated +memory spaces are referred to as `chunk`, which is the smallest unit of operator for allocation and de-allocation. -#### 使用说明 +To enable fast computation, the `vector` stores metadata such as `min`/`max`/`average`/`sum`/`bitmap`, this will help with +aggregation and locating a specific element in the `vector`. The `vector` will also keep track of the `last used timestamp` of the chunks +in the `vector`. -1. xxxx -2. xxxx -3. xxxx +#### OMCache +OmniRuntime will also provide a `OMCache` API which keep tracks of all the loaded `OMVector`. The OMCache also maintains the mapping between +the `OMVector` and the `Table`, e.g. the schema informtion to support SQL alike operations -#### 参与贡献 -1. Fork 本仓库 -2. 新建 Feat_xxx 分支 -3. 提交代码 -4. 新建 Pull Request +#### Transport +TO BE ADDED +#### Java Binding +MORE DETAILS TO BE ADDED -#### 特技 +##### Vector - the base java class +##### LongVector - long data type +##### VarcharVector - variable length string +... -1. 使用 Readme\_XXX.md 来支持不同的语言,例如 Readme\_en.md, Readme\_zh.md -2. Gitee 官方博客 [blog.gitee.com](https://blog.gitee.com) -3. 你可以 [https://gitee.com/explore](https://gitee.com/explore) 这个地址来了解 Gitee 上的优秀开源项目 -4. [GVP](https://gitee.com/gvp) 全称是 Gitee 最有价值开源项目,是综合评定出的优秀开源项目 -5. Gitee 官方提供的使用手册 [https://gitee.com/help](https://gitee.com/help) -6. Gitee 封面人物是一档用来展示 Gitee 会员风采的栏目 [https://gitee.com/gitee-stars/](https://gitee.com/gitee-stars/) + +### Getting Started +We provide the guidance to help developers setup and install OmniRuntime. See [building OmniRuntime](./BUILDING.MD). \ No newline at end of file diff --git a/bindings/java/pom.xml b/bindings/java/pom.xml new file mode 100644 index 0000000..3017950 --- /dev/null +++ b/bindings/java/pom.xml @@ -0,0 +1,272 @@ + + + + 4.0.0 + + com.huawei.boostkit + boostkit-omniop-bindings + jar + 1.9.0 + + + UTF-8 + 30.0-jre + 8 + 8 + 8 + 5.0.0 + 3.13.0-h19 + 2.10.0 + 6.10 + 1.20 + 3.2.4 + 1.2.78 + 0.7.9 + 1.7.30 + ${env.OMNI_HOME} + NONE + + + + + com.google.guava + guava + ${guava.version} + provided + + + com.google.protobuf + protobuf-java + ${protobuf.version} + provided + + + com.fasterxml.jackson.core + jackson-annotations + ${jackson.version} + provided + + + com.fasterxml.jackson.core + jackson-databind + ${jackson.version} + provided + + + + + org.testng + testng + ${testng.version} + test + + + org.openjdk.jmh + jmh-core + ${jdk.jmh.version} + test + + + org.openjdk.jmh + jmh-generator-annprocess + ${jdk.jmh.version} + test + + + net.openhft + affinity + ${openhft.version} + test + + + org.apache.arrow + arrow-memory-core + ${arrow.version} + test + + + org.slf4j + slf4j-api + ${log.version} + provided + + + org.slf4j + slf4j-simple + ${log.version} + test + + + org.apache.arrow + arrow-memory-netty + ${arrow.version} + test + + + org.apache.arrow + arrow-vector + ${arrow.version} + test + + + com.alibaba + fastjson + ${fastjson.version} + test + + + org.jacoco + org.jacoco.agent + runtime + ${jacoco.version} + test + + + + + + + ${omni.home}/java-binding + + *.so + + + + + + exec-maven-plugin + org.codehaus.mojo + 3.0.0 + + + Build CPP + generate-resources + + exec + + + bash + ${project.basedir}/../.. + + ${omni.home} + + + ${project.build.scriptSourceDirectory}/build_cpp.sh + ${omni.build.options} + + + + + + + org.apache.maven.plugins + maven-jar-plugin + 3.1.2 + + + + jar + + default-jar + + aarch64 + + false + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.8.0 + + 1.8 + 1.8 + UTF-8 + true + + -Xlint:all + + + + + org.xolstice.maven.plugins + protobuf-maven-plugin + 0.6.1 + + protoc + + + + + compile + + + + + + org.apache.maven.plugins + maven-surefire-plugin + 3.0.0-M6 + + + ${omni.home}/lib:${env.LD_LIBRARY_PATH} + + + + + com.huawei.devtest + devtestcov-maven-plugin + 2.1.1 + + ${active.devtest} + SiteManagement + 2.1.0 + true + + **/nova/hetu/omniruntime/vector/serialize/* + + + + + insert + + dev_insert + + + + + + com.huawei.bepcloud + rebuild-maven-plugin + 2.1.T1.SPC2 + + + reproducibleBuilds + + reproducibleBuild + + + true + true + + + MANIFEST.MF + Include-Resource + + + pom.properties + # + + + + + + + + + diff --git a/bindings/java/src/main/cpp/CMakeLists.txt b/bindings/java/src/main/cpp/CMakeLists.txt new file mode 100644 index 0000000..04920b8 --- /dev/null +++ b/bindings/java/src/main/cpp/CMakeLists.txt @@ -0,0 +1,19 @@ +aux_source_directory(${CMAKE_CURRENT_LIST_DIR}/src JNI_LIBS_SOURCE) + +set(BINDING_TARGET boostkit-omniop-java-binding-${OMNI_RUNTIME_VERSION}-${OS_ARCHITECTURES}) + +add_library(${BINDING_TARGET} SHARED ${JNI_LIBS_SOURCE}) + +# dependent include +target_include_directories(${BINDING_TARGET} PUBLIC $ENV{JAVA_HOME}/include) +target_include_directories(${BINDING_TARGET} PUBLIC $ENV{JAVA_HOME}/include/linux) + +target_link_libraries(${BINDING_TARGET} PUBLIC ${OMNI_OPERATOR_SO} ${OMNI_CODEGEN_SO} ${OMNI_VECTOR_SO}) + +if (EXISTS $ENV{OMNI_HOME}) + file(REMOVE_RECURSE $ENV{OMNI_HOME}/java-binding) + set_target_properties(${BINDING_TARGET} PROPERTIES LIBRARY_OUTPUT_DIRECTORY $ENV{OMNI_HOME}/java-binding) + set_target_properties(${BINDING_TARGET} PROPERTIES SKIP_BUILD_RPATH YES) +endif() + + diff --git a/bindings/java/src/main/cpp/src/jni_common_def.cpp b/bindings/java/src/main/cpp/src/jni_common_def.cpp new file mode 100644 index 0000000..9438e8d --- /dev/null +++ b/bindings/java/src/main/cpp/src/jni_common_def.cpp @@ -0,0 +1,44 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2024. All rights reserved. + * Description: JNI common functions + */ + +#include "jni_common_def.h" + +jclass bufCls; +jclass traceUtilCls; +jclass omniRuntimeExceptionClass; + +jmethodID traceUtilStackMethodId; + +static const jint JNI_VERSION = JNI_VERSION_1_6; + +jclass CreateGlobalClassRef(JNIEnv *env, const char *className) +{ + jclass localClass = env->FindClass(className); + jclass globalClass = (jclass)env->NewGlobalRef(localClass); + env->DeleteLocalRef(localClass); + return globalClass; +} + +jint JNI_OnLoad(JavaVM *vm, void *reserved) +{ + JNIEnv *env = nullptr; + if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION) != JNI_OK) { + return JNI_ERR; + } + bufCls = CreateGlobalClassRef(env, "java/nio/ByteBuffer"); + traceUtilCls = CreateGlobalClassRef(env, "nova/hetu/omniruntime/utils/TraceUtil"); + traceUtilStackMethodId = env->GetStaticMethodID(traceUtilCls, "stack", "()Ljava/lang/String;"); + omniRuntimeExceptionClass = CreateGlobalClassRef(env, "nova/hetu/omniruntime/utils/OmniRuntimeException"); + return JNI_VERSION; +} + +void JNI_OnUnload(JavaVM *vm, const void *reserved) +{ + JNIEnv *env = nullptr; + vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6); + env->DeleteGlobalRef(bufCls); + env->DeleteGlobalRef(traceUtilCls); + env->DeleteGlobalRef(omniRuntimeExceptionClass); +} \ No newline at end of file diff --git a/bindings/java/src/main/cpp/src/jni_common_def.h b/bindings/java/src/main/cpp/src/jni_common_def.h new file mode 100644 index 0000000..82cd3e2 --- /dev/null +++ b/bindings/java/src/main/cpp/src/jni_common_def.h @@ -0,0 +1,111 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2024. All rights reserved. + * Description: JNI common functions + */ +#ifndef JNI_COMMON_DEF_H +#define JNI_COMMON_DEF_H + +#include +#include "util/omni_exception.h" + +#define JNI_METHOD_START try { +// macro end + +#define JNI_METHOD_END(fallBackExpr) \ + } \ + catch (const std::exception &e) \ + { \ + env->ThrowNew(omniRuntimeExceptionClass, e.what()); \ + return fallBackExpr; \ + } \ + // macro end + +#define JNI_METHOD_END_WITH_EXPRS_RELEASE(fallBackExpr, toDeleteExprs) \ + } \ + catch (const std::exception &e) \ + { \ + Expr::DeleteExprs(toDeleteExprs); \ + env->ThrowNew(omniRuntimeExceptionClass, e.what()); \ + return fallBackExpr; \ + } \ + // macro end + +#define JNI_METHOD_END_WITH_MULTI_EXPRS(fallBackExpr, toDeleteExprs1, toDeleteExprs2) \ + } \ + catch (const std::exception &e) \ + { \ + Expr::DeleteExprs(toDeleteExprs1); \ + Expr::DeleteExprs(toDeleteExprs2); \ + env->ThrowNew(omniRuntimeExceptionClass, e.what()); \ + return fallBackExpr; \ + } \ + // macro end + +#define JNI_METHOD_END_WITH_THREE_EXPRS(fallBackExpr, toDeleteExprs1, toDeleteExprs2, toDeleteExprs3) \ + } \ + catch (const std::exception &e) \ + { \ + Expr::DeleteExprs(toDeleteExprs1); \ + Expr::DeleteExprs(toDeleteExprs2); \ + Expr::DeleteExprs(toDeleteExprs3); \ + env->ThrowNew(omniRuntimeExceptionClass, e.what()); \ + return fallBackExpr; \ + } \ + // macro end + + +#define JNI_METHOD_END_WITH_OVERFLOW(fallBackExpr, overflowConfig) \ + } \ + catch (const std::exception &e) \ + { \ + delete (overflowConfig); \ + env->ThrowNew(omniRuntimeExceptionClass, e.what()); \ + return fallBackExpr; \ + } + +#define JNI_METHOD_END_WITH_EXPRS_OVERFLOW(fallBackExpr, toDeleteExprs, overflowConfig) \ + } \ + catch (const std::exception &e) \ + { \ + Expr::DeleteExprs(toDeleteExprs); \ + delete (overflowConfig); \ + env->ThrowNew(omniRuntimeExceptionClass, e.what()); \ + return fallBackExpr; \ + } + + +#define JNI_METHOD_END_WITH_MULTI_EXPRS_OVERFLOW(fallBackExpr, toDeleteExprs1, toDeleteExprs2, overflowConfig) \ + } \ + catch (const std::exception &e) \ + { \ + Expr::DeleteExprs(toDeleteExprs1); \ + Expr::DeleteExprs(toDeleteExprs2); \ + delete (overflowConfig); \ + env->ThrowNew(omniRuntimeExceptionClass, e.what()); \ + return fallBackExpr; \ + } + +#define JNI_METHOD_END_WITH_VECBATCH(fallBackExpr, toDeleteVecBatch) \ + } \ + catch (const std::exception &e) \ + { \ + VectorHelper::FreeVecBatch(toDeleteVecBatch); \ + env->ThrowNew(omniRuntimeExceptionClass, e.what()); \ + return fallBackExpr; \ + } + +#ifdef __cplusplus +extern "C" { +#endif + +extern jclass bufCls; +extern jclass traceUtilCls; +extern jclass omniRuntimeExceptionClass; +extern jmethodID traceUtilStackMethodId; + +jclass CreateGlobalClassRef(JNIEnv *env, const char *className); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/bindings/java/src/main/cpp/src/jni_constants.cpp b/bindings/java/src/main/cpp/src/jni_constants.cpp new file mode 100644 index 0000000..d00fa1c --- /dev/null +++ b/bindings/java/src/main/cpp/src/jni_constants.cpp @@ -0,0 +1,12 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2024. All rights reserved. + * Description: JNI Constants + */ + +#include "jni_constants.h" + +JNIEXPORT jstring JNICALL Java_nova_hetu_omniruntime_OmniLibs_getVersion(JNIEnv *env, jclass ignore) +{ + return (*env).NewStringUTF( + "Product Name: Kunpeng BoostKit\nProduct Version: 25.0.0\nComponent Name: BoostKit-omniop\nComponent Version: 1.9.0"); +} \ No newline at end of file diff --git a/bindings/java/src/main/cpp/src/jni_constants.h b/bindings/java/src/main/cpp/src/jni_constants.h new file mode 100644 index 0000000..6c52d26 --- /dev/null +++ b/bindings/java/src/main/cpp/src/jni_constants.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2024. All rights reserved. + * Description: JNI Constants + */ +#ifndef JNI_CONSTANTS_H +#define JNI_CONSTANTS_H + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * Class: nova_hetu_omniruntime_OmniLibs + * Method: getVersion + * Signature: ()Ljava/lang/String; + */ +JNIEXPORT jstring JNICALL Java_nova_hetu_omniruntime_OmniLibs_getVersion(JNIEnv *env, jclass ignore); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/bindings/java/src/main/cpp/src/jni_helper.cpp b/bindings/java/src/main/cpp/src/jni_helper.cpp new file mode 100644 index 0000000..da4c9f3 --- /dev/null +++ b/bindings/java/src/main/cpp/src/jni_helper.cpp @@ -0,0 +1,25 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * Description: JNI Operator Operations Source File + */ +#include "jni_helper.h" +#include "operator/hash_util.h" + +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_utils_ShuffleHashHelper_computePartitionIds(JNIEnv *env, + jclass jClass, jlongArray vecAddrArray, jint partitionNum, jint rowCount) +{ + if (partitionNum == 0) { + env->ThrowNew(omniRuntimeExceptionClass, "PartitionNum should not be 0"); + return 0; + } + jsize length = env->GetArrayLength(vecAddrArray); + jlong *addrs = (*env).GetLongArrayElements(vecAddrArray, nullptr); + std::vector vecs; + for (int i = 0; i < length; ++i) { + auto vec = reinterpret_cast(addrs[i]); + vecs.push_back(vec); + } + env->ReleaseLongArrayElements(vecAddrArray, addrs, JNI_ABORT); + auto ret = omniruntime::op::HashUtil::ComputePartitionIds(vecs, partitionNum, rowCount); + return (jlong)ret.release(); +} \ No newline at end of file diff --git a/bindings/java/src/main/cpp/src/jni_helper.h b/bindings/java/src/main/cpp/src/jni_helper.h new file mode 100644 index 0000000..e3004ca --- /dev/null +++ b/bindings/java/src/main/cpp/src/jni_helper.h @@ -0,0 +1,27 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * Description: JNI Operator Operations Source File + */ + +#ifndef JNI_HELPER_H +#define JNI_HELPER_H +#ifdef __cplusplus + +#include +#include "jni_common_def.h" + +extern "C" { +#endif + +/* + * Class: nova_hetu_omniruntime_utils_ShuffleHashHelper + * Method: ComputePartitionIds + * Signature: ([JII)J + */ +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_utils_ShuffleHashHelper_computePartitionIds(JNIEnv *env, + jclass jClass, jlongArray vecAddrArray, jint partitionNum, jint rowCount); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/bindings/java/src/main/cpp/src/jni_operator.cpp b/bindings/java/src/main/cpp/src/jni_operator.cpp new file mode 100644 index 0000000..86ffdb5 --- /dev/null +++ b/bindings/java/src/main/cpp/src/jni_operator.cpp @@ -0,0 +1,377 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2025. All rights reserved. + * Description: JNI Operator Operations Source File + */ + +#include +#include "securec.h" +#include "vector/vector_batch.h" +#include "vector/vector_helper.h" +#include "jni_common_def.h" +#include "operator/operator_factory.h" +#include "operator/aggregation/group_aggregation_expr.h" +#include "vector/omni_row.h" +#include "jni_operator.h" + +using namespace omniruntime::op; +using namespace omniruntime::vec; + +static std::once_flag loadVecBatchClsFlag; + +static jclass vecBatchCls = nullptr; +static jclass rowBatchCls = nullptr; +static jclass rowCls = nullptr; +static jclass omniResultsCls = nullptr; +static jclass rowResultsCls = nullptr; + +static jmethodID omniResultsInitMethodId = nullptr; +static jmethodID vecBatchInitMethodId = nullptr; +static jmethodID rowBatchInitMethodId = nullptr; +static jmethodID rowInitMethodId = nullptr; +static jmethodID rowResultsInitMethodId = nullptr; + +static void RecordInputVectorsStack(VectorBatch *vectorBatch, JNIEnv *env) +{ +#ifdef DEBUG_VECTOR + jstring jstack = (jstring)env->CallStaticObjectMethod(traceUtilCls, traceUtilStackMethodId); + auto stackChars = env->GetStringUTFChars(jstack, JNI_FALSE); + std::string stack(stackChars); + int32_t vecCount = vectorBatch->GetVectorCount(); + for (int i = 0; i < vecCount; ++i) { + Vector *vector = vectorBatch->GetVector(i); + vector->RecordStack(stack, VecOpType::JNI_ADD_INPUT); + } + env->ReleaseStringUTFChars(jstack, stackChars); +#endif +} + +static void RecordOutputVectorsStack(VectorBatch &outputVecBatch, JNIEnv *env) +{ +#ifdef DEBUG_VECTOR + jstring jstack = (jstring)env->CallStaticObjectMethod(traceUtilCls, traceUtilStackMethodId); + auto stackChars = env->GetStringUTFChars(jstack, JNI_FALSE); + std::string stack(stackChars); + for (int j = 0; j < outputVecBatch.GetVectorCount(); ++j) { + Vector *vector = outputVecBatch.GetVector(j); + vector->RecordStack(stack, VecOpType::JNI_GET_OUTPUT); + } + env->ReleaseStringUTFChars(jstack, stackChars); +#endif +} + +static void LoadVecBatchAndOmniResults(JNIEnv *env) +{ + if (vecBatchCls == nullptr) { + // the java adaptor maybe only use VecBatch or Vec like ColumnarBroadcastExchangeExec + // it will load VecBatch class, and then will call System.load to load so + // so load VecBatch class on demand to avoid deadlock + vecBatchCls = CreateGlobalClassRef(env, "nova/hetu/omniruntime/vector/VecBatch"); + vecBatchInitMethodId = env->GetMethodID(vecBatchCls, "", "(J[J[J[J[J[I[II)V"); + omniResultsCls = CreateGlobalClassRef(env, "nova/hetu/omniruntime/operator/OmniResults"); + + rowCls = CreateGlobalClassRef(env, "nova/hetu/omniruntime/vector/Row"); + rowBatchCls = CreateGlobalClassRef(env, "nova/hetu/omniruntime/vector/RowBatch"); + rowResultsCls = CreateGlobalClassRef(env, "nova/hetu/omniruntime/operator/OmniRowResults"); + + omniResultsInitMethodId = + env->GetMethodID(omniResultsCls, "", "(Lnova/hetu/omniruntime/vector/VecBatch;I)V"); + + // Row(long dataAddr, int hashPos, int len) + rowInitMethodId = env->GetMethodID(rowCls, "", "(JI)V"); + + // RowBatch(long nativeAddress, Row[] rows, int rowCount) + rowBatchInitMethodId = env->GetMethodID(rowBatchCls, "", "(J[Lnova/hetu/omniruntime/vector/Row;I)V"); + + // OmniRowResults(RowBatch rowBatch, int status) + rowResultsInitMethodId = + env->GetMethodID(rowResultsCls, "", "(Lnova/hetu/omniruntime/vector/RowBatch;I)V"); + } +} + +static jobject Transform(JNIEnv *env, VectorBatch &result) +{ + int32_t vecCount = result.GetVectorCount(); + int64_t vecAddresses[vecCount]; + int32_t encodings[vecCount]; + int32_t dataTypeIds[vecCount]; + int64_t valueBufAddrs[vecCount]; + int64_t nullBufAddrs[vecCount]; + int64_t offsetsBufAddrs[vecCount]; + for (int32_t i = 0; i < vecCount; ++i) { + BaseVector *vector = result.Get(i); + vecAddresses[i] = reinterpret_cast(vector); + dataTypeIds[i] = vector->GetTypeId(); + encodings[i] = vector->GetEncoding(); + // By default, all 3 buf arrays will have a value, + // if not, it will be 0, which means a null pointer. + valueBufAddrs[i] = reinterpret_cast(VectorHelper::UnsafeGetValues(vector)); + nullBufAddrs[i] = reinterpret_cast(omniruntime::vec::unsafe::UnsafeBaseVector::GetNulls(vector)); + offsetsBufAddrs[i] = reinterpret_cast(VectorHelper::UnsafeGetOffsetsAddr(vector)); + } + + // set vector addresses parameter to vector batch construct. + jlongArray jVecAddresses = env->NewLongArray(vecCount); + env->SetLongArrayRegion(jVecAddresses, 0, vecCount, vecAddresses); + + // set vector encoding + jintArray jVecEncodingIds = env->NewIntArray(vecCount); + env->SetIntArrayRegion(jVecEncodingIds, 0, vecCount, encodings); + + // set vector type ids parameter to vector batch construct. + jintArray jDataTypeIds = env->NewIntArray(vecCount); + env->SetIntArrayRegion(jDataTypeIds, 0, vecCount, dataTypeIds); + + // set vector value buf address + jlongArray jVecValueBufAddrs = env->NewLongArray(vecCount); + env->SetLongArrayRegion(jVecValueBufAddrs, 0, vecCount, valueBufAddrs); + + // set vec null buf address + jlongArray jVecNullBufAddrs = env->NewLongArray(vecCount); + env->SetLongArrayRegion(jVecNullBufAddrs, 0, vecCount, nullBufAddrs); + + // set vec offsets buf address + jlongArray jVecOffsetsBufAddrs = env->NewLongArray(vecCount); + env->SetLongArrayRegion(jVecOffsetsBufAddrs, 0, vecCount, offsetsBufAddrs); + + // create vector batch java object. + jobject obj = env->NewObject(vecBatchCls, vecBatchInitMethodId, (jlong)((int64_t)(&result)), jVecAddresses, + jVecValueBufAddrs, jVecNullBufAddrs, jVecOffsetsBufAddrs, jVecEncodingIds, jDataTypeIds, result.GetRowCount()); + return obj; +} + +static jobject TransformFromRow(JNIEnv *env, RowBatch &result) +{ + int32_t rowCount = result.GetRowCount(); + jobjectArray resultArray = env->NewObjectArray(rowCount, rowCls, nullptr); + for (int32_t i = 0; i < rowCount; ++i) { + RowInfo *row = result.Get(i); + jobject rowObject = env->NewObject(rowCls, rowInitMethodId, reinterpret_cast(row->row), row->length); + env->SetObjectArrayElement(resultArray, i, rowObject); + env->DeleteLocalRef(rowObject); + } + + // create vector batch java object. + jobject obj = env->NewObject(rowBatchCls, rowBatchInitMethodId, (jlong)(&result), resultArray, rowCount); + env->DeleteLocalRef(resultArray); + return obj; +} + +JNIEXPORT jint JNICALL Java_nova_hetu_omniruntime_operator_OmniOperator_addInputNative(JNIEnv *env, jobject jObj, + jlong jOperatorAddress, jlong jVecBatchAddress) +{ + int32_t errNo = 0; + auto *vecBatch = reinterpret_cast(jVecBatchAddress); + auto *nativeOperator = reinterpret_cast(jOperatorAddress); + JNI_METHOD_START + RecordInputVectorsStack(vecBatch, env); + nativeOperator->SetInputVecBatch(vecBatch); + errNo = nativeOperator->AddInput(vecBatch); + JNI_METHOD_END_WITH_VECBATCH(errNo, nativeOperator->GetInputVecBatch()) + return errNo; +} + +/* + * Class: nova_hetu_omniruntime_operator_OmniOperator + * Method: getOutputNative + * Signature: (J)[Lnova/hetu/omniruntime/operator/OMResult; + */ +JNIEXPORT jobject JNICALL Java_nova_hetu_omniruntime_operator_OmniOperator_getOutputNative(JNIEnv *env, jobject jObj, + jlong jOperatorAddr) +{ + std::call_once(loadVecBatchClsFlag, LoadVecBatchAndOmniResults, env); + if (vecBatchCls == nullptr || omniResultsCls == nullptr) { + env->ThrowNew(omniRuntimeExceptionClass, "The class VecBatch or OmniResult has not load yet."); + return nullptr; + } + + auto *nativeOperator = reinterpret_cast(jOperatorAddr); + VectorBatch *outputVecBatch = nullptr; + JNI_METHOD_START + nativeOperator->GetOutput(&outputVecBatch); + JNI_METHOD_END_WITH_VECBATCH(nullptr, outputVecBatch) + jobject result = nullptr; + if (outputVecBatch) { + RecordOutputVectorsStack(*outputVecBatch, env); + result = Transform(env, *outputVecBatch); + } + return env->NewObject(omniResultsCls, omniResultsInitMethodId, result, nativeOperator->GetStatus()); +} + +/* + * Class: nova_hetu_omniruntime_operator_OmniOperator + * Method: close + * Signature: (J)[Lnova/hetu/omniruntime/operator/void; + */ +JNIEXPORT void JNICALL Java_nova_hetu_omniruntime_operator_OmniOperator_closeNative(JNIEnv *env, jobject jObj, + jlong jOperatorAddr) +{ + try { + auto *nativeOperator = reinterpret_cast(jOperatorAddr); + op::Operator::DeleteOperator(nativeOperator); + } catch (const std::exception &e) { + env->ThrowNew(omniRuntimeExceptionClass, e.what()); + } +} + +/* + * Class: nova_hetu_omniruntime_operator_OmniOperator + * Method: getSpilledBytesNative + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_operator_OmniOperator_getSpilledBytesNative(JNIEnv *env, + jobject jObj, jlong jOperatorAddr) +{ + auto *nativeOperator = reinterpret_cast(jOperatorAddr); + return static_cast(nativeOperator->GetSpilledBytes()); +} + +JNIEXPORT jlongArray JNICALL Java_nova_hetu_omniruntime_operator_OmniOperator_getMetricsInfoNative(JNIEnv *env, + jobject jObj, jlong jOperatorAddr) +{ + auto *nativeOperator = reinterpret_cast(jOperatorAddr); + // get simpleMetrics info, used by all operators. + static const uint64_t metricsLength = 200; + static const uint64_t boundaryIndex = 100; + jlongArray metricsInfoArray = env->NewLongArray(metricsLength); + jlong* elementsSimple = env->GetLongArrayElements(metricsInfoArray, nullptr); + elementsSimple[0] = static_cast(nativeOperator->GetSpilledBytes()); + // get specialMetrics info, every operator is different. + std::vector specialMetricsInfoArray = nativeOperator->GetSpecialMetricsInfo(); + long specialMetricsLength = specialMetricsInfoArray.size(); + for (uint64_t i = 0; i < specialMetricsLength; i++) { + elementsSimple[i + boundaryIndex] = specialMetricsInfoArray[i]; + } + + env->ReleaseLongArrayElements(metricsInfoArray, elementsSimple, 0); + return metricsInfoArray; +} + +JNIEXPORT jobject JNICALL Java_nova_hetu_omniruntime_operator_OmniOperator_alignSchemaNative(JNIEnv *env, jobject jObj, + jlong jOperatorAddr, jlong jVecBatchAddr) +{ + auto nativeOperator = (op::Operator *)jOperatorAddr; + auto nativeVeBatch = (vec::VectorBatch *)jVecBatchAddr; + auto outputVecBatch = nativeOperator->AlignSchema(nativeVeBatch); + jobject result = nullptr; + if (outputVecBatch) { + result = Transform(env, *outputVecBatch); + } + return result; +} + +/* + * Class: nova_hetu_omniruntime_operator_OmniOperator + * Method: getHashMapUniqueKeysNative + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_operator_OmniOperator_getHashMapUniqueKeysNative(JNIEnv *env, + jobject jObj, jlong jOperatorAddr) +{ + auto *nativeOperator = (op::Operator *)jOperatorAddr; + return static_cast(nativeOperator->GetHashMapUniqueKeys()); +} + +JNIEXPORT void JNICALL Java_nova_hetu_omniruntime_vector_RowBatch_freeRowBatchNative(JNIEnv *env, jclass jcls, + jlong jrowBatchAddress) +{ + auto *rowBatch = reinterpret_cast(jrowBatchAddress); + delete rowBatch; +} + +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_vector_RowBatch_newRowBatchNative(JNIEnv *env, jclass jcls, + jobjectArray rows, jint rowCount) +{ + jclass rowClass = env->FindClass("nova/hetu/omniruntime/vector/Row"); + + jfieldID rowAddrId = env->GetFieldID(rowClass, "nativeRow", "J"); + jfieldID lengthId = env->GetFieldID(rowClass, "length", "I"); + + auto *rowBatch = new RowBatch(rowCount); + + for (int i = 0; i < rowCount; ++i) { + auto obj = env->GetObjectArrayElement(rows, i); + auto rowAddr = env->GetLongField(obj, rowAddrId); + auto length = env->GetIntField(obj, lengthId); + rowBatch->SetRow(i, new RowInfo((uint8_t *)(rowAddr), length)); + } + return reinterpret_cast(rowBatch); +} + +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_vector_RowBatch_transFromVectorBatch(JNIEnv *env, jclass jcls, + jlong vectorBatch) +{ + auto *vecBatch = reinterpret_cast(vectorBatch); + int32_t vecCount = vecBatch->GetVectorCount(); + std::vector typeIds; + std::vector encodings; + for (int i = 0; i < vecCount; i++) { + BaseVector *vector = vecBatch->Get(i); + typeIds.push_back(vector->GetTypeId()); + encodings.push_back(vector->GetEncoding()); + } + + auto rowBuffer = std::make_unique(typeIds, encodings, typeIds.size() - 1); + + auto rowBatch = std::make_unique(vecBatch->GetRowCount(), typeIds); + for (int32_t i = 0; i < vecBatch->GetRowCount(); ++i) { + // 1.get value from vector batch + rowBuffer->TransValueFromVectorBatch(vecBatch, i); + + // 2.generate one buffer of one row + auto oneRowLen = rowBuffer->FillBuffer(); + + // 3.set one row + rowBatch->SetRow(i, new RowInfo(rowBuffer->TakeRowBuffer(), oneRowLen)); + } + return reinterpret_cast(rowBatch.release()); +} + +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_vector_serialize_OmniRowDeserializer_newOmniRowDeserializer( + JNIEnv *env, jclass jcls, jintArray typeArray, jlongArray vecArray) +{ + jboolean isCopy = false; + auto *types = env->GetIntArrayElements(typeArray, &isCopy); + auto *vecs = env->GetLongArrayElements(vecArray, &isCopy); + auto len = env->GetArrayLength(typeArray); + auto *parser = new RowParser((type::DataTypeId *)types, vecs, len); + env->ReleaseIntArrayElements(typeArray, types, 0); + return reinterpret_cast(parser); +} + +JNIEXPORT void JNICALL Java_nova_hetu_omniruntime_vector_serialize_OmniRowDeserializer_freeOmniRowDeserializer( + JNIEnv *env, jclass jcls, jlong parserAddr) +{ + auto *rowParser = reinterpret_cast(parserAddr); + delete rowParser; +} + +JNIEXPORT void JNICALL Java_nova_hetu_omniruntime_vector_serialize_OmniRowDeserializer_parseOneRow(JNIEnv *env, + jclass jcls, jlong parserAddr, jbyteArray bytes, jint rowIndex) +{ + jboolean isCopy = false; + auto *row = env->GetByteArrayElements(bytes, &isCopy); + auto *parser = reinterpret_cast(parserAddr); + parser->ParseOnRow(reinterpret_cast(row), rowIndex); + env->ReleaseByteArrayElements(bytes, row, 0); +} + +JNIEXPORT void JNICALL Java_nova_hetu_omniruntime_vector_serialize_OmniRowDeserializer_parseOneRowByAddr(JNIEnv *env, + jclass jcls, jlong parserAddr, jlong rowAddr, jint rowIndex) +{ + auto *parser = reinterpret_cast(parserAddr); + auto *row = reinterpret_cast(rowAddr); + if (row != nullptr) { + parser->ParseOnRow(row, rowIndex); + } +} + +JNIEXPORT void JNICALL Java_nova_hetu_omniruntime_vector_serialize_OmniRowDeserializer_parseAllRow(JNIEnv *env, + jclass jcls, jlong parserAddr, jlong rowBatchAddr) +{ + auto *parser = reinterpret_cast(parserAddr); + auto *rowBatch = reinterpret_cast(rowBatchAddr); + + for (int i = 0; i < rowBatch->GetRowCount(); ++i) { + parser->ParseOnRow(rowBatch->Get(i)->row, i); + } +} \ No newline at end of file diff --git a/bindings/java/src/main/cpp/src/jni_operator.h b/bindings/java/src/main/cpp/src/jni_operator.h new file mode 100644 index 0000000..6605cab --- /dev/null +++ b/bindings/java/src/main/cpp/src/jni_operator.h @@ -0,0 +1,101 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2024. All rights reserved. + * Description: JNI Operator Operations Header + */ +#ifndef JNI_OPERATOR_H +#define JNI_OPERATOR_H + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * Class: nova_hetu_omniruntime_operator_OmniOperator + * Method: getOutputNative + * Signature: (J)[Lnova/hetu/omniruntime/operator/OMResult; + */ + +/* + * Class: nova_hetu_omniruntime_operator_OmniOperator + * Method: addInputNative + * Signature: (JJIJI)I + */ +JNIEXPORT jint JNICALL Java_nova_hetu_omniruntime_operator_OmniOperator_addInputNative(JNIEnv *, jobject, jlong, jlong); +/* + * Class: nova_hetu_omniruntime_operator_OmniOperator + * Method: getOutputNative + * Signature: (J)[Lnova/hetu/omniruntime/operator/OMResult; + */ + +JNIEXPORT jobject JNICALL Java_nova_hetu_omniruntime_operator_OmniOperator_getOutputNative(JNIEnv *, jobject, jlong); + +/* + * Class: nova_hetu_omniruntime_operator_OmniOperator + * Method: close + * Signature: (J)[Lnova/hetu/omniruntime/operator/void; + */ +JNIEXPORT void JNICALL Java_nova_hetu_omniruntime_operator_OmniOperator_closeNative(JNIEnv *, jobject, jlong); + +/* + * Class: nova_hetu_omniruntime_operator_OmniOperator + * Method: getSpilledBytesNative + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_operator_OmniOperator_getSpilledBytesNative(JNIEnv *, jobject, + jlong); + +/* + * Class: nova_hetu_omniruntime_operator_OmniOperator + * Method: getMetricsInfoNative + * Signature: (J)J + */ +JNIEXPORT jlongArray JNICALL Java_nova_hetu_omniruntime_operator_OmniOperator_getMetricsInfoNative(JNIEnv *, jobject, + jlong); + +/* + * Class: nova_hetu_omniruntime_operator_OmniOperator + * Method: alignSchemaNative + * Signature: (JJ)[Lnova/hetu/omniruntime/vector/VecBatch + */ +JNIEXPORT jobject JNICALL Java_nova_hetu_omniruntime_operator_OmniOperator_alignSchemaNative(JNIEnv *, jobject, jlong, + jlong); + +/* + * Class: nova_hetu_omniruntime_operator_OmniOperator + * Method: getHashMapUniqueKeysNative + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_operator_OmniOperator_getHashMapUniqueKeysNative(JNIEnv *, jobject, + jlong); + + +JNIEXPORT void JNICALL Java_nova_hetu_omniruntime_vector_RowBatch_freeRowBatchNative(JNIEnv *env, jclass jcls, + jlong jVecBatchAddress); + +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_vector_RowBatch_transFromVectorBatch(JNIEnv *env, jclass jcls, + jlong vectorBatch); + +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_vector_RowBatch_newRowBatchNative(JNIEnv *env, jclass jcls, + jobjectArray rows, jint rowCount); + +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_vector_serialize_OmniRowDeserializer_newOmniRowDeserializer( + JNIEnv *env, jclass jcls, jintArray typeArray, jlongArray vecs); + +JNIEXPORT void JNICALL Java_nova_hetu_omniruntime_vector_serialize_OmniRowDeserializer_freeOmniRowDeserializer( + JNIEnv *env, jclass jcls, jlong parserAddr); + +JNIEXPORT void JNICALL Java_nova_hetu_omniruntime_vector_serialize_OmniRowDeserializer_parseOneRow(JNIEnv *env, + jclass jcls, jlong parserAddr, jbyteArray bytes, jint rowIndex); + +JNIEXPORT void JNICALL Java_nova_hetu_omniruntime_vector_serialize_OmniRowDeserializer_parseOneRowByAddr(JNIEnv *env, + jclass jcls, jlong parserAddr, jlong rowAddr, jint rowIndex); + +JNIEXPORT void JNICALL Java_nova_hetu_omniruntime_vector_serialize_OmniRowDeserializer_parseAllRow(JNIEnv *env, + jclass jcls, jlong parserAddr, jlong rowBatchAddr); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/bindings/java/src/main/cpp/src/jni_operator_factory.cpp b/bindings/java/src/main/cpp/src/jni_operator_factory.cpp new file mode 100644 index 0000000..440c515 --- /dev/null +++ b/bindings/java/src/main/cpp/src/jni_operator_factory.cpp @@ -0,0 +1,1637 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2025. All rights reserved. + * Description: JNI Operator Factory Source File + */ + +#include "jni_operator_factory.h" +#include "operator/operator_factory.h" +#include "operator/sort/sort.h" +#include "operator/sort/sort_expr.h" +#include "operator/aggregation/aggregator/aggregator_util.h" +#include "operator/aggregation/group_aggregation.h" +#include "operator/aggregation/group_aggregation_expr.h" +#include "operator/aggregation/non_group_aggregation.h" +#include "operator/aggregation/non_group_aggregation_expr.h" +#include "operator/filter/filter_and_project.h" +#include "codegen/bloom_filter.h" +#include "operator/window/window.h" +#include "operator/join/hash_builder.h" +#include "operator/join/lookup_join.h" +#include "operator/join/hash_builder_expr.h" +#include "operator/join/lookup_join_expr.h" +#include "operator/join/lookup_outer_join.h" +#include "operator/join/lookup_outer_join_expr.h" +#include "operator/join/sortmergejoin/sort_merge_join_expr.h" +#include "operator/join/sortmergejoin/sort_merge_join_expr_v3.h" +#include "operator/topn/topn.h" +#include "operator/topn/topn_expr.h" +#include "operator/topnsort/topn_sort_expr.h" +#include "operator/union/union.h" +#include "operator/window/window_expr.h" +#include "operator/window/window_group_limit_expr.h" +#include "operator/limit/limit.h" +#include "operator/limit/distinct_limit.h" +#include "operator/config/operator_config.h" +#include "util/config_util.h" +#include "config.h" +#include "jni_common_def.h" +#include "expression/expr_verifier.h" +#include "operator/join/nest_loop_join_builder.h" +#include "operator/join/nest_loop_join_lookup.h" + + +using namespace omniruntime::op; +using namespace omniruntime::expressions; +using namespace std; + +void GetColumnsFromExpressions(JNIEnv *env, jobjectArray &jExpressions, int32_t *columns, int32_t length) +{ + for (int32_t i = 0; i < length; i++) { + auto jSortCol = static_cast(env->GetObjectArrayElement(jExpressions, i)); + const char *columnString = env->GetStringUTFChars(jSortCol, JNI_FALSE); + columns[i] = std::stoi(columnString + 1); + env->ReleaseStringUTFChars(jSortCol, columnString); + } +} + +void GetExpressions(JNIEnv *env, jobjectArray jExpressions, std::string *expressions, int32_t expressionCount) +{ + for (int32_t i = 0; i < expressionCount; i++) { + auto jExpression = static_cast(env->GetObjectArrayElement(jExpressions, i)); + auto key = env->GetStringUTFChars(jExpression, JNI_FALSE); + expressions[i] = key; + env->ReleaseStringUTFChars(jExpression, key); + } +} + +void GetExprsFromJson(const string *keysArr, jint keyCount, std::vector &expressions) +{ + for (int32_t i = 0; i < keyCount; i++) { + auto jsonExpression = nlohmann::json::parse(keysArr[i]); + auto expression = JSONParser::ParseJSON(jsonExpression); + if (expression == nullptr) { + Expr::DeleteExprs(expressions); + throw omniruntime::exception::OmniException("EXPRESSION_NOT_SUPPORT", + "The expression is not supported yet: " + jsonExpression.dump()); + } + expressions.push_back(expression); + } +} + +void GetExprsFromJson(std::vector &keysArr, jint keyCount, + std::vector &expressions) +{ + for (int32_t i = 0; i < keyCount; i++) { + auto jsonExpression = nlohmann::json::parse(keysArr.at(i)); + auto expression = JSONParser::ParseJSON(jsonExpression); + if (expression == nullptr) { + Expr::DeleteExprs(expressions); + throw omniruntime::exception::OmniException("EXPRESSION_NOT_SUPPORT", + "The expression is not supported yet: " + jsonExpression.dump()); + } + expressions.push_back(expression); + } +} + +void GetFilterExprsFromJson(std::string *keysArr, jint keyCount, + std::vector &expressions) +{ + for (int32_t i = 0; i < keyCount; i++) { + if (keysArr[i].empty()) { + expressions.push_back(nullptr); + continue; + } + auto jsonExpression = nlohmann::json::parse(keysArr[i]); + auto expression = JSONParser::ParseJSON(jsonExpression); + expressions.push_back(expression); + } +} + +void GetBoolVector(JNIEnv *env, jbooleanArray booleanArray, std::vector &output) +{ + auto length = static_cast(env->GetArrayLength(booleanArray)); + auto bools = env->GetBooleanArrayElements(booleanArray, JNI_FALSE); + for (int32_t i = 0; i < length; i++) { + output.push_back(bools[i]); + } + env->ReleaseBooleanArrayElements(booleanArray, bools, 0); +} + +void GetIntVector(JNIEnv *env, jintArray intArray, std::vector &output) +{ + auto length = static_cast(env->GetArrayLength(intArray)); + auto ptr = env->GetIntArrayElements(intArray, JNI_FALSE); + for (int32_t i = 0; i < length; i++) { + output.push_back(ptr[i]); + } + env->ReleaseIntArrayElements(intArray, ptr, 0); +} + + +void GetDataTypesVector(JNIEnv *env, jobjectArray jSourceType, std::vector &output) +{ + auto len = static_cast(env->GetArrayLength(jSourceType)); + for (int i = 0; i < len; ++i) { + auto str = static_cast(env->GetObjectArrayElement(jSourceType, i)); + auto sourceTypesCharPtr = env->GetStringUTFChars(str, JNI_FALSE); + auto dataTypes = Deserialize(sourceTypesCharPtr); + env->ReleaseStringUTFChars(str, sourceTypesCharPtr); + output.push_back(dataTypes); + } +} + +void DeserializeJsonToArray(const char *str, std::vector &arr) +{ + auto result = nlohmann::json::parse(str); + for (auto &json : result) { + arr.push_back(json); + } +} + +/* + * Class: nova_hetu_omniruntime_operator_OmniOperatorFactory + * Method: createOperatorNative + * Signature: (J)JJ + */ +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_operator_OmniOperatorFactory_createOperatorNative(JNIEnv *env, + jobject jObj, jlong jNativeFactoryObj) +{ + auto operatorFactory = (OperatorFactory *)jNativeFactoryObj; + omniruntime::op::Operator *nativeOperator = nullptr; + + JNI_METHOD_START + nativeOperator = operatorFactory->CreateOperator(); + if (nativeOperator == nullptr) { + throw omniruntime::exception::OmniException("CREATE_OPERATOR_FAILED", + "return a null pointer when creating operator"); + } + JNI_METHOD_END(0L) + + return reinterpret_cast(static_cast(nativeOperator)); +} + +/* + * Return an HashAggregationFactory object address. + */ +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_aggregator_OmniHashAggregationOperatorFactory_createHashAggregationOperatorFactory( + JNIEnv *env, jclass jObj, jobjectArray jGroupByChannel, jstring jGroupByType, jobjectArray jAggChannel, + jstring jAggType, jintArray jAggFuncType, jintArray jMaskCols, jstring jOutPutTye, jboolean inputRaw, + jboolean outputPartial, jstring jOperatorConfig) +{ + // groupby channel and id + auto groupByNum = static_cast(env->GetArrayLength(jGroupByChannel)); + int32_t groupByCols[groupByNum]; + GetColumnsFromExpressions(env, jGroupByChannel, groupByCols, static_cast(groupByNum)); + auto groupByTypesCharPtr = env->GetStringUTFChars(jGroupByType, JNI_FALSE); + auto aggInputChannelNum = static_cast(env->GetArrayLength(jAggChannel)); + int32_t aggCols[aggInputChannelNum]; + GetColumnsFromExpressions(env, jAggChannel, aggCols, static_cast(aggInputChannelNum)); + auto aggTypesCharPtr = env->GetStringUTFChars(jAggType, JNI_FALSE); + jint *aggFuncTypes = env->GetIntArrayElements(jAggFuncType, JNI_FALSE); + jint *maskColumns = env->GetIntArrayElements(jMaskCols, JNI_FALSE); + auto outTypesCharPtr = env->GetStringUTFChars(jOutPutTye, JNI_FALSE); + + auto groupByDataTypes = Deserialize(groupByTypesCharPtr); + auto aggDataTypes = Deserialize(aggTypesCharPtr); + auto outDataTypes = Deserialize(outTypesCharPtr); + env->ReleaseStringUTFChars(jGroupByType, groupByTypesCharPtr); + env->ReleaseStringUTFChars(jAggType, aggTypesCharPtr); + env->ReleaseStringUTFChars(jOutPutTye, outTypesCharPtr); + + auto operatorConfigChars = env->GetStringUTFChars(jOperatorConfig, JNI_FALSE); + auto operatorConfig = OperatorConfig::DeserializeOperatorConfig(operatorConfigChars); + env->ReleaseStringUTFChars(jOperatorConfig, operatorConfigChars); + + auto aggNum = static_cast(env->GetArrayLength(jAggFuncType)); + + std::vector groupByColVector = std::vector(reinterpret_cast(groupByCols), + reinterpret_cast(groupByCols) + groupByNum); + std::vector aggColVector = std::vector(reinterpret_cast(aggCols), + reinterpret_cast(aggCols) + aggInputChannelNum); + std::vector aggFuncTypeVector = std::vector(reinterpret_cast(aggFuncTypes), + reinterpret_cast(aggFuncTypes) + aggNum); + std::vector maskColumnVector = std::vector(reinterpret_cast(maskColumns), + reinterpret_cast(maskColumns) + aggNum); + + auto aggColVectorWrap = AggregatorUtil::WrapWithVector(aggColVector); + auto aggInputTypesWrap = AggregatorUtil::WrapWithVector(aggDataTypes); + auto aggOutputTypesWrap = AggregatorUtil::WrapWithVector(outDataTypes); + auto inputRawsWrap = std::vector(aggFuncTypeVector.size(), inputRaw); + auto outputPartialsWrap = std::vector(aggFuncTypeVector.size(), outputPartial); + + HashAggregationOperatorFactory *nativeOperatorFactory = nullptr; + JNI_METHOD_START + nativeOperatorFactory = + new HashAggregationOperatorFactory(groupByColVector, groupByDataTypes, aggColVectorWrap, aggInputTypesWrap, + aggOutputTypesWrap, aggFuncTypeVector, maskColumnVector, inputRawsWrap, outputPartialsWrap, operatorConfig); + JNI_METHOD_END(0L) + nativeOperatorFactory->Init(); + + env->ReleaseIntArrayElements(jAggFuncType, aggFuncTypes, 0); + env->ReleaseIntArrayElements(jMaskCols, maskColumns, 0); + return reinterpret_cast(static_cast(nativeOperatorFactory)); +} + +/* + * Return an AggregationFactory object address. + */ +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_aggregator_OmniAggregationOperatorFactory_createAggregationOperatorFactory( + JNIEnv *env, jclass jObj, jstring jSourceTypes, jintArray jAggFuncTypes, jintArray jAggInputCols, + jintArray jMaskCols, jstring jAggOutputTypes, jboolean inputRaw, jboolean outputPartial) +{ + auto sourceTypesCharPtr = env->GetStringUTFChars(jSourceTypes, JNI_FALSE); + auto sourceTypes = Deserialize(sourceTypesCharPtr); + env->ReleaseStringUTFChars(jSourceTypes, sourceTypesCharPtr); + + auto aggFuncTypes = env->GetIntArrayElements(jAggFuncTypes, JNI_FALSE); + auto aggInputCols = env->GetIntArrayElements(jAggInputCols, JNI_FALSE); + auto maskCols = env->GetIntArrayElements(jMaskCols, JNI_FALSE); + auto aggOutputTypesCharPtr = env->GetStringUTFChars(jAggOutputTypes, JNI_FALSE); + auto aggOutputTypes = Deserialize(aggOutputTypesCharPtr); + env->ReleaseStringUTFChars(jAggOutputTypes, aggOutputTypesCharPtr); + + auto aggInputColsCount = static_cast(env->GetArrayLength(jAggInputCols)); + auto aggCount = static_cast(aggOutputTypes.GetSize()); + + std::vector aggInputColsVector = vector(reinterpret_cast(aggInputCols), + reinterpret_cast(aggInputCols) + aggInputColsCount); + std::vector maskColsVector = std::vector(reinterpret_cast(maskCols), + reinterpret_cast(maskCols) + aggCount); + std::vector aggFuncTypesVector = std::vector(reinterpret_cast(aggFuncTypes), + reinterpret_cast(aggFuncTypes) + aggCount); + + auto aggInputColsVectorWrap = AggregatorUtil::WrapWithVector(aggInputColsVector); + auto aggOutputTypesWrap = AggregatorUtil::WrapWithVector(aggOutputTypes); + auto inputRawWrap = std::vector(aggFuncTypesVector.size(), inputRaw); + auto outputPartialWrap = std::vector(aggFuncTypesVector.size(), outputPartial); + + AggregationOperatorFactory *nativeOperatorFactory = nullptr; + JNI_METHOD_START + nativeOperatorFactory = new AggregationOperatorFactory(sourceTypes, aggFuncTypesVector, aggInputColsVectorWrap, + maskColsVector, aggOutputTypesWrap, inputRawWrap, outputPartialWrap); + JNI_METHOD_END(0L) + nativeOperatorFactory->Init(); + + env->ReleaseIntArrayElements(jAggFuncTypes, aggFuncTypes, 0); + env->ReleaseIntArrayElements(jAggInputCols, aggInputCols, 0); + env->ReleaseIntArrayElements(jMaskCols, maskCols, 0); + return reinterpret_cast(static_cast(nativeOperatorFactory)); +} + +/* + * Class: nova_hetu_omniruntime_operator_sort_OmniSortOperatorFactory + * Method: createSortOperatorFactory + * Signature: (Ljava/lang/String;[I[Ljava/lang/String;[I[IJLjava/lang/String;)J + */ +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_operator_sort_OmniSortOperatorFactory_createSortOperatorFactory( + JNIEnv *env, jclass jObj, jstring jSourceTypes, jintArray jOutputCols, jobjectArray jSortCols, + jintArray jAscendings, jintArray jNullFirsts, jstring jOperatorConfig) +{ + auto sourceTypesCharPtr = env->GetStringUTFChars(jSourceTypes, JNI_FALSE); + jint *outputColsArr = env->GetIntArrayElements(jOutputCols, JNI_FALSE); + auto outputColsCount = env->GetArrayLength(jOutputCols); + auto sortColsCount = env->GetArrayLength(jSortCols); + int32_t sortColsArr[sortColsCount]; + GetColumnsFromExpressions(env, jSortCols, sortColsArr, sortColsCount); + jint *ascendingsArr = env->GetIntArrayElements(jAscendings, JNI_FALSE); + jint *nullFirstsArr = env->GetIntArrayElements(jNullFirsts, JNI_FALSE); + + auto sourceDataTypes = Deserialize(sourceTypesCharPtr); + env->ReleaseStringUTFChars(jSourceTypes, sourceTypesCharPtr); + + auto operatorConfigChars = env->GetStringUTFChars(jOperatorConfig, JNI_FALSE); + auto operatorConfig = OperatorConfig::DeserializeOperatorConfig(operatorConfigChars); + env->ReleaseStringUTFChars(jOperatorConfig, operatorConfigChars); + + SortOperatorFactory *sortOperatorFactory = nullptr; + JNI_METHOD_START + sortOperatorFactory = SortOperatorFactory::CreateSortOperatorFactory(sourceDataTypes, outputColsArr, + outputColsCount, sortColsArr, ascendingsArr, nullFirstsArr, sortColsCount, operatorConfig); + JNI_METHOD_END(0L) + + env->ReleaseIntArrayElements(jOutputCols, outputColsArr, 0); + env->ReleaseIntArrayElements(jAscendings, ascendingsArr, 0); + env->ReleaseIntArrayElements(jNullFirsts, nullFirstsArr, 0); + return reinterpret_cast(static_cast(sortOperatorFactory)); +} + +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_window_OmniWindowOperatorFactory_createWindowOperatorFactory(JNIEnv *env, + jobject jObj, jstring jSourceTypes, jintArray jOutputChannels, jintArray jWindowFunction, + jintArray jPartitionChannels, jintArray JPreGroupedChannels, jintArray jSortChannels, jintArray jSortOrder, + jintArray jSortNullFirsts, jint preSortedChannelPrefix, jint expectedPositions, jintArray jArgumentChannels, + jstring jWindowFunctionReturnType, jintArray jWindowFrameTypes, jintArray jWindowFrameStartTypes, + jintArray jWindowFrameStartChannels, jintArray jWindowFrameEndTypes, jintArray jWindowFrameEndChannels, + jstring jOperatorConfig) +{ + auto sourceTypesCharPtr = env->GetStringUTFChars(jSourceTypes, JNI_FALSE); + jint *outputChannels = env->GetIntArrayElements(jOutputChannels, JNI_FALSE); + jint *windowFunction = env->GetIntArrayElements(jWindowFunction, JNI_FALSE); + jint *partitionChannels = env->GetIntArrayElements(jPartitionChannels, JNI_FALSE); + jint *preGroupedChannels = env->GetIntArrayElements(JPreGroupedChannels, JNI_FALSE); + jint *sortChannels = env->GetIntArrayElements(jSortChannels, JNI_FALSE); + jint *sortOrder = env->GetIntArrayElements(jSortOrder, JNI_FALSE); + jint *sortNullFirsts = env->GetIntArrayElements(jSortNullFirsts, JNI_FALSE); + jint *argumentChannels = env->GetIntArrayElements(jArgumentChannels, JNI_FALSE); + jint *windowFrameTypes = env->GetIntArrayElements(jWindowFrameTypes, JNI_FALSE); + jint *windowFrameStartTypes = env->GetIntArrayElements(jWindowFrameStartTypes, JNI_FALSE); + jint *windowFrameStartChannels = env->GetIntArrayElements(jWindowFrameStartChannels, JNI_FALSE); + jint *windowFrameEndTypes = env->GetIntArrayElements(jWindowFrameEndTypes, JNI_FALSE); + jint *windowFrameEndChannels = env->GetIntArrayElements(jWindowFrameEndChannels, JNI_FALSE); + + auto windowFunctionReturnTypeCharPtr = env->GetStringUTFChars(jWindowFunctionReturnType, JNI_FALSE); + + auto inputDataTypes = Deserialize(sourceTypesCharPtr); + auto outputDataTypes = Deserialize(windowFunctionReturnTypeCharPtr); + env->ReleaseStringUTFChars(jSourceTypes, sourceTypesCharPtr); + env->ReleaseStringUTFChars(jWindowFunctionReturnType, windowFunctionReturnTypeCharPtr); + + auto operatorConfigChars = env->GetStringUTFChars(jOperatorConfig, JNI_FALSE); + auto operatorConfig = OperatorConfig::DeserializeOperatorConfig(operatorConfigChars); + env->ReleaseStringUTFChars(jOperatorConfig, operatorConfigChars); + + jint outputColsCount = env->GetArrayLength(jOutputChannels); + jint windowFunctionCount = env->GetArrayLength(jWindowFunction); + jint partitionCount = env->GetArrayLength(jPartitionChannels); + jint preGroupedCount = env->GetArrayLength(JPreGroupedChannels); + jint sortColCount = env->GetArrayLength(jSortChannels); + jint argumentChannelsCount = env->GetArrayLength(jArgumentChannels); + + std::vector allTypesVec; + allTypesVec.insert(allTypesVec.end(), inputDataTypes.Get().begin(), inputDataTypes.Get().end()); + allTypesVec.insert(allTypesVec.end(), outputDataTypes.Get().begin(), outputDataTypes.Get().end()); + + DataTypes allTypes(allTypesVec); + + WindowOperatorFactory *windowOperatorFactory = nullptr; + JNI_METHOD_START + windowOperatorFactory = WindowOperatorFactory::CreateWindowOperatorFactory(inputDataTypes, outputChannels, + outputColsCount, windowFunction, windowFunctionCount, partitionChannels, partitionCount, preGroupedChannels, + preGroupedCount, sortChannels, sortOrder, sortNullFirsts, sortColCount, preSortedChannelPrefix, + expectedPositions, allTypes, argumentChannels, argumentChannelsCount, windowFrameTypes, windowFrameStartTypes, + windowFrameStartChannels, windowFrameEndTypes, windowFrameEndChannels, operatorConfig); + JNI_METHOD_END(0L) + windowOperatorFactory->Init(); + + env->ReleaseIntArrayElements(jOutputChannels, outputChannels, 0); + env->ReleaseIntArrayElements(jWindowFunction, windowFunction, 0); + env->ReleaseIntArrayElements(jPartitionChannels, partitionChannels, 0); + env->ReleaseIntArrayElements(JPreGroupedChannels, preGroupedChannels, 0); + env->ReleaseIntArrayElements(jSortChannels, sortChannels, 0); + env->ReleaseIntArrayElements(jSortOrder, sortOrder, 0); + env->ReleaseIntArrayElements(jSortNullFirsts, sortNullFirsts, 0); + env->ReleaseIntArrayElements(jWindowFrameTypes, windowFrameTypes, 0); + env->ReleaseIntArrayElements(jWindowFrameStartTypes, windowFrameStartTypes, 0); + env->ReleaseIntArrayElements(jWindowFrameStartChannels, windowFrameStartChannels, 0); + env->ReleaseIntArrayElements(jWindowFrameEndTypes, windowFrameEndTypes, 0); + env->ReleaseIntArrayElements(jWindowFrameEndChannels, windowFrameEndChannels, 0); + return reinterpret_cast(static_cast(windowOperatorFactory)); +} + +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_topn_OmniTopNOperatorFactory_createTopNOperatorFactory(JNIEnv *env, jclass jObj, + jstring jSourceTypes, jint jN, jint jOffset, jobjectArray jSortCols, jintArray jSortAsc, jintArray jSortNullFirsts) +{ + auto sourceTypesCharPtr = env->GetStringUTFChars(jSourceTypes, JNI_FALSE); + jint sortColCount = env->GetArrayLength(jSortCols); + int32_t sortColsArr[sortColCount]; + GetColumnsFromExpressions(env, jSortCols, sortColsArr, sortColCount); + jint *sortAsc = env->GetIntArrayElements(jSortAsc, JNI_FALSE); + jint *sortNullFirsts = env->GetIntArrayElements(jSortNullFirsts, JNI_FALSE); + + auto sourceTypes = Deserialize(sourceTypesCharPtr); + env->ReleaseStringUTFChars(jSourceTypes, sourceTypesCharPtr); + + TopNOperatorFactory *topNOperatorFactory = nullptr; + JNI_METHOD_START + topNOperatorFactory = + new TopNOperatorFactory(sourceTypes, jN, jOffset, sortColsArr, sortAsc, sortNullFirsts, sortColCount); + JNI_METHOD_END(0L) + + env->ReleaseIntArrayElements(jSortAsc, sortAsc, 0); + env->ReleaseIntArrayElements(jSortNullFirsts, sortNullFirsts, 0); + return reinterpret_cast(static_cast(topNOperatorFactory)); +} + +static bool CheckExpressionSupported(bool skipVerify, Expr *filterExpr) +{ + if (!skipVerify) { + ExprVerifier verifier; + if (!verifier.VisitExpr(*filterExpr)) { +#ifdef DEBUG + std::cout << "The filter expression is not supported: " << std::endl; + ExprPrinter p; + filterExpr->Accept(p); + std::cout << std::endl; +#endif + LogWarn("Verifier failed"); + return false; + } + } + return true; +} + +static bool CheckExpressionsSupported(bool skipVerify, const std::vector &projectExprs) +{ + if (!skipVerify) { + auto exprSize = projectExprs.size(); + ExprVerifier verifier; + for (size_t i = 0; i < exprSize; i++) { + if (!verifier.VisitExpr(*projectExprs[i])) { +#ifdef DEBUG + std::cout << "The " << i << "-th project expression is not supported: " << std::endl; + ExprPrinter p; + projectExprs[i]->Accept(p); + std::cout << std::endl; +#endif + LogWarn("Verifier failed"); + return false; + } + } + } + return true; +} + +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_filter_OmniFilterAndProjectOperatorFactory_createFilterAndProjectOperatorFactory( + JNIEnv *env, jclass jObj, jstring jInputTypes, jint jInputLength, jstring jExpression, jobjectArray jProjections, + jint jProjectLength, jint jParseFormat, jstring jOperatorConfig) +{ + auto expressionCharPtr = env->GetStringUTFChars(jExpression, JNI_FALSE); + std::string filterExpression = std::string(expressionCharPtr); + auto inputTypesCharPtr = env->GetStringUTFChars(jInputTypes, JNI_FALSE); + auto inputDataTypes = Deserialize(inputTypesCharPtr); + env->ReleaseStringUTFChars(jInputTypes, inputTypesCharPtr); + env->ReleaseStringUTFChars(jExpression, expressionCharPtr); + auto inputLength = (int32_t)jInputLength; + + auto operatorConfigChars = env->GetStringUTFChars(jOperatorConfig, JNI_FALSE); + auto operatorConfig = OperatorConfig::DeserializeOperatorConfig(operatorConfigChars); + auto overflowConfig = operatorConfig.GetOverflowConfig(); + bool isSkipVerify = operatorConfig.IsSkipVerify(); + env->ReleaseStringUTFChars(jOperatorConfig, operatorConfigChars); + + auto parseFormat = static_cast((int8_t)jParseFormat); + std::string projectExpressions[jProjectLength]; + GetExpressions(env, jProjections, projectExpressions, jProjectLength); + + std::vector projectExprs; + omniruntime::expressions::Expr *filterExpr = nullptr; + if (parseFormat == JSON) { + JNI_METHOD_START + auto filterJsonExpr = nlohmann::json::parse(filterExpression); + filterExpr = JSONParser::ParseJSON(filterJsonExpr); + JNI_METHOD_END(0L) + JNI_METHOD_START + nlohmann::json jsonProjectExprs[jProjectLength]; + for (int32_t i = 0; i < jProjectLength; i++) { + jsonProjectExprs[i] = nlohmann::json::parse(projectExpressions[i]); + } + projectExprs = JSONParser::ParseJSON(jsonProjectExprs, jProjectLength); + JNI_METHOD_END_WITH_EXPRS_RELEASE(0L, { filterExpr }) + } else { + Parser parser; + JNI_METHOD_START + filterExpr = parser.ParseRowExpression(filterExpression, inputDataTypes, inputLength); + JNI_METHOD_END(0L) + JNI_METHOD_START + projectExprs = parser.ParseExpressions(projectExpressions, jProjectLength, inputDataTypes); + JNI_METHOD_END_WITH_EXPRS_RELEASE(0L, { filterExpr }) + } + if (filterExpr == nullptr || (projectExprs.size() != static_cast(jProjectLength))) { + delete filterExpr; + Expr::DeleteExprs(projectExprs); + return 0; + } + + if (!CheckExpressionSupported(isSkipVerify, filterExpr)) { + delete filterExpr; + Expr::DeleteExprs(projectExprs); + return 0; + } + if (!CheckExpressionsSupported(isSkipVerify, projectExprs)) { + delete filterExpr; + Expr::DeleteExprs(projectExprs); + return 0; + } + + FilterAndProjectOperatorFactory *factory = nullptr; + auto exprEvaluator = + std::make_shared(filterExpr, projectExprs, inputDataTypes, overflowConfig); + if (!exprEvaluator->IsSupportedExpr()) { + return 0; + } + + factory = new FilterAndProjectOperatorFactory(std::move(exprEvaluator)); + + return reinterpret_cast(static_cast(factory)); +} + +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_project_OmniProjectOperatorFactory_createProjectOperatorFactory(JNIEnv *env, + jclass jobj, jstring jInputTypes, jint jInputLength, jobjectArray jExprs, jint jExprsLength, jint jParseFormat, + jstring jOperatorConfig) +{ + auto parseFormat = static_cast((int8_t)jParseFormat); + std::string exprs[jExprsLength]; + GetExpressions(env, jExprs, exprs, jExprsLength); + + auto inputTypesCharPtr = env->GetStringUTFChars(jInputTypes, JNI_FALSE); + auto inputDataTypes = Deserialize(inputTypesCharPtr); + env->ReleaseStringUTFChars(jInputTypes, inputTypesCharPtr); + + auto operatorConfigChars = env->GetStringUTFChars(jOperatorConfig, JNI_FALSE); + auto operatorConfig = OperatorConfig::DeserializeOperatorConfig(operatorConfigChars); + auto overflowConfig = operatorConfig.GetOverflowConfig(); + bool isSkipVerify = operatorConfig.IsSkipVerify(); + env->ReleaseStringUTFChars(jOperatorConfig, operatorConfigChars); + + std::vector expressions; + JNI_METHOD_START + if (parseFormat == JSON) { + nlohmann::json jsonExprs[jExprsLength]; + for (int32_t i = 0; i < jExprsLength; i++) { + jsonExprs[i] = nlohmann::json::parse(exprs[i]); + } + expressions = JSONParser::ParseJSON(jsonExprs, jExprsLength); + } else { + Parser parser; + expressions = parser.ParseExpressions(exprs, jExprsLength, inputDataTypes); + } + JNI_METHOD_END(0L) + if (expressions.size() != static_cast(jExprsLength)) { + Expr::DeleteExprs(expressions); + return 0; + } + + if (!CheckExpressionsSupported(isSkipVerify, expressions)) { + Expr::DeleteExprs(expressions); + return 0; + } + + ProjectionOperatorFactory *factory = nullptr; + auto exprEvaluator = std::make_shared(expressions, inputDataTypes, overflowConfig); + if (!exprEvaluator->IsSupportedExpr()) { + return 0; + } + + factory = new ProjectionOperatorFactory(std::move(exprEvaluator)); + return reinterpret_cast(static_cast(factory)); +} + +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_join_OmniHashBuilderOperatorFactory_createHashBuilderOperatorFactory(JNIEnv *env, + jclass jObj, jint jJoinType, jstring jBuildTypes, jintArray jBuildHashCols, jint jOperatorCount, + jstring jOperatorConfig) +{ + auto buildTypesCharPtr = env->GetStringUTFChars(jBuildTypes, JNI_FALSE); + auto buildHashColsCount = env->GetArrayLength(jBuildHashCols); + auto buildHashColsArr = env->GetIntArrayElements(jBuildHashCols, JNI_FALSE); + + auto buildDataTypes = Deserialize(buildTypesCharPtr); + env->ReleaseStringUTFChars(jBuildTypes, buildTypesCharPtr); + + HashBuilderOperatorFactory *hashBuilderOperatorFactory = nullptr; + JNI_METHOD_START + hashBuilderOperatorFactory = HashBuilderOperatorFactory::CreateHashBuilderOperatorFactory((JoinType)jJoinType, + buildDataTypes, buildHashColsArr, buildHashColsCount, jOperatorCount); + JNI_METHOD_END(0L) + + env->ReleaseIntArrayElements(jBuildHashCols, buildHashColsArr, 0); + return reinterpret_cast(static_cast(hashBuilderOperatorFactory)); +} + +omniruntime::expressions::Expr *CreateJoinFilterExpr(const std::string &filterString) +{ + omniruntime::expressions::Expr *filterExpr = nullptr; + if (!filterString.empty()) { + filterExpr = JSONParser::ParseJSON(nlohmann::json::parse(filterString)); + if (filterExpr == nullptr) { + throw omniruntime::exception::OmniException("EXPRESSION_NOT_SUPPORT", + "The expression is not supported yet."); + } + } + return filterExpr; +} + +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_join_OmniLookupJoinOperatorFactory_createLookupJoinOperatorFactory(JNIEnv *env, + jclass jObj, jstring jProbeTypes, jintArray jProbeOutputCols, jintArray jProbeHashCols, jintArray jBuildOutputCols, + jstring jBuildOutputTypes, jlong jHashBuilderOperatorFactory, jstring jFilter, + jboolean isShuffleExchangeBuildPlan, jstring jOperatorConfig) +{ + auto probeTypesCharPtr = env->GetStringUTFChars(jProbeTypes, JNI_FALSE); + auto probeOutputColsArr = env->GetIntArrayElements(jProbeOutputCols, JNI_FALSE); + auto probeHashColsCount = env->GetArrayLength(jProbeHashCols); + auto probeHashColsArr = env->GetIntArrayElements(jProbeHashCols, JNI_FALSE); + auto buildOutputColsArr = env->GetIntArrayElements(jBuildOutputCols, JNI_FALSE); + auto buildOutputColsCount = env->GetArrayLength(jBuildOutputCols); + auto buildOutputTypesCharPtr = env->GetStringUTFChars(jBuildOutputTypes, JNI_FALSE); + auto probeOutputColsCount = env->GetArrayLength(jProbeOutputCols); + + auto probeDataTypes = Deserialize(probeTypesCharPtr); + auto buildOutputDataTypes = Deserialize(buildOutputTypesCharPtr); + env->ReleaseStringUTFChars(jProbeTypes, probeTypesCharPtr); + env->ReleaseStringUTFChars(jBuildOutputTypes, buildOutputTypesCharPtr); + + auto filterChars = env->GetStringUTFChars(jFilter, JNI_FALSE); + std::string filterExpression = std::string(filterChars); + env->ReleaseStringUTFChars(jFilter, filterChars); + Expr *filterExpr = nullptr; + JNI_METHOD_START + // extract the expression and the BuildDataTypes to parse the expression + filterExpr = CreateJoinFilterExpr(filterExpression); + JNI_METHOD_END(0L) + + auto operatorConfigChars = env->GetStringUTFChars(jOperatorConfig, JNI_FALSE); + auto operatorConfig = OperatorConfig::DeserializeOperatorConfig(operatorConfigChars); + auto overflowConfig = operatorConfig.GetOverflowConfig(); + env->ReleaseStringUTFChars(jOperatorConfig, operatorConfigChars); + + LookupJoinOperatorFactory *lookupJoinOperatorFactory = nullptr; + JNI_METHOD_START + lookupJoinOperatorFactory = LookupJoinOperatorFactory::CreateLookupJoinOperatorFactory(probeDataTypes, + probeOutputColsArr, probeOutputColsCount, probeHashColsArr, probeHashColsCount, buildOutputColsArr, + buildOutputColsCount, buildOutputDataTypes, jHashBuilderOperatorFactory, filterExpr, + isShuffleExchangeBuildPlan, overflowConfig); + JNI_METHOD_END_WITH_EXPRS_RELEASE(0L, { filterExpr }) + Expr::DeleteExprs({ filterExpr }); + + env->ReleaseIntArrayElements(jProbeOutputCols, probeOutputColsArr, 0); + env->ReleaseIntArrayElements(jProbeHashCols, probeHashColsArr, 0); + env->ReleaseIntArrayElements(jBuildOutputCols, buildOutputColsArr, 0); + return reinterpret_cast(static_cast(lookupJoinOperatorFactory)); +} + +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_operator_union_OmniUnionOperatorFactory_createUnionOperatorFactory( + JNIEnv *env, jobject jObj, jstring jSourceTypes, jboolean jDistinct) +{ + const char *sourceTypesCharPtr = env->GetStringUTFChars(jSourceTypes, JNI_FALSE); + auto sourcesTypes = Deserialize(sourceTypesCharPtr); + env->ReleaseStringUTFChars(jSourceTypes, sourceTypesCharPtr); + int32_t sourceTypesCount = sourcesTypes.GetSize(); + + UnionOperatorFactory *unionOperatorFactory = nullptr; + JNI_METHOD_START + unionOperatorFactory = new UnionOperatorFactory(sourcesTypes, sourceTypesCount, jDistinct); + JNI_METHOD_END(0L) + + return reinterpret_cast(static_cast(unionOperatorFactory)); +} + +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_sort_OmniSortWithExprOperatorFactory_createSortWithExprOperatorFactory(JNIEnv *env, + jclass jObj, jstring jSourceTypes, jintArray jOutputCols, jobjectArray jSortKeys, jintArray jAscendings, + jintArray jNullFirsts, jstring jOperatorConfig) +{ + auto sourceTypesChars = env->GetStringUTFChars(jSourceTypes, JNI_FALSE); + jint *outputCols = env->GetIntArrayElements(jOutputCols, JNI_FALSE); + auto outputColsCount = env->GetArrayLength(jOutputCols); + auto sortKeysCount = env->GetArrayLength(jSortKeys); + std::string sortKeysArr[sortKeysCount]; + GetExpressions(env, jSortKeys, sortKeysArr, sortKeysCount); + jint *ascendings = env->GetIntArrayElements(jAscendings, JNI_FALSE); + jint *nullFirsts = env->GetIntArrayElements(jNullFirsts, JNI_FALSE); + + auto sourceDataTypes = Deserialize(sourceTypesChars); + env->ReleaseStringUTFChars(jSourceTypes, sourceTypesChars); + + auto operatorConfigChars = env->GetStringUTFChars(jOperatorConfig, JNI_FALSE); + auto operatorConfig = OperatorConfig::DeserializeOperatorConfig(operatorConfigChars); + env->ReleaseStringUTFChars(jOperatorConfig, operatorConfigChars); + + std::vector sortKeyExprArr; + JNI_METHOD_START + // parse the expressions + GetExprsFromJson(sortKeysArr, sortKeysCount, sortKeyExprArr); + JNI_METHOD_END(0L) + + SortWithExprOperatorFactory *operatorFactory = nullptr; + JNI_METHOD_START + operatorFactory = SortWithExprOperatorFactory::CreateSortWithExprOperatorFactory(sourceDataTypes, outputCols, + outputColsCount, sortKeyExprArr, ascendings, nullFirsts, sortKeysCount, operatorConfig); + JNI_METHOD_END_WITH_EXPRS_RELEASE(0L, sortKeyExprArr) + Expr::DeleteExprs(sortKeyExprArr); + + env->ReleaseIntArrayElements(jOutputCols, outputCols, 0); + env->ReleaseIntArrayElements(jAscendings, ascendings, 0); + env->ReleaseIntArrayElements(jNullFirsts, nullFirsts, 0); + return reinterpret_cast(static_cast(operatorFactory)); +} + +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_join_OmniHashBuilderWithExprOperatorFactory_createHashBuilderWithExprOperatorFactory( + JNIEnv *env, jclass jObj, jint jJoinType, jint jBuildSide, jstring jBuildTypes, jobjectArray jBuildHashKeys, + jint jHashTableCount, jstring jOperatorConfig) +{ + auto buildTypesChars = env->GetStringUTFChars(jBuildTypes, JNI_FALSE); + auto buildHashKeysCount = env->GetArrayLength(jBuildHashKeys); + std::string buildHashKeysArr[buildHashKeysCount]; + GetExpressions(env, jBuildHashKeys, buildHashKeysArr, buildHashKeysCount); + auto buildDataTypes = Deserialize(buildTypesChars); + env->ReleaseStringUTFChars(jBuildTypes, buildTypesChars); + + auto operatorConfigChars = env->GetStringUTFChars(jOperatorConfig, JNI_FALSE); + auto operatorConfig = OperatorConfig::DeserializeOperatorConfig(operatorConfigChars); + auto overflowConfig = operatorConfig.GetOverflowConfig(); + env->ReleaseStringUTFChars(jOperatorConfig, operatorConfigChars); + + std::vector buildHashKeysArrExprs; + JNI_METHOD_START + // parse the expressions + GetExprsFromJson(buildHashKeysArr, buildHashKeysCount, buildHashKeysArrExprs); + JNI_METHOD_END(0L) + + HashBuilderWithExprOperatorFactory *operatorFactory = nullptr; + JNI_METHOD_START + operatorFactory = HashBuilderWithExprOperatorFactory::CreateHashBuilderWithExprOperatorFactory((JoinType)jJoinType, + (BuildSide)jBuildSide, buildDataTypes, buildHashKeysArrExprs, jHashTableCount, overflowConfig); + JNI_METHOD_END_WITH_EXPRS_RELEASE(0L, buildHashKeysArrExprs) + Expr::DeleteExprs(buildHashKeysArrExprs); + + return reinterpret_cast(static_cast(operatorFactory)); +} + +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_join_OmniLookupJoinWithExprOperatorFactory_createLookupJoinWithExprOperatorFactory( + JNIEnv *env, jclass jObj, jstring jProbeTypes, jintArray jProbeOutputCols, jobjectArray jProbeHashKeys, + jintArray jBuildOutputCols, jstring jBuildOutputTypes, jlong jHashBuilderOperatorFactory, + jstring jFilter, jboolean isShuffleExchangeBuildPlan, jstring jOperatorConfig) +{ + auto probeTypesChars = env->GetStringUTFChars(jProbeTypes, JNI_FALSE); + auto probeOutputCols = env->GetIntArrayElements(jProbeOutputCols, JNI_FALSE); + auto probeHashKeysCount = env->GetArrayLength(jProbeHashKeys); + std::string probeHashKeysArr[probeHashKeysCount]; + GetExpressions(env, jProbeHashKeys, probeHashKeysArr, probeHashKeysCount); + auto buildOutputCols = env->GetIntArrayElements(jBuildOutputCols, JNI_FALSE); + auto buildOutputColsCount = env->GetArrayLength(jBuildOutputCols); + auto buildOutputTypesChars = env->GetStringUTFChars(jBuildOutputTypes, JNI_FALSE); + jint probeOutputColsCount = env->GetArrayLength(jProbeOutputCols); + + auto probeDataTypes = Deserialize(probeTypesChars); + auto buildOutputDataTypes = Deserialize(buildOutputTypesChars); + env->ReleaseStringUTFChars(jProbeTypes, probeTypesChars); + env->ReleaseStringUTFChars(jBuildOutputTypes, buildOutputTypesChars); + + auto filterChars = env->GetStringUTFChars(jFilter, JNI_FALSE); + std::string filterExpression = std::string(filterChars); + env->ReleaseStringUTFChars(jFilter, filterChars); + Expr *filterExpr = nullptr; + JNI_METHOD_START + // extract the expression and the BuildDataTypes to parse the expression + filterExpr = CreateJoinFilterExpr(filterExpression); + JNI_METHOD_END(0L) + + auto operatorConfigChars = env->GetStringUTFChars(jOperatorConfig, JNI_FALSE); + auto operatorConfig = OperatorConfig::DeserializeOperatorConfig(operatorConfigChars); + auto overflowConfig = operatorConfig.GetOverflowConfig(); + env->ReleaseStringUTFChars(jOperatorConfig, operatorConfigChars); + + std::vector probeHashKeysArrExprs; + JNI_METHOD_START + // parse the expressions + GetExprsFromJson(probeHashKeysArr, probeHashKeysCount, probeHashKeysArrExprs); + JNI_METHOD_END_WITH_EXPRS_RELEASE(0L, { filterExpr }) + + LookupJoinWithExprOperatorFactory *operatorFactory = nullptr; + JNI_METHOD_START + operatorFactory = LookupJoinWithExprOperatorFactory::CreateLookupJoinWithExprOperatorFactory(probeDataTypes, + probeOutputCols, probeOutputColsCount, probeHashKeysArrExprs, probeHashKeysCount, buildOutputCols, + buildOutputColsCount, buildOutputDataTypes, jHashBuilderOperatorFactory, filterExpr, + isShuffleExchangeBuildPlan, overflowConfig); + JNI_METHOD_END_WITH_MULTI_EXPRS(0L, { filterExpr }, probeHashKeysArrExprs) + Expr::DeleteExprs({ filterExpr }); + Expr::DeleteExprs(probeHashKeysArrExprs); + + env->ReleaseIntArrayElements(jProbeOutputCols, probeOutputCols, 0); + env->ReleaseIntArrayElements(jBuildOutputCols, buildOutputCols, 0); + return reinterpret_cast(static_cast(operatorFactory)); +} + +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_join_OmniLookupOuterJoinWithExprOperatorFactory_createLookupOuterJoinWithExprOperatorFactory( + JNIEnv *env, jclass jObj, jstring jProbeTypes, jintArray jProbeOutputCols, jobjectArray jProbeHashKeys, + jintArray jBuildOutputCols, jstring jBuildOutputTypes, jlong jHashBuilderOperatorFactory) +{ + auto probeTypesChars = env->GetStringUTFChars(jProbeTypes, JNI_FALSE); + auto probeOutputCols = env->GetIntArrayElements(jProbeOutputCols, JNI_FALSE); + auto probeHashKeysCount = env->GetArrayLength(jProbeHashKeys); + std::string probeHashKeysArr[probeHashKeysCount]; + GetExpressions(env, jProbeHashKeys, probeHashKeysArr, probeHashKeysCount); + auto buildOutputCols = env->GetIntArrayElements(jBuildOutputCols, JNI_FALSE); + auto buildOutputTypesChars = env->GetStringUTFChars(jBuildOutputTypes, JNI_FALSE); + jint probeOutputColsCount = env->GetArrayLength(jProbeOutputCols); + + auto probeDataTypes = Deserialize(probeTypesChars); + auto buildOutputDataTypes = Deserialize(buildOutputTypesChars); + env->ReleaseStringUTFChars(jProbeTypes, probeTypesChars); + env->ReleaseStringUTFChars(jBuildOutputTypes, buildOutputTypesChars); + + std::vector probeHashKeysArrExprs; + JNI_METHOD_START + // parse the expressions + GetExprsFromJson(probeHashKeysArr, probeHashKeysCount, probeHashKeysArrExprs); + JNI_METHOD_END(0L) + + LookupOuterJoinWithExprOperatorFactory *operatorFactory = nullptr; + JNI_METHOD_START + operatorFactory = LookupOuterJoinWithExprOperatorFactory::CreateLookupOuterJoinWithExprOperatorFactory( + probeDataTypes, probeOutputCols, probeOutputColsCount, probeHashKeysArrExprs, probeHashKeysCount, + buildOutputCols, buildOutputDataTypes, jHashBuilderOperatorFactory); + JNI_METHOD_END_WITH_EXPRS_RELEASE(0L, probeHashKeysArrExprs) + Expr::DeleteExprs(probeHashKeysArrExprs); + + env->ReleaseIntArrayElements(jProbeOutputCols, probeOutputCols, 0); + env->ReleaseIntArrayElements(jBuildOutputCols, buildOutputCols, 0); + return reinterpret_cast(static_cast(operatorFactory)); +} + +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_join_OmniLookupOuterJoinOperatorFactory_createLookupOuterJoinOperatorFactory( + JNIEnv *env, jclass jObj, jstring jProbeTypes, jintArray jProbeOutputCols, jintArray jBuildOutputCols, + jstring jBuildOutputTypes, jlong jHashBuilderOperatorFactory) +{ + auto probeTypesChars = env->GetStringUTFChars(jProbeTypes, JNI_FALSE); + auto probeOutputCols = env->GetIntArrayElements(jProbeOutputCols, JNI_FALSE); + auto buildOutputCols = env->GetIntArrayElements(jBuildOutputCols, JNI_FALSE); + auto buildOutputTypesChars = env->GetStringUTFChars(jBuildOutputTypes, JNI_FALSE); + jint probeOutputColsCount = env->GetArrayLength(jProbeOutputCols); + + auto probeDataTypes = Deserialize(probeTypesChars); + auto buildOutputDataTypes = Deserialize(buildOutputTypesChars); + env->ReleaseStringUTFChars(jProbeTypes, probeTypesChars); + env->ReleaseStringUTFChars(jBuildOutputTypes, buildOutputTypesChars); + + LookupOuterJoinOperatorFactory *operatorFactory = nullptr; + JNI_METHOD_START + operatorFactory = LookupOuterJoinOperatorFactory::CreateLookupOuterJoinOperatorFactory(probeDataTypes, + probeOutputCols, probeOutputColsCount, buildOutputCols, buildOutputDataTypes, jHashBuilderOperatorFactory); + JNI_METHOD_END(0L) + + env->ReleaseIntArrayElements(jProbeOutputCols, probeOutputCols, 0); + env->ReleaseIntArrayElements(jBuildOutputCols, buildOutputCols, 0); + return reinterpret_cast(static_cast(operatorFactory)); +} + +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_window_OmniWindowWithExprOperatorFactory_createWindowWithExprOperatorFactory( + JNIEnv *env, jobject jObj, jstring jSourceTypes, jintArray jOutputChannels, jintArray jWindowFunction, + jintArray jPartitionChannels, jintArray JPreGroupedChannels, jintArray jSortChannels, jintArray jSortOrder, + jintArray jSortNullFirsts, jint preSortedChannelPrefix, jint expectedPositions, jobjectArray jArgumentKeys, + jstring jWindowFunctionReturnType, jintArray jWindowFrameTypes, jintArray jWindowFrameStartTypes, + jintArray jWindowFrameStartChannels, jintArray jWindowFrameEndTypes, jintArray jWindowFrameEndChannels, + jstring jOperatorConfig) +{ + auto sourceTypesCharPtr = env->GetStringUTFChars(jSourceTypes, JNI_FALSE); + jint *outputChannels = env->GetIntArrayElements(jOutputChannels, JNI_FALSE); + jint *windowFunction = env->GetIntArrayElements(jWindowFunction, JNI_FALSE); + jint *partitionChannels = env->GetIntArrayElements(jPartitionChannels, JNI_FALSE); + jint *preGroupedChannels = env->GetIntArrayElements(JPreGroupedChannels, JNI_FALSE); + jint *sortChannels = env->GetIntArrayElements(jSortChannels, JNI_FALSE); + jint *sortOrder = env->GetIntArrayElements(jSortOrder, JNI_FALSE); + jint *sortNullFirsts = env->GetIntArrayElements(jSortNullFirsts, JNI_FALSE); + jint *windowFrameTypes = env->GetIntArrayElements(jWindowFrameTypes, JNI_FALSE); + jint *windowFrameStartTypes = env->GetIntArrayElements(jWindowFrameStartTypes, JNI_FALSE); + jint *windowFrameStartChannels = env->GetIntArrayElements(jWindowFrameStartChannels, JNI_FALSE); + jint *windowFrameEndTypes = env->GetIntArrayElements(jWindowFrameEndTypes, JNI_FALSE); + jint *windowFrameEndChannels = env->GetIntArrayElements(jWindowFrameEndChannels, JNI_FALSE); + + auto argumentKeysArrCount = env->GetArrayLength(jArgumentKeys); + std::string argumentKeysArr[argumentKeysArrCount]; + GetExpressions(env, jArgumentKeys, argumentKeysArr, argumentKeysArrCount); + auto windowFunctionReturnTypeCharPtr = env->GetStringUTFChars(jWindowFunctionReturnType, JNI_FALSE); + + auto inputDataTypes = Deserialize(sourceTypesCharPtr); + auto outputDataTypes = Deserialize(windowFunctionReturnTypeCharPtr); + env->ReleaseStringUTFChars(jSourceTypes, sourceTypesCharPtr); + env->ReleaseStringUTFChars(jWindowFunctionReturnType, windowFunctionReturnTypeCharPtr); + + auto operatorConfigChars = env->GetStringUTFChars(jOperatorConfig, JNI_FALSE); + auto operatorConfig = OperatorConfig::DeserializeOperatorConfig(operatorConfigChars); + env->ReleaseStringUTFChars(jOperatorConfig, operatorConfigChars); + + jint outputColsCount = env->GetArrayLength(jOutputChannels); + jint windowFunctionCount = env->GetArrayLength(jWindowFunction); + jint partitionCount = env->GetArrayLength(jPartitionChannels); + jint preGroupedCount = env->GetArrayLength(JPreGroupedChannels); + jint sortColCount = env->GetArrayLength(jSortChannels); + jint argumentKeysCount = env->GetArrayLength(jArgumentKeys); + + std::vector argumentKeysArrExprs; + JNI_METHOD_START + // parse the expressions + GetExprsFromJson(argumentKeysArr, argumentKeysCount, argumentKeysArrExprs); + JNI_METHOD_END(0L) + + WindowWithExprOperatorFactory *windowWithExprOperatorFactory = nullptr; + JNI_METHOD_START + windowWithExprOperatorFactory = + WindowWithExprOperatorFactory::CreateWindowWithExprOperatorFactory(inputDataTypes, outputChannels, + outputColsCount, windowFunction, windowFunctionCount, partitionChannels, partitionCount, preGroupedChannels, + preGroupedCount, sortChannels, sortOrder, sortNullFirsts, sortColCount, preSortedChannelPrefix, + expectedPositions, outputDataTypes, argumentKeysArrExprs, argumentKeysCount, windowFrameTypes, + windowFrameStartTypes, windowFrameStartChannels, windowFrameEndTypes, windowFrameEndChannels, operatorConfig); + JNI_METHOD_END_WITH_EXPRS_RELEASE(0L, argumentKeysArrExprs) + Expr::DeleteExprs(argumentKeysArrExprs); + + env->ReleaseIntArrayElements(jOutputChannels, outputChannels, 0); + env->ReleaseIntArrayElements(jWindowFunction, windowFunction, 0); + env->ReleaseIntArrayElements(jPartitionChannels, partitionChannels, 0); + env->ReleaseIntArrayElements(JPreGroupedChannels, preGroupedChannels, 0); + env->ReleaseIntArrayElements(jSortChannels, sortChannels, 0); + env->ReleaseIntArrayElements(jSortOrder, sortOrder, 0); + env->ReleaseIntArrayElements(jSortNullFirsts, sortNullFirsts, 0); + env->ReleaseIntArrayElements(jWindowFrameTypes, windowFrameTypes, 0); + env->ReleaseIntArrayElements(jWindowFrameStartTypes, windowFrameStartTypes, 0); + env->ReleaseIntArrayElements(jWindowFrameStartChannels, windowFrameStartChannels, 0); + env->ReleaseIntArrayElements(jWindowFrameEndTypes, windowFrameEndTypes, 0); + env->ReleaseIntArrayElements(jWindowFrameEndChannels, windowFrameEndChannels, 0); + return reinterpret_cast(static_cast(windowWithExprOperatorFactory)); +} + +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_aggregator_OmniHashAggregationWithExprOperatorFactory_createHashAggregationWithExprOperatorFactory( + JNIEnv *env, jclass jObj, jobjectArray jGroupByChannel, jobjectArray jAggChannels, jobjectArray jAggChannelsFilter, + jstring jSourceType, jintArray jAggFuncType, jintArray jMaskCols, jobjectArray jOutputType, + jbooleanArray jInputRaws, jbooleanArray jOutputPartials, jstring jOperatorConfig) +{ + // groupby channel and id + auto groupByNum = static_cast(env->GetArrayLength(jGroupByChannel)); + std::string groupByKeys[groupByNum]; + GetExpressions(env, jGroupByChannel, groupByKeys, groupByNum); + + auto aggChannelsLength = static_cast(env->GetArrayLength(jAggChannels)); + std::vector> aggKeysVector; + std::vector aggColsNums; + for (int i = 0; i < aggChannelsLength; ++i) { + auto jAggChannel = static_cast(env->GetObjectArrayElement(jAggChannels, i)); + auto aggChannelCharPtr = env->GetStringUTFChars(jAggChannel, JNI_FALSE); + std::vector expressions; + DeserializeJsonToArray(aggChannelCharPtr, expressions); + aggKeysVector.push_back(expressions); + aggColsNums.push_back(expressions.size()); + } + + // parse string expression + auto sourceTypesCharPtr = env->GetStringUTFChars(jSourceType, JNI_FALSE); + auto sourceDataTypes = Deserialize(sourceTypesCharPtr); + + std::vector outDataTypes; + GetDataTypesVector(env, jOutputType, outDataTypes); + + std::vector aggFuncTypes; + GetIntVector(env, jAggFuncType, aggFuncTypes); + + auto aggFilterCount = env->GetArrayLength(jAggChannelsFilter); + std::string aggFilterArr[aggFilterCount]; + GetExpressions(env, jAggChannelsFilter, aggFilterArr, aggFilterCount); + + std::vector maskColumns; + GetIntVector(env, jMaskCols, maskColumns); + + std::vector inputRaws; + GetBoolVector(env, jInputRaws, inputRaws); + std::vector outputPartials; + GetBoolVector(env, jOutputPartials, outputPartials); + + auto operatorConfigChars = env->GetStringUTFChars(jOperatorConfig, JNI_FALSE); + auto operatorConfig = OperatorConfig::DeserializeOperatorConfig(operatorConfigChars); + env->ReleaseStringUTFChars(jOperatorConfig, operatorConfigChars); + + std::vector groupByKeysExprs; + JNI_METHOD_START + // parse the expressions + GetExprsFromJson(groupByKeys, groupByNum, groupByKeysExprs); + JNI_METHOD_END(0L) + + std::vector> aggKeysExprsVector; + for (int i = 0; i < aggChannelsLength; ++i) { + std::vector aggKeysExprs; + JNI_METHOD_START + // parse the expressions + GetExprsFromJson(aggKeysVector.at(i), aggColsNums.at(i), aggKeysExprs); + JNI_METHOD_END(0L) + aggKeysExprsVector.push_back(aggKeysExprs); + } + + // parse the filter expressions + auto aggChannelsFilterLength = static_cast(env->GetArrayLength(jAggChannelsFilter)); + std::vector aggFilterExprs; + JNI_METHOD_START + // parse the expressions + GetFilterExprsFromJson(aggFilterArr, aggChannelsFilterLength, aggFilterExprs); + JNI_METHOD_END_WITH_THREE_EXPRS(0L, groupByKeysExprs, aggKeysExprsVector, aggFilterExprs) + + HashAggregationWithExprOperatorFactory *nativeOperatorFactory = nullptr; + JNI_METHOD_START + nativeOperatorFactory = + new HashAggregationWithExprOperatorFactory(groupByKeysExprs, groupByNum, aggKeysExprsVector, aggFilterExprs, + sourceDataTypes, outDataTypes, aggFuncTypes, maskColumns, inputRaws, outputPartials, operatorConfig); + JNI_METHOD_END_WITH_THREE_EXPRS(0L, groupByKeysExprs, aggKeysExprsVector, aggFilterExprs) + + Expr::DeleteExprs(groupByKeysExprs); + Expr::DeleteExprs(aggKeysExprsVector); + Expr::DeleteExprs(aggFilterExprs); + + return reinterpret_cast(static_cast(nativeOperatorFactory)); +} + +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_aggregator_OmniAggregationWithExprOperatorFactory_createAggregationWithExprOperatorFactory( + JNIEnv *env, jclass jObj, jobjectArray jGroupByChannel, jobjectArray jAggChannels, jobjectArray jAggChannelsFilter, + jstring jSourceType, jintArray jAggFuncType, jintArray jMaskCols, jobjectArray jOutputType, + jbooleanArray jInputRaws, jbooleanArray jOutputPartials, jstring jOperatorConfig) +{ + // groupby channel and id + auto groupByNum = static_cast(env->GetArrayLength(jGroupByChannel)); + std::string groupByKeys[groupByNum]; + GetExpressions(env, jGroupByChannel, groupByKeys, groupByNum); + + auto aggChannelsLength = static_cast(env->GetArrayLength(jAggChannels)); + std::vector> aggKeysVector; + std::vector aggColsNums; + for (int i = 0; i < aggChannelsLength; ++i) { + auto jAggChannel = static_cast(env->GetObjectArrayElement(jAggChannels, i)); + auto aggChannelCharPtr = env->GetStringUTFChars(jAggChannel, JNI_FALSE); + std::vector expressions; + DeserializeJsonToArray(aggChannelCharPtr, expressions); + aggKeysVector.push_back(expressions); + aggColsNums.push_back(expressions.size()); + } + + auto sourceTypesCharPtr = env->GetStringUTFChars(jSourceType, JNI_FALSE); + auto sourceDataTypes = Deserialize(sourceTypesCharPtr); + + std::vector outDataTypes; + GetDataTypesVector(env, jOutputType, outDataTypes); + + std::vector aggFuncTypes; + GetIntVector(env, jAggFuncType, aggFuncTypes); + std::vector maskColumns; + GetIntVector(env, jMaskCols, maskColumns); + + auto aggFilterCount = env->GetArrayLength(jAggChannelsFilter); + std::string aggFilterArr[aggFilterCount]; + GetExpressions(env, jAggChannelsFilter, aggFilterArr, aggFilterCount); + + std::vector inputRaws; + GetBoolVector(env, jInputRaws, inputRaws); + std::vector outputPartials; + GetBoolVector(env, jOutputPartials, outputPartials); + + auto operatorConfigChars = env->GetStringUTFChars(jOperatorConfig, JNI_FALSE); + auto operatorConfig = OperatorConfig::DeserializeOperatorConfig(operatorConfigChars); + auto overflowConfig = operatorConfig.GetOverflowConfig(); + auto isStatisticalAggregate = operatorConfig.IsStatisticalAggregate(); + env->ReleaseStringUTFChars(jOperatorConfig, operatorConfigChars); + + std::vector groupByKeysExprs; + JNI_METHOD_START + // parse the expressions + GetExprsFromJson(groupByKeys, groupByNum, groupByKeysExprs); + JNI_METHOD_END(0L) + + std::vector> aggKeysExprsVector; + for (int i = 0; i < aggChannelsLength; ++i) { + std::vector aggKeysExprs; + JNI_METHOD_START + // parse the expressions + GetExprsFromJson(aggKeysVector.at(i), aggColsNums.at(i), aggKeysExprs); + JNI_METHOD_END(0L) + aggKeysExprsVector.push_back(aggKeysExprs); + } + // parse the filter expressions + auto aggChannelsFilterLength = static_cast(env->GetArrayLength(jAggChannelsFilter)); + std::vector aggFilterExprs; + JNI_METHOD_START + // parse the expressions + GetFilterExprsFromJson(aggFilterArr, aggChannelsFilterLength, aggFilterExprs); + JNI_METHOD_END_WITH_THREE_EXPRS(0L, groupByKeysExprs, aggKeysExprsVector, aggFilterExprs) + AggregationWithExprOperatorFactory *nativeOperatorFactory = nullptr; + JNI_METHOD_START + nativeOperatorFactory = new AggregationWithExprOperatorFactory(groupByKeysExprs, groupByNum, aggKeysExprsVector, + sourceDataTypes, outDataTypes, aggFuncTypes, aggFilterExprs, maskColumns, inputRaws, outputPartials, + overflowConfig, isStatisticalAggregate); + JNI_METHOD_END_WITH_THREE_EXPRS(0L, groupByKeysExprs, aggKeysExprsVector, aggFilterExprs) + + Expr::DeleteExprs(groupByKeysExprs); + Expr::DeleteExprs(aggKeysExprsVector); + Expr::DeleteExprs(aggFilterExprs); + + return reinterpret_cast(static_cast(nativeOperatorFactory)); +} + + +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_topn_OmniTopNWithExprOperatorFactory_createTopNWithExprOperatorFactory(JNIEnv *env, + jclass jObj, jstring jSourceTypes, jint jN, jint jOffset, jobjectArray jSortKeys, jintArray jSortAsc, + jintArray jSortNullFirsts, jstring jOperatorConfig) +{ + auto sourceTypesCharPtr = env->GetStringUTFChars(jSourceTypes, JNI_FALSE); + jint sortKeyCount = env->GetArrayLength(jSortKeys); + std::string sortKeysArr[sortKeyCount]; + GetExpressions(env, jSortKeys, sortKeysArr, sortKeyCount); + + jint *sortAsc = env->GetIntArrayElements(jSortAsc, JNI_FALSE); + jint *sortNullFirsts = env->GetIntArrayElements(jSortNullFirsts, JNI_FALSE); + auto limit = (int32_t)jN; + auto offset = static_cast(jOffset); + auto sourceDataTypes = Deserialize(sourceTypesCharPtr); + env->ReleaseStringUTFChars(jSourceTypes, sourceTypesCharPtr); + + auto operatorConfigChars = env->GetStringUTFChars(jOperatorConfig, JNI_FALSE); + auto operatorConfig = OperatorConfig::DeserializeOperatorConfig(operatorConfigChars); + auto overflowConfig = operatorConfig.GetOverflowConfig(); + env->ReleaseStringUTFChars(jOperatorConfig, operatorConfigChars); + + std::vector sortKeyExprArr; + JNI_METHOD_START + // parse the expressions + GetExprsFromJson(sortKeysArr, sortKeyCount, sortKeyExprArr); + JNI_METHOD_END(0L) + + TopNWithExprOperatorFactory *topNWithExprOperatorFactory = nullptr; + JNI_METHOD_START + topNWithExprOperatorFactory = new TopNWithExprOperatorFactory(sourceDataTypes, limit, offset, sortKeyExprArr, + sortAsc, sortNullFirsts, sortKeyCount, overflowConfig); + JNI_METHOD_END_WITH_EXPRS_RELEASE(0L, sortKeyExprArr) + Expr::DeleteExprs(sortKeyExprArr); + + env->ReleaseIntArrayElements(jSortAsc, sortAsc, 0); + env->ReleaseIntArrayElements(jSortNullFirsts, sortNullFirsts, 0); + return reinterpret_cast(topNWithExprOperatorFactory); +} + +JNIEXPORT void JNICALL Java_nova_hetu_omniruntime_operator_OmniOperatorFactory_closeNativeOperatorFactory(JNIEnv *env, + jclass jclz, jlong jNativeOperatorFactory) +{ + auto nativeOperatorFactory = reinterpret_cast(jNativeOperatorFactory); + delete nativeOperatorFactory; +} + +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_operator_limit_OmniLimitOperatorFactory_createLimitOperatorFactory( + JNIEnv *env, jclass jObj, jint jLimit, jint jOffset) +{ + LimitOperatorFactory *limitOperatorFactory = nullptr; + JNI_METHOD_START + limitOperatorFactory = LimitOperatorFactory::CreateLimitOperatorFactory(jLimit, jOffset); + JNI_METHOD_END(0L) + return reinterpret_cast(static_cast(limitOperatorFactory)); +} + +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_limit_OmniDistinctLimitOperatorFactory_createDistinctLimitOperatorFactory( + JNIEnv *env, jclass jObj, jstring jSoureTypes, jintArray jDistinctChannel, jint jHashChannel, jlong jLimit) +{ + auto distinctColCount = (int32_t)env->GetArrayLength(jDistinctChannel); + jint *distinctCols = env->GetIntArrayElements(jDistinctChannel, JNI_FALSE); + + const char *sourceTypesCharPtr = env->GetStringUTFChars(jSoureTypes, JNI_FALSE); + auto sourceTypes = Deserialize(sourceTypesCharPtr); + env->ReleaseStringUTFChars(jSoureTypes, sourceTypesCharPtr); + + DistinctLimitOperatorFactory *distinctLimitOperatorFactory = nullptr; + JNI_METHOD_START + distinctLimitOperatorFactory = DistinctLimitOperatorFactory::CreateDistinctLimitOperatorFactory(sourceTypes, + distinctCols, distinctColCount, jHashChannel, jLimit); + JNI_METHOD_END(0L) + env->ReleaseIntArrayElements(jDistinctChannel, distinctCols, 0); + return reinterpret_cast(static_cast(distinctLimitOperatorFactory)); +} + +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_join_OmniSmjStreamedTableWithExprOperatorFactory_createSmjStreamedTableWithExprOperatorFactory( + JNIEnv *env, jclass jObj, jstring jSourceTypes, jobjectArray jEqualKeyExprs, jintArray jOutputChannels, + jint jJoinType, jstring jFilter, jstring jOperatorConfig) +{ + switch ((JoinType)jJoinType) { + case JoinType::OMNI_JOIN_TYPE_INNER: + case JoinType::OMNI_JOIN_TYPE_LEFT: + case JoinType::OMNI_JOIN_TYPE_FULL: + case JoinType::OMNI_JOIN_TYPE_LEFT_SEMI: + case JoinType::OMNI_JOIN_TYPE_LEFT_ANTI: + break; + default: + return 0L; + } + + auto streamedTypesChars = env->GetStringUTFChars(jSourceTypes, JNI_FALSE); + auto streamedDataTypes = Deserialize(streamedTypesChars); + env->ReleaseStringUTFChars(jSourceTypes, streamedTypesChars); + + auto streamedKeyExpsCount = env->GetArrayLength(jEqualKeyExprs); + std::string streamedKeyExpsArr[streamedKeyExpsCount]; + GetExpressions(env, jEqualKeyExprs, streamedKeyExpsArr, streamedKeyExpsCount); + + auto streamedOutputColsCnt = env->GetArrayLength(jOutputChannels); + auto streamedOutputCols = env->GetIntArrayElements(jOutputChannels, JNI_FALSE); + + auto operatorConfigChars = env->GetStringUTFChars(jOperatorConfig, JNI_FALSE); + auto operatorConfig = OperatorConfig::DeserializeOperatorConfig(operatorConfigChars); + auto overflowConfig = operatorConfig.GetOverflowConfig(); + env->ReleaseStringUTFChars(jOperatorConfig, operatorConfigChars); + + std::string filterExpression; + if (jFilter == nullptr) { + filterExpression = ""; + } else { + auto filterChars = env->GetStringUTFChars(jFilter, JNI_FALSE); + filterExpression = std::string(filterChars); + env->ReleaseStringUTFChars(jFilter, filterChars); + } + + std::vector streamedKeysArrExprs; + JNI_METHOD_START + // parse the expressions + GetExprsFromJson(streamedKeyExpsArr, streamedKeyExpsCount, streamedKeysArrExprs); + JNI_METHOD_END(0L) + + StreamedTableWithExprOperatorFactory *operatorFactory = nullptr; + JNI_METHOD_START + operatorFactory = StreamedTableWithExprOperatorFactory::CreateStreamedTableWithExprOperatorFactory( + streamedDataTypes, streamedKeysArrExprs, streamedKeyExpsCount, streamedOutputCols, streamedOutputColsCnt, + (JoinType)jJoinType, filterExpression, overflowConfig); + JNI_METHOD_END_WITH_EXPRS_RELEASE(0L, streamedKeysArrExprs) + Expr::DeleteExprs(streamedKeysArrExprs); + + env->ReleaseIntArrayElements(jOutputChannels, streamedOutputCols, 0); + return reinterpret_cast(static_cast(operatorFactory)); +} + +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_join_OmniSmjBufferedTableWithExprOperatorFactory_createSmjBufferedTableWithExprOperatorFactory( + JNIEnv *env, jclass jObj, jstring jSourceTypes, jobjectArray jEqualKeyExprs, jintArray jOutputChannels, + jlong jSmjStreamedTableWithExprOperatorFactory, jstring jOperatorConfig) +{ + auto bufferedTypesChars = env->GetStringUTFChars(jSourceTypes, JNI_FALSE); + auto bufferedDataTypes = Deserialize(bufferedTypesChars); + env->ReleaseStringUTFChars(jSourceTypes, bufferedTypesChars); + + auto bufferedKeyExpsCnt = env->GetArrayLength(jEqualKeyExprs); + std::string bufferedKeyExpsArr[bufferedKeyExpsCnt]; + GetExpressions(env, jEqualKeyExprs, bufferedKeyExpsArr, bufferedKeyExpsCnt); + + auto bufferedOutputCols = env->GetIntArrayElements(jOutputChannels, JNI_FALSE); + auto bufferedOutputColsCnt = env->GetArrayLength(jOutputChannels); + + auto operatorConfigChars = env->GetStringUTFChars(jOperatorConfig, JNI_FALSE); + auto operatorConfig = OperatorConfig::DeserializeOperatorConfig(operatorConfigChars); + auto overflowConfig = operatorConfig.GetOverflowConfig(); + env->ReleaseStringUTFChars(jOperatorConfig, operatorConfigChars); + + std::vector bufferedKeysArrExprs; + JNI_METHOD_START + // parse the expressions + GetExprsFromJson(bufferedKeyExpsArr, bufferedKeyExpsCnt, bufferedKeysArrExprs); + JNI_METHOD_END(0L) + + BufferedTableWithExprOperatorFactory *operatorFactory = nullptr; + JNI_METHOD_START + operatorFactory = BufferedTableWithExprOperatorFactory::CreateBufferedTableWithExprOperatorFactory( + bufferedDataTypes, bufferedKeysArrExprs, bufferedKeyExpsCnt, bufferedOutputCols, bufferedOutputColsCnt, + jSmjStreamedTableWithExprOperatorFactory, overflowConfig); + JNI_METHOD_END_WITH_EXPRS_RELEASE(0L, bufferedKeysArrExprs) + Expr::DeleteExprs(bufferedKeysArrExprs); + + env->ReleaseIntArrayElements(jOutputChannels, bufferedOutputCols, 0); + return reinterpret_cast(static_cast(operatorFactory)); +} + +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_join_OmniSmjStreamedTableWithExprOperatorFactoryV3_createSmjStreamedTableWithExprOperatorFactoryV3( + JNIEnv *env, jclass jObj, jstring jSourceTypes, jobjectArray jEqualKeyExprs, jintArray jOutputChannels, + jint jJoinType, jstring jFilter, jstring jOperatorConfig) +{ + switch ((JoinType)jJoinType) { + case JoinType::OMNI_JOIN_TYPE_INNER: + case JoinType::OMNI_JOIN_TYPE_LEFT: + case JoinType::OMNI_JOIN_TYPE_FULL: + case JoinType::OMNI_JOIN_TYPE_LEFT_SEMI: + case JoinType::OMNI_JOIN_TYPE_LEFT_ANTI: + break; + default: + return 0L; + } + + auto streamedTypesChars = env->GetStringUTFChars(jSourceTypes, JNI_FALSE); + auto streamedDataTypes = Deserialize(streamedTypesChars); + env->ReleaseStringUTFChars(jSourceTypes, streamedTypesChars); + + auto streamedKeyExpsCount = env->GetArrayLength(jEqualKeyExprs); + std::string streamedKeyExpsArr[streamedKeyExpsCount]; + GetExpressions(env, jEqualKeyExprs, streamedKeyExpsArr, streamedKeyExpsCount); + + auto streamedOutputColsCnt = env->GetArrayLength(jOutputChannels); + auto streamedOutputColsPtr = env->GetIntArrayElements(jOutputChannels, JNI_FALSE); + std::vector streamedOutputCols(streamedOutputColsPtr, streamedOutputColsPtr + streamedOutputColsCnt); + env->ReleaseIntArrayElements(jOutputChannels, streamedOutputColsPtr, 0); + + auto operatorConfigChars = env->GetStringUTFChars(jOperatorConfig, JNI_FALSE); + auto operatorConfig = OperatorConfig::DeserializeOperatorConfig(operatorConfigChars); + env->ReleaseStringUTFChars(jOperatorConfig, operatorConfigChars); + + std::string filterExpression; + if (jFilter == nullptr) { + filterExpression = ""; + } else { + auto filterChars = env->GetStringUTFChars(jFilter, JNI_FALSE); + filterExpression = std::string(filterChars); + env->ReleaseStringUTFChars(jFilter, filterChars); + } + + std::vector streamedKeysArrExprs; + JNI_METHOD_START + // parse the expressions + GetExprsFromJson(streamedKeyExpsArr, streamedKeyExpsCount, streamedKeysArrExprs); + JNI_METHOD_END(0L) + + StreamedTableWithExprOperatorFactoryV3 *operatorFactory = nullptr; + JNI_METHOD_START + operatorFactory = + StreamedTableWithExprOperatorFactoryV3::CreateStreamedTableWithExprOperatorFactory(streamedDataTypes, + streamedKeysArrExprs, streamedOutputCols, (JoinType)jJoinType, filterExpression, operatorConfig); + JNI_METHOD_END_WITH_EXPRS_RELEASE(0L, streamedKeysArrExprs) + Expr::DeleteExprs(streamedKeysArrExprs); + + return reinterpret_cast(static_cast(operatorFactory)); +} + +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_join_OmniSmjBufferedTableWithExprOperatorFactoryV3_createSmjBufferedTableWithExprOperatorFactoryV3( + JNIEnv *env, jclass jObj, jstring jSourceTypes, jobjectArray jEqualKeyExprs, jintArray jOutputChannels, + jlong jSmjStreamedTableWithExprOperatorFactory, jstring jOperatorConfig) +{ + auto bufferedTypesChars = env->GetStringUTFChars(jSourceTypes, JNI_FALSE); + auto bufferedDataTypes = Deserialize(bufferedTypesChars); + env->ReleaseStringUTFChars(jSourceTypes, bufferedTypesChars); + + auto bufferedKeyExpsCnt = env->GetArrayLength(jEqualKeyExprs); + std::string bufferedKeyExpsArr[bufferedKeyExpsCnt]; + GetExpressions(env, jEqualKeyExprs, bufferedKeyExpsArr, bufferedKeyExpsCnt); + + auto bufferedOutputColsPtr = env->GetIntArrayElements(jOutputChannels, JNI_FALSE); + auto bufferedOutputColsCnt = env->GetArrayLength(jOutputChannels); + std::vector bufferedOutputCols(bufferedOutputColsPtr, bufferedOutputColsPtr + bufferedOutputColsCnt); + env->ReleaseIntArrayElements(jOutputChannels, bufferedOutputColsPtr, 0); + + auto operatorConfigChars = env->GetStringUTFChars(jOperatorConfig, JNI_FALSE); + auto operatorConfig = OperatorConfig::DeserializeOperatorConfig(operatorConfigChars); + env->ReleaseStringUTFChars(jOperatorConfig, operatorConfigChars); + + std::vector bufferedKeysArrExprs; + JNI_METHOD_START + // parse the expressions + GetExprsFromJson(bufferedKeyExpsArr, bufferedKeyExpsCnt, bufferedKeysArrExprs); + JNI_METHOD_END(0L) + + BufferedTableWithExprOperatorFactoryV3 *operatorFactory = nullptr; + JNI_METHOD_START + operatorFactory = BufferedTableWithExprOperatorFactoryV3::CreateBufferedTableWithExprOperatorFactory( + bufferedDataTypes, bufferedKeysArrExprs, bufferedOutputCols, + (StreamedTableWithExprOperatorFactoryV3 *)jSmjStreamedTableWithExprOperatorFactory, operatorConfig); + JNI_METHOD_END_WITH_EXPRS_RELEASE(0L, bufferedKeysArrExprs) + Expr::DeleteExprs(bufferedKeysArrExprs); + + return reinterpret_cast(static_cast(operatorFactory)); +} + +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_operator_OmniExprVerify_exprVerify(JNIEnv *env, jclass jObj, + jstring jInputTypes, jint jInputLength, jstring jExpression, jobjectArray jProjections, jint jProjectLength, + jint jParseFormat) +{ + omniruntime::expressions::Expr *filterExpr = nullptr; + std::vector projectExprs; + JNI_METHOD_START + auto expressionCharPtr = env->GetStringUTFChars(jExpression, JNI_FALSE); + std::string filterExpression = std::string(expressionCharPtr); + auto inputTypesCharPtr = env->GetStringUTFChars(jInputTypes, JNI_FALSE); + auto inputDataTypes = Deserialize(inputTypesCharPtr); + env->ReleaseStringUTFChars(jInputTypes, inputTypesCharPtr); + env->ReleaseStringUTFChars(jExpression, expressionCharPtr); + auto inputLength = (int32_t)jInputLength; + + auto parseFormat = static_cast((int8_t)jParseFormat); + std::string projectExpressions[jProjectLength]; + GetExpressions(env, jProjections, projectExpressions, jProjectLength); + + if (parseFormat == JSON) { + if (!filterExpression.empty()) { + auto filterJsonExpr = nlohmann::json::parse(filterExpression); + filterExpr = JSONParser::ParseJSON(filterJsonExpr); + if (filterExpr == nullptr) { + LogWarn("The filter expression is not supported: %s", filterJsonExpr.dump(1).c_str()); + return 0; + } + } + nlohmann::json jsonProjectExprs[jProjectLength]; + for (int32_t i = 0; i < jProjectLength; i++) { + jsonProjectExprs[i] = nlohmann::json::parse(projectExpressions[i]); + } + projectExprs = JSONParser::ParseJSON(jsonProjectExprs, jProjectLength); + } else { + Parser parser; + if (!filterExpression.empty()) { + filterExpr = parser.ParseRowExpression(filterExpression, inputDataTypes, inputLength); + } + projectExprs = parser.ParseExpressions(projectExpressions, jProjectLength, inputDataTypes); + } + + if ((!filterExpression.empty() && filterExpr == nullptr) || + (static_cast(jProjectLength) != projectExprs.size())) { + delete filterExpr; + Expr::DeleteExprs(projectExprs); + return 0; + } + + if (filterExpr != nullptr && !CheckExpressionSupported(false, filterExpr)) { + delete filterExpr; + Expr::DeleteExprs(projectExprs); + return 0; + } + if (!CheckExpressionsSupported(false, projectExprs)) { + delete filterExpr; + Expr::DeleteExprs(projectExprs); + return 0; + } + JNI_METHOD_END_WITH_MULTI_EXPRS(0, { filterExpr }, projectExprs) + Expr::DeleteExprs({ filterExpr }); + Expr::DeleteExprs(projectExprs); + return 1; +} + +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_filter_OmniBloomFilterOperatorFactory_createBloomFilterOperatorFactory(JNIEnv *env, + jclass jObj, jint jInputVersion) +{ + auto inputVersion = (int32_t)jInputVersion; + + BloomFilterOperatorFactory *factory = nullptr; + JNI_METHOD_START + factory = new BloomFilterOperatorFactory(inputVersion); + JNI_METHOD_END(0L) + + return reinterpret_cast(static_cast(factory)); +} + +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_topnsort_OmniTopNSortWithExprOperatorFactory_createTopNSortWithExprOperatorFactory( + JNIEnv *env, jclass jObj, jstring jInputTypes, jint jLimitN, jboolean jIsStrict, jobjectArray jPartitionKeys, + jobjectArray jSortKeys, jintArray jSortAsc, jintArray jSortNullFirsts, jstring jOperatorConfig) +{ + auto inputTypesCharPtr = env->GetStringUTFChars(jInputTypes, JNI_FALSE); + auto inputTypes = Deserialize(inputTypesCharPtr); + env->ReleaseStringUTFChars(jInputTypes, inputTypesCharPtr); + + jint partitionKeyCount = env->GetArrayLength(jPartitionKeys); + std::string partitionKeysArr[partitionKeyCount]; + GetExpressions(env, jPartitionKeys, partitionKeysArr, partitionKeyCount); + + jint sortKeyCount = env->GetArrayLength(jSortKeys); + std::string sortKeysArr[sortKeyCount]; + GetExpressions(env, jSortKeys, sortKeysArr, sortKeyCount); + + std::vector partitionKeys; + JNI_METHOD_START + // parse the expressions + GetExprsFromJson(partitionKeysArr, partitionKeyCount, partitionKeys); + JNI_METHOD_END(0L) + std::vector sortKeys; + JNI_METHOD_START + // parse the expressions + GetExprsFromJson(sortKeysArr, sortKeyCount, sortKeys); + JNI_METHOD_END_WITH_EXPRS_RELEASE(0L, partitionKeys) + + jint *sortAscPtr = env->GetIntArrayElements(jSortAsc, JNI_FALSE); + jint *sortNullFirstsPtr = env->GetIntArrayElements(jSortNullFirsts, JNI_FALSE); + std::vector sortAscendings(sortAscPtr, sortAscPtr + sortKeyCount); + std::vector sortNullFirsts(sortNullFirstsPtr, sortNullFirstsPtr + sortKeyCount); + env->ReleaseIntArrayElements(jSortAsc, sortAscPtr, 0); + env->ReleaseIntArrayElements(jSortNullFirsts, sortNullFirstsPtr, 0); + + auto operatorConfigChars = env->GetStringUTFChars(jOperatorConfig, JNI_FALSE); + auto operatorConfig = OperatorConfig::DeserializeOperatorConfig(operatorConfigChars); + auto overflowConfig = operatorConfig.GetOverflowConfig(); + env->ReleaseStringUTFChars(jOperatorConfig, operatorConfigChars); + + TopNSortWithExprOperatorFactory *operatorFactory = nullptr; + JNI_METHOD_START + operatorFactory = new TopNSortWithExprOperatorFactory(inputTypes, jLimitN, jIsStrict, partitionKeys, sortKeys, + sortAscendings, sortNullFirsts, overflowConfig); + JNI_METHOD_END_WITH_MULTI_EXPRS(0L, partitionKeys, sortKeys) + + Expr::DeleteExprs(partitionKeys); + Expr::DeleteExprs(sortKeys); + return reinterpret_cast(static_cast(operatorFactory)); +} + +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_window_OmniWindowGroupLimitWithExprOperatorFactory_createWindowGroupLimitWithExprOperatorFactory( + JNIEnv *env, jclass jObj, jstring jInputTypes, jint jN, jstring jFuncName, jobjectArray jPartitionKeys, + jobjectArray jSortKeys, jintArray jSortAsc, jintArray jSortNullFirsts, jstring jOperatorConfig) +{ + auto inputTypesCharPtr = env->GetStringUTFChars(jInputTypes, JNI_FALSE); + auto inputTypes = Deserialize(inputTypesCharPtr); + env->ReleaseStringUTFChars(jInputTypes, inputTypesCharPtr); + + auto funcNameCharPtr = env->GetStringUTFChars(jFuncName, JNI_FALSE); + std::string funcName = std::string(funcNameCharPtr); + env->ReleaseStringUTFChars(jFuncName, funcNameCharPtr); + + jint partitionKeyCount = env->GetArrayLength(jPartitionKeys); + std::string partitionKeysArr[partitionKeyCount]; + GetExpressions(env, jPartitionKeys, partitionKeysArr, partitionKeyCount); + + jint sortKeyCount = env->GetArrayLength(jSortKeys); + std::string sortKeysArr[sortKeyCount]; + GetExpressions(env, jSortKeys, sortKeysArr, sortKeyCount); + + std::vector partitionKeys; + // parse the expressions + JNI_METHOD_START + GetExprsFromJson(partitionKeysArr, partitionKeyCount, partitionKeys); + JNI_METHOD_END(0L) + std::vector sortKeys; + JNI_METHOD_START + GetExprsFromJson(sortKeysArr, sortKeyCount, sortKeys); + JNI_METHOD_END_WITH_EXPRS_RELEASE(0L, partitionKeys) + + jint *sortAscPtr = env->GetIntArrayElements(jSortAsc, JNI_FALSE); + jint *sortNullFirstsPtr = env->GetIntArrayElements(jSortNullFirsts, JNI_FALSE); + std::vector sortAscendings(sortAscPtr, sortAscPtr + sortKeyCount); + std::vector sortNullFirsts(sortNullFirstsPtr, sortNullFirstsPtr + sortKeyCount); + env->ReleaseIntArrayElements(jSortAsc, sortAscPtr, 0); + env->ReleaseIntArrayElements(jSortNullFirsts, sortNullFirstsPtr, 0); + + auto operatorConfigChars = env->GetStringUTFChars(jOperatorConfig, JNI_FALSE); + auto operatorConfig = OperatorConfig::DeserializeOperatorConfig(operatorConfigChars); + auto overflowConfig = operatorConfig.GetOverflowConfig(); + env->ReleaseStringUTFChars(jOperatorConfig, operatorConfigChars); + + WindowGroupLimitWithExprOperatorFactory *operatorFactory = nullptr; + JNI_METHOD_START + operatorFactory = new WindowGroupLimitWithExprOperatorFactory(inputTypes, jN, funcName, partitionKeys, sortKeys, + sortAscendings, sortNullFirsts, overflowConfig); + JNI_METHOD_END_WITH_MULTI_EXPRS(0L, partitionKeys, sortKeys) + + Expr::DeleteExprs(partitionKeys); + Expr::DeleteExprs(sortKeys); + return reinterpret_cast(static_cast(operatorFactory)); +} + +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_join_OmniNestedLoopJoinBuildOperatorFactory_createNestedLoopJoinBuildOperatorFactory( + JNIEnv *env, jclass jObj, jstring jBuildTypes, jintArray jBuildOutputCols) +{ + auto buildTypesCharPtr = env->GetStringUTFChars(jBuildTypes, JNI_FALSE); + auto buildOutputColsCount = env->GetArrayLength(jBuildOutputCols); + auto buildOutputColsArr = env->GetIntArrayElements(jBuildOutputCols, JNI_FALSE); + + auto buildDataTypes = Deserialize(buildTypesCharPtr); + env->ReleaseStringUTFChars(jBuildTypes, buildTypesCharPtr); + + NestedLoopJoinBuildOperatorFactory *nestedLoopJoinBuildOperatorFactory = nullptr; + JNI_METHOD_START + nestedLoopJoinBuildOperatorFactory = + new NestedLoopJoinBuildOperatorFactory(buildDataTypes, buildOutputColsArr, buildOutputColsCount); + JNI_METHOD_END(0L) + + env->ReleaseIntArrayElements(jBuildOutputCols, buildOutputColsArr, 0); + return reinterpret_cast(static_cast(nestedLoopJoinBuildOperatorFactory)); +} + +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_join_OmniNestedLoopJoinLookupOperatorFactory_createNestedLoopJoinLookupOperatorFactory( + JNIEnv *env, jclass jObj, jint jJoinType, jstring jProbeTypes, jintArray jProbeOutputCols, jstring jFilter, + jlong jNestedLoopJoinBuildOperatorFactory, jstring jOperatorConfig) +{ + auto probeTypesCharPtr = env->GetStringUTFChars(jProbeTypes, JNI_FALSE); + auto probeOutputColsCount = env->GetArrayLength(jProbeOutputCols); + auto probeOutputColsArr = env->GetIntArrayElements(jProbeOutputCols, JNI_FALSE); + auto filterChars = env->GetStringUTFChars(jFilter, JNI_FALSE); + auto probeDataTypes = Deserialize(probeTypesCharPtr); + std::string filterExpression = std::string(filterChars); + + env->ReleaseStringUTFChars(jProbeTypes, probeTypesCharPtr); + env->ReleaseStringUTFChars(jFilter, filterChars); + Expr *filterExpr = nullptr; + JNI_METHOD_START + // extract the expression and the BuildDataTypes to parse the expression + filterExpr = CreateJoinFilterExpr(filterExpression); + JNI_METHOD_END(0L) + NestLoopJoinLookupOperatorFactory *nestLoopJoinLookupOperatorFactory = nullptr; + JNI_METHOD_START + auto operatorConfigChars = env->GetStringUTFChars(jOperatorConfig, JNI_FALSE); + auto operatorConfig = OperatorConfig::DeserializeOperatorConfig(operatorConfigChars); + auto *overflowConfig = operatorConfig.GetOverflowConfig(); + env->ReleaseStringUTFChars(jOperatorConfig, operatorConfigChars); + auto joinType = (JoinType)jJoinType; + nestLoopJoinLookupOperatorFactory = + NestLoopJoinLookupOperatorFactory::CreateNestLoopJoinLookupOperatorFactory(joinType, probeDataTypes, + probeOutputColsArr, probeOutputColsCount, filterExpr, jNestedLoopJoinBuildOperatorFactory, overflowConfig); + JNI_METHOD_END(0L) + Expr::DeleteExprs({ filterExpr }); + env->ReleaseIntArrayElements(jProbeOutputCols, probeOutputColsArr, 0); + return reinterpret_cast(static_cast(nestLoopJoinLookupOperatorFactory)); +} \ No newline at end of file diff --git a/bindings/java/src/main/cpp/src/jni_operator_factory.h b/bindings/java/src/main/cpp/src/jni_operator_factory.h new file mode 100644 index 0000000..b7e395d --- /dev/null +++ b/bindings/java/src/main/cpp/src/jni_operator_factory.h @@ -0,0 +1,360 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. + * Description: Type Operator Factory Header + */ +#ifndef JNI_OPERATOR_FACTORY_H +#define JNI_OPERATOR_FACTORY_H + +#include +#include "expression/parserhelper.h" +#include "expression/jsonparser/jsonparser.h" +#include "codegen/expr_evaluator.h" + +#ifdef __cplusplus +extern "C" { +#endif +/* + * Class: nova_hetu_omniruntime_operator_OmniOperatorFactory + * Method: createOperatorNative + * Signature: (JJ)J + */ +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_operator_OmniOperatorFactory_createOperatorNative(JNIEnv *env, + jobject jObj, jlong jNativeFactoryObj); + +/* + * Class: nova_hetu_omniruntime_operator_OmniOperatorFactory + * Method: closeNativeOperatorFactory + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_nova_hetu_omniruntime_operator_OmniOperatorFactory_closeNativeOperatorFactory(JNIEnv *env, + jclass jclz, jlong jNativeOperatorFactory); + +/* + * Class: nova_hetu_omniruntime_operator_sort_OmniSortOperatorFactory + * Method: createSortOperatorFactory + * Signature: (Ljava/lang/String;[I[Ljava/lang/String;[I[ILjava/lang/String;)J + */ +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_operator_sort_OmniSortOperatorFactory_createSortOperatorFactory( + JNIEnv *env, jclass jObj, jstring jSourceTypes, jintArray jOutputCols, jobjectArray jSortCols, + jintArray jAscendings, jintArray jNullFirsts, jstring jOperatorConfig); + +/* + * Class: nova_hetu_omniruntime_operator_aggregator_OmniHashAggregationOperatorFactory + * Method: createHashAggregationOperatorFactory + * Signature: ([Ljava/lang/String;Ljava/lang/String;[Ljava/lang/String;Ljava/lang/String;[I[ILjava/lang/String;ZZ)J + */ +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_aggregator_OmniHashAggregationOperatorFactory_createHashAggregationOperatorFactory( + JNIEnv *env, jclass jObj, jobjectArray jGroupByChannel, jstring jGroupByType, jobjectArray jAggChannel, + jstring jAggType, jintArray jAggFuncType, jintArray jMaskCols, jstring jOutPutTye, jboolean inputRaw, + jboolean outputPartial, jstring jOperatorConfig); + +/* + * Class: nova_hetu_omniruntime_operator_aggregator_OmniAggregationOperatorFactory + * Method: createAggregationOperatorFactory + * Signature: (Ljava/lang/String;[I[I[ILjava/lang/String;ZZ)J + */ +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_aggregator_OmniAggregationOperatorFactory_createAggregationOperatorFactory( + JNIEnv *env, jclass jObj, jstring jSourceTypes, jintArray jAggFuncTypes, jintArray jAggInputCols, + jintArray jMaskCols, jstring jAggOutputTypes, jboolean inputRaw, jboolean outputPartial); + +/* + * Class: nova_hetu_omniruntime_operator_filter_OmniFilterAndProjectOperatorFactory + * Method: createFilterAndProjectOperatorFactory + * Signature: (Ljava/lang/String;ILjava/lang/String;[Ljava/lang/Object;IIZ)J + */ +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_filter_OmniFilterAndProjectOperatorFactory_createFilterAndProjectOperatorFactory( + JNIEnv *env, jclass jObj, jstring jInputTypes, jint jInputLength, jstring jExpression, jobjectArray jProjections, + jint jProjectLength, jint jParseFormat, jstring jOperatorConfig); + +/* + * Class: nova_hetu_omniruntime_operator_project_OmniProjectOperatorFactory + * Method: createProjectOperatorFactory + * Signature: (Ljava/lang/String;I[Ljava/lang/Object;IIZ)J + */ +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_project_OmniProjectOperatorFactory_createProjectOperatorFactory(JNIEnv *env, + jclass jobj, jstring jInputTypes, jint jInputLength, jobjectArray jExprs, jint jExprsLength, jint jParseFormat, + jstring jOperatorConfig); + +/* + * Class: nova_hetu_omniruntime_operator_window_OmniWindowOperatorFactory + * Method: createWindowOperatorFactory + * Signature: (Ljava/lang/String;[I[I[I[I[I[I[III[ILjava/lang/String;[I[I[I[I[I)J + */ +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_window_OmniWindowOperatorFactory_createWindowOperatorFactory(JNIEnv *env, + jobject jObj, jstring jSourceTypes, jintArray jOutputChannels, jintArray jWindowFunction, + jintArray jPartitionChannels, jintArray JPreGroupedChannels, jintArray jSortChannels, jintArray jSortOrder, + jintArray jSortNullFirsts, jint preSortedChannelPrefix, jint expectedPositions, jintArray jArgumentChannels, + jstring jWindowFunctionReturnType, jintArray jWindowFrameTypes, jintArray jWindowFrameStartTypes, + jintArray jWindowFrameStartChannels, jintArray jWindowFrameEndTypes, jintArray jWindowFrameEndChannels, + jstring jOperatorConfig); + +/* + * Class: nova_hetu_omniruntime_operator_topn_OmniTopNOperatorFactory + * Method: createTopNOperatorFactory + * Signature: (Ljava/lang/String;I[Ljava/lang/String;[I[I)J + */ +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_topn_OmniTopNOperatorFactory_createTopNOperatorFactory(JNIEnv *env, jclass jObj, + jstring jSourceTypes, jint jN, jint jOffset, jobjectArray jSortCols, jintArray jSortAsc, jintArray jSortNullFirsts); + +/* + * Class: nova_hetu_omniruntime_operator_join_OmniHashBuilderOperatorFactory + * Method: createHashBuilderOperatorFactory + * Signature: (Ljava/lang/String;[IILjava/lang/String;)J + */ +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_join_OmniHashBuilderOperatorFactory_createHashBuilderOperatorFactory(JNIEnv *env, + jclass jObj, jint jJoinType, jstring jBuildTypes, jintArray jBuildHashCols, jint jOperatorCount, + jstring jOperatorConfig); + +/* + * Class: nova_hetu_omniruntime_operator_join_OmniLookupJoinOperatorFactory + * Method: createLookupJoinOperatorFactory + * Signature: (Ljava/lang/String;[I[I[ILjava/lang/String;IJLjava/lang/String;Ljava/lang/String;)J + */ +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_join_OmniLookupJoinOperatorFactory_createLookupJoinOperatorFactory(JNIEnv *env, + jclass jObj, jstring jProbeTypes, jintArray jProbeOutputCols, jintArray jProbeHashCols, jintArray jBuildOutputCols, + jstring jBuildOutputTypes, jlong jHashBuilderOperatorFactory, + jstring jFilter, jboolean isShuffleExchangeBuildPlan, jstring jOperatorConfig); + +/* + * Class: nova_hetu_omniruntime_operator_partitionedoutput_OmniPartitionedOutPutOperatorFactory + * Method: createPartitionedOutputOperatorFactory + * Signature: (Ljava/lang/String;ZI[II[IZLjava/lang/String;[I)J + */ +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_partitionedoutput_OmniPartitionedOutPutOperatorFactory_createPartitionedOutputOperatorFactory( + JNIEnv *env, jobject jObj, jstring jSourceTypes, jboolean jReplicatesAnyRow, jint jNullChannel, + jintArray jPartitionChannels, jint jPartitionCount, jintArray jBucketToPartition, jboolean isHashPrecomputed, + jstring jHashChannelTypes, jintArray jHashChannels); + +/* + * Class: nova_hetu_omniruntime_operator_union_OmniUnionOperatorFactory + * Method: createUnionOperatorFactory + * Signature: (Ljava/lang/String;Z)J + */ +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_operator_union_OmniUnionOperatorFactory_createUnionOperatorFactory( + JNIEnv *env, jobject jObj, jstring jSourceTypes, jboolean jDistinct); + +/* + * Class: nova_hetu_omniruntime_operator_sort_OmniSortWithExprOperatorFactory + * Method: createSortWithExprOperatorFactory + * Signature: (Ljava/lang/String;[I[Ljava/lang/String;[I[ILjava/lang/String;)J + */ +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_sort_OmniSortWithExprOperatorFactory_createSortWithExprOperatorFactory(JNIEnv *env, + jclass jObj, jstring jSourceTypes, jintArray jOutputCols, jobjectArray jSortKeys, jintArray jAscendings, + jintArray jNullFirsts, jstring jOperatorConfig); + +/* + * Class: nova_hetu_omniruntime_operator_join_OmniHashBuilderWithExprOperatorFactory + * Method: createHashBuilderWithExprOperatorFactory + * Signature: (Ljava/lang/String;[Ljava/lang/String;ILjava/lang/String;)J + */ +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_join_OmniHashBuilderWithExprOperatorFactory_createHashBuilderWithExprOperatorFactory( + JNIEnv *env, jclass jObj, jint jJoinType, jint jBuildSide, jstring jBuildTypes, jobjectArray jBuildHashKeys, + jint jHashTableCount, jstring jOperatorConfig); + +/* + * Class: nova_hetu_omniruntime_operator_join_OmniLookupJoinWithExprOperatorFactory + * Method: createLookupJoinWithExprOperatorFactory + * Signature: (Ljava/lang/String;[I[Ljava/lang/String;[ILjava/lang/String;IJLjava/lang/String;Ljava/lang/String;)J + */ +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_join_OmniLookupJoinWithExprOperatorFactory_createLookupJoinWithExprOperatorFactory( + JNIEnv *env, jclass jObj, jstring jProbeTypes, jintArray jProbeOutputCols, jobjectArray jProbeHashKeys, + jintArray jBuildOutputCols, jstring jBuildOutputTypes, jlong jHashBuilderOperatorFactory, + jstring jFilter, jboolean isShuffleExchangeBuildPlan, jstring jOperatorConfig); + +/* + * Class: nova_hetu_omniruntime_operator_join_OmniLookupOuterJoinWithExprOperatorFactory + * Method: createLookupOuterJoinWithExprOperatorFactory + * Signature: (Ljava/lang/String;[I[Ljava/lang/String;[ILjava/lang/String;J)J + */ +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_join_OmniLookupOuterJoinWithExprOperatorFactory_createLookupOuterJoinWithExprOperatorFactory( + JNIEnv *env, jclass jObj, jstring jProbeTypes, jintArray jProbeOutputCols, jobjectArray jProbeHashKeys, + jintArray jBuildOutputCols, jstring jBuildOutputTypes, jlong jHashBuilderOperatorFactory); + +/* + * Class: nova_hetu_omniruntime_operator_join_OmniLookupOuterJoinOperatorFactory + * Method: createLookupOuterJoinOperatorFactory + * Signature: (Ljava/lang/String;[I[ILjava/lang/String;J)J + */ +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_join_OmniLookupOuterJoinOperatorFactory_createLookupOuterJoinOperatorFactory( + JNIEnv *env, jclass jObj, jstring jProbeTypes, jintArray jProbeOutputCols, jintArray jBuildOutputCols, + jstring jBuildOutputTypes, jlong jHashBuilderOperatorFactory); + +/* + * Class: nova_hetu_omniruntime_operator_window_OmniWindowWithExprOperatorFactory + * Method: createWindowWithExprOperatorFactory + * Signature: (Ljava/lang/String;[I[I[I[I[I[I[III[Ljava/lang/String;Ljava/lang/String;[I[I[I[I[I)J + */ +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_window_OmniWindowWithExprOperatorFactory_createWindowWithExprOperatorFactory( + JNIEnv *env, jobject jObj, jstring jSourceTypes, jintArray jOutputChannels, jintArray jWindowFunction, + jintArray jPartitionChannels, jintArray JPreGroupedChannels, jintArray jSortChannels, jintArray jSortOrder, + jintArray jSortNullFirsts, jint preSortedChannelPrefix, jint expectedPositions, jobjectArray jArgumentKeys, + jstring jWindowFunctionReturnType, jintArray jWindowFrameTypes, jintArray jWindowFrameStartTypes, + jintArray jWindowFrameStartChannels, jintArray jWindowFrameEndTypes, jintArray jWindowFrameEndChannels, + jstring jOperatorConfig); + +/* + * Class: nova_hetu_omniruntime_operator_aggregator_OmniHashAggregationWithExprOperatorFactory + * Method: createHashAggregationWithExprOperatorFactory + * Signature: ([Ljava/lang/String;[Ljava/lang/String;Ljava/lang/String;[I[ILjava/lang/String;ZZ)J + */ +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_aggregator_OmniHashAggregationWithExprOperatorFactory_createHashAggregationWithExprOperatorFactory( + JNIEnv *env, jclass jObj, jobjectArray jGroupByChannel, jobjectArray jAggChannels, jobjectArray jAggChannelsFilter, + jstring jSourceType, jintArray jAggFuncType, jintArray jMaskCols, jobjectArray jOutputType, + jbooleanArray jInputRaws, jbooleanArray jOutputPartials, jstring jOperatorConfig); + +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_aggregator_OmniAggregationWithExprOperatorFactory_createAggregationWithExprOperatorFactory( + JNIEnv *env, jclass jObj, jobjectArray jGroupByChannel, jobjectArray jAggChannels, jobjectArray jAggChannelsFilter, + jstring jSourceType, jintArray jAggFuncType, jintArray jMaskCols, jobjectArray jOutputType, + jbooleanArray jInputRaws, jbooleanArray jOutputPartials, jstring jOperatorConfig); + +/* + * Class: nova_hetu_omniruntime_operator_topn_OmniTopNWithExprOperatorFactory + * Method: createTopNWithExprOperatorFactory + * Signature: (Ljava/lang/String;I[Ljava/lang/String;[I[I)J + */ +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_topn_OmniTopNWithExprOperatorFactory_createTopNWithExprOperatorFactory(JNIEnv *env, + jclass jObj, jstring jSourceTypes, jint jN, jint jOffset, jobjectArray jSortKeys, jintArray jSortAsc, + jintArray jSortNullFirsts, jstring jOperatorConfig); + +/* + * Class: nova_hetu_omniruntime_operator_limit_OmniLimitOperatorFactory + * Method: createLimitOperatorFactory + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_operator_limit_OmniLimitOperatorFactory_createLimitOperatorFactory( + JNIEnv *env, jclass jObj, jint jLimit, jint jOffset); + +/* + * Class: nova_hetu_omniruntime_operator_limit_OmniDistinctLimitOperatorFactory + * Method: createDistinctLimitOperatorFactory + * Signature: (Ljava/lang/String;[IIJ)J + * Note: out put seq as below: + * 1. distinct cols + * 2. normal cols + * 3. hash col(jHashChannel) + * Note: put jHashChannel to -1 if no precomputed hash value for distinct cols + */ +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_limit_OmniDistinctLimitOperatorFactory_createDistinctLimitOperatorFactory( + JNIEnv *env, jclass jObj, jstring jSoureTypes, jintArray jDistinctChannel, jint jHashChannel, jlong jLimit); + +/* + * Class: nova_hetu_omniruntime_operator_join_OmniSmjStreamedTableWithExprOperatorFactory + * Method: createSmjStreamedTableWithExprOperatorFactory + * Signature: (Ljava/lang/String;[Ljava/lang/String;[IILjava/lang/String;)J + */ +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_join_OmniSmjStreamedTableWithExprOperatorFactory_createSmjStreamedTableWithExprOperatorFactory( + JNIEnv *env, jclass jObj, jstring jSourceTypes, jobjectArray jEqualKeyExprs, jintArray jOutputChannels, + jint jJoinType, jstring jFilter, jstring jOperatorConfig); + +/* + * Class: nova_hetu_omniruntime_operator_join_OmniSmjBufferedTableWithExprOperatorFactory + * Method: createSmjBufferedTableWithExprOperatorFactory + * Signature: (Ljava/lang/String;[Ljava/lang/String;[IJ)J + */ +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_join_OmniSmjBufferedTableWithExprOperatorFactory_createSmjBufferedTableWithExprOperatorFactory( + JNIEnv *env, jclass jObj, jstring jSourceTypes, jobjectArray jEqualKeyExprs, jintArray jOutputChannels, + jlong jSmjStreamedTableWithExprOperatorFactory, jstring jOperatorConfig); + +/* + * Class: nova_hetu_omniruntime_operator_join_OmniSmjStreamedTableWithExprOperatorFactoryV3 + * Method: createSmjStreamedTableWithExprOperatorFactoryV3 + * Signature: (Ljava/lang/String;[Ljava/lang/String;[IILjava/lang/String;)J + */ +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_join_OmniSmjStreamedTableWithExprOperatorFactoryV3_createSmjStreamedTableWithExprOperatorFactoryV3( + JNIEnv *env, jclass jObj, jstring jSourceTypes, jobjectArray jEqualKeyExprs, jintArray jOutputChannels, + jint jJoinType, jstring jFilter, jstring jOperatorConfig); + +/* + * Class: nova_hetu_omniruntime_operator_join_OmniSmjBufferedTableWithExprOperatorFactoryV3 + * Method: createSmjBufferedTableWithExprOperatorFactoryV3 + * Signature: (Ljava/lang/String;[Ljava/lang/String;[IJ)J + */ +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_join_OmniSmjBufferedTableWithExprOperatorFactoryV3_createSmjBufferedTableWithExprOperatorFactoryV3( + JNIEnv *env, jclass jObj, jstring jSourceTypes, jobjectArray jEqualKeyExprs, jintArray jOutputChannels, + jlong jSmjStreamedTableWithExprOperatorFactoryV3, jstring jOperatorConfig); + + +/* + * Class: nova_hetu_omniruntime_operator_OmniExprVerify + * Method: exprVerify + * Signature: (Ljava/lang/String;ILjava/lang/String;[Ljava/lang/Object;II)J + */ +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_operator_OmniExprVerify_exprVerify(JNIEnv *env, jclass jObj, + jstring jInputTypes, jint jInputLength, jstring jExpression, jobjectArray jProjections, jint jProjectLength, + jint jParseFormat); + +/* + * Class: nova_hetu_omniruntime_operator_filter_OmniBloomFilterOperatorFactory + * Method: createBloomFilterOperatorFactory + * Signature: ([J[ILjava/lang/String;)J + */ +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_filter_OmniBloomFilterOperatorFactory_createBloomFilterOperatorFactory(JNIEnv *env, + jclass jObj, jint jInputVersion); + +/* + * Class: nova_hetu_omniruntime_operator_topnsort_OmniTopNSortWithExprOperatorFactory + * Method: createTopNSortWithExprOperatorFactory + * Signature: (Ljava/lang/String;IZ[Ljava/lang/String;[Ljava/lang/String;[I[ILjava/lang/String;)J + */ +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_topnsort_OmniTopNSortWithExprOperatorFactory_createTopNSortWithExprOperatorFactory( + JNIEnv *, jclass, jstring, jint, jboolean, jobjectArray, jobjectArray, jintArray, jintArray, jstring); + +/* + * Class: nova_hetu_omniruntime_operator_window_OmniWindowGroupLimitWithExprOperatorFactory + * Method: createWindowGroupLimitWithExprOperatorFactory + * Signature: (Ljava/lang/String;IZ[Ljava/lang/String;[Ljava/lang/String;[I[ILjava/lang/String;)J + */ +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_window_OmniWindowGroupLimitWithExprOperatorFactory_createWindowGroupLimitWithExprOperatorFactory( + JNIEnv *env, jclass jObj, jstring jInputTypes, jint jN, jstring jFuncName, jobjectArray jPartitionKeys, + jobjectArray jSortKeys, jintArray jSortAsc, jintArray jSortNullFirsts, jstring jOperatorConfig); + + +/* + * Class: nova_hetu_omniruntime_operator_join_OmniNestedLoopJoinBuildOperatorFactory + * Method: createNestedLoopJoinBuildOperatorFactory + * Signature: (Ljava/lang/String;IZ[Ljava/lang/String;[Ljava/lang/String;[I[ILjava/lang/String;)J + */ +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_join_OmniNestedLoopJoinBuildOperatorFactory_createNestedLoopJoinBuildOperatorFactory( + JNIEnv *env, jclass jObj, jstring jBuildTypes, jintArray jBuildOutputCols); + +/* + * Class: nova_hetu_omniruntime_operator_join_OmniNestedLoopJoinLookupOperatorFactory + * Method: createNestedLoopJoinLookupOperatorFactory + * Signature: (Ljava/lang/String;IZ[Ljava/lang/String;[Ljava/lang/String;[I[ILjava/lang/String;)J + */ +JNIEXPORT jlong JNICALL +Java_nova_hetu_omniruntime_operator_join_OmniNestedLoopJoinLookupOperatorFactory_createNestedLoopJoinLookupOperatorFactory( + JNIEnv *env, jclass jObj, jint jJoinType, jstring jProbeTypes, jintArray jProbeOutputCols, jstring jFilter, + jlong jNestedLoopJoinBuildOperatorFactory, jstring jOperatorConfig); + +#ifdef __cplusplus +} +#endif +#endif \ No newline at end of file diff --git a/bindings/java/src/main/cpp/src/jni_vector.cpp b/bindings/java/src/main/cpp/src/jni_vector.cpp new file mode 100644 index 0000000..84b9d50 --- /dev/null +++ b/bindings/java/src/main/cpp/src/jni_vector.cpp @@ -0,0 +1,299 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2024. All rights reserved. + * Description: JNI Vector Operations Source File + */ +#include "jni_vector.h" +#include +#include "memory/memory_pool.h" +#include "vector/vector_batch.h" +#include "vector/unsafe_vector.h" +#include "vector/vector_helper.h" +#include "vector/vector.h" +#include "jni_common_def.h" +#include "operator/aggregation/container_vector.h" +#include "type/data_type_serializer.h" +#include "memory/thread_memory_manager.h" + +using namespace omniruntime::vec; +using namespace omniruntime::mem; + +static ALWAYS_INLINE BaseVector *TransformVector(long vectorAddr) +{ + return reinterpret_cast(vectorAddr); +} + +#ifdef TRACE +static void RecordStack(BaseVector *vector, JNIEnv *env) +{ + jstring jstack = (jstring)env->CallStaticObjectMethod(traceUtilCls, traceUtilStackMethodId); + auto stackChars = env->GetStringUTFChars(jstack, JNI_FALSE); + std::string stack(stackChars); + ThreadMemoryTrace *threadMemoryTrace = ThreadMemoryTrace::GetThreadMemoryTrace(); + // replace c++ stack with java stack after vector is created. + threadMemoryTrace->ReplaceVectorTracedLog(reinterpret_cast(vector), stack); + env->ReleaseStringUTFChars(jstack, stackChars); +} +#endif + +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_vector_Vec_newVectorNative(JNIEnv *env, jclass jcls, + jint jValueCount, jint jVectorEncodingId, jint jVectorTypeId, jint jCapacityInBytes) +{ + BaseVector *vector = nullptr; + JNI_METHOD_START + vector = VectorHelper::CreateVector(jVectorEncodingId, jVectorTypeId, jValueCount, jCapacityInBytes); + if (UNLIKELY(vector == nullptr)) { + throw omniruntime::exception::OmniException("CREATE_FLAT_VECTOR_FAILED", + "return a null pointer when creating flat vector"); + } + JNI_METHOD_END(0) +#ifdef TRACE + RecordStack(vector, env); +#endif + return reinterpret_cast(reinterpret_cast(vector)); +} + +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_vector_Vec_newDictionaryVectorNative(JNIEnv *env, jclass jcls, + jlong jDictionaryNativeVector, jintArray jIds, jint size, jint dataTypeId) +{ + BaseVector *dictionaryVector = TransformVector(jDictionaryNativeVector); + jint idsArray[size]; + env->GetIntArrayRegion(jIds, 0, size, idsArray); + jint *ids = idsArray; + BaseVector *vector = nullptr; + JNI_METHOD_START + vector = VectorHelper::CreateDictionaryVector(ids, size, dictionaryVector, dataTypeId); + if (UNLIKELY(vector == nullptr)) { + throw omniruntime::exception::OmniException("CREATE_DICTIONARY_VECTOR_FAILED", + "return a null pointer when creating dictionary vector"); + } + JNI_METHOD_END(0) +#ifdef TRACE + RecordStack(vector, env); +#endif + return reinterpret_cast(reinterpret_cast(vector)); +} + +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_vector_Vec_sliceVectorNative(JNIEnv *env, jclass jcls, + jlong jNativeVector, jint jStartIndex, jint jLength) +{ + BaseVector *nativeVector = TransformVector(jNativeVector); + BaseVector *sliceVector = nullptr; + JNI_METHOD_START + sliceVector = VectorHelper::SliceVector(nativeVector, jStartIndex, jLength); + JNI_METHOD_END(0) +#ifdef TRACE + RecordStack(sliceVector, env); +#endif + return reinterpret_cast(reinterpret_cast(sliceVector)); +} + +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_vector_Vec_copyPositionsNative(JNIEnv *env, jclass jcls, + jlong jNativeVector, jintArray jPositions, jint jOffset, jint jLength) +{ + BaseVector *nativeVector = TransformVector(jNativeVector); + jint positionArray[jLength]; + env->GetIntArrayRegion(jPositions, jOffset, jLength, positionArray); + jint *positions = positionArray; + BaseVector *copyVector = nullptr; + JNI_METHOD_START + copyVector = VectorHelper::CopyPositionsVector(nativeVector, reinterpret_cast(positions), 0, jLength); + JNI_METHOD_END(0) +#ifdef TRACE + RecordStack(copyVector, env); +#endif + return reinterpret_cast(reinterpret_cast(copyVector)); +} + +JNIEXPORT void JNICALL Java_nova_hetu_omniruntime_vector_Vec_freeVectorNative(JNIEnv *env, jclass jcls, + jlong jNativeVector) +{ + BaseVector *nativeVector = TransformVector(jNativeVector); + if (nativeVector == nullptr) { + std::cerr << "free vector native vector is null:" << jNativeVector << std::endl; + } + delete nativeVector; +} + +JNIEXPORT jint JNICALL Java_nova_hetu_omniruntime_vector_Vec_getCapacityInBytesNative(JNIEnv *env, jclass jcls, + jlong jNativeVector) +{ + BaseVector *nativeVector = TransformVector(jNativeVector); + DataTypeId typeId = nativeVector->GetTypeId(); + if (typeId != omniruntime::type::OMNI_VARCHAR && typeId != omniruntime::type::OMNI_CHAR) { + throw omniruntime::exception::OmniException("vector type is no supported", + "the interface only supports varchar/char vector."); + } + auto *varCharVector = reinterpret_cast> *>(nativeVector); + return omniruntime::vec::unsafe::UnsafeStringVector::GetContainer(varCharVector)->GetCapacityInBytes(); +} + +JNIEXPORT jint JNICALL Java_nova_hetu_omniruntime_vector_Vec_getSizeNative(JNIEnv *env, jclass jcls, + jlong jNativeVector) +{ + BaseVector *nativeVector = TransformVector(jNativeVector); + return nativeVector->GetSize(); +} + +JNIEXPORT jint JNICALL Java_nova_hetu_omniruntime_vector_Vec_setSizeNative(JNIEnv *env, jclass jcls, + jlong jNativeVector, jint jSize) +{ + BaseVector *nativeVector = TransformVector(jNativeVector); + if (jSize < 0 || jSize > nativeVector->GetSize()) { + std::cerr << "size is error, the range is[0," << nativeVector->GetSize() << "]" << std::endl; + return jSize; + } + omniruntime::vec::unsafe::UnsafeBaseVector::SetSize(nativeVector, jSize); + return jSize; +} + +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_vector_Vec_getValuesNative(JNIEnv *env, jclass jlcls, + jlong jNativeVector) +{ + BaseVector *nativeVector = TransformVector(jNativeVector); + return reinterpret_cast(VectorHelper::UnsafeGetValues(nativeVector)); +} + +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_vector_Vec_getValueNullsNative(JNIEnv *env, jclass jcls, + jlong jNativeVector) +{ + BaseVector *nativeVector = TransformVector(jNativeVector); + return reinterpret_cast(omniruntime::vec::unsafe::UnsafeBaseVector::GetNulls(nativeVector)); +} + +JNIEXPORT jint JNICALL Java_nova_hetu_omniruntime_vector_ContainerVec_getPositionNative(JNIEnv *env, jclass jcls, + jlong jNativeVector) +{ + ContainerVector *containerVec = reinterpret_cast(jNativeVector); + return containerVec->GetSize(); +} + +JNIEXPORT void JNICALL Java_nova_hetu_omniruntime_vector_ContainerVec_setDataTypesNative(JNIEnv *env, jclass jcls, + jlong jNativeVector, jstring dataTypes) +{ + ContainerVector *containerVec = reinterpret_cast(jNativeVector); + auto dataTypeString = env->GetStringUTFChars(dataTypes, JNI_FALSE); + containerVec->SetDataTypes(omniruntime::type::Deserialize(dataTypeString).Get()); + env->ReleaseStringUTFChars(dataTypes, dataTypeString); +} + +JNIEXPORT jstring JNICALL Java_nova_hetu_omniruntime_vector_ContainerVec_getDataTypesNative(JNIEnv *env, jclass jcls, + jlong jNativeVector) +{ + ContainerVector *containerVec = reinterpret_cast(jNativeVector); + auto &DataTypes = containerVec->GetDataTypes(); + return env->NewStringUTF(Serialize(DataTypes).data()); +} + +JNIEXPORT void JNICALL Java_nova_hetu_omniruntime_vector_Vec_appendVectorNative(JNIEnv *env, jclass jcls, + jlong jNativeVectorDest, jint jOffSet, jlong jNativeVectorSrc, jint jLength) +{ + BaseVector *nativeVectorSrc = TransformVector(jNativeVectorSrc); + BaseVector *nativeVectorDest = TransformVector(jNativeVectorDest); + JNI_METHOD_START + VectorHelper::AppendVector(nativeVectorDest, (int32_t)jOffSet, nativeVectorSrc, (int32_t)jLength); + JNI_METHOD_END() +} + +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_vector_VariableWidthVec_getValueOffsetsNative(JNIEnv *env, + jclass jcls, jlong jNativeVector) +{ + BaseVector *nativeVector = TransformVector(jNativeVector); + auto offsetsAddr = VectorHelper::UnsafeGetOffsetsAddr(nativeVector); + if (UNLIKELY(offsetsAddr == nullptr)) { + throw omniruntime::exception::OmniException("GET_OFFSETS_FAILED", + "return a null pointer when getting offsets address"); + } + return reinterpret_cast(offsetsAddr); +} + +JNIEXPORT void JNICALL Java_nova_hetu_omniruntime_memory_MemoryManager_setGlobalMemoryLimitNative(JNIEnv *env, + jclass jcls, jlong jLimit) +{ + omniruntime::mem::MemoryManager::SetGlobalMemoryLimit(jLimit); +} + +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_memory_MemoryManager_getAllocatedMemoryNative(JNIEnv *env, + jclass jcls) +{ + auto threadMemoryManager = omniruntime::mem::ThreadMemoryManager::GetThreadMemoryManager(); + int64_t accountedMemory = threadMemoryManager->GetThreadAccountedMemory(); + int64_t untrackedMemory = threadMemoryManager->GetUntrackedMemory(); + return accountedMemory + untrackedMemory; +} + +JNIEXPORT void JNICALL Java_nova_hetu_omniruntime_memory_MemoryManager_memoryClearNative(JNIEnv *env, jclass jcls) +{ + auto threadMemoryManager = omniruntime::mem::ThreadMemoryManager::GetThreadMemoryManager(); + threadMemoryManager->Clear(); +} + +JNIEXPORT void JNICALL Java_nova_hetu_omniruntime_memory_MemoryManager_memoryReclamationNative(JNIEnv *env, jclass jcls) +{ + ThreadMemoryTrace *threadMemoryTrace = ThreadMemoryTrace::GetThreadMemoryTrace(); + if (threadMemoryTrace->HasMemoryLeak()) { + threadMemoryTrace->FreeLeakedMemory(); + } +} + +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_vector_VecBatch_newVectorBatchNative(JNIEnv *env, jclass jcls, + jlongArray jVectorAddresses, jint rRowCount) +{ + jlong *vecAddresses = env->GetLongArrayElements(jVectorAddresses, JNI_FALSE); + jsize vecCount = env->GetArrayLength(jVectorAddresses); + VectorBatch *vecBatch = new VectorBatch(rRowCount); + for (int i = 0; i < vecCount; ++i) { + vecBatch->Append(reinterpret_cast(vecAddresses[i])); + } + env->ReleaseLongArrayElements(jVectorAddresses, vecAddresses, JNI_ABORT); + return reinterpret_cast(reinterpret_cast(vecBatch)); +} + +JNIEXPORT void JNICALL Java_nova_hetu_omniruntime_vector_VecBatch_freeVectorBatchNative(JNIEnv *env, jclass jcls, + jlong jVecBatchAddress) +{ + VectorBatch *vecBatch = reinterpret_cast(jVecBatchAddress); + vecBatch->ClearVectors(); + delete vecBatch; +} + +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_vector_DictionaryVec_getDictionaryNative(JNIEnv *env, jclass jcls, + jlong jNativeVector) +{ + BaseVector *nativeVector = TransformVector(jNativeVector); + auto dictionaryAddr = VectorHelper::UnsafeGetDictionary(nativeVector); + if (UNLIKELY(dictionaryAddr == nullptr)) { + throw omniruntime::exception::OmniException("GET_DICTIONARY_NATIVE_FAILED", + "return a null pointer when getting dictionary address"); + } + return reinterpret_cast(dictionaryAddr); +} + +JNIEXPORT jint JNICALL Java_nova_hetu_omniruntime_vector_Vec_getVecEncodingNative(JNIEnv *env, jclass jcls, + jlong jNativeVector) +{ + BaseVector *nativeVector = TransformVector(jNativeVector); + return nativeVector->GetEncoding(); +} + +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_vector_VarcharVec_expandDataCapacity(JNIEnv *env, jclass jcls, + jlong jNativeVector, jint jToCapacityInBytes) +{ + auto nativeVector = reinterpret_cast> *>(jNativeVector); + char *newBuffAddress = + omniruntime::vec::unsafe::UnsafeStringVector::ExpandStringBuffer(nativeVector, jToCapacityInBytes); + return reinterpret_cast(reinterpret_cast(newBuffAddress)); +} + +JNIEXPORT void JNICALL Java_nova_hetu_omniruntime_vector_Vec_setNullFlagNative(JNIEnv *env, jclass jcls, + jlong jNativeVector, jboolean jHasNull) +{ + BaseVector *nativeVector = TransformVector(jNativeVector); + nativeVector->SetNullFlag(jHasNull); +} + +JNIEXPORT jboolean JNICALL Java_nova_hetu_omniruntime_vector_Vec_hasNullNative(JNIEnv *env, jclass jcls, + jlong jNativeVector) +{ + BaseVector *nativeVector = TransformVector(jNativeVector); + return nativeVector->HasNull(); +} diff --git a/bindings/java/src/main/cpp/src/jni_vector.h b/bindings/java/src/main/cpp/src/jni_vector.h new file mode 100644 index 0000000..2b70c96 --- /dev/null +++ b/bindings/java/src/main/cpp/src/jni_vector.h @@ -0,0 +1,208 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description: JNI Vector Operations Header + */ +#ifndef JNI_VECTOR_H +#define JNI_VECTOR_H +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * Class: nova_hetu_omniruntime_vector_Vec + * Method: newVectorNative + * Signature: (IIII)J + */ +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_vector_Vec_newVectorNative(JNIEnv *, jclass, jint, jint, jint, jint); + +/* + * Class: nova_hetu_omniruntime_vector_Vec + * Method: newDictionaryVectorNative + * Signature: (J[III)J + */ +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_vector_Vec_newDictionaryVectorNative(JNIEnv *, jclass, jlong, + jintArray, jint, jint); + +/* + * Class: nova_hetu_omniruntime_vector_Vec + * Method: sliceVectorNative + * Signature: (JII)J + */ +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_vector_Vec_sliceVectorNative(JNIEnv *, jclass, jlong, jint, jint); + +/* + * Class: nova_hetu_omniruntime_vector_Vec + * Method: copyPositionsNative + * Signature: (J[III)J + */ +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_vector_Vec_copyPositionsNative(JNIEnv *, jclass, jlong, jintArray, + jint, jint); + +/* + * Class: nova_hetu_omniruntime_vector_Vec + * Method: freeVectorNative + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_nova_hetu_omniruntime_vector_Vec_freeVectorNative(JNIEnv *, jclass, jlong); + +/* + * Class: nova_hetu_omniruntime_vector_Vec + * Method: getCapacityInBytesNative + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_nova_hetu_omniruntime_vector_Vec_getCapacityInBytesNative(JNIEnv *, jclass, jlong); + +/* + * Class: nova_hetu_omniruntime_vector_Vec + * Method: getSizeNative + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_nova_hetu_omniruntime_vector_Vec_getSizeNative(JNIEnv *, jclass, jlong); + +/* + * Class: nova_hetu_omniruntime_vector_Vec + * Method: setSizeNative + * Signature: (JI)I + */ +JNIEXPORT jint JNICALL Java_nova_hetu_omniruntime_vector_Vec_setSizeNative(JNIEnv *, jclass, jlong, jint); + +/* + * Class: nova_hetu_omniruntime_vector_Vec + * Method: getValuesNative + * Signature: (J)J; + */ +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_vector_Vec_getValuesNative(JNIEnv *, jclass, jlong); + +/* + * Class: nova_hetu_omniruntime_vector_Vec + * Method: getValueNullsNative + * Signature: (J)J; + */ +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_vector_Vec_getValueNullsNative(JNIEnv *, jclass, jlong); + +/* + * Class: nova_hetu_omniruntime_vector_Vec + * Method: appendVectorNative + * Signature: (JIJI)V + */ +JNIEXPORT void JNICALL Java_nova_hetu_omniruntime_vector_Vec_appendVectorNative(JNIEnv *, jclass, jlong, jint, jlong, + jint); + +/* + * Class: nova_hetu_omniruntime_vector_ContainerVec + * Method: getPositionNative + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_nova_hetu_omniruntime_vector_ContainerVec_getPositionNative(JNIEnv *, jclass, jlong); + +/* + * Class: nova_hetu_omniruntime_vector_ContainerVec + * Method: setDataTypesNative + * Signature: (JLjava/lang/Sting;)V; + */ +JNIEXPORT void JNICALL Java_nova_hetu_omniruntime_vector_ContainerVec_setDataTypesNative(JNIEnv *, jclass, jlong, + jstring); + +/* + * Class: nova_hetu_omniruntime_vector_ContainerVec + * Method: getDataTypesNative + * Signature: (J)Ljava/lang/Sting; + */ +JNIEXPORT jstring JNICALL Java_nova_hetu_omniruntime_vector_ContainerVec_getDataTypesNative(JNIEnv *, jclass, jlong); + +/* + * Class: nova_hetu_omniruntime_vector_VariableWidthVec + * Method: getValueOffsetsNative + * Signature: (J)J; + */ +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_vector_VariableWidthVec_getValueOffsetsNative(JNIEnv *, jclass, + jlong); + +/* + * Class: Java_nova_hetu_omniruntime_memory_MemoryManager + * Method: setGlobalMemoryLimitNative + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_nova_hetu_omniruntime_memory_MemoryManager_setGlobalMemoryLimitNative(JNIEnv *, + jclass, jlong); + +/* + * Class: Java_nova_hetu_omniruntime_memory_MemoryManager + * Method: getAllocatedMemoryNative + * Signature: (V)J + */ +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_memory_MemoryManager_getAllocatedMemoryNative(JNIEnv *, + jclass); + +/* + * Class: Java_nova_hetu_omniruntime_memory_MemoryManager + * Method: memoryClearNative + * Signature: (V)V + */ +JNIEXPORT void JNICALL Java_nova_hetu_omniruntime_memory_MemoryManager_memoryClearNative(JNIEnv *env, + jclass jcls); + +/* + * Class: Java_nova_hetu_omniruntime_memory_MemoryManager + * Method: memoryReclamationNative + * Signature: (V)V + */ +JNIEXPORT void JNICALL Java_nova_hetu_omniruntime_memory_MemoryManager_memoryReclamationNative(JNIEnv *env, + jclass jcls); + +/* + * Class: nova_hetu_omniruntime_vector_VecBatch + * Method: newVectorBatchNative + * Signature: ([JI)J + */ +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_vector_VecBatch_newVectorBatchNative(JNIEnv *, jclass, jlongArray, + jint); + +/* + * Class: nova_hetu_omniruntime_vector_VecBatch + * Method: freeVectorBatchNative + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_nova_hetu_omniruntime_vector_VecBatch_freeVectorBatchNative(JNIEnv *, jclass, jlong); + +/* + * Class: nova_hetu_omniruntime_vector_DictionaryVec + * Method: getDictionaryNative + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_vector_DictionaryVec_getDictionaryNative(JNIEnv *, jclass, jlong); + +/* + * Class: nova_hetu_omniruntime_vector_Vec + * Method: getVecEncodingNative + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_nova_hetu_omniruntime_vector_Vec_getVecEncodingNative(JNIEnv *, jclass, jlong); +/* + * Class: nova_hetu_omniruntime_vector_VarcharVec + * Method: expandDataCapacity + * Signature: (JI)J; + */ +JNIEXPORT jlong JNICALL Java_nova_hetu_omniruntime_vector_VarcharVec_expandDataCapacity(JNIEnv *, jclass, + jlong, jint); +/* + * Class: nova_hetu_omniruntime_vector_Vec + * Method: setNullFlagNative + * Signature: (JZ)V + */ +JNIEXPORT void JNICALL Java_nova_hetu_omniruntime_vector_Vec_setNullFlagNative(JNIEnv *, jclass, jlong, jboolean); + +/* + * Class: nova_hetu_omniruntime_vector_Vec + * Method: hasNullNative + * Signature: (J)Z + */ +JNIEXPORT jboolean JNICALL Java_nova_hetu_omniruntime_vector_Vec_hasNullNative(JNIEnv *, jclass, jlong); + + +#ifdef __cplusplus +} +#endif +#endif diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/OmniLibs.java b/bindings/java/src/main/java/nova/hetu/omniruntime/OmniLibs.java new file mode 100644 index 0000000..4b86ed2 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/OmniLibs.java @@ -0,0 +1,79 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2024. All rights reserved. + */ + +package nova.hetu.omniruntime; + +import nova.hetu.omniruntime.utils.NativeLog; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; + +/** + * load libomni_runtime.so. + * + * @since 2021-07-17 + */ +public class OmniLibs { + private static volatile OmniLibs instance; + + private static final String LIBRARY_NAME = "boostkit-omniop-java-binding-1.9.0-aarch64"; + + private static final Logger LOG = LoggerFactory.getLogger(OmniLibs.class); + + private static final int BUFFER_SIZE = 1024; + + private OmniLibs() { + File tempFile = null; + try { + String nativeLibraryPath = File.separator + System.mapLibraryName(LIBRARY_NAME); + tempFile = File.createTempFile(LIBRARY_NAME, ".so"); + try (InputStream in = OmniLibs.class.getResourceAsStream(nativeLibraryPath); + FileOutputStream fos = new FileOutputStream(tempFile)) { + int i; + byte[] buf = new byte[BUFFER_SIZE]; + while ((i = in.read(buf)) != -1) { + fos.write(buf, 0, i); + } + System.load(tempFile.getCanonicalPath()); + } + } catch (IOException e) { + LOG.warn("fail to load library from Jar!errmsg:{}", e.getMessage()); + } finally { + if (tempFile != null) { + tempFile.deleteOnExit(); + } + } + } + + public static OmniLibs getInstance() { + if (instance == null) { + synchronized (OmniLibs.class) { + if (instance == null) { + instance = new OmniLibs(); + NativeLog.getInstance(); + } + } + } + return instance; + } + + /** + * Loading the dll. + */ + public static void load() { + getInstance(); + } + + /** + * Geting the version + * + * @return the version string + */ + public static native String getVersion(); +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/constants/BuildSide.java b/bindings/java/src/main/java/nova/hetu/omniruntime/constants/BuildSide.java new file mode 100644 index 0000000..cbb919e --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/constants/BuildSide.java @@ -0,0 +1,38 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ + +package nova.hetu.omniruntime.constants; + +/** + * The Build Side. + * + * @since 2025-01-14 + */ +public class BuildSide extends Constant { + /** + * The constant BUILD_UNKNOWN. The representative doesn't need to know + */ + public static final BuildSide BUILD_UNKNOWN = new BuildSide(0); + + /** + * The constant BUILD_LEFT. + */ + public static final BuildSide BUILD_LEFT = new BuildSide(1); + + /** + * The constant BUILD_RIGHT. + */ + public static final BuildSide BUILD_RIGHT = new BuildSide(2); + + private static final long serialVersionUID = -4047841645954651422L; + + /** + * Instantiates a new Build Side. + * + * @param value the value + */ + public BuildSide(int value) { + super(value); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/constants/Constant.java b/bindings/java/src/main/java/nova/hetu/omniruntime/constants/Constant.java new file mode 100644 index 0000000..71cefc5 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/constants/Constant.java @@ -0,0 +1,58 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.constants; + +import java.io.Serializable; +import java.util.Objects; + +/** + * The type Constant. The abstract class of all enum constant class + * + * @since 2021-06-30 + */ +public abstract class Constant implements Serializable { + private static final long serialVersionUID = -2589766491699675794L; + + private final int value; + + /** + * Instantiates a new Constant. + * + * @param value the value + */ + public Constant(int value) { + this.value = value; + } + + /** + * Gets value. + * + * @return the value + */ + public int getValue() { + return value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + return ((Constant) obj).getValue() == value; + } + + @Override + public int hashCode() { + return Objects.hash(value); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/constants/ConstantHelper.java b/bindings/java/src/main/java/nova/hetu/omniruntime/constants/ConstantHelper.java new file mode 100644 index 0000000..d824993 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/constants/ConstantHelper.java @@ -0,0 +1,27 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.constants; + +import java.util.Arrays; + +/** + * The type Constant helper. + * + * @since 2021-06-30 + */ +public class ConstantHelper { + private ConstantHelper() { + } + + /** + * To native constants int [ ]. + * + * @param constants the constants + * @return the int [ ] + */ + public static int[] toNativeConstants(Constant[] constants) { + return Arrays.stream(constants).map(Constant::getValue).mapToInt(Integer::intValue).toArray(); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/constants/FunctionType.java b/bindings/java/src/main/java/nova/hetu/omniruntime/constants/FunctionType.java new file mode 100644 index 0000000..fce9337 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/constants/FunctionType.java @@ -0,0 +1,98 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.constants; + +/** + * The type Agg type. + * + * @since 2021-06-30 + */ +public class FunctionType extends Constant { + /** + * The constant OMNI_AGGREGATION_TYPE_SUM. + */ + public static final FunctionType OMNI_AGGREGATION_TYPE_SUM = new FunctionType(0); + + /** + * The constant OMNI_AGGREGATION_TYPE_COUNT_COLUMN. + */ + public static final FunctionType OMNI_AGGREGATION_TYPE_COUNT_COLUMN = new FunctionType(1); + + /** + * The constant OMNI_AGGREGATION_TYPE_COUNT_ALL. + */ + public static final FunctionType OMNI_AGGREGATION_TYPE_COUNT_ALL = new FunctionType(2); + + /** + * The constant OMNI_AGGREGATION_TYPE_AVG. + */ + public static final FunctionType OMNI_AGGREGATION_TYPE_AVG = new FunctionType(3); + + /** + * The constant OMNI_AGGREGATION_TYPE_SAMP. + */ + public static final FunctionType OMNI_AGGREGATION_TYPE_SAMP = new FunctionType(4); + + /** + * The constant OMNI_AGGREGATION_TYPE_MAX. + */ + public static final FunctionType OMNI_AGGREGATION_TYPE_MAX = new FunctionType(5); + + /** + * The constant OMNI_AGGREGATION_TYPE_MIN. + */ + public static final FunctionType OMNI_AGGREGATION_TYPE_MIN = new FunctionType(6); + + /** + * The constant OMNI_AGGREGATION_TYPE_DNV. + */ + public static final FunctionType OMNI_AGGREGATION_TYPE_DNV = new FunctionType(7); + + /** + * The constant OMNI_AGGREGATION_TYPE_FIRST_IGNORENULL. + */ + public static final FunctionType OMNI_AGGREGATION_TYPE_FIRST_IGNORENULL = new FunctionType(8); + + /** + * The constant OMNI_AGGREGATION_TYPE_FIRST_INCLUDENULL. + */ + public static final FunctionType OMNI_AGGREGATION_TYPE_FIRST_INCLUDENULL = new FunctionType(9); + + /** + * The constant OMNI_AGGREGATION_TYPE_INVALID. + */ + public static final FunctionType OMNI_AGGREGATION_TYPE_INVALID = new FunctionType(10); + + /** + * The constant OMNI_WINDOW_TYPE_ROW_NUMBER. + */ + public static final FunctionType OMNI_WINDOW_TYPE_ROW_NUMBER = new FunctionType(11); + + /** + * The constant OMNI_WINDOW_TYPE_RANK. + */ + public static final FunctionType OMNI_WINDOW_TYPE_RANK = new FunctionType(12); + + /** + * The constant OMNI_AGGREGATION_TYPE_TRY_SUM. + */ + public static final FunctionType OMNI_AGGREGATION_TYPE_TRY_SUM = new FunctionType(13); + + /** + * The constant OMNI_AGGREGATION_TYPE_TRY_AVG. + */ + public static final FunctionType OMNI_AGGREGATION_TYPE_TRY_AVG = new FunctionType(14); + + private static final long serialVersionUID = 5337378607473315604L; + + /** + * Instantiates a new Agg type. + * + * @param value the value + */ + public FunctionType(int value) { + super(value); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/constants/JoinType.java b/bindings/java/src/main/java/nova/hetu/omniruntime/constants/JoinType.java new file mode 100644 index 0000000..d978f0e --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/constants/JoinType.java @@ -0,0 +1,58 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.constants; + +/** + * The Join type. + * + * @since 2021-06-30 + */ +public class JoinType extends Constant { + /** + * The constant OMNI_JOIN_TYPE_INNER. + */ + public static final JoinType OMNI_JOIN_TYPE_INNER = new JoinType(0); + + /** + * The constant OMNI_JOIN_TYPE_LEFT. + */ + public static final JoinType OMNI_JOIN_TYPE_LEFT = new JoinType(1); + + /** + * The constant OMNI_JOIN_TYPE_RIGHT. + */ + public static final JoinType OMNI_JOIN_TYPE_RIGHT = new JoinType(2); + + /** + * The constant OMNI_JOIN_TYPE_FULL. + */ + public static final JoinType OMNI_JOIN_TYPE_FULL = new JoinType(3); + + /** + * The constant OMNI_JOIN_TYPE_LEFT_SEMI. + */ + public static final JoinType OMNI_JOIN_TYPE_LEFT_SEMI = new JoinType(4); + + /** + * The constant OMNI_JOIN_TYPE_LEFT_ANTI. + */ + public static final JoinType OMNI_JOIN_TYPE_LEFT_ANTI = new JoinType(5); + + /** + * The constant OMNI_JOIN_TYPE_EXISTENCE. + */ + public static final JoinType OMNI_JOIN_TYPE_EXISTENCE = new JoinType(6); + + private static final long serialVersionUID = -4086671645951741450L; + + /** + * Instantiates a new Join type. + * + * @param value the value + */ + public JoinType(int value) { + super(value); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/constants/OmniWindowFrameBoundType.java b/bindings/java/src/main/java/nova/hetu/omniruntime/constants/OmniWindowFrameBoundType.java new file mode 100644 index 0000000..a9d381a --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/constants/OmniWindowFrameBoundType.java @@ -0,0 +1,48 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.constants; + +/** + * The type frame bound type for window operator. + * + * @since 2022-04-15 + */ +public class OmniWindowFrameBoundType extends Constant { + /** + * The constant OMNI_FRAME_BOUND_UNBOUNDED_PRECEDING. + */ + public static final OmniWindowFrameBoundType OMNI_FRAME_BOUND_UNBOUNDED_PRECEDING = new OmniWindowFrameBoundType(0); + + /** + * The constant OMNI_FRAME_BOUND_PRECEDING. + */ + public static final OmniWindowFrameBoundType OMNI_FRAME_BOUND_PRECEDING = new OmniWindowFrameBoundType(1); + + /** + * The constant OMNI_FRAME_BOUND_CURRENT_ROW. + */ + public static final OmniWindowFrameBoundType OMNI_FRAME_BOUND_CURRENT_ROW = new OmniWindowFrameBoundType(2); + + /** + * The constant OMNI_FRAME_BOUND_FOLLOWING. + */ + public static final OmniWindowFrameBoundType OMNI_FRAME_BOUND_FOLLOWING = new OmniWindowFrameBoundType(3); + + /** + * The constant OMNI_FRAME_BOUND_UNBOUNDED_FOLLOWING. + */ + public static final OmniWindowFrameBoundType OMNI_FRAME_BOUND_UNBOUNDED_FOLLOWING = new OmniWindowFrameBoundType(4); + + private static final long serialVersionUID = 3646147886114670835L; + + /** + * Instantiates a new frame bound type. + * + * @param value the value + */ + public OmniWindowFrameBoundType(int value) { + super(value); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/constants/OmniWindowFrameType.java b/bindings/java/src/main/java/nova/hetu/omniruntime/constants/OmniWindowFrameType.java new file mode 100644 index 0000000..458bcac --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/constants/OmniWindowFrameType.java @@ -0,0 +1,33 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.constants; + +/** + * The type window frame type. + * + * @since 2022-04-15 + */ +public class OmniWindowFrameType extends Constant { + /** + * The constant OMNI_FRAME_TYPE_RANGE. + */ + public static final OmniWindowFrameType OMNI_FRAME_TYPE_RANGE = new OmniWindowFrameType(0); + + /** + * The constant OMNI_FRAME_TYPE_ROWS. + */ + public static final OmniWindowFrameType OMNI_FRAME_TYPE_ROWS = new OmniWindowFrameType(1); + + private static final long serialVersionUID = -3453499979001168717L; + + /** + * Instantiates a new window frame type. + * + * @param value the value + */ + public OmniWindowFrameType(int value) { + super(value); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/constants/OperatorType.java b/bindings/java/src/main/java/nova/hetu/omniruntime/constants/OperatorType.java new file mode 100644 index 0000000..e57c6e9 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/constants/OperatorType.java @@ -0,0 +1,108 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.constants; + +/** + * OperatorType + * + * @since 2022-12-19 + */ +public class OperatorType extends Constant { + /** + * The constant OMNI_FILTER_AND_PROJECT. + */ + public static final OperatorType OMNI_FILTER_AND_PROJECT = new OperatorType(0); + + /** + * The constant OMNI_PROJECT. + */ + public static final OperatorType OMNI_PROJECT = new OperatorType(1); + + /** + * The constant OMNI_LIMIT. + */ + public static final OperatorType OMNI_LIMIT = new OperatorType(2); + + /** + * The constant OMNI_DISTINCT_LIMIT. + */ + public static final OperatorType OMNI_DISTINCT_LIMIT = new OperatorType(3); + + /** + * The constant OMNI_SORT. + */ + public static final OperatorType OMNI_SORT = new OperatorType(4); + + /** + * The constant OMNI_TOPN. + */ + public static final OperatorType OMNI_TOPN = new OperatorType(5); + + /** + * The constant OMNI_AGGREGATION. + */ + public static final OperatorType OMNI_AGGREGATION = new OperatorType(6); + + /** + * The constant OMNI_HASH_AGGREGATION. + */ + public static final OperatorType OMNI_HASH_AGGREGATION = new OperatorType(7); + + /** + * The constant OMNI_WINDOW. + */ + public static final OperatorType OMNI_WINDOW = new OperatorType(8); + + /** + * The constant OMNI_HASH_BUILDER. + */ + public static final OperatorType OMNI_HASH_BUILDER = new OperatorType(9); + + /** + * The constant OMNI_LOOKUP_JOIN. + */ + public static final OperatorType OMNI_LOOKUP_JOIN = new OperatorType(10); + + /** + * The constant OMNI_LOOKUP_OUTER_JOIN. + */ + public static final OperatorType OMNI_LOOKUP_OUTER_JOIN = new OperatorType(11); + + /** + * The constant OMNI_SMJ_BUFFER. + */ + public static final OperatorType OMNI_SMJ_BUFFER = new OperatorType(12); + + /** + * The constant OMNI_SMJ_STREAM. + */ + public static final OperatorType OMNI_SMJ_STREAM = new OperatorType(13); + + /** + * The constant OMNI_PARTITIONED_OUTPUT. + */ + public static final OperatorType OMNI_PARTITIONED_OUTPUT = new OperatorType(14); + + /** + * The constant OMNI_UNION. + */ + public static final OperatorType OMNI_UNION = new OperatorType(15); + + /** + * The constant OMNI_FUSION. + */ + public static final OperatorType OMNI_FUSION = new OperatorType(16); + + private static final long serialVersionUID = -4350951859220483149L; + + /** + * Instantiates a new operator type. + * + * @param value the value + */ + public OperatorType(int value) { + super(value); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/constants/Status.java b/bindings/java/src/main/java/nova/hetu/omniruntime/constants/Status.java new file mode 100644 index 0000000..a0f21c7 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/constants/Status.java @@ -0,0 +1,38 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.constants; + +/** + * The type Status. + * + * @since 2021-06-30 + */ +public class Status extends Constant { + /** + * The constant OMNI_STATUS_NORMAL. + */ + public static final Status OMNI_STATUS_NORMAL = new Status(0); + + /** + * The constant OMNI_STATUS_ERROR. + */ + public static final Status OMNI_STATUS_ERROR = new Status(-1); + + /** + * The constant OMNI_STATUS_FINISHED. + */ + public static final Status OMNI_STATUS_FINISHED = new Status(1); + + private static final long serialVersionUID = -3424552555224669902L; + + /** + * Instantiates a new Status. + * + * @param value the value + */ + public Status(int value) { + super(value); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/memory/MemoryManager.java b/bindings/java/src/main/java/nova/hetu/omniruntime/memory/MemoryManager.java new file mode 100644 index 0000000..bc6c0fb --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/memory/MemoryManager.java @@ -0,0 +1,80 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.memory; + +import nova.hetu.omniruntime.OmniLibs; +import nova.hetu.omniruntime.utils.ParseUtil; +import sun.misc.VM; + +/** + * memory manager. + * + * @since 2023-01-17 + */ +public class MemoryManager implements AutoCloseable { + /** + * -1 means no memory limit + */ + public static final long UNLIMITED = -1; + + static { + OmniLibs.load(); + } + + /** + * set the size of the memory that can be used by the root allocator, + * this can limit memory usage at the process level + */ + public static void setGlobalMemoryLimit() { + // parse environment variable OMNI_OFFHEAP_MEMORY_SIZE + String memorySize = System.getenv("OMNI_OFFHEAP_MEMORY_SIZE"); + long rootLimit = memorySize == null ? VM.maxDirectMemory() : ParseUtil.parserMemoryParameters(memorySize); + // the off heap memory from director or environment variable, set global memory limit + setGlobalMemoryLimit(rootLimit); + } + + /** + * set global memory limit about off-heap + * + * @param limit the number of global memory limit about off-heap + * */ + public static void setGlobalMemoryLimit(long limit) { + setGlobalMemoryLimitNative(limit); + } + + /** + * get allocated memory of current allocator + * + * @return allocated memory in bytes + */ + public long getAllocatedMemory() { + return getAllocatedMemoryNative(); + } + + /** + * clear memory of current task and current executor + * */ + public static void clearMemory() { + memoryClearNative(); + } + + /** + * Reclaim memory of current task if memory leak exists + * */ + public static void reclaimMemory() { + memoryReclamationNative(); + } + + @Override + public void close() {} + + private static native void setGlobalMemoryLimitNative(long limit); + + private static native long getAllocatedMemoryNative(); + + private static native long memoryClearNative(); + + private static native long memoryReclamationNative(); +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/OmniExprVerify.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/OmniExprVerify.java new file mode 100644 index 0000000..781b9e5 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/OmniExprVerify.java @@ -0,0 +1,37 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.operator; + +import nova.hetu.omniruntime.OmniLibs; + +/** + * To verify if expr is supported before codegen + * + * @since 2022-05-16 + */ +public class OmniExprVerify { + static { + OmniLibs.load(); + } + + private static native long exprVerify(String inputTypes, int inputLength, String expression, Object[] projections, + int projectLength, int parseFormat); + + /** + * exprVerifyNative + * + * @param inputTypes the input types + * @param inputLength the length of input types + * @param filterExpr filter expression + * @param projections a set of projection expressions + * @param projectLength the length of projection expressions + * @param parseFormat json or string + * @return if expr is supported + */ + public long exprVerifyNative(String inputTypes, int inputLength, String filterExpr, Object[] projections, + int projectLength, int parseFormat) { + return exprVerify(inputTypes, inputLength, filterExpr, projections, projectLength, parseFormat); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/OmniOperator.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/OmniOperator.java new file mode 100644 index 0000000..c925862 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/OmniOperator.java @@ -0,0 +1,189 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.operator; + +import static nova.hetu.omniruntime.constants.Status.OMNI_STATUS_NORMAL; + +import nova.hetu.omniruntime.vector.VecBatch; + +import java.util.Iterator; + +/** + * The type Omni operator. + * + * @since 2021-06-30 + */ +public final class OmniOperator implements AutoCloseable { + /** + * The Native operator. + */ + protected final long nativeOperator; + + private VecBatchIterator outputIterator; + + /** + * Instantiates a new Omni operator. + * + * @param nativeOperator the native operator + */ + protected OmniOperator(long nativeOperator) { + this.nativeOperator = nativeOperator; + } + + // addInput + private static native int addInputNative(long nativeOperator, long nativeVectorBatch); + + // getOutput + private static native OmniResults getOutputNative(long nativeOperator); + + // close + private static native void closeNative(long nativeOperator); + + // getSpilledBytes + private static native long getSpilledBytesNative(long nativeOperator); + + // getMetricsInfo + private static native long[] getMetricsInfoNative(long nativeOperator); + + // getHashMapUniqueKeys called by the adaptive partial hashagg optimization + private static native long getHashMapUniqueKeysNative(long nativeOperator); + + // called by the adaptive partial hashagg optimization + private static native VecBatch alignSchemaNative(long nativeOperator, long inputVecBatchNative); + + + /** + * Add input. + * + * @param vecBatch the vec batch + * @return the int + */ + public int addInput(VecBatch vecBatch) { + return addInputNative(nativeOperator, vecBatch.getNativeVectorBatch()); + } + + /** + * Gets output. + * + * @return the output + */ + public Iterator getOutput() { + if (outputIterator == null) { + outputIterator = new VecBatchIterator(); + } + outputIterator.reset(); + return outputIterator; + } + + /** + * Close native operator. + */ + public void close() { + closeNative(nativeOperator); + } + + /** + * Get spill size. + * + * @return the spilled size + */ + public long getSpilledBytes() { + return getSpilledBytesNative(nativeOperator); + } + + /** + * Get all Metrics info. + * + * @return the metrics info array + */ + public long[] getMetricsInfo() { + return getMetricsInfoNative(nativeOperator); + } + + /** + * Get the number of hashmap unique key. + * + * @return the unique key number + */ + public long getHashMapUniqueKeys() { + return getHashMapUniqueKeysNative(nativeOperator); + } + + /** + * The input vecBatch is aligned based on the operator schema. + * + * @param inputVecBatch the input vec batch + * @return aligned vecBatch + */ + public VecBatch alignSchema(VecBatch inputVecBatch) { + return alignSchemaNative(nativeOperator, inputVecBatch.getNativeVectorBatch()); + } + + private class VecBatchIterator implements Iterator { + private boolean hasNext; + + private OmniResults results; + + private VecBatch next; + + /** + * Instantiates a new Vec batch iterator. + */ + public VecBatchIterator() { + resetIterator(); + advanced(); + hasNext = true; + } + + public void reset() { + hasNext = true; + } + + @Override + public boolean hasNext() { + if (!hasNext) { + return false; + } + // if it first, the results is null, + // or index reach the count of vector batches but it don't finished, + // then advanced(). + if (results == null || (next == results.getVecBatch() && !isFinished())) { + resetIterator(); + advanced(); + } + + // after advanced(), if results is still null, + // or vectorBatch hash been pulled, or vecBatch is null + // means there is no more data. + if (results == null || next == results.getVecBatch()) { + resetIterator(); + hasNext = false; + return false; + } + + hasNext = true; + return true; + } + + @Override + public VecBatch next() { + next = results.getVecBatch(); + return next; + } + + private void resetIterator() { + results = null; + next = null; + } + + private void advanced() { + results = getOutputNative(nativeOperator); + } + + private boolean isFinished() { + return !OMNI_STATUS_NORMAL.equals(results.getStatus()); + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/OmniOperatorFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/OmniOperatorFactory.java new file mode 100644 index 0000000..c9b166c --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/OmniOperatorFactory.java @@ -0,0 +1,95 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.operator; + +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import com.google.common.util.concurrent.UncheckedExecutionException; + +import nova.hetu.omniruntime.OmniLibs; +import nova.hetu.omniruntime.utils.OmniRuntimeException; + +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; + +/** + * The type Omni operator factory. + * + * @param the type parameter + * @since 2021-06-30 + */ +public abstract class OmniOperatorFactory { + private static final Cache FACTORY_CACHE = CacheBuilder.newBuilder() + .expireAfterAccess(24, TimeUnit.HOURS).maximumSize(100000).build(); + + static { + OmniLibs.load(); + } + + private final long nativeOperatorFactory; + + private OmniOperatorFactoryContext context; + + /** + * Instantiates a new Omni operator factory. + * + * @param context the context + */ + public OmniOperatorFactory(OmniOperatorFactoryContext context) { + try { + if (context.isNeedCache()) { + nativeOperatorFactory = FACTORY_CACHE.get(context, () -> createNativeOperatorFactory((T) context)); + } else { + nativeOperatorFactory = createNativeOperatorFactory((T) context); + } + this.context = context; + } catch (ExecutionException e) { + throw new RuntimeException("Get operator factory instance failed."); + } catch (UncheckedExecutionException e) { + throw new OmniRuntimeException(e.getCause().getMessage()); + } + } + + // createOperator + private static native long createOperatorNative(long factoryAddress); + + /** + * Gets native operator factory. + * + * @return the native operator factory + */ + public long getNativeOperatorFactory() { + return nativeOperatorFactory; + } + + /** + * Create operator omni operator. + * + * @return the omni operator + */ + public OmniOperator createOperator() { + long nativeOperator = createOperatorNative(nativeOperatorFactory); + return new OmniOperator(nativeOperator); + } + + /** + * Create native operator factory long. + * + * @param context the context + * @return the long + */ + protected abstract long createNativeOperatorFactory(T context); + + /** + * release operator factory + */ + public void close() { + if (!context.isNeedCache()) { + closeNativeOperatorFactory(nativeOperatorFactory); + } + } + + private static native void closeNativeOperatorFactory(long nativeOperatorFactory); +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/OmniOperatorFactoryContext.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/OmniOperatorFactoryContext.java new file mode 100644 index 0000000..5c9f04a --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/OmniOperatorFactoryContext.java @@ -0,0 +1,55 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.operator; + +/** + * The type Omni operator factory context. + * + * @since 2021-06-30 + */ +public abstract class OmniOperatorFactoryContext { + /** + * Switch for configuring factory cache defaults + */ + private static boolean defaultNeedCacheValue = true; + + /** + * Whether the omni operator factory needs to be cached. + */ + private boolean isNeedCache = defaultNeedCacheValue; + + /** + * Instantiates a new Omni operator factory context. + */ + public OmniOperatorFactoryContext() { + } + + /** + * Interface for setting default values for engine initialization + * + * @param value enable op factory cache + */ + public static void setDefaultNeedCacheValue(boolean value) { + defaultNeedCacheValue = value; + } + + /** + * Get the flag needCache whether the omni operator factory needs to be cached. + * + * @return the flag needCache + */ + public boolean isNeedCache() { + return isNeedCache; + } + + /** + * Set the flag needCache whether the omni operator factory needs to be cached. + * + * @param isNeedCache the flag needCache + */ + public void setNeedCache(boolean isNeedCache) { + this.isNeedCache = isNeedCache; + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/OmniResults.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/OmniResults.java new file mode 100644 index 0000000..ab2154b --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/OmniResults.java @@ -0,0 +1,58 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.operator; + +import nova.hetu.omniruntime.constants.Status; +import nova.hetu.omniruntime.vector.VecBatch; + +import java.io.Closeable; + +/** + * The type Omni results. + * + * @since 2021-06-30 + */ +public class OmniResults implements Closeable { + private final VecBatch vecBatch; + + private final Status status; + + /** + * Instantiates a new Omni results. + * + * @param vecBatch the vec batch + * @param status the status + */ + public OmniResults(VecBatch vecBatch, int status) { + this.vecBatch = vecBatch; + this.status = new Status(status); + } + + /** + * Get vec batches vec batch [ ]. + * + * @return the vec batch [ ] + */ + public VecBatch getVecBatch() { + return vecBatch; + } + + /** + * Gets status. + * + * @return the status + */ + public Status getStatus() { + return status; + } + + @Override + public void close() { + if (vecBatch != null) { + vecBatch.releaseAllVectors(); + vecBatch.close(); + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/OmniRowResults.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/OmniRowResults.java new file mode 100644 index 0000000..f9e7d97 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/OmniRowResults.java @@ -0,0 +1,57 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.operator; + +import nova.hetu.omniruntime.constants.Status; +import nova.hetu.omniruntime.vector.RowBatch; + +import java.io.Closeable; + +/** + * The type Omni results. + * + * @since 2024-05-16 + */ +public class OmniRowResults implements Closeable { + private final RowBatch rowBatch; + + private final Status status; + + /** + * Instantiates a new Omni results. + * + * @param rowBatch the vec batch + * @param status the status + */ + public OmniRowResults(RowBatch rowBatch, int status) { + this.rowBatch = rowBatch; + this.status = new Status(status); + } + + /** + * Get vec batches vec batch [ ]. + * + * @return the vec batch [ ] + */ + public RowBatch getRowBatch() { + return rowBatch; + } + + /** + * Gets status. + * + * @return the status + */ + public Status getStatus() { + return status; + } + + @Override + public void close() { + if (rowBatch != null) { + rowBatch.close(); + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/aggregator/OmniAggregationOperatorFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/aggregator/OmniAggregationOperatorFactory.java new file mode 100644 index 0000000..bba1dc8 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/aggregator/OmniAggregationOperatorFactory.java @@ -0,0 +1,146 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.aggregator; + +import static java.util.Objects.requireNonNull; +import static nova.hetu.omniruntime.constants.ConstantHelper.toNativeConstants; + +import nova.hetu.omniruntime.constants.FunctionType; +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.OmniOperatorFactoryContext; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DataTypeSerializer; + +import java.util.Arrays; +import java.util.Objects; + +/** + * The type Omni aggregation operator factory. + * + * @since 2021-06-30 + */ +public class OmniAggregationOperatorFactory extends OmniOperatorFactory { + /** + * Instantiates a new Omni aggregation operator factory. + * + * @param sourceTypes the aggregation source types + * @param aggFunctionTypes the aggregation function types + * @param aggInputChannels the aggregation function input channels + * @param maskChannels mask chennels array for aggregetions + * @param aggOutputTypes the aggregation output types + * @param isInputRaw the input raw + * @param isOutputPartial the output partial + * @param operatorConfig the operator config + */ + public OmniAggregationOperatorFactory(DataType[] sourceTypes, FunctionType[] aggFunctionTypes, + int[] aggInputChannels, int[] maskChannels, DataType[] aggOutputTypes, boolean isInputRaw, + boolean isOutputPartial, OperatorConfig operatorConfig) { + super(new FactoryContext(sourceTypes, aggFunctionTypes, aggInputChannels, maskChannels, aggOutputTypes, + isInputRaw, isOutputPartial, operatorConfig)); + } + + /** + * Instantiates a new Omni aggregation operator factory with default operator + * config. + * + * @param sourceTypes the aggregation source types + * @param aggFunctionTypes the aggregation function types + * @param aggInputChannels the aggregation function input channels + * @param maskChannels mask chennels array for aggregetions + * @param aggOutputTypes the aggregation output types + * @param isInputRaw the input raw + * @param isOutputPartial the output partial + */ + public OmniAggregationOperatorFactory(DataType[] sourceTypes, FunctionType[] aggFunctionTypes, + int[] aggInputChannels, int[] maskChannels, DataType[] aggOutputTypes, boolean isInputRaw, + boolean isOutputPartial) { + this(sourceTypes, aggFunctionTypes, aggInputChannels, maskChannels, aggOutputTypes, isInputRaw, isOutputPartial, + new OperatorConfig()); + } + + private static native long createAggregationOperatorFactory(String sourceTypes, int[] aggFunctionTypes, + int[] aggInputChannels, int[] maskChannels, String aggOutputTypes, boolean isInputRaw, + boolean isOutputPartial); + + @Override + protected long createNativeOperatorFactory(FactoryContext context) { + return createAggregationOperatorFactory(DataTypeSerializer.serialize(context.sourceTypes), + toNativeConstants(context.aggFunctionTypes), context.aggInputChannels, context.maskChannels, + DataTypeSerializer.serialize(context.aggOutputTypes), context.isInputRaw, context.isOutputPartial); + } + + /** + * The type Factory context. + * + * @since 2021-06-30 + */ + public static class FactoryContext extends OmniOperatorFactoryContext { + private final DataType[] sourceTypes; + + private final FunctionType[] aggFunctionTypes; + + private final int[] aggInputChannels; + + private final int[] maskChannels; + + private final DataType[] aggOutputTypes; + + private final boolean isInputRaw; + + private final boolean isOutputPartial; + + private final OperatorConfig operatorConfig; + + /** + * Instantiates a new Context. + * + * @param sourceTypes the source types + * @param aggFunctionTypes the aggregation function types + * @param aggInputChannels the aggregation input channels + * @param maskChannels the aggregation mask channels + * @param aggOutputTypes the aggregation output types + * @param isInputRaw the input raw + * @param isOutputPartial the output partial + * @param operatorConfig the operator config + */ + public FactoryContext(DataType[] sourceTypes, FunctionType[] aggFunctionTypes, int[] aggInputChannels, + int[] maskChannels, DataType[] aggOutputTypes, boolean isInputRaw, boolean isOutputPartial, + OperatorConfig operatorConfig) { + this.sourceTypes = requireNonNull(sourceTypes, "sourceTypes is null"); + this.aggFunctionTypes = requireNonNull(aggFunctionTypes, "aggFunctionTypes is null"); + this.aggInputChannels = requireNonNull(aggInputChannels, "aggInputChannels is null"); + this.maskChannels = requireNonNull(maskChannels, "maskChannels is null"); + this.aggOutputTypes = requireNonNull(aggOutputTypes, "aggOutputTypes is null"); + this.isInputRaw = isInputRaw; + this.isOutputPartial = isOutputPartial; + this.operatorConfig = operatorConfig; + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(sourceTypes), Arrays.hashCode(aggFunctionTypes), + Arrays.hashCode(aggInputChannels), Arrays.hashCode(maskChannels), Arrays.hashCode(aggOutputTypes), + isInputRaw, isOutputPartial, operatorConfig); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + FactoryContext that = (FactoryContext) obj; + return Arrays.equals(sourceTypes, that.sourceTypes) + && Arrays.equals(aggFunctionTypes, that.aggFunctionTypes) + && Arrays.equals(aggInputChannels, that.aggInputChannels) + && Arrays.equals(maskChannels, that.maskChannels) + && Arrays.equals(aggOutputTypes, that.aggOutputTypes) && isInputRaw == that.isInputRaw + && isOutputPartial == that.isOutputPartial && operatorConfig.equals(that.operatorConfig); + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/aggregator/OmniAggregationWithExprOperatorFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/aggregator/OmniAggregationWithExprOperatorFactory.java new file mode 100644 index 0000000..c4d877f --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/aggregator/OmniAggregationWithExprOperatorFactory.java @@ -0,0 +1,224 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.aggregator; + +import static java.util.Objects.requireNonNull; +import static nova.hetu.omniruntime.constants.ConstantHelper.toNativeConstants; + +import nova.hetu.omniruntime.constants.FunctionType; +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.OmniOperatorFactoryContext; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DataTypeSerializer; +import nova.hetu.omniruntime.utils.JsonUtils; + +import java.util.Arrays; +import java.util.Objects; + +/** + * The type Omni aggregation with expression operator factory. + * + * @since 2021-10-21 + */ +public class OmniAggregationWithExprOperatorFactory + extends OmniOperatorFactory { + /** + * Instantiates a new Omni hash aggregation with expression operator factory. + * + * @param groupByChanel the group by chanel + * @param aggChannels the agg channels + * @param aggChannelsFilter the agg filter Expr + * @param sourceTypes the source types + * @param aggFunctionTypes the agg function types + * @param aggOutputTypes the agg output types + * @param isInputRaws the input raw flags + * @param isOutputPartials the output partial flags + * @param operatorConfig the operator config + */ + public OmniAggregationWithExprOperatorFactory(String[] groupByChanel, String[][] aggChannels, + String[] aggChannelsFilter, DataType[] sourceTypes, FunctionType[] aggFunctionTypes, + DataType[][] aggOutputTypes, boolean[] isInputRaws, boolean[] isOutputPartials, + OperatorConfig operatorConfig) { + super(new FactoryContext(groupByChanel, aggChannels, aggChannelsFilter, sourceTypes, aggFunctionTypes, + aggOutputTypes, isInputRaws, isOutputPartials, operatorConfig)); + } + + /** + * Instantiates a new Omni hash aggregation with expression operator factory. + * + * @param groupByChanel the group by chanel + * @param aggChannels the agg channels + * @param aggChannelsFilter the agg filter Expr + * @param sourceTypes the source types + * @param aggFunctionTypes the agg function types + * @param maskChannels mask channel list for aggregators + * @param aggOutputTypes the agg output types + * @param isInputRaws the input raw flags + * @param isOutputPartials the output partial flags + * @param operatorConfig the operator config + */ + public OmniAggregationWithExprOperatorFactory(String[] groupByChanel, String[][] aggChannels, + String[] aggChannelsFilter, DataType[] sourceTypes, FunctionType[] aggFunctionTypes, int[] maskChannels, + DataType[][] aggOutputTypes, boolean[] isInputRaws, boolean[] isOutputPartials, + OperatorConfig operatorConfig) { + super(new FactoryContext(groupByChanel, aggChannels, aggChannelsFilter, sourceTypes, aggFunctionTypes, + maskChannels, aggOutputTypes, isInputRaws, isOutputPartials, operatorConfig)); + } + + /** + * Instantiates a new Omni hash aggregation with expression operator factory + * with jit default. + * + * @param groupByChanel the group by chanel + * @param aggChannels the agg channels + * @param aggChannelsFilter the agg filter Expr + * @param sourceTypes the source types + * @param aggFunctionTypes the agg function types + * @param aggOutputTypes the agg output types + * @param isInputRaws the input raw flags + * @param isOutputPartials the output partial flags + */ + public OmniAggregationWithExprOperatorFactory(String[] groupByChanel, String[][] aggChannels, + String[] aggChannelsFilter, DataType[] sourceTypes, FunctionType[] aggFunctionTypes, + DataType[][] aggOutputTypes, boolean[] isInputRaws, boolean[] isOutputPartials) { + this(groupByChanel, aggChannels, aggChannelsFilter, sourceTypes, aggFunctionTypes, aggOutputTypes, isInputRaws, + isOutputPartials, new OperatorConfig()); + } + + private static native long createAggregationWithExprOperatorFactory(String[] groupByChanel, String[] aggChannels, + String[] aggChannelsFilter, String sourceTypes, int[] aggFunctionTypes, int[] maskChannels, + String[] aggOutputTypes, boolean[] isInputRaws, boolean[] isOutputPartials, String operatorConfig); + + @Override + protected long createNativeOperatorFactory(FactoryContext context) { + return createAggregationWithExprOperatorFactory(context.groupByChanel, + JsonUtils.jsonStringArray(context.aggChannels), context.aggChannelsFilter, + DataTypeSerializer.serialize(context.sourceTypes), toNativeConstants(context.aggFunctionTypes), + context.maskChannels, DataTypeSerializer.serialize(context.aggOutputTypes), context.isInputRaws, + context.isOutputPartials, OperatorConfig.serialize(context.operatorConfig)); + } + + /** + * The type Factory context. + * + * @since 20210630 + */ + public static class FactoryContext extends OmniOperatorFactoryContext { + private static final int INVALID_MASK_CHANNEL = -1; + + private final String[] groupByChanel; + + private final String[][] aggChannels; + + private final String[] aggChannelsFilter; + + private final DataType[] sourceTypes; + + private final FunctionType[] aggFunctionTypes; + + private final int[] maskChannels; + + private final DataType[][] aggOutputTypes; + + private final boolean[] isInputRaws; + + private final boolean[] isOutputPartials; + + private final OperatorConfig operatorConfig; + + /** + * Instantiates a new Context. + * + * @param groupByChanel the group by chanel + * @param aggChannels the agg channels + * @param aggChannelsFilter the agg filter Expr + * @param sourceTypes the source types + * @param aggFunctionTypes the agg function types + * @param maskChannels mask channel list for aggregators + * @param aggOutputTypes the agg output types + * @param isInputRaws the input raw flags + * @param isOutputPartials the output partial flags + * @param operatorConfig the operator config + */ + public FactoryContext(String[] groupByChanel, String[][] aggChannels, String[] aggChannelsFilter, + DataType[] sourceTypes, FunctionType[] aggFunctionTypes, int[] maskChannels, + DataType[][] aggOutputTypes, boolean[] isInputRaws, boolean[] isOutputPartials, + OperatorConfig operatorConfig) { + this.groupByChanel = requireNonNull(groupByChanel, "requireNonNull"); + this.aggChannels = requireNonNull(aggChannels, "aggChannels"); + this.aggChannelsFilter = checkAggChannelsFilter(aggChannelsFilter); + this.sourceTypes = requireNonNull(sourceTypes, "sourceTypes"); + this.aggFunctionTypes = requireNonNull(aggFunctionTypes, "aggFunctionTypes"); + this.maskChannels = requireNonNull(maskChannels, "maskChannels is null"); + this.aggOutputTypes = requireNonNull(aggOutputTypes, "aggOutputTypes"); + this.isInputRaws = requireNonNull(isInputRaws, "isInputRaws"); + this.isOutputPartials = requireNonNull(isOutputPartials, "isInputRaws"); + this.operatorConfig = operatorConfig; + } + + /** + * Instantiates a new Context. + * + * @param groupByChanel the group by chanel + * @param aggChannels the agg channels + * @param aggChannelsFilter the agg filter Expr + * @param sourceTypes the source types + * @param aggFunctionTypes the agg function types + * @param aggOutputTypes the agg output types + * @param isInputRaws the input raw flags + * @param isOutputPartials the output partial flags + * @param operatorConfig the operator config + */ + public FactoryContext(String[] groupByChanel, String[][] aggChannels, String[] aggChannelsFilter, + DataType[] sourceTypes, FunctionType[] aggFunctionTypes, DataType[][] aggOutputTypes, + boolean[] isInputRaws, boolean[] isOutputPartials, OperatorConfig operatorConfig) { + this(groupByChanel, aggChannels, aggChannelsFilter, sourceTypes, aggFunctionTypes, + getDefaultMaskChannel(aggFunctionTypes), aggOutputTypes, isInputRaws, isOutputPartials, + operatorConfig); + } + + private static String[] checkAggChannelsFilter(String[] aggChannelsFilter) { + for (int i = 0; i < aggChannelsFilter.length; i++) { + aggChannelsFilter[i] = aggChannelsFilter[i] == null ? "" : aggChannelsFilter[i]; + } + return aggChannelsFilter; + } + + private static int[] getDefaultMaskChannel(FunctionType[] aggFunctionTypes) { + int[] maskChannelArray = new int[aggFunctionTypes.length]; // one mask channel for each function + Arrays.fill(maskChannelArray, INVALID_MASK_CHANNEL); + return maskChannelArray; + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(groupByChanel), Arrays.deepHashCode(aggChannels), + Arrays.hashCode(aggChannelsFilter), Arrays.hashCode(sourceTypes), Arrays.hashCode(aggFunctionTypes), + Arrays.hashCode(maskChannels), Arrays.deepHashCode(aggOutputTypes), Arrays.hashCode(isInputRaws), + Arrays.hashCode(isOutputPartials), operatorConfig); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + FactoryContext that = (FactoryContext) obj; + return Arrays.equals(groupByChanel, that.groupByChanel) && Arrays.deepEquals(aggChannels, that.aggChannels) + && Arrays.equals(aggChannelsFilter, that.aggChannelsFilter) + && Arrays.equals(sourceTypes, that.sourceTypes) + && Arrays.equals(aggFunctionTypes, that.aggFunctionTypes) + && Arrays.equals(maskChannels, that.maskChannels) + && Arrays.deepEquals(aggOutputTypes, that.aggOutputTypes) + && Arrays.equals(isInputRaws, that.isInputRaws) + && Arrays.equals(isOutputPartials, that.isOutputPartials) + && operatorConfig.equals(that.operatorConfig); + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/aggregator/OmniHashAggregationOperatorFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/aggregator/OmniHashAggregationOperatorFactory.java new file mode 100644 index 0000000..7abc738 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/aggregator/OmniHashAggregationOperatorFactory.java @@ -0,0 +1,231 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.aggregator; + +import static java.util.Objects.requireNonNull; +import static nova.hetu.omniruntime.constants.ConstantHelper.toNativeConstants; + +import nova.hetu.omniruntime.constants.FunctionType; +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.OmniOperatorFactoryContext; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DataTypeSerializer; + +import java.util.Arrays; +import java.util.Objects; + +/** + * The type Omni hash aggregation operator factory. + * + * @since 2021-06-30 + */ +public class OmniHashAggregationOperatorFactory + extends OmniOperatorFactory { + /** + * Instantiates a new Omni hash aggregation operator factory. + * + * @param groupByChanel the group by chanel + * @param groupByTypes the group by types + * @param aggChannels the agg channels + * @param aggTypes the agg types + * @param aggFunctionTypes the agg function types + * @param aggOutputTypes the agg output types + * @param isInputRaw the input raw + * @param isOutputPartial the output partial + * @param operatorConfig the operator config + */ + public OmniHashAggregationOperatorFactory(String[] groupByChanel, DataType[] groupByTypes, String[] aggChannels, + DataType[] aggTypes, FunctionType[] aggFunctionTypes, DataType[] aggOutputTypes, boolean isInputRaw, + boolean isOutputPartial, OperatorConfig operatorConfig) { + super(new FactoryContext(groupByChanel, groupByTypes, aggChannels, aggTypes, aggFunctionTypes, aggOutputTypes, + isInputRaw, isOutputPartial, operatorConfig)); + } + + /** + * Instantiates a new Omni hash aggregation operator factory. + * + * @param groupByChanel the group by chanel + * @param groupByTypes the group by types + * @param aggChannels the agg channels + * @param aggTypes the agg types + * @param aggFunctionTypes the agg function types + * @param aggOutputTypes the agg output types + * @param maskChannels mask channel list for aggregators + * @param isInputRaw the input raw + * @param isOutputPartial the output partial + * @param operatorConfig the operator config + */ + public OmniHashAggregationOperatorFactory(String[] groupByChanel, DataType[] groupByTypes, String[] aggChannels, + DataType[] aggTypes, FunctionType[] aggFunctionTypes, int[] maskChannels, DataType[] aggOutputTypes, + boolean isInputRaw, boolean isOutputPartial, OperatorConfig operatorConfig) { + super(new FactoryContext(groupByChanel, groupByTypes, aggChannels, aggTypes, aggFunctionTypes, maskChannels, + aggOutputTypes, isInputRaw, isOutputPartial, operatorConfig)); + } + + /** + * Instantiates a new Omni hash aggregation operator factory with jit + * default. + * + * @param groupByChanel the group by chanel + * @param groupByTypes the group by types + * @param aggChannels the agg channels + * @param aggTypes the agg types + * @param aggFunctionTypes the agg function types + * @param aggOutputTypes the agg output types + * @param isInputRaw the input raw + * @param isOutputPartial the output partial + */ + public OmniHashAggregationOperatorFactory(String[] groupByChanel, DataType[] groupByTypes, String[] aggChannels, + DataType[] aggTypes, FunctionType[] aggFunctionTypes, DataType[] aggOutputTypes, boolean isInputRaw, + boolean isOutputPartial) { + this(groupByChanel, groupByTypes, aggChannels, aggTypes, aggFunctionTypes, aggOutputTypes, isInputRaw, + isOutputPartial, new OperatorConfig()); + } + + /** + * Instantiates a new Omni hash aggregation operator factory with jit + * default. + * + * @param groupByChanel the group by chanel + * @param groupByTypes the group by types + * @param aggChannels the agg channels + * @param aggTypes the agg types + * @param aggFunctionTypes the agg function types + * @param maskChannels mask channel list for aggregators + * @param aggOutputTypes the agg output types + * @param isInputRaw the input raw + * @param isOutputPartial the output partial + */ + public OmniHashAggregationOperatorFactory(String[] groupByChanel, DataType[] groupByTypes, String[] aggChannels, + DataType[] aggTypes, FunctionType[] aggFunctionTypes, int[] maskChannels, DataType[] aggOutputTypes, + boolean isInputRaw, boolean isOutputPartial) { + this(groupByChanel, groupByTypes, aggChannels, aggTypes, aggFunctionTypes, maskChannels, aggOutputTypes, + isInputRaw, isOutputPartial, new OperatorConfig()); + } + + private static native long createHashAggregationOperatorFactory(String[] groupByChanel, String groupByTypes, + String[] aggChannels, String aggTypes, int[] aggFunctionTypes, int[] maskChannels, String aggOutputTypes, + boolean isInputRaw, boolean isOutputPartial, String operatorConfig); + + @Override + protected long createNativeOperatorFactory(FactoryContext context) { + return createHashAggregationOperatorFactory(context.groupByChanel, + DataTypeSerializer.serialize(context.groupByTypes), context.aggChannels, + DataTypeSerializer.serialize(context.aggTypes), toNativeConstants(context.aggFunctionTypes), + context.maskChannels, DataTypeSerializer.serialize(context.aggOutputTypes), context.isInputRaw, + context.isOutputPartial, OperatorConfig.serialize(context.operatorConfig)); + } + + /** + * The type Factory context. + * + * @since 2021-06-30 + */ + public static class FactoryContext extends OmniOperatorFactoryContext { + private static final int INVALID_MASK_CHANNEL = -1; + + private final String[] groupByChanel; + + private final DataType[] groupByTypes; + + private final String[] aggChannels; + + private final DataType[] aggTypes; + + private final FunctionType[] aggFunctionTypes; + + private final int[] maskChannels; + + private final DataType[] aggOutputTypes; + + private final boolean isInputRaw; + + private final boolean isOutputPartial; + + private final OperatorConfig operatorConfig; + + /** + * Instantiates a new Context. + * + * @param groupByChanel the group by chanel + * @param groupByTypes the group by types + * @param aggChannels the agg channels + * @param aggTypes the agg types + * @param aggFunctionTypes the agg function types + * @param maskChannels mask channel list for aggregators + * @param aggOutputTypes the agg output types + * @param isInputRaw the input raw + * @param isOutputPartial the output partial + * @param operatorConfig the operator config + */ + public FactoryContext(String[] groupByChanel, DataType[] groupByTypes, String[] aggChannels, + DataType[] aggTypes, FunctionType[] aggFunctionTypes, int[] maskChannels, DataType[] aggOutputTypes, + boolean isInputRaw, boolean isOutputPartial, OperatorConfig operatorConfig) { + this.groupByChanel = requireNonNull(groupByChanel, "requireNonNull"); + this.groupByTypes = requireNonNull(groupByTypes, "groupByTypes"); + this.aggChannels = requireNonNull(aggChannels, "aggChannels"); + this.aggTypes = requireNonNull(aggTypes, "aggTypes"); + this.aggFunctionTypes = requireNonNull(aggFunctionTypes, "aggFunctionTypes"); + this.maskChannels = requireNonNull(maskChannels, "maskChannels is null"); + this.aggOutputTypes = requireNonNull(aggOutputTypes, "aggOutputTypes"); + this.isInputRaw = isInputRaw; + this.isOutputPartial = isOutputPartial; + this.operatorConfig = operatorConfig; + } + + /** + * Instantiates a new Context. + * + * @param groupByChanel the group by chanel + * @param groupByTypes the group by types + * @param aggChannels the agg channels + * @param aggTypes the agg types + * @param aggFunctionTypes the agg function types + * @param aggOutputTypes the agg output types + * @param isInputRaw the input raw + * @param isOutputPartial the output partial + * @param operatorConfig the operator config + */ + public FactoryContext(String[] groupByChanel, DataType[] groupByTypes, String[] aggChannels, + DataType[] aggTypes, FunctionType[] aggFunctionTypes, DataType[] aggOutputTypes, boolean isInputRaw, + boolean isOutputPartial, OperatorConfig operatorConfig) { + this(groupByChanel, groupByTypes, aggChannels, aggTypes, aggFunctionTypes, + getDefaultMaskChannel(aggFunctionTypes), aggOutputTypes, isInputRaw, isOutputPartial, + operatorConfig); + } + + private static int[] getDefaultMaskChannel(FunctionType[] aggFunctionTypes) { + int[] maskChannelArray = new int[aggFunctionTypes.length]; // one mask channel for each function + Arrays.fill(maskChannelArray, INVALID_MASK_CHANNEL); + return maskChannelArray; + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(groupByChanel), Arrays.hashCode(groupByTypes), + Arrays.hashCode(aggChannels), Arrays.hashCode(aggTypes), Arrays.hashCode(aggFunctionTypes), + Arrays.hashCode(maskChannels), Arrays.hashCode(aggOutputTypes), isInputRaw, isOutputPartial, + operatorConfig); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + FactoryContext that = (FactoryContext) obj; + return Arrays.equals(groupByChanel, that.groupByChanel) && Arrays.equals(groupByTypes, that.groupByTypes) + && Arrays.equals(aggChannels, that.aggChannels) && Arrays.equals(aggTypes, that.aggTypes) + && Arrays.equals(aggFunctionTypes, that.aggFunctionTypes) + && Arrays.equals(maskChannels, that.maskChannels) + && Arrays.equals(aggOutputTypes, that.aggOutputTypes) && isInputRaw == that.isInputRaw + && isOutputPartial == that.isOutputPartial && operatorConfig.equals(that.operatorConfig); + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/aggregator/OmniHashAggregationWithExprOperatorFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/aggregator/OmniHashAggregationWithExprOperatorFactory.java new file mode 100644 index 0000000..d8a9e02 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/aggregator/OmniHashAggregationWithExprOperatorFactory.java @@ -0,0 +1,225 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.aggregator; + +import static java.util.Objects.requireNonNull; +import static nova.hetu.omniruntime.constants.ConstantHelper.toNativeConstants; + +import nova.hetu.omniruntime.constants.FunctionType; +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.OmniOperatorFactoryContext; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DataTypeSerializer; +import nova.hetu.omniruntime.utils.JsonUtils; + +import java.util.Arrays; +import java.util.Objects; + +/** + * The type Omni aggregation with expression operator factory. + * + * @since 2021-10-21 + */ +public class OmniHashAggregationWithExprOperatorFactory + extends OmniOperatorFactory { + /** + * Instantiates a new Omni hash aggregation with expression operator factory. + * + * @param groupByChanel the group by chanel + * @param aggChannels the agg channels + * @param aggChannelsFilter the agg filter Expr + * @param sourceTypes the source types + * @param aggFunctionTypes the agg function types + * @param aggOutputTypes the agg output types + * @param isInputRaws the input raw flags + * @param isOutputPartials the output partial flags + * @param operatorConfig the operator config + */ + public OmniHashAggregationWithExprOperatorFactory(String[] groupByChanel, String[][] aggChannels, + String[] aggChannelsFilter, DataType[] sourceTypes, FunctionType[] aggFunctionTypes, + DataType[][] aggOutputTypes, boolean[] isInputRaws, boolean[] isOutputPartials, + OperatorConfig operatorConfig) { + super(new FactoryContext(groupByChanel, aggChannels, aggChannelsFilter, sourceTypes, aggFunctionTypes, + aggOutputTypes, isInputRaws, isOutputPartials, operatorConfig)); + } + + /** + * Instantiates a new Omni hash aggregation with expression operator factory. + * + * @param groupByChanel the group by chanel + * @param aggChannels the agg channels + * @param aggChannelsFilter the agg filter Expr + * @param sourceTypes the source types + * @param aggFunctionTypes the agg function types + * @param maskChannels mask channel list for aggregators + * @param aggOutputTypes the agg output types + * @param isInputRaws the input raw flags + * @param isOutputPartials the output partial flags + * @param operatorConfig the operator config + */ + public OmniHashAggregationWithExprOperatorFactory(String[] groupByChanel, String[][] aggChannels, + String[] aggChannelsFilter, DataType[] sourceTypes, FunctionType[] aggFunctionTypes, int[] maskChannels, + DataType[][] aggOutputTypes, boolean[] isInputRaws, boolean[] isOutputPartials, + OperatorConfig operatorConfig) { + super(new FactoryContext(groupByChanel, aggChannels, aggChannelsFilter, sourceTypes, aggFunctionTypes, + maskChannels, aggOutputTypes, isInputRaws, isOutputPartials, operatorConfig)); + } + + /** + * Instantiates a new Omni hash aggregation with expression operator factory + * with jit default. + * + * @param groupByChanel the group by chanel + * @param aggChannels the agg channels + * @param aggChannelsFilter the agg filter Expr + * @param sourceTypes the source types + * @param aggFunctionTypes the agg function types + * @param aggOutputTypes the agg output types + * @param isInputRaws the input raw flags + * @param isOutputPartials the output partial flags + */ + public OmniHashAggregationWithExprOperatorFactory(String[] groupByChanel, String[][] aggChannels, + String[] aggChannelsFilter, DataType[] sourceTypes, FunctionType[] aggFunctionTypes, + DataType[][] aggOutputTypes, boolean[] isInputRaws, boolean[] isOutputPartials) { + this(groupByChanel, aggChannels, aggChannelsFilter, sourceTypes, aggFunctionTypes, aggOutputTypes, isInputRaws, + isOutputPartials, new OperatorConfig()); + } + + private static native long createHashAggregationWithExprOperatorFactory(String[] groupByChanel, + String[] aggChannels, String[] aggChannelsFilter, String sourceTypes, int[] aggFunctionTypes, + int[] maskChannels, String[] aggOutputTypes, boolean[] isInputRaws, boolean[] isOutputPartials, + String operatorConfig); + + @Override + protected long createNativeOperatorFactory(FactoryContext context) { + return createHashAggregationWithExprOperatorFactory(context.groupByChanel, + JsonUtils.jsonStringArray(context.aggChannels), context.aggChannelsFilter, + DataTypeSerializer.serialize(context.sourceTypes), toNativeConstants(context.aggFunctionTypes), + context.maskChannels, DataTypeSerializer.serialize(context.aggOutputTypes), context.isInputRaws, + context.isOutputPartials, OperatorConfig.serialize(context.operatorConfig)); + } + + /** + * The type Factory context. + * + * @since 20210630 + */ + public static class FactoryContext extends OmniOperatorFactoryContext { + private static final int INVALID_MASK_CHANNEL = -1; + + private final String[] groupByChanel; + + private final String[][] aggChannels; + + private final String[] aggChannelsFilter; + + private final DataType[] sourceTypes; + + private final FunctionType[] aggFunctionTypes; + + private final int[] maskChannels; + + private final DataType[][] aggOutputTypes; + + private final boolean[] isInputRaws; + + private final boolean[] isOutputPartials; + + private final OperatorConfig operatorConfig; + + /** + * Instantiates a new Context. + * + * @param groupByChanel the group by chanel + * @param aggChannels the agg channels + * @param aggChannelsFilter the agg filter Expr + * @param sourceTypes the source types + * @param aggFunctionTypes the agg function types + * @param maskChannels mask channel list for aggregators + * @param aggOutputTypes the agg output types + * @param isInputRaws the input raw flags + * @param isOutputPartials the output partial flags + * @param operatorConfig the operator config + */ + public FactoryContext(String[] groupByChanel, String[][] aggChannels, String[] aggChannelsFilter, + DataType[] sourceTypes, FunctionType[] aggFunctionTypes, int[] maskChannels, + DataType[][] aggOutputTypes, boolean[] isInputRaws, boolean[] isOutputPartials, + OperatorConfig operatorConfig) { + this.groupByChanel = requireNonNull(groupByChanel, "requireNonNull"); + this.aggChannels = requireNonNull(aggChannels, "aggChannels"); + this.aggChannelsFilter = checkAggChannelsFilter(aggChannelsFilter); + this.sourceTypes = requireNonNull(sourceTypes, "sourceTypes"); + this.aggFunctionTypes = requireNonNull(aggFunctionTypes, "aggFunctionTypes"); + this.maskChannels = requireNonNull(maskChannels, "maskChannels is null"); + this.aggOutputTypes = requireNonNull(aggOutputTypes, "aggOutputTypes"); + this.isInputRaws = requireNonNull(isInputRaws, "isInputRaws"); + this.isOutputPartials = requireNonNull(isOutputPartials, "isInputRaws"); + this.operatorConfig = operatorConfig; + } + + /** + * Instantiates a new Context. + * + * @param groupByChanel the group by chanel + * @param aggChannels the agg channels + * @param aggChannelsFilter the agg filter Expr + * @param sourceTypes the source types + * @param aggFunctionTypes the agg function types + * @param aggOutputTypes the agg output types + * @param isInputRaws the input raw flags + * @param isOutputPartials the output partial flags + * @param operatorConfig the operator config + */ + public FactoryContext(String[] groupByChanel, String[][] aggChannels, String[] aggChannelsFilter, + DataType[] sourceTypes, FunctionType[] aggFunctionTypes, DataType[][] aggOutputTypes, + boolean[] isInputRaws, boolean[] isOutputPartials, OperatorConfig operatorConfig) { + this(groupByChanel, aggChannels, aggChannelsFilter, sourceTypes, aggFunctionTypes, + getDefaultMaskChannel(aggFunctionTypes), aggOutputTypes, isInputRaws, isOutputPartials, + operatorConfig); + } + + private static String[] checkAggChannelsFilter(String[] aggChannelsFilter) { + for (int i = 0; i < aggChannelsFilter.length; i++) { + aggChannelsFilter[i] = aggChannelsFilter[i] == null ? "" : aggChannelsFilter[i]; + } + return aggChannelsFilter; + } + + private static int[] getDefaultMaskChannel(FunctionType[] aggFunctionTypes) { + int[] maskChannelArray = new int[aggFunctionTypes.length]; // one mask channel for each function + Arrays.fill(maskChannelArray, INVALID_MASK_CHANNEL); + return maskChannelArray; + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(groupByChanel), Arrays.deepHashCode(aggChannels), + Arrays.hashCode(aggChannelsFilter), Arrays.hashCode(sourceTypes), Arrays.hashCode(aggFunctionTypes), + Arrays.hashCode(maskChannels), Arrays.deepHashCode(aggOutputTypes), Arrays.hashCode(isInputRaws), + Arrays.hashCode(isOutputPartials), operatorConfig); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + FactoryContext that = (FactoryContext) obj; + return Arrays.equals(groupByChanel, that.groupByChanel) && Arrays.deepEquals(aggChannels, that.aggChannels) + && Arrays.equals(aggChannelsFilter, that.aggChannelsFilter) + && Arrays.equals(sourceTypes, that.sourceTypes) + && Arrays.equals(aggFunctionTypes, that.aggFunctionTypes) + && Arrays.equals(maskChannels, that.maskChannels) + && Arrays.deepEquals(aggOutputTypes, that.aggOutputTypes) + && Arrays.equals(isInputRaws, that.isInputRaws) + && Arrays.equals(isOutputPartials, that.isOutputPartials) + && operatorConfig.equals(that.operatorConfig); + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/config/OperatorConfig.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/config/OperatorConfig.java new file mode 100644 index 0000000..7c1948c --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/config/OperatorConfig.java @@ -0,0 +1,278 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.config; + +import static nova.hetu.omniruntime.utils.OmniErrorType.OMNI_INNER_ERROR; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; + +import nova.hetu.omniruntime.utils.OmniRuntimeException; + +import java.util.Objects; + +/** + * operator config. + * + * @since 2022-04-16 + */ +public class OperatorConfig { + /** + * NONE operator config. + */ + public static final OperatorConfig NONE = new OperatorConfig(); + + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + private SpillConfig spillConfig; + + private OverflowConfig overflowConfig; + + /** + * When set to true, statistical aggregate function returns Double.NaN + * if divide by zero occurred during expression evaluation, otherwise, it returns null. + * Before Spark version 3.1.0, it returns NaN in divideByZero case by default. + */ + private boolean isStatisticalAggregate; + + private boolean isSkipExpressionVerify; + + private int adaptivityThreshold = -1; + + private boolean isRowOutput = false; + + /** + * Operator config default constructor. + */ + public OperatorConfig() { + this(SpillConfig.NONE, new OverflowConfig(), false); + } + + /** + * Operator config constructor. + * + * @param spillConfig the spill config + */ + public OperatorConfig(SpillConfig spillConfig) { + this(spillConfig, new OverflowConfig(), false); + } + + /** + * Operator config constructor. + * + * @param overflowConfig overflowConfig + */ + public OperatorConfig(OverflowConfig overflowConfig) { + this(SpillConfig.NONE, overflowConfig, false); + } + + /** + * Operator config constructor. + * + * @param spillConfig spillConfig + * @param overflowConfig overflowConfig + */ + public OperatorConfig(SpillConfig spillConfig, OverflowConfig overflowConfig) { + this(spillConfig, overflowConfig, false); + } + + /** + * Operator config constructor. + * + * @param spillConfig the spill config + * @param overflowConfig the overflow config + * @param isSkipExpressionVerify whether to skip exprVerify + */ + public OperatorConfig(SpillConfig spillConfig, OverflowConfig overflowConfig, boolean isSkipExpressionVerify) { + this.spillConfig = spillConfig; + this.overflowConfig = overflowConfig; + this.isSkipExpressionVerify = isSkipExpressionVerify; + this.isStatisticalAggregate = false; + } + + public OperatorConfig(SpillConfig spillConfig, boolean isStatisticalAggregate, OverflowConfig overflowConfig, + boolean isSkipExpressionVerify) { + this.spillConfig = spillConfig; + this.overflowConfig = overflowConfig; + this.isSkipExpressionVerify = isSkipExpressionVerify; + this.isStatisticalAggregate = isStatisticalAggregate; + } + + /** + * Operator config constructor. + * + * @param spillConfig the spill config + * @param overflowConfig the overflow config + * @param isSkipExpressionVerify whether to skip exprVerify + * @param adaptivityThreshold an int for adaptivity of operator. For example, + * radix sort threshold for Sort + */ + public OperatorConfig(SpillConfig spillConfig, OverflowConfig overflowConfig, boolean isSkipExpressionVerify, + int adaptivityThreshold) { + this(spillConfig, overflowConfig, isSkipExpressionVerify); + this.adaptivityThreshold = adaptivityThreshold; + } + + /** + * Operator config constructor. + * + * @param spillConfig the spill config + * @param overflowConfig the overflow config + * @param isSkipExpressionVerify whether to skip exprVerify + * @param isRowOutput true mean operator need to output row batch, + */ + public OperatorConfig(SpillConfig spillConfig, OverflowConfig overflowConfig, boolean isSkipExpressionVerify, + boolean isRowOutput) { + this(spillConfig, overflowConfig, isSkipExpressionVerify); + this.isRowOutput = isRowOutput; + } + + /** + * Get spill config. + * + * @return the spill config + */ + public SpillConfig getSpillConfig() { + return spillConfig; + } + + /** + * Get overflow config. + * + * @return the overflow config + */ + public OverflowConfig getOverflowConfig() { + return overflowConfig; + } + + /** + * Set spill config. + * + * @param spillConfig the spill config + */ + public void setSpillConfig(SpillConfig spillConfig) { + this.spillConfig = spillConfig; + } + + /** + * Set overflow config. + * + * @param overflowConfig overflowConfig + */ + public void setOverflowConfig(OverflowConfig overflowConfig) { + this.overflowConfig = overflowConfig; + } + + /** + * Set skipExpressionVerify + * + * @param isSkipExpressionVerify whether to skip exprVerify + */ + public void setSkipExpressionVerify(boolean isSkipExpressionVerify) { + this.isSkipExpressionVerify = isSkipExpressionVerify; + } + + /** + * Get skipExpressionVerify + * + * @return skipExpressionVerify + */ + public boolean isSkipExpressionVerify() { + return isSkipExpressionVerify; + } + + /** + * Get statisticalAggregate + * + * @return statisticalAggregate + */ + public boolean isStatisticalAggregate() { + return isStatisticalAggregate; + } + + /** + * Set statisticalAggregate. + * + * @param isStatisticalAggregate boolean + */ + public void setStatisticalAggregate(boolean isStatisticalAggregate) { + this.isStatisticalAggregate = isStatisticalAggregate; + } + + /** + * Set adaptivityThreshold + * + * @param adaptivityThreshold a threshold for some kind of adaptivity in operator + */ + public void setAdaptivityThreshold(int adaptivityThreshold) { + this.adaptivityThreshold = adaptivityThreshold; + } + + public void setIsRowOutput(boolean inputRowOutput) { + this.isRowOutput = inputRowOutput; + } + + /** + * Get adaptivityThreshold + * + * @return adaptivityThreshold + */ + public int getAdaptivityThreshold() { + return adaptivityThreshold; + } + + public boolean getIsRowOutput() { + return isRowOutput; + } + + /** + * Serialize operator config to string. + * + * @param operatorConfig the operator config + * @return the string result of serialization + */ + public static String serialize(OperatorConfig operatorConfig) { + try { + return OBJECT_MAPPER.writeValueAsString(operatorConfig); + } catch (JsonProcessingException e) { + throw new OmniRuntimeException(OMNI_INNER_ERROR, "Serialization failed.", e); + } + } + + /** + * Deserialize string to the operator config. + * + * @param operatorConfigString the operator config string + * @return the operator config of deserialization + */ + public static OperatorConfig deserialize(String operatorConfigString) { + try { + return OBJECT_MAPPER.readerFor(OperatorConfig.class).readValue(operatorConfigString); + } catch (JsonProcessingException e) { + throw new OmniRuntimeException(OMNI_INNER_ERROR, "Deserialization failed.", e); + } + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + OperatorConfig that = (OperatorConfig) obj; + return Objects.equals(spillConfig, that.spillConfig) && Objects.equals(overflowConfig, that.overflowConfig) + && isSkipExpressionVerify == that.isSkipExpressionVerify + && isStatisticalAggregate == that.isStatisticalAggregate + && adaptivityThreshold == that.adaptivityThreshold && isRowOutput == that.isRowOutput; + } + + @Override + public int hashCode() { + return Objects.hash(spillConfig, overflowConfig, isSkipExpressionVerify, isStatisticalAggregate, + adaptivityThreshold, isRowOutput); + } +} \ No newline at end of file diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/config/OverflowConfig.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/config/OverflowConfig.java new file mode 100644 index 0000000..8a7dbcd --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/config/OverflowConfig.java @@ -0,0 +1,62 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.config; + +import java.io.Serializable; +import java.util.Objects; + +/** + * OverflowConfig + * + * @since 2022-08-09 + */ +public class OverflowConfig implements Serializable { + private static final long serialVersionUID = 393901615896834842L; + + private OverflowConfigId overflowConfigId; + + public OverflowConfig() { + this(OverflowConfigId.OVERFLOW_CONFIG_EXCEPTION); + } + + public OverflowConfig(OverflowConfigId overflowConfigId) { + this.overflowConfigId = overflowConfigId; + } + + public void setOverflowConfigId(OverflowConfigId overflowConfigId) { + this.overflowConfigId = overflowConfigId; + } + + public OverflowConfigId getOverflowConfigId() { + return overflowConfigId; + } + + /** + * OverflowConfigId + * + * @since 2022-08-09 + */ + public enum OverflowConfigId { + OVERFLOW_CONFIG_EXCEPTION, + OVERFLOW_CONFIG_NULL + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + OverflowConfig that = (OverflowConfig) obj; + return overflowConfigId == that.overflowConfigId; + } + + @Override + public int hashCode() { + return Objects.hash(overflowConfigId); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/config/SparkSpillConfig.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/config/SparkSpillConfig.java new file mode 100644 index 0000000..bc16fde --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/config/SparkSpillConfig.java @@ -0,0 +1,127 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.config; + +import java.util.Objects; + +/** + * spark spill config. + * + * @since 2022-04-16 + */ +public class SparkSpillConfig extends SpillConfig { + private int numElementsForSpillThreshold; + private int memUsagePctForSpillThreshold; + + /** + * Instantiates a new spark spill config. + */ + public SparkSpillConfig() { + super(); + numElementsForSpillThreshold = Integer.MAX_VALUE; + memUsagePctForSpillThreshold = 90; + } + + /** + * Instantiates a new spark spill config. + * + * @param spillPath the spill path + * @param numElementsForSpillThreshold the num elements for spill threshold + */ + public SparkSpillConfig(String spillPath, int numElementsForSpillThreshold) { + this(true, spillPath, DEFAULT_MAX_SPILL_BYTES, numElementsForSpillThreshold); + this.memUsagePctForSpillThreshold = 90; // default memory usage percentage for spill threshold + } + + /** + * Instantiates a new spark spill config. + * + * @param isSpillEnabled the spill enabled + * @param spillPath the spill path + * @param maxSpillBytes the max spill bytes + * @param numElementsForSpillThreshold the num elements for spill threshold + */ + public SparkSpillConfig(boolean isSpillEnabled, String spillPath, long maxSpillBytes, + int numElementsForSpillThreshold) { + super(SpillConfigId.SPILL_CONFIG_SPARK, isSpillEnabled, spillPath, maxSpillBytes, DEFAULT_WRITE_BUFFER_SIZE); + this.numElementsForSpillThreshold = numElementsForSpillThreshold; + this.memUsagePctForSpillThreshold = 90; // default memory usage percentage for spill threshold + } + + /** + * Instantiates a new spark spill config. + * + * @param isSpillEnabled the spill enabled + * @param spillPath the spill path + * @param maxSpillBytes the max spill bytes + * @param numElementsForSpillThreshold the num elements for spill threshold + * @param memUsagePctForSpillThreshold the memory usage percentage for spill threshold + * @param writeBufferSize the spill write buffer size + */ + public SparkSpillConfig(boolean isSpillEnabled, String spillPath, long maxSpillBytes, + int numElementsForSpillThreshold, int memUsagePctForSpillThreshold, long writeBufferSize) { + super(SpillConfigId.SPILL_CONFIG_SPARK, isSpillEnabled, spillPath, maxSpillBytes, writeBufferSize); + this.numElementsForSpillThreshold = numElementsForSpillThreshold; + this.memUsagePctForSpillThreshold = memUsagePctForSpillThreshold; + } + + /** + * get the num elements for spill threshold. + * + * @return the num elements for spill threshold + */ + public int getNumElementsForSpillThreshold() { + return numElementsForSpillThreshold; + } + + /** + * set the num elements for spill threshold. + * + * @param numElementsForSpillThreshold the num elements for spill threshold + */ + public void setNumElementsForSpillThreshold(int numElementsForSpillThreshold) { + this.numElementsForSpillThreshold = numElementsForSpillThreshold; + } + + /** + * set the memory usage percentage for spill threshold. + * + * @param memUsagePctForSpillThreshold the memory usage percentage for spill + * threshold + */ + public void setMemUsagePctForSpillThreshold(int memUsagePctForSpillThreshold) { + this.memUsagePctForSpillThreshold = memUsagePctForSpillThreshold; + } + + /** + * get the memory usage percentage for spill threshold. + * + * @return the num elements for spill threshold + */ + public int getMemUsagePctForSpillThreshold() { + return memUsagePctForSpillThreshold; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + if (!super.equals(obj)) { + return false; + } + SparkSpillConfig that = (SparkSpillConfig) obj; + return numElementsForSpillThreshold == that.numElementsForSpillThreshold + && memUsagePctForSpillThreshold == that.memUsagePctForSpillThreshold; + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), numElementsForSpillThreshold, memUsagePctForSpillThreshold); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/config/SpillConfig.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/config/SpillConfig.java new file mode 100644 index 0000000..9d8e56d --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/config/SpillConfig.java @@ -0,0 +1,223 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.config; + +import static nova.hetu.omniruntime.utils.OmniErrorType.OMNI_PARAM_ERROR; + +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; + +import nova.hetu.omniruntime.utils.OmniRuntimeException; + +import java.io.Serializable; +import java.util.Objects; + +/** + * spill config. + * + * @since 2022-04-16 + */ +@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.PROPERTY, property = "name") +@JsonSubTypes(value = {@JsonSubTypes.Type(value = SparkSpillConfig.class, name = "SparkSpillConfig")}) +public class SpillConfig implements Serializable { + /** + * NONE spill config. + */ + public static final SpillConfig NONE = new SpillConfig(SpillConfigId.SPILL_CONFIG_NONE); + + /** + * INVALID spill config. + */ + public static final SpillConfig INVALID = new SpillConfig(SpillConfigId.SPILL_CONFIG_INVALID); + + /** + * The default max spill bytes. + */ + public static final long DEFAULT_MAX_SPILL_BYTES = 100L * (1 << 30); // 100GB + + /** + * The default spill write buffer size. + */ + public static final long DEFAULT_WRITE_BUFFER_SIZE = 4 * (1 << 20); // 4MB + + private static final long serialVersionUID = -1420544948753374714L; + + private SpillConfigId spillConfigId; + + private boolean isSpillEnabled; + + private String spillPath; + + private long maxSpillBytes; + + private long writeBufferSize; + + /** + * Spill config default constructor. + */ + public SpillConfig() { + this(SpillConfigId.SPILL_CONFIG_NONE, false, "", DEFAULT_MAX_SPILL_BYTES, DEFAULT_WRITE_BUFFER_SIZE); + } + + /** + * Spill config constructor. + * + * @param spillConfigId the spill config id + */ + public SpillConfig(SpillConfigId spillConfigId) { + this(spillConfigId, false, "", DEFAULT_MAX_SPILL_BYTES, DEFAULT_WRITE_BUFFER_SIZE); + } + + /** + * Spill config constructor. + * + * @param spillConfigId the spill config id + * @param isSpillEnabled whether the spill enabled + * @param spillPath the spill path + */ + public SpillConfig(SpillConfigId spillConfigId, boolean isSpillEnabled, String spillPath) { + this(spillConfigId, isSpillEnabled, spillPath, DEFAULT_MAX_SPILL_BYTES, DEFAULT_WRITE_BUFFER_SIZE); + } + + /** + * Spill config constructor. + * + * @param spillConfigId the spill config id + * @param isSpillEnabled whether the spill enabled + * @param spillPath the spill path + * @param maxSpillBytes the max spill bytes + * @param writeBufferSize the sill write buffer size + */ + public SpillConfig(SpillConfigId spillConfigId, boolean isSpillEnabled, String spillPath, long maxSpillBytes, + long writeBufferSize) { + if (isSpillEnabled && (spillPath == null || spillPath.isEmpty())) { + throw new OmniRuntimeException(OMNI_PARAM_ERROR, "Enable spill but do not config spill path."); + } + this.spillConfigId = spillConfigId; + this.isSpillEnabled = isSpillEnabled; + this.spillPath = spillPath; + this.maxSpillBytes = maxSpillBytes; + this.writeBufferSize = writeBufferSize; + } + + /** + * get the spill config id. + * + * @return the spill config id + */ + public SpillConfigId getSpillConfigId() { + return spillConfigId; + } + + /** + * set the spill config id. + * + * @param spillConfigId the spill config id + */ + public void setSpillConfigId(SpillConfigId spillConfigId) { + this.spillConfigId = spillConfigId; + } + + /** + * get whether the spill enabled. + * + * @return return true if enable spill, return false if disable spill + */ + public boolean isSpillEnabled() { + return isSpillEnabled; + } + + /** + * set whether spill enabled. + * + * @param isSpillEnabled the status of spill enabled + */ + public void setSpillEnabled(boolean isSpillEnabled) { + this.isSpillEnabled = isSpillEnabled; + } + + /** + * get the spill path. + * + * @return the spill path + */ + public String getSpillPath() { + return spillPath; + } + + /** + * set the spill path. + * + * @param spillPath the spill path + */ + public void setSpillPath(String spillPath) { + this.spillPath = spillPath; + } + + /** + * get the max spill bytes. + * + * @return the max spill bytes + */ + public long getMaxSpillBytes() { + return maxSpillBytes; + } + + /** + * set the max spill bytes. + * + * @param maxSpillBytes the max spill bytes + */ + public void setMaxSpillBytes(long maxSpillBytes) { + this.maxSpillBytes = maxSpillBytes; + } + + /** + * get the spill write buffer size. + * + * @return the spill write buffer size + */ + public long getWriteBufferSize() { + return writeBufferSize; + } + + /** + * set the spill write buffer size. + * + * @param writeBufferSize the spill write buffer size + */ + public void setWriteBufferSize(long writeBufferSize) { + this.writeBufferSize = writeBufferSize; + } + + /** + * The enum for spill config id. + */ + public enum SpillConfigId { + SPILL_CONFIG_NONE, + SPILL_CONFIG_OLK, + SPILL_CONFIG_SPARK, + SPILL_CONFIG_INVALID + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + SpillConfig spillConfig = (SpillConfig) obj; + return spillConfigId == spillConfig.spillConfigId && isSpillEnabled == isSpillEnabled + && spillPath.equals(spillConfig.spillPath) && maxSpillBytes == spillConfig.maxSpillBytes + && writeBufferSize == spillConfig.writeBufferSize; + } + + @Override + public int hashCode() { + return Objects.hash(spillConfigId, isSpillEnabled, spillPath, maxSpillBytes, writeBufferSize); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/filter/OmniBloomFilterOperatorFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/filter/OmniBloomFilterOperatorFactory.java new file mode 100644 index 0000000..63c00d4 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/filter/OmniBloomFilterOperatorFactory.java @@ -0,0 +1,72 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.filter; + +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.OmniOperatorFactoryContext; +import nova.hetu.omniruntime.operator.config.OperatorConfig; + +import java.util.Objects; + +/** + * The type Omni bloom filter operator factory. + * + * @since 2023-03-03 + */ +public class OmniBloomFilterOperatorFactory extends OmniOperatorFactory { + public OmniBloomFilterOperatorFactory(int version, OperatorConfig operatorConfig) { + super(new OmniBloomFilterOperatorFactory.FactoryContext(version, operatorConfig)); + } + + public OmniBloomFilterOperatorFactory(int version) { + this(version, new OperatorConfig()); + } + + @Override + protected long createNativeOperatorFactory(OmniBloomFilterOperatorFactory.FactoryContext context) { + return createBloomFilterOperatorFactory(context.version); + } + + private static native long createBloomFilterOperatorFactory(int version); + + /** + * The type Factory context. + * + * @since 20230303 + */ + public static class FactoryContext extends OmniOperatorFactoryContext { + private final int version; + + private final OperatorConfig operatorConfig; + + /** + * Instantiates a new Context. + * + * @param version the bloom filter version + * @param operatorConfig the operator config + */ + public FactoryContext(int version, OperatorConfig operatorConfig) { + this.version = version; + this.operatorConfig = operatorConfig; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + FactoryContext context = (FactoryContext) obj; + return version == context.version && operatorConfig.equals(context.operatorConfig); + } + + @Override + public int hashCode() { + return Objects.hash(this.version, operatorConfig); + } + } +} \ No newline at end of file diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/filter/OmniFilterAndProjectOperatorFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/filter/OmniFilterAndProjectOperatorFactory.java new file mode 100644 index 0000000..218664f --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/filter/OmniFilterAndProjectOperatorFactory.java @@ -0,0 +1,169 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.filter; + +import static java.util.Objects.requireNonNull; + +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.OmniOperatorFactoryContext; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DataTypeSerializer; + +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +/** + * The type Omni filter and project operator factory. + * + * @since 2021-06-30 + */ +public class OmniFilterAndProjectOperatorFactory + extends OmniOperatorFactory { + private boolean isSupported; + + /** + * Instantiates a new Omni filter and project operator factory. + * + * @param expression the expression + * @param inputTypes the input types + * @param projections the projections + * @param operatorConfig the operator config + */ + public OmniFilterAndProjectOperatorFactory(String expression, DataType[] inputTypes, List projections, + OperatorConfig operatorConfig) { + super(new FactoryContext(expression, inputTypes, projections, operatorConfig)); + } + + /** + * Instantiates a new Omni filter and project operator factory with default + * operator config. + * + * @param expression the expression + * @param inputTypes the input types + * @param projections the projections + */ + public OmniFilterAndProjectOperatorFactory(String expression, DataType[] inputTypes, List projections) { + this(expression, inputTypes, projections, new OperatorConfig()); + } + + /** + * Instantiates a new Omni filter and project operator factory with configured + * expression parsing format. + * + * @param expression the expression + * @param inputTypes the input types + * @param projections the projections + * @param parseFormat the format to parse expression into + * @param operatorConfig the operator config + */ + public OmniFilterAndProjectOperatorFactory(String expression, DataType[] inputTypes, List projections, + int parseFormat, OperatorConfig operatorConfig) { + super(new FactoryContext(expression, inputTypes, projections, parseFormat, operatorConfig)); + } + + /** + * Instantiates a new Omni filter and project operator factory with configured + * expression parsing format with default operator config. + * + * @param expression the expression + * @param inputTypes the input types + * @param projections the projections + * @param parseFormat the format to parse expression into + */ + public OmniFilterAndProjectOperatorFactory(String expression, DataType[] inputTypes, List projections, + int parseFormat) { + this(expression, inputTypes, projections, parseFormat, new OperatorConfig()); + } + + private static native long createFilterAndProjectOperatorFactory(String inputTypes, int inputLength, + String expression, Object[] projections, int projectLength, int parseFormat, String operatorConfig); + + @Override + protected long createNativeOperatorFactory(FactoryContext context) { + long factoryAddr = createFilterAndProjectOperatorFactory(DataTypeSerializer.serialize(context.inputTypes), + context.inputTypes.length, context.expression, context.projections.toArray(), + context.projections.size(), context.parseFormat, OperatorConfig.serialize(context.operatorConfig)); + if (factoryAddr != 0) { + isSupported = true; + } + return factoryAddr; + } + + public boolean isSupported() { + return isSupported; + } + + /** + * The type Factory context. + * + * @since 20210630 + */ + public static class FactoryContext extends OmniOperatorFactoryContext { + private final DataType[] inputTypes; + + private final String expression; + + private final List projections; + + private final int parseFormat; + + private final OperatorConfig operatorConfig; + + /** + * Instantiates a new Context. + * + * @param expression the expression + * @param inputTypes the input types + * @param projections the projections + * @param operatorConfig the operator config + */ + public FactoryContext(String expression, DataType[] inputTypes, List projections, + OperatorConfig operatorConfig) { + this(expression, inputTypes, projections, 0, operatorConfig); + } + + /** + * Instantiates a new Context with configured parsing format of the expression. + * + * @param expression the expression + * @param inputTypes the input types + * @param projections the projections + * @param parseFormat the parsing format of expressions + * @param operatorConfig the operator config + */ + public FactoryContext(String expression, DataType[] inputTypes, List projections, int parseFormat, + OperatorConfig operatorConfig) { + this.inputTypes = requireNonNull(inputTypes, "Input types array is null."); + this.expression = requireNonNull(expression, "Expression is null."); + this.projections = requireNonNull(projections, "Project indices is null."); + this.parseFormat = parseFormat; + this.operatorConfig = operatorConfig; + } + + @Override + public int hashCode() { + return Objects.hash(expression, Arrays.hashCode(inputTypes), Objects.hashCode(projections), parseFormat, + operatorConfig); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + FactoryContext that = (FactoryContext) obj; + return Objects.equals(expression, that.expression) && Arrays.equals(inputTypes, that.inputTypes) + && Objects.equals(projections.stream().sorted().collect(Collectors.toList()), + that.projections.stream().sorted().collect(Collectors.toList())) + && parseFormat == that.parseFormat && operatorConfig.equals(that.operatorConfig); + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniHashBuilderOperatorFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniHashBuilderOperatorFactory.java new file mode 100644 index 0000000..8ab0c5f --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniHashBuilderOperatorFactory.java @@ -0,0 +1,118 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.join; + +import static java.util.Objects.requireNonNull; + +import nova.hetu.omniruntime.constants.JoinType; +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.OmniOperatorFactoryContext; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DataTypeSerializer; + +import java.util.Arrays; +import java.util.Objects; + +/** + * The type Omni hash builder operator factory. + * + * @since 2021-06-30 + */ +public class OmniHashBuilderOperatorFactory extends OmniOperatorFactory { + /** + * Instantiates a new Omni hash builder operator factory. + * + * @param joinType the join type + * @param buildTypes the build types + * @param buildHashCols the build hash cols + * @param operatorCount the operator count + * @param operatorConfig the operator config + */ + public OmniHashBuilderOperatorFactory(JoinType joinType, DataType[] buildTypes, int[] buildHashCols, + int operatorCount, OperatorConfig operatorConfig) { + super(new FactoryContext(joinType, buildTypes, buildHashCols, operatorCount, operatorConfig)); + } + + /** + * Instantiates a new Omni hash builder operator factory with default operator + * config. + * + * @param joinType the join type + * @param buildTypes the build types + * @param buildHashCols the build hash cols + * @param operatorCount the operator count + */ + public OmniHashBuilderOperatorFactory(JoinType joinType, DataType[] buildTypes, int[] buildHashCols, + int operatorCount) { + this(joinType, buildTypes, buildHashCols, operatorCount, new OperatorConfig()); + } + + private static native long createHashBuilderOperatorFactory(int joinType, String buildTypes, int[] buildHashCols, + int operatorCount, String operatorConfig); + + @Override + protected long createNativeOperatorFactory(FactoryContext context) { + return createHashBuilderOperatorFactory(context.joinType.getValue(), + DataTypeSerializer.serialize(context.buildTypes), context.buildHashCols, context.operatorCount, + OperatorConfig.serialize(context.operatorConfig)); + } + + /** + * The type Factory context. + * + * @since 20210630 + */ + public static class FactoryContext extends OmniOperatorFactoryContext { + private final JoinType joinType; + + private final DataType[] buildTypes; + + private final int[] buildHashCols; + + private final int operatorCount; + + private final OperatorConfig operatorConfig; + + /** + * Instantiates a new Context. + * + * @param joinType the join type + * @param buildTypes the build types + * @param buildHashCols the build hash cols + * @param operatorCount the operator count + * @param operatorConfig the operator config + */ + public FactoryContext(JoinType joinType, DataType[] buildTypes, int[] buildHashCols, int operatorCount, + OperatorConfig operatorConfig) { + this.joinType = requireNonNull(joinType, "joinType"); + this.buildTypes = requireNonNull(buildTypes, "buildTypes"); + this.buildHashCols = requireNonNull(buildHashCols, "buildHashCols"); + this.operatorCount = operatorCount; + this.operatorConfig = operatorConfig; + setNeedCache(false); + } + + @Override + public int hashCode() { + return Objects.hash(joinType, Arrays.hashCode(buildTypes), Arrays.hashCode(buildHashCols), operatorCount, + operatorConfig); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + FactoryContext that = (FactoryContext) obj; + return joinType.equals(that.joinType) && Arrays.equals(buildTypes, that.buildTypes) + && Arrays.equals(buildHashCols, that.buildHashCols) && operatorCount == that.operatorCount + && operatorConfig.equals(that.operatorConfig); + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniHashBuilderWithExprOperatorFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniHashBuilderWithExprOperatorFactory.java new file mode 100644 index 0000000..8fd91af --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniHashBuilderWithExprOperatorFactory.java @@ -0,0 +1,224 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.join; + +import static java.util.Objects.requireNonNull; + +import nova.hetu.omniruntime.constants.JoinType; +import nova.hetu.omniruntime.constants.BuildSide; +import nova.hetu.omniruntime.operator.OmniOperator; +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.OmniOperatorFactoryContext; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DataTypeSerializer; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; + +/** + * The Omni hash builder with expression operator factory. + * + * @since 2021-10-16 + */ +public class OmniHashBuilderWithExprOperatorFactory + extends OmniOperatorFactory { + /** + * The global lock is used for sharing hash builder operator and factory + * concurrently. + */ + public static final Lock gLock = new ReentrantLock(); + + /** + * The hashmap cached the shared hash builder operator, the key is the build + * plan node id. + */ + private static Map operatorCache = new HashMap<>(); + + /** + * The hashmap cached the shared hash builder operator factory, the key is the + * build plan node id. + */ + private static Map factoryCache = new HashMap<>(); + + /** + * The hashmap stores the num of lookup which is in-use shared hash builder + * operator + * the key is the build plan node id. + */ + private static Map ref = new HashMap<>(); + + /** + * Instantiates a new Omni hash builder with expression operator factory. + * + * @param joinType the join type + * @param buildTypes the build input types + * @param buildHashKeys the build hash keys + * @param operatorCount the operator count + * @param operatorConfig the operator config + */ + public OmniHashBuilderWithExprOperatorFactory(JoinType joinType, DataType[] buildTypes, String[] buildHashKeys, + int operatorCount, OperatorConfig operatorConfig) { + super(new FactoryContext(joinType, buildTypes, buildHashKeys, operatorCount, operatorConfig)); + } + + /** + * Instantiates a new Omni hash builder with expression operator factory. + * + * @param joinType the join type + * @param buildSide the build side + * @param buildTypes the build input types + * @param buildHashKeys the build hash keys + * @param operatorCount the operator count + * @param operatorConfig the operator config + */ + public OmniHashBuilderWithExprOperatorFactory(JoinType joinType, BuildSide buildSide, DataType[] buildTypes, + String[] buildHashKeys, int operatorCount, OperatorConfig operatorConfig) { + super(new FactoryContext(joinType, buildSide, buildTypes, buildHashKeys, operatorCount, operatorConfig)); + } + + /** + * Instantiates a new Omni hash builder with expression operator factory with + * default operator config. + * + * @param joinType the join type + * @param buildTypes the build input types + * @param buildHashKeys the build hash keys + * @param operatorCount the operator count + */ + public OmniHashBuilderWithExprOperatorFactory(JoinType joinType, DataType[] buildTypes, String[] buildHashKeys, + int operatorCount) { + this(joinType, buildTypes, buildHashKeys, operatorCount, new OperatorConfig()); + } + + private static native long createHashBuilderWithExprOperatorFactory(int joinType, int buildSide, String buildTypes, + String[] buildHashKeys, int operatorCount, String operatorConfig); + + @Override + protected long createNativeOperatorFactory(FactoryContext context) { + return createHashBuilderWithExprOperatorFactory(context.joinType.getValue(), context.buildSide.getValue(), + DataTypeSerializer.serialize(context.buildTypes), context.buildHashKeys, context.operatorCount, + OperatorConfig.serialize(context.operatorConfig)); + } + + /** + * save a new shared hash builder operator and factory into cache + * + * @param builderNodeId the build plan node id + * @param factory the shared hash builder operator factory + * @param operator the shared hash builder operator + */ + public static void saveHashBuilderOperatorAndFactory(Integer builderNodeId, + OmniHashBuilderWithExprOperatorFactory factory, OmniOperator operator) { + operatorCache.put(builderNodeId, operator); + factoryCache.put(builderNodeId, factory); + } + + /** + * try get a shared hash builder factory form cache + * + * @param builderNodeId the build plan node id + * @return the shared hash builder operator factory + */ + public static OmniHashBuilderWithExprOperatorFactory getHashBuilderOperatorFactory(Integer builderNodeId) { + ref.computeIfAbsent(builderNodeId, key -> new AtomicInteger(0)); + ref.get(builderNodeId).incrementAndGet(); + return factoryCache.get(builderNodeId); + } + + /** + * try close and remove a shared hash builder factory form cache + * + * @param builderNodeId the build plan node id + */ + public static void dereferenceHashBuilderOperatorAndFactory(Integer builderNodeId) { + if (ref.get(builderNodeId).decrementAndGet() == 0) { + ref.remove(builderNodeId); + operatorCache.get(builderNodeId).close(); + operatorCache.remove(builderNodeId); + factoryCache.get(builderNodeId).close(); + factoryCache.remove(builderNodeId); + } + } + + /** + * The Factory context. + * + * @since 2021-10-16 + */ + public static class FactoryContext extends OmniOperatorFactoryContext { + private final JoinType joinType; + + private final DataType[] buildTypes; + + private final String[] buildHashKeys; + + private final int operatorCount; + + private final OperatorConfig operatorConfig; + + private BuildSide buildSide = BuildSide.BUILD_UNKNOWN; + + /** + * Instantiates a new Context. + * + * @param joinType the join type + * @param buildTypes the build types + * @param buildHashKeys the build hash keys + * @param operatorCount the operator count + * @param operatorConfig the operator config + */ + public FactoryContext(JoinType joinType, DataType[] buildTypes, String[] buildHashKeys, int operatorCount, + OperatorConfig operatorConfig) { + this.joinType = requireNonNull(joinType, "joinType"); + this.buildTypes = requireNonNull(buildTypes, "buildTypes"); + this.buildHashKeys = requireNonNull(buildHashKeys, "buildHashKeys"); + this.operatorCount = operatorCount; + this.operatorConfig = operatorConfig; + setNeedCache(false); + } + + /** + * Instantiates a new Context. + * + * @param joinType the join type + * @param buildSide the build side + * @param buildTypes the build types + * @param buildHashKeys the build hash keys + * @param operatorCount the operator count + * @param operatorConfig the operator config + */ + public FactoryContext(JoinType joinType, BuildSide buildSide, DataType[] buildTypes, String[] buildHashKeys, + int operatorCount, OperatorConfig operatorConfig) { + this(joinType, buildTypes, buildHashKeys, operatorCount, operatorConfig); + this.buildSide = buildSide; + } + + @Override + public int hashCode() { + return Objects.hash(joinType, Arrays.hashCode(buildTypes), Arrays.hashCode(buildHashKeys), operatorCount, + operatorConfig); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + FactoryContext that = (FactoryContext) obj; + return joinType.equals(that.joinType) && Arrays.equals(buildTypes, that.buildTypes) + && Arrays.equals(buildHashKeys, that.buildHashKeys) && operatorCount == that.operatorCount + && operatorConfig.equals(that.operatorConfig); + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniLookupJoinOperatorFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniLookupJoinOperatorFactory.java new file mode 100644 index 0000000..1435310 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniLookupJoinOperatorFactory.java @@ -0,0 +1,245 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.join; + +import static java.util.Objects.requireNonNull; + +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.OmniOperatorFactoryContext; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DataTypeSerializer; + +import java.util.Arrays; +import java.util.Objects; +import java.util.Optional; + +/** + * The Omni lookup join operator factory. + * + * @since 2021-06-30 + */ +public class OmniLookupJoinOperatorFactory extends OmniOperatorFactory { + /** + * Instantiates a new Omni lookup join operator factory. + * + * @param probeTypes the probe types + * @param probeOutputCols the probe output cols + * @param probeHashCols the probe hash cols + * @param buildOutputCols the build output cols + * @param buildOutputTypes the build output types + * @param hashBuilderOperatorFactory the hash builder operator factory + * @param filterExpression the join filter expression + * @param isShuffleExchangeBuildPlan build plan is shuffleExchange + * @param operatorConfig the operator config + */ + public OmniLookupJoinOperatorFactory(DataType[] probeTypes, int[] probeOutputCols, int[] probeHashCols, + int[] buildOutputCols, DataType[] buildOutputTypes, + OmniHashBuilderOperatorFactory hashBuilderOperatorFactory, + Optional filterExpression, boolean isShuffleExchangeBuildPlan, OperatorConfig operatorConfig) { + super(new FactoryContext(probeTypes, probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, filterExpression, isShuffleExchangeBuildPlan, operatorConfig)); + } + + + /** + * Instantiates a new Omni lookup join operator factory. + * + * @param probeTypes the probe types + * @param probeOutputCols the probe output cols + * @param probeHashCols the probe hash cols + * @param buildOutputCols the build output cols + * @param buildOutputTypes the build output types + * @param hashBuilderOperatorFactory the hash builder operator factory + * @param filterExpression the join filter expression + * @param operatorConfig the operator config + */ + public OmniLookupJoinOperatorFactory(DataType[] probeTypes, int[] probeOutputCols, int[] probeHashCols, + int[] buildOutputCols, DataType[] buildOutputTypes, + OmniHashBuilderOperatorFactory hashBuilderOperatorFactory, + Optional filterExpression, OperatorConfig operatorConfig) { + super(new FactoryContext(probeTypes, probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, filterExpression, false, operatorConfig)); + } + + /** + * Instantiates a new Omni lookup join operator factory with default operator + * config. + * + * @param probeTypes the probe types + * @param probeOutputCols the probe output cols + * @param probeHashCols the probe hash cols + * @param buildOutputCols the build output cols + * @param buildOutputTypes the build output types + * @param hashBuilderOperatorFactory the hash builder operator factory + * @param isShuffleExchangeBuildPlan build plan is shuffleExchange + */ + public OmniLookupJoinOperatorFactory(DataType[] probeTypes, int[] probeOutputCols, int[] probeHashCols, + int[] buildOutputCols, DataType[] buildOutputTypes, + OmniHashBuilderOperatorFactory hashBuilderOperatorFactory, boolean isShuffleExchangeBuildPlan) { + this(probeTypes, probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, hashBuilderOperatorFactory, + Optional.empty(), isShuffleExchangeBuildPlan, new OperatorConfig()); + } + + /** + * Instantiates a new Omni lookup join operator factory with default operator + * config. + * + * @param probeTypes the probe types + * @param probeOutputCols the probe output cols + * @param probeHashCols the probe hash cols + * @param buildOutputCols the build output cols + * @param buildOutputTypes the build output types + * @param hashBuilderOperatorFactory the hash builder operator factory + */ + public OmniLookupJoinOperatorFactory(DataType[] probeTypes, int[] probeOutputCols, int[] probeHashCols, + int[] buildOutputCols, DataType[] buildOutputTypes, + OmniHashBuilderOperatorFactory hashBuilderOperatorFactory) { + this(probeTypes, probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, hashBuilderOperatorFactory, + Optional.empty(), false, new OperatorConfig()); + } + + /** + * Instantiates a new Omni lookup join operator factory with default operator + * config. + * + * @param probeTypes the probe types + * @param probeOutputCols the probe output cols + * @param probeHashCols the probe hash cols + * @param buildOutputCols the build output cols + * @param buildOutputTypes the build output types + * @param hashBuilderOperatorFactory the hash builder operator factory + * @param filterExpression the join filter expression + * @param isShuffleExchangeBuildPlan build plan is shuffleExchange + */ + public OmniLookupJoinOperatorFactory(DataType[] probeTypes, int[] probeOutputCols, int[] probeHashCols, + int[] buildOutputCols, DataType[] buildOutputTypes, + OmniHashBuilderOperatorFactory hashBuilderOperatorFactory, + Optional filterExpression, boolean isShuffleExchangeBuildPlan) { + this(probeTypes, probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, hashBuilderOperatorFactory, + filterExpression, isShuffleExchangeBuildPlan, new OperatorConfig()); + } + + /** + * Instantiates a new Omni lookup join operator factory with default operator + * config. + * + * @param probeTypes the probe types + * @param probeOutputCols the probe output cols + * @param probeHashCols the probe hash cols + * @param buildOutputCols the build output cols + * @param buildOutputTypes the build output types + * @param hashBuilderOperatorFactory the hash builder operator factory + * @param filterExpression the join filter expression + */ + public OmniLookupJoinOperatorFactory(DataType[] probeTypes, int[] probeOutputCols, int[] probeHashCols, + int[] buildOutputCols, DataType[] buildOutputTypes, + OmniHashBuilderOperatorFactory hashBuilderOperatorFactory, + Optional filterExpression) { + this(probeTypes, probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, hashBuilderOperatorFactory, + filterExpression, false, new OperatorConfig()); + } + + private static native long createLookupJoinOperatorFactory(String probeTypes, int[] probeOutputCols, + int[] probeHashCols, int[] buildOutputCols, String buildOutputTypes, long hashBuilderOperatorFactory, + String filterExpression, boolean isShuffleExchangeBuildPlan, String operatorConfig); + + @Override + protected long createNativeOperatorFactory(FactoryContext context) { + return createLookupJoinOperatorFactory(DataTypeSerializer.serialize(context.probeTypes), + context.probeOutputCols, context.probeHashCols, context.buildOutputCols, + DataTypeSerializer.serialize(context.buildOutputTypes), + context.getHashBuilderOperatorFactory(), context.filterExpression, + context.isShuffleExchangeBuildPlan, OperatorConfig.serialize(context.operatorConfig)); + } + + /** + * The type Factory context. + * + * @since 20210630 + */ + public static class FactoryContext extends OmniOperatorFactoryContext { + private final DataType[] probeTypes; + + private final int[] probeOutputCols; + + private final int[] probeHashCols; + + private final int[] buildOutputCols; + + private final DataType[] buildOutputTypes; + + private final long hashBuilderOperatorFactory; + + private final String filterExpression; + + private final boolean isShuffleExchangeBuildPlan; + + private final OperatorConfig operatorConfig; + + /** + * Instantiates a new Context. + * + * @param probeTypes the probe types + * @param probeOutputCols the probe output cols + * @param probeHashCols the probe hash cols + * @param buildOutputCols the build output cols + * @param buildOutputTypes the build output types + * @param hashBuilderOperatorFactory hashBuilderOperatorFactory + * @param filterExpression the join filter expression + * @param isShuffleExchangeBuildPlan build plan is shuffleExchange + * @param operatorConfig the operator config + */ + public FactoryContext(DataType[] probeTypes, int[] probeOutputCols, int[] probeHashCols, int[] buildOutputCols, + DataType[] buildOutputTypes, OmniHashBuilderOperatorFactory hashBuilderOperatorFactory, + Optional filterExpression, boolean isShuffleExchangeBuildPlan, OperatorConfig operatorConfig) { + this.probeTypes = requireNonNull(probeTypes, "probeTypes"); + this.probeOutputCols = requireNonNull(probeOutputCols, "probeOutputCols"); + this.probeHashCols = requireNonNull(probeHashCols, "probeHashCols"); + this.buildOutputCols = requireNonNull(buildOutputCols, "buildOutputCols"); + this.buildOutputTypes = requireNonNull(buildOutputTypes, "buildOutputTypes"); + this.hashBuilderOperatorFactory = hashBuilderOperatorFactory.getNativeOperatorFactory(); + this.filterExpression = filterExpression.isPresent() ? filterExpression.get() : ""; + this.isShuffleExchangeBuildPlan = isShuffleExchangeBuildPlan; + this.operatorConfig = operatorConfig; + setNeedCache(false); + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(probeTypes), Arrays.hashCode(probeOutputCols), + Arrays.hashCode(probeHashCols), Arrays.hashCode(buildOutputCols), Arrays.hashCode(buildOutputTypes), + hashBuilderOperatorFactory, filterExpression, isShuffleExchangeBuildPlan, operatorConfig); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + FactoryContext that = (FactoryContext) obj; + return Arrays.equals(probeTypes, that.probeTypes) && Arrays.equals(probeOutputCols, that.probeOutputCols) + && Arrays.equals(probeHashCols, that.probeHashCols) + && Arrays.equals(buildOutputCols, that.buildOutputCols) + && Arrays.equals(buildOutputTypes, that.buildOutputTypes) + && hashBuilderOperatorFactory == that.hashBuilderOperatorFactory + && Objects.equals(filterExpression, that.filterExpression) + && isShuffleExchangeBuildPlan == that.isShuffleExchangeBuildPlan + && Objects.equals(operatorConfig, that.operatorConfig); + } + + /** + * Gets hash builder operator factory. + * + * @return the hash builder operator factory + */ + public long getHashBuilderOperatorFactory() { + return hashBuilderOperatorFactory; + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniLookupJoinWithExprOperatorFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniLookupJoinWithExprOperatorFactory.java new file mode 100644 index 0000000..5d97785 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniLookupJoinWithExprOperatorFactory.java @@ -0,0 +1,243 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.join; + +import static java.util.Objects.requireNonNull; + +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.OmniOperatorFactoryContext; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DataTypeSerializer; + +import java.util.Arrays; +import java.util.Objects; +import java.util.Optional; + +/** + * The Omni lookup join with expression operator factory. + * + * @since 2021-10-16 + */ +public class OmniLookupJoinWithExprOperatorFactory + extends OmniOperatorFactory { + /** + * Instantiates a new Omni lookup join with expression operator factory. + * + * @param probeTypes the probe input types + * @param probeOutputCols the probe output columns + * @param probeHashKeys the probe hash keys + * @param buildOutputCols the build output columns + * @param buildOutputTypes the build output column types + * @param hashBuilderWithExprOperatorFactory the hash builder operator factory + * @param filterExpression the join filter expression + * @param isShuffleExchangeBuildPlan build plan is shuffleExchange + * @param operatorConfig the operator config + */ + public OmniLookupJoinWithExprOperatorFactory(DataType[] probeTypes, int[] probeOutputCols, String[] probeHashKeys, + int[] buildOutputCols, DataType[] buildOutputTypes, + OmniHashBuilderWithExprOperatorFactory hashBuilderWithExprOperatorFactory, + Optional filterExpression, boolean isShuffleExchangeBuildPlan, OperatorConfig operatorConfig) { + super(new FactoryContext(probeTypes, probeOutputCols, probeHashKeys, buildOutputCols, buildOutputTypes, + hashBuilderWithExprOperatorFactory, filterExpression, isShuffleExchangeBuildPlan, operatorConfig)); + } + + /** + * Instantiates a new Omni lookup join with expression operator factory. + * + * @param probeTypes the probe input types + * @param probeOutputCols the probe output columns + * @param probeHashKeys the probe hash keys + * @param buildOutputCols the build output columns + * @param buildOutputTypes the build output column types + * @param hashBuilderWithExprOperatorFactory the hash builder operator factory + * @param filterExpression the join filter expression + * @param operatorConfig the operator config + */ + public OmniLookupJoinWithExprOperatorFactory(DataType[] probeTypes, int[] probeOutputCols, String[] probeHashKeys, + int[] buildOutputCols, DataType[] buildOutputTypes, + OmniHashBuilderWithExprOperatorFactory hashBuilderWithExprOperatorFactory, + Optional filterExpression, OperatorConfig operatorConfig) { + super(new FactoryContext(probeTypes, probeOutputCols, probeHashKeys, buildOutputCols, buildOutputTypes, + hashBuilderWithExprOperatorFactory, filterExpression, false, operatorConfig)); + } + + /** + * Instantiates a new Omni lookup join with expression operator factory with + * default operator config. + * + * @param probeTypes the probe input types + * @param probeOutputCols the probe output columns + * @param probeHashKeys the probe hash keys + * @param buildOutputCols the build output columns + * @param buildOutputTypes the build output column types + * @param hashBuilderWithExprOperatorFactory the hash builder operator factory + * @param filterExpression the join filter expression + * @param isShuffleExchangeBuildPlan build plan is shuffleExchange + */ + public OmniLookupJoinWithExprOperatorFactory(DataType[] probeTypes, int[] probeOutputCols, String[] probeHashKeys, + int[] buildOutputCols, DataType[] buildOutputTypes, + OmniHashBuilderWithExprOperatorFactory hashBuilderWithExprOperatorFactory, + Optional filterExpression, boolean isShuffleExchangeBuildPlan) { + this(probeTypes, probeOutputCols, probeHashKeys, buildOutputCols, buildOutputTypes, + hashBuilderWithExprOperatorFactory, filterExpression, isShuffleExchangeBuildPlan, new OperatorConfig()); + } + + /** + * Instantiates a new Omni lookup join with expression operator factory with + * default operator config. + * + * @param probeTypes the probe input types + * @param probeOutputCols the probe output columns + * @param probeHashKeys the probe hash keys + * @param buildOutputCols the build output columns + * @param buildOutputTypes the build output column types + * @param hashBuilderWithExprOperatorFactory the hash builder operator factory + * @param filterExpression the join filter expression + */ + public OmniLookupJoinWithExprOperatorFactory(DataType[] probeTypes, int[] probeOutputCols, String[] probeHashKeys, + int[] buildOutputCols, DataType[] buildOutputTypes, + OmniHashBuilderWithExprOperatorFactory hashBuilderWithExprOperatorFactory, + Optional filterExpression) { + this(probeTypes, probeOutputCols, probeHashKeys, buildOutputCols, buildOutputTypes, + hashBuilderWithExprOperatorFactory, filterExpression, false, new OperatorConfig()); + } + + /** + * Instantiates a new Omni lookup join with expression operator factory with + * default operator config. + * + * @param probeTypes the probe input types + * @param probeOutputCols the probe output columns + * @param probeHashKeys the probe hash keys + * @param buildOutputCols the build output columns + * @param buildOutputTypes the build output column types + * @param hashBuilderWithExprOperatorFactory the hash builder operator factory + * @param isShuffleExchangeBuildPlan build plan is shuffleExchange + */ + public OmniLookupJoinWithExprOperatorFactory(DataType[] probeTypes, int[] probeOutputCols, String[] probeHashKeys, + int[] buildOutputCols, DataType[] buildOutputTypes, + OmniHashBuilderWithExprOperatorFactory hashBuilderWithExprOperatorFactory, + boolean isShuffleExchangeBuildPlan) { + this(probeTypes, probeOutputCols, probeHashKeys, buildOutputCols, buildOutputTypes, + hashBuilderWithExprOperatorFactory, Optional.empty(), isShuffleExchangeBuildPlan, new OperatorConfig()); + } + + /** + * Instantiates a new Omni lookup join with expression operator factory with + * default operator config. + * + * @param probeTypes the probe input types + * @param probeOutputCols the probe output columns + * @param probeHashKeys the probe hash keys + * @param buildOutputCols the build output columns + * @param buildOutputTypes the build output column types + * @param hashBuilderWithExprOperatorFactory the hash builder operator factory + */ + public OmniLookupJoinWithExprOperatorFactory(DataType[] probeTypes, int[] probeOutputCols, String[] probeHashKeys, + int[] buildOutputCols, DataType[] buildOutputTypes, + OmniHashBuilderWithExprOperatorFactory hashBuilderWithExprOperatorFactory) { + this(probeTypes, probeOutputCols, probeHashKeys, buildOutputCols, buildOutputTypes, + hashBuilderWithExprOperatorFactory, Optional.empty(), false, new OperatorConfig()); + } + + private static native long createLookupJoinWithExprOperatorFactory(String probeTypes, int[] probeOutputCols, + String[] probeHashKeys, int[] buildOutputCols, String buildOutputTypes, + long hashBuilderWithExprOperatorFactory, String filterExpression, + boolean isShuffleExchangeBuildPlan, String operatorConfig); + + @Override + protected long createNativeOperatorFactory(FactoryContext context) { + return createLookupJoinWithExprOperatorFactory(DataTypeSerializer.serialize(context.probeTypes), + context.probeOutputCols, context.probeHashKeys, context.buildOutputCols, + DataTypeSerializer.serialize(context.buildOutputTypes), + context.getHashBuilderWithExprOperatorFactory(), context.filterExpression, + context.isShuffleExchangeBuildPlan, OperatorConfig.serialize(context.operatorConfig)); + } + + /** + * The Factory context. + * + * @since 2021-10-16 + */ + public static class FactoryContext extends OmniOperatorFactoryContext { + private final DataType[] probeTypes; + + private final int[] probeOutputCols; + + private final String[] probeHashKeys; + + private final int[] buildOutputCols; + + private final DataType[] buildOutputTypes; + + private final long hashBuilderWithExprOperatorFactory; + + private final String filterExpression; + + private final boolean isShuffleExchangeBuildPlan; + + private final OperatorConfig operatorConfig; + + /** + * Instantiates a new Context. + * + * @param probeTypes the probe types + * @param probeOutputCols the probe output cols + * @param probeHashKeys the probe hash keys + * @param buildOutputCols the build output cols + * @param buildOutputTypes the build output types + * @param hashBuilderWithExprOperatorFactory hashBuilderWithExprOperatorFactory + * @param filterExpression the join filter expression + * @param isShuffleExchangeBuildPlan build plan is shuffleExchange + * @param operatorConfig the operator config + */ + public FactoryContext(DataType[] probeTypes, int[] probeOutputCols, String[] probeHashKeys, + int[] buildOutputCols, DataType[] buildOutputTypes, + OmniHashBuilderWithExprOperatorFactory hashBuilderWithExprOperatorFactory, + Optional filterExpression, boolean isShuffleExchangeBuildPlan, OperatorConfig operatorConfig) { + this.probeTypes = requireNonNull(probeTypes, "probeTypes"); + this.probeOutputCols = requireNonNull(probeOutputCols, "probeOutputCols"); + this.probeHashKeys = requireNonNull(probeHashKeys, "probeHashKeys"); + this.buildOutputCols = requireNonNull(buildOutputCols, "buildOutputCols"); + this.buildOutputTypes = requireNonNull(buildOutputTypes, "buildOutputTypes"); + this.hashBuilderWithExprOperatorFactory = hashBuilderWithExprOperatorFactory.getNativeOperatorFactory(); + this.filterExpression = filterExpression.isPresent() ? filterExpression.get() : ""; + this.isShuffleExchangeBuildPlan = isShuffleExchangeBuildPlan; + this.operatorConfig = operatorConfig; + setNeedCache(false); + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(probeTypes), Arrays.hashCode(probeOutputCols), + Arrays.hashCode(probeHashKeys), Arrays.hashCode(buildOutputCols), Arrays.hashCode(buildOutputTypes), + hashBuilderWithExprOperatorFactory, filterExpression, isShuffleExchangeBuildPlan, operatorConfig); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + FactoryContext that = (FactoryContext) obj; + return Arrays.equals(probeTypes, that.probeTypes) && Arrays.equals(probeOutputCols, that.probeOutputCols) + && Arrays.equals(probeHashKeys, that.probeHashKeys) + && Arrays.equals(buildOutputCols, that.buildOutputCols) + && Arrays.equals(buildOutputTypes, that.buildOutputTypes) + && (hashBuilderWithExprOperatorFactory == that.hashBuilderWithExprOperatorFactory) + && filterExpression.equals(that.filterExpression) + && isShuffleExchangeBuildPlan == that.isShuffleExchangeBuildPlan + && operatorConfig.equals(that.operatorConfig); + } + + public long getHashBuilderWithExprOperatorFactory() { + return hashBuilderWithExprOperatorFactory; + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniLookupOuterJoinOperatorFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniLookupOuterJoinOperatorFactory.java new file mode 100644 index 0000000..f59ac46 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniLookupOuterJoinOperatorFactory.java @@ -0,0 +1,139 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.join; + +import static java.util.Objects.requireNonNull; + +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.OmniOperatorFactoryContext; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DataTypeSerializer; + +import java.util.Arrays; +import java.util.Objects; + +/** + * The Omni lookup outer join operator factory. + * + * @since 2022-08-30 + */ +public class OmniLookupOuterJoinOperatorFactory + extends OmniOperatorFactory { + /** + * Instantiates a new Omni lookup outer join operator factory. + * + * @param probeTypes the probe types + * @param probeOutputCols the probe output cols + * @param buildOutputCols the build output cols + * @param buildOutputTypes the build output types + * @param hashBuilderOperatorFactory the hash builder operator factory + * @param operatorConfig the operator config + */ + public OmniLookupOuterJoinOperatorFactory(DataType[] probeTypes, int[] probeOutputCols, int[] buildOutputCols, + DataType[] buildOutputTypes, OmniHashBuilderOperatorFactory hashBuilderOperatorFactory, + OperatorConfig operatorConfig) { + super(new FactoryContext(probeTypes, probeOutputCols, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, operatorConfig)); + } + + /** + * Instantiates a new Omni lookup outer join operator factory with default operator config. + * + * @param probeTypes the probe types + * @param probeOutputCols the probe output cols + * @param buildOutputCols the build output cols + * @param buildOutputTypes the build output types + * @param hashBuilderOperatorFactory the hash builder operator factory + */ + public OmniLookupOuterJoinOperatorFactory(DataType[] probeTypes, int[] probeOutputCols, int[] buildOutputCols, + DataType[] buildOutputTypes, OmniHashBuilderOperatorFactory hashBuilderOperatorFactory) { + this(probeTypes, probeOutputCols, buildOutputCols, buildOutputTypes, hashBuilderOperatorFactory, + new OperatorConfig()); + } + + private static native long createLookupOuterJoinOperatorFactory(String probeTypes, int[] probeOutputCols, + int[] buildOutputCols, String buildOutputTypes, long hashBuilderOperatorFactory); + + @Override + protected long createNativeOperatorFactory(FactoryContext context) { + return createLookupOuterJoinOperatorFactory(DataTypeSerializer.serialize(context.probeTypes), + context.probeOutputCols, context.buildOutputCols, + DataTypeSerializer.serialize(context.buildOutputTypes), context.getHashBuilderOperatorFactory()); + } + + /** + * The type Factory context. + * + * @since 20220830 + */ + public static class FactoryContext extends OmniOperatorFactoryContext { + private final DataType[] probeTypes; + + private final int[] probeOutputCols; + + private final int[] buildOutputCols; + + private final DataType[] buildOutputTypes; + + private final long hashBuilderOperatorFactory; + + private final OperatorConfig operatorConfig; + + /** + * Instantiates a new Context. + * + * @param probeTypes the probe types + * @param probeOutputCols the probe output cols + * @param buildOutputCols the build output cols + * @param buildOutputTypes the build output types + * @param hashBuilderOperatorFactory hashBuilderOperatorFactory + * @param operatorConfig the operator config + */ + public FactoryContext(DataType[] probeTypes, int[] probeOutputCols, int[] buildOutputCols, + DataType[] buildOutputTypes, OmniHashBuilderOperatorFactory hashBuilderOperatorFactory, + OperatorConfig operatorConfig) { + this.probeTypes = requireNonNull(probeTypes, "probeTypes"); + this.probeOutputCols = requireNonNull(probeOutputCols, "probeOutputCols"); + this.buildOutputCols = requireNonNull(buildOutputCols, "buildOutputCols"); + this.buildOutputTypes = requireNonNull(buildOutputTypes, "buildOutputTypes"); + this.hashBuilderOperatorFactory = hashBuilderOperatorFactory.getNativeOperatorFactory(); + this.operatorConfig = operatorConfig; + setNeedCache(false); + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(probeTypes), Arrays.hashCode(probeOutputCols), + Arrays.hashCode(buildOutputCols), Arrays.hashCode(buildOutputTypes), hashBuilderOperatorFactory, + operatorConfig); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + FactoryContext that = (FactoryContext) obj; + return Arrays.equals(probeTypes, that.probeTypes) && Arrays.equals(probeOutputCols, that.probeOutputCols) + && Arrays.equals(buildOutputCols, that.buildOutputCols) + && Arrays.equals(buildOutputTypes, that.buildOutputTypes) + && hashBuilderOperatorFactory == that.hashBuilderOperatorFactory + && operatorConfig.equals(that.operatorConfig); + } + + /** + * Gets hash builder operator factory. + * + * @return the hash builder operator factory + */ + public long getHashBuilderOperatorFactory() { + return hashBuilderOperatorFactory; + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniLookupOuterJoinWithExprOperatorFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniLookupOuterJoinWithExprOperatorFactory.java new file mode 100644 index 0000000..2da4d70 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniLookupOuterJoinWithExprOperatorFactory.java @@ -0,0 +1,143 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.join; + +import static java.util.Objects.requireNonNull; + +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.OmniOperatorFactoryContext; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DataTypeSerializer; + +import java.util.Arrays; +import java.util.Objects; + +/** + * The Omni lookup outer join with expression operator factory. + * + * @since 2022-9-1 + */ +public class OmniLookupOuterJoinWithExprOperatorFactory + extends OmniOperatorFactory { + /** + * Instantiates a new Omni lookup outer join with expression operator factory. + * + * @param probeTypes the probe input types + * @param probeOutputCols the probe output columns + * @param probeHashKeys the probe hash keys + * @param buildOutputCols the build output columns + * @param buildOutputTypes the build output column types + * @param hashBuilderWithExprOperatorFactory the hash builder operator factory + * @param operatorConfig the operator config + */ + public OmniLookupOuterJoinWithExprOperatorFactory(DataType[] probeTypes, int[] probeOutputCols, + String[] probeHashKeys, int[] buildOutputCols, DataType[] buildOutputTypes, + OmniHashBuilderWithExprOperatorFactory hashBuilderWithExprOperatorFactory, OperatorConfig operatorConfig) { + super(new FactoryContext(probeTypes, probeOutputCols, probeHashKeys, buildOutputCols, buildOutputTypes, + operatorConfig, hashBuilderWithExprOperatorFactory)); + } + + /** + * Instantiates a new Omni lookup outer join with expression operator factory with + * default operator config. + * + * @param probeTypes the probe input types + * @param probeOutputCols the probe output columns + * @param probeHashKeys the probe hash keys + * @param buildOutputCols the build output columns + * @param buildOutputTypes the build output column types + * @param hashBuilderWithExprOperatorFactory the hash builder operator factory + */ + public OmniLookupOuterJoinWithExprOperatorFactory(DataType[] probeTypes, int[] probeOutputCols, + String[] probeHashKeys, int[] buildOutputCols, DataType[] buildOutputTypes, + OmniHashBuilderWithExprOperatorFactory hashBuilderWithExprOperatorFactory) { + this(probeTypes, probeOutputCols, probeHashKeys, buildOutputCols, buildOutputTypes, + hashBuilderWithExprOperatorFactory, new OperatorConfig()); + } + + private static native long createLookupOuterJoinWithExprOperatorFactory(String probeTypes, int[] probeOutputCols, + String[] probeHashKeys, int[] buildOutputCols, String buildOutputTypes, + long hashBuilderWithExprOperatorFactory); + + @Override + protected long createNativeOperatorFactory(FactoryContext context) { + return createLookupOuterJoinWithExprOperatorFactory(DataTypeSerializer.serialize(context.probeTypes), + context.probeOutputCols, context.probeHashKeys, context.buildOutputCols, + DataTypeSerializer.serialize(context.buildOutputTypes), + context.getHashBuilderWithExprOperatorFactory()); + } + + /** + * The factory Context. + */ + public static class FactoryContext extends OmniOperatorFactoryContext { + private final DataType[] probeTypes; + + private final int[] probeOutputCols; + + private final String[] probeHashKeys; + + private final int[] buildOutputCols; + + private final DataType[] buildOutputTypes; + + private final OperatorConfig operatorConfig; + + private final long hashBuilderWithExprOperatorFactory; + + /** + * Instantiates a new Context. + * + * @param probeTypes the probe types + * @param probeOutputCols the probe output cols + * @param probeHashKeys the probe hash keys + * @param buildOutputCols the build output cols + * @param buildOutputTypes the build output types + * @param operatorConfig the operator config + * @param hashBuilderWithExprOperatorFactory hashBuilderWithExprOperatorFactory + */ + public FactoryContext(DataType[] probeTypes, int[] probeOutputCols, String[] probeHashKeys, + int[] buildOutputCols, DataType[] buildOutputTypes, OperatorConfig operatorConfig, + OmniHashBuilderWithExprOperatorFactory hashBuilderWithExprOperatorFactory) { + this.probeTypes = requireNonNull(probeTypes, "probeTypes"); + this.probeOutputCols = requireNonNull(probeOutputCols, "probeOutputCols"); + this.probeHashKeys = requireNonNull(probeHashKeys, "probeHashKeys"); + this.buildOutputCols = requireNonNull(buildOutputCols, "buildOutputCols"); + this.buildOutputTypes = requireNonNull(buildOutputTypes, "buildOutputTypes"); + this.operatorConfig = operatorConfig; + this.hashBuilderWithExprOperatorFactory = hashBuilderWithExprOperatorFactory.getNativeOperatorFactory(); + setNeedCache(false); + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(probeTypes), Arrays.hashCode(probeOutputCols), + Arrays.hashCode(probeHashKeys), Arrays.hashCode(buildOutputCols), Arrays.hashCode(buildOutputTypes), + operatorConfig, hashBuilderWithExprOperatorFactory); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + FactoryContext that = (FactoryContext) obj; + return Arrays.equals(probeTypes, that.probeTypes) && Arrays.equals(probeOutputCols, that.probeOutputCols) + && Arrays.equals(probeHashKeys, that.probeHashKeys) + && Arrays.equals(buildOutputCols, that.buildOutputCols) + && Arrays.equals(buildOutputTypes, that.buildOutputTypes) + && operatorConfig.equals(that.operatorConfig) + && hashBuilderWithExprOperatorFactory == that.hashBuilderWithExprOperatorFactory; + } + + public long getHashBuilderWithExprOperatorFactory() { + return hashBuilderWithExprOperatorFactory; + } + } +} \ No newline at end of file diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniNestedLoopJoinBuildOperatorFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniNestedLoopJoinBuildOperatorFactory.java new file mode 100644 index 0000000..cfe91e4 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniNestedLoopJoinBuildOperatorFactory.java @@ -0,0 +1,153 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.join; + +import static java.util.Objects.requireNonNull; + +import nova.hetu.omniruntime.operator.OmniOperator; +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.OmniOperatorFactoryContext; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DataTypeSerializer; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; + +/** + * The type Omni nested loop join builder operator factory. + * + * @since 2021-06-30 + */ +public class OmniNestedLoopJoinBuildOperatorFactory + extends OmniOperatorFactory { + /** + * The global lock is used for sharing nested builder operator and factory + * concurrently. + */ + public static final Lock gLock = new ReentrantLock(); + + /** + * The hashmap cached the shared nested builder operator, the key is the build + * plan node id. + */ + private static Map operatorCache = new HashMap<>(); + + /** + * The hashmap cached the shared nested builder operator factory, the key is the + * build plan node id. + */ + private static Map factoryCache = new HashMap<>(); + + /** + * The hashmap stores the num of lookup which is in-use shared nested builder + * operator + * the key is the build plan node id. + */ + private static Map ref = new HashMap<>(); + + /** + * Instantiates a new Omni nested loop join builder operator factory. + * + * @param buildTypes the build types + * @param buildOutputCols the build output cols + */ + public OmniNestedLoopJoinBuildOperatorFactory(DataType[] buildTypes, int[] buildOutputCols) { + super(new FactoryContext(buildTypes, buildOutputCols)); + } + + /** + * save a new shared builder operator and factory into cache + * + * @param builderNodeId the build plan node id + * @param factory the shared hash builder operator factory + * @param operator the shared hash builder operator + */ + public static void saveNestedLoopJoinBuilderOperatorAndFactory(Integer builderNodeId, + OmniNestedLoopJoinBuildOperatorFactory factory, OmniOperator operator) { + operatorCache.put(builderNodeId, operator); + factoryCache.put(builderNodeId, factory); + } + + /** + * try get a shared nested builder factory form cache + * + * @param builderNodeId the build plan node id + * @return the shared nested builder operator factory + */ + public static OmniNestedLoopJoinBuildOperatorFactory getNestedLoopJoinBuilderOperatorFactory( + Integer builderNodeId) { + ref.computeIfAbsent(builderNodeId, key -> new AtomicInteger(0)); + ref.get(builderNodeId).incrementAndGet(); + return factoryCache.get(builderNodeId); + } + + private static native long createNestedLoopJoinBuildOperatorFactory(String buildTypes, int[] buildOutputCols); + + @Override + protected long createNativeOperatorFactory(FactoryContext context) { + return createNestedLoopJoinBuildOperatorFactory(DataTypeSerializer.serialize(context.buildTypes), + context.buildOutputCols); + } + + /** + * try close and remove a shared nested builder factory form cache + * + * @param builderNodeId the build plan node id + */ + public static void dereferenceNestedBuilderOperatorAndFactory(Integer builderNodeId) { + if (ref.get(builderNodeId).decrementAndGet() == 0) { + ref.remove(builderNodeId); + operatorCache.get(builderNodeId).close(); + operatorCache.remove(builderNodeId); + factoryCache.get(builderNodeId).close(); + factoryCache.remove(builderNodeId); + } + } + + /** + * The type Factory context. + * + * @since 20241210 + */ + public static class FactoryContext extends OmniOperatorFactoryContext { + private final DataType[] buildTypes; + + private final int[] buildOutputCols; + + /** + * Instantiates a new Context. + * + * @param buildTypes the build types + * @param buildOutputCols the build nested loop join cols + */ + public FactoryContext(DataType[] buildTypes, int[] buildOutputCols) { + this.buildTypes = requireNonNull(buildTypes, "buildTypes"); + this.buildOutputCols = requireNonNull(buildOutputCols, "buildOutputCols"); + setNeedCache(false); + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(buildTypes), Arrays.hashCode(buildOutputCols)); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + FactoryContext that = (FactoryContext) obj; + return Arrays.equals(buildTypes, that.buildTypes) && Arrays.equals(buildOutputCols, that.buildOutputCols); + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniNestedLoopJoinLookupOperatorFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniNestedLoopJoinLookupOperatorFactory.java new file mode 100644 index 0000000..002dd6b --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniNestedLoopJoinLookupOperatorFactory.java @@ -0,0 +1,121 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.join; + +import static java.util.Objects.requireNonNull; + +import nova.hetu.omniruntime.constants.JoinType; +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.OmniOperatorFactoryContext; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DataTypeSerializer; + +import java.util.Arrays; +import java.util.Objects; +import java.util.Optional; + +/** + * The Omni nested loop lookup join operator factory. + * + * @since 2024-12-10 + */ +public class OmniNestedLoopJoinLookupOperatorFactory + extends OmniOperatorFactory { + /** + * Instantiates a new Omni lookup join operator factory. + * + * @param joinType the join types + * @param probeTypes the probe types + * @param probeOutputCols the probe output cols + * @param buildOpFactory the NestedLoopJoinBuildOperatorFactory + * @param filter the json string for connecting conditional expressions + * @param operatorConfig the operator config + */ + public OmniNestedLoopJoinLookupOperatorFactory(JoinType joinType, DataType[] probeTypes, int[] probeOutputCols, + Optional filter, OmniNestedLoopJoinBuildOperatorFactory buildOpFactory, + OperatorConfig operatorConfig) { + super(new FactoryContext(joinType, probeTypes, probeOutputCols, filter, buildOpFactory, operatorConfig)); + } + + private static native long createNestedLoopJoinLookupOperatorFactory(int joinType, String probeTypes, + int[] probeOutputCols, String filter, long buildOpFactory, String operatorConfig); + + @Override + protected long createNativeOperatorFactory(OmniNestedLoopJoinLookupOperatorFactory.FactoryContext context) { + return createNestedLoopJoinLookupOperatorFactory(context.joinType.getValue(), + DataTypeSerializer.serialize(context.probeTypes), context.probeOutputCols, context.filter, + context.getNestedLoopJoinBuildOperatorFactory(), OperatorConfig.serialize(context.operatorConfig)); + } + + /** + * The type Factory context. + * + * @since 20241210 + */ + public static class FactoryContext extends OmniOperatorFactoryContext { + private final JoinType joinType; + + private final DataType[] probeTypes; + + private final int[] probeOutputCols; + + private final long buildOpFactory; + + private final String filter; + + private final OperatorConfig operatorConfig; + + /** + * Instantiates a new Context. + * + * @param joinType the join types + * @param probeTypes the probe types + * @param probeOutputCols the probe output cols + * @param buildOpFactory the NestedLoopJoinBuildOperatorFactory + * @param filter the json string for connecting conditional expressions + * @param operatorConfig the operator config + */ + public FactoryContext(JoinType joinType, DataType[] probeTypes, int[] probeOutputCols, Optional filter, + OmniNestedLoopJoinBuildOperatorFactory buildOpFactory, OperatorConfig operatorConfig) { + this.joinType = requireNonNull(joinType, "joinType"); + this.probeTypes = requireNonNull(probeTypes, "probeTypes"); + this.probeOutputCols = requireNonNull(probeOutputCols, "probeOutputCols"); + this.filter = filter.orElse(""); + this.buildOpFactory = buildOpFactory.getNativeOperatorFactory(); + this.operatorConfig = operatorConfig; + setNeedCache(false); + } + + @Override + public int hashCode() { + return Objects.hash(joinType, Arrays.hashCode(probeTypes), Arrays.hashCode(probeOutputCols), filter, + buildOpFactory, operatorConfig); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + FactoryContext that = (FactoryContext) obj; + return joinType.equals(that.joinType) && Arrays.equals(probeTypes, that.probeTypes) + && Arrays.equals(probeOutputCols, that.probeOutputCols) && Objects.equals(filter, that.filter) + && buildOpFactory == that.buildOpFactory && Objects.equals(operatorConfig, that.operatorConfig); + } + + /** + * Gets nested builder operator factory. + * + * @return the nested loop join builder operator factory + */ + public long getNestedLoopJoinBuildOperatorFactory() { + return buildOpFactory; + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniSmjBufferedTableWithExprOperatorFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniSmjBufferedTableWithExprOperatorFactory.java new file mode 100644 index 0000000..e8298bc --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniSmjBufferedTableWithExprOperatorFactory.java @@ -0,0 +1,127 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.join; + +import static java.util.Objects.requireNonNull; + +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.OmniOperatorFactoryContext; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DataTypeSerializer; + +import java.util.Arrays; +import java.util.Objects; + +/** + * The type Omni sort merge buffered table with expression operator factory. + * + * @since 2021-10-30 + */ +public class OmniSmjBufferedTableWithExprOperatorFactory + extends OmniOperatorFactory { + /** + * Instantiates a new Omni sort merge buffered table factory. + * + * @param soruceTypes the all input vector types + * @param equalKeyExprs equal condition key expressions + * @param outputChannels output of streamed table + * @param smjStreamedTableOperatorFactory streamed table operator factory + * instance + * @param operatorConfig the operator config + */ + public OmniSmjBufferedTableWithExprOperatorFactory(DataType[] soruceTypes, String[] equalKeyExprs, + int[] outputChannels, OmniSmjStreamedTableWithExprOperatorFactory smjStreamedTableOperatorFactory, + OperatorConfig operatorConfig) { + super(new FactoryContext(soruceTypes, equalKeyExprs, outputChannels, operatorConfig, + smjStreamedTableOperatorFactory)); + } + + /** + * Instantiates a new Omni sort merge buffered table factory with default + * operator config. + * + * @param soruceTypes the all input vector types + * @param equalKeyExprs equal condition key expressions + * @param outputChannels output of streamed table + * @param smjStreamedTableOperatorFactory streamed table operator factory + * instance + */ + public OmniSmjBufferedTableWithExprOperatorFactory(DataType[] soruceTypes, String[] equalKeyExprs, + int[] outputChannels, OmniSmjStreamedTableWithExprOperatorFactory smjStreamedTableOperatorFactory) { + this(soruceTypes, equalKeyExprs, outputChannels, smjStreamedTableOperatorFactory, new OperatorConfig()); + } + + private static native long createSmjBufferedTableWithExprOperatorFactory(String soruceTypes, String[] equalKeyExprs, + int[] outputChannels, long smjStreamedTableOperatorFactory, String operatorConfig); + + @Override + protected long createNativeOperatorFactory(FactoryContext context) { + return createSmjBufferedTableWithExprOperatorFactory(DataTypeSerializer.serialize(context.soruceTypes), + context.equalKeyExprs, context.outputChannels, context.getStreamedTableOperatorFactory(), + OperatorConfig.serialize(context.operatorConfig)); + } + + /** + * The type Factory context. + * + * @since 2021-10-30 + */ + public static class FactoryContext extends OmniOperatorFactoryContext { + private final DataType[] soruceTypes; + + private final String[] equalKeyExprs; + + private final int[] outputChannels; + + private final OperatorConfig operatorConfig; + + private final long streamedTableOperatorFactory; + + /** + * Instantiates a new Context. + * + * @param soruceTypes the all input vector types + * @param equalKeyExps equal condition key expressions + * @param outputChannels output of streamed table + * @param operatorConfig the operator config + * @param streamedTableOperatorFactory streamedTableOperatorFactory + */ + public FactoryContext(DataType[] soruceTypes, String[] equalKeyExps, int[] outputChannels, + OperatorConfig operatorConfig, + OmniSmjStreamedTableWithExprOperatorFactory streamedTableOperatorFactory) { + this.soruceTypes = requireNonNull(soruceTypes, "soruceTypes"); + this.equalKeyExprs = requireNonNull(equalKeyExps, "equalKeyExprs"); + this.outputChannels = requireNonNull(outputChannels, "outputChannels"); + this.operatorConfig = requireNonNull(operatorConfig, "operatorConfig"); + this.streamedTableOperatorFactory = streamedTableOperatorFactory.getNativeOperatorFactory(); + setNeedCache(false); + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(soruceTypes), Arrays.hashCode(equalKeyExprs), + Arrays.hashCode(outputChannels), operatorConfig, streamedTableOperatorFactory); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + FactoryContext that = (FactoryContext) obj; + return Arrays.equals(soruceTypes, that.soruceTypes) && Arrays.equals(equalKeyExprs, that.equalKeyExprs) + && Arrays.equals(outputChannels, that.outputChannels) && operatorConfig.equals(that.operatorConfig) + && streamedTableOperatorFactory == that.streamedTableOperatorFactory; + } + + public long getStreamedTableOperatorFactory() { + return streamedTableOperatorFactory; + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniSmjBufferedTableWithExprOperatorFactoryV3.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniSmjBufferedTableWithExprOperatorFactoryV3.java new file mode 100644 index 0000000..cb5dd1f --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniSmjBufferedTableWithExprOperatorFactoryV3.java @@ -0,0 +1,128 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.join; + +import static java.util.Objects.requireNonNull; + +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.OmniOperatorFactoryContext; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DataTypeSerializer; + +import java.util.Arrays; +import java.util.Objects; + +/** + * The type Omni sort merge buffered table with expression operator factory. + * + * @since 2023-07-30 + */ +public class OmniSmjBufferedTableWithExprOperatorFactoryV3 + extends OmniOperatorFactory { + /** + * Instantiates a new Omni sort merge buffered table factory. + * + * @param soruceTypes the all input vector types + * @param equalKeyExprs equal condition key expressions + * @param outputChannels output of streamed table + * @param smjStreamedTableOperatorFactoryV3 streamed table operator factory + * instance + * @param operatorConfig the operator config + */ + public OmniSmjBufferedTableWithExprOperatorFactoryV3(DataType[] soruceTypes, String[] equalKeyExprs, + int[] outputChannels, OmniSmjStreamedTableWithExprOperatorFactoryV3 smjStreamedTableOperatorFactoryV3, + OperatorConfig operatorConfig) { + super(new FactoryContext(soruceTypes, equalKeyExprs, outputChannels, operatorConfig, + smjStreamedTableOperatorFactoryV3)); + } + + /** + * Instantiates a new Omni sort merge buffered table factory with default + * operator config. + * + * @param soruceTypes the all input vector types + * @param equalKeyExprs equal condition key expressions + * @param outputChannels output of streamed table + * @param smjStreamedTableOperatorFactoryV3 streamed table operator factory + * instance + */ + public OmniSmjBufferedTableWithExprOperatorFactoryV3(DataType[] soruceTypes, String[] equalKeyExprs, + int[] outputChannels, OmniSmjStreamedTableWithExprOperatorFactoryV3 smjStreamedTableOperatorFactoryV3) { + this(soruceTypes, equalKeyExprs, outputChannels, smjStreamedTableOperatorFactoryV3, new OperatorConfig()); + } + + private static native long createSmjBufferedTableWithExprOperatorFactoryV3(String soruceTypes, + String[] equalKeyExprs, int[] outputChannels, long smjStreamedTableOperatorFactoryV3, + String operatorConfig); + + @Override + protected long createNativeOperatorFactory(FactoryContext context) { + return createSmjBufferedTableWithExprOperatorFactoryV3(DataTypeSerializer.serialize(context.soruceTypes), + context.equalKeyExprs, context.outputChannels, context.getStreamedTableOperatorFactoryV3(), + OperatorConfig.serialize(context.operatorConfig)); + } + + /** + * The type Factory context. + * + * @since 2023-07-30 + */ + public static class FactoryContext extends OmniOperatorFactoryContext { + private final DataType[] soruceTypes; + + private final String[] equalKeyExprs; + + private final int[] outputChannels; + + private final OperatorConfig operatorConfig; + + private final long streamedTableOperatorFactoryV3; + + /** + * Instantiates a new Context. + * + * @param soruceTypes the all input vector types + * @param equalKeyExps equal condition key expressions + * @param outputChannels output of streamed table + * @param operatorConfig the operator config + * @param streamedTableOperatorFactoryV3 streamedTableOperatorFactory + */ + public FactoryContext(DataType[] soruceTypes, String[] equalKeyExps, int[] outputChannels, + OperatorConfig operatorConfig, + OmniSmjStreamedTableWithExprOperatorFactoryV3 streamedTableOperatorFactoryV3) { + this.soruceTypes = requireNonNull(soruceTypes, "soruceTypes"); + this.equalKeyExprs = requireNonNull(equalKeyExps, "equalKeyExprs"); + this.outputChannels = requireNonNull(outputChannels, "outputChannels"); + this.operatorConfig = requireNonNull(operatorConfig, "operatorConfig"); + this.streamedTableOperatorFactoryV3 = streamedTableOperatorFactoryV3.getNativeOperatorFactory(); + setNeedCache(false); + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(soruceTypes), Arrays.hashCode(equalKeyExprs), + Arrays.hashCode(outputChannels), operatorConfig, streamedTableOperatorFactoryV3); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + FactoryContext that = (FactoryContext) obj; + return Arrays.equals(soruceTypes, that.soruceTypes) && Arrays.equals(equalKeyExprs, that.equalKeyExprs) + && Arrays.equals(outputChannels, that.outputChannels) && operatorConfig.equals(that.operatorConfig) + && streamedTableOperatorFactoryV3 == that.streamedTableOperatorFactoryV3; + } + + public long getStreamedTableOperatorFactoryV3() { + return streamedTableOperatorFactoryV3; + } + } +} \ No newline at end of file diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniSmjStreamedTableWithExprOperatorFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniSmjStreamedTableWithExprOperatorFactory.java new file mode 100644 index 0000000..63263c3 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniSmjStreamedTableWithExprOperatorFactory.java @@ -0,0 +1,127 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.join; + +import static java.util.Objects.requireNonNull; + +import nova.hetu.omniruntime.constants.JoinType; +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.OmniOperatorFactoryContext; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DataTypeSerializer; + +import java.util.Arrays; +import java.util.Objects; +import java.util.Optional; + +/** + * The type Omni sort merge streamed table with expression operator factory. + * + * @since 2021-10-30 + */ +public class OmniSmjStreamedTableWithExprOperatorFactory + extends OmniOperatorFactory { + /** + * Instantiates a new Omni sort merge streamed table factory. + * + * @param sourceTypes the all input vector types + * @param equalKeyExprs equal condition key expressions + * @param outputChannels output of streamed table + * @param joinType join type + * @param filter condition for not equal expression + * @param operatorConfig the operator config + */ + public OmniSmjStreamedTableWithExprOperatorFactory(DataType[] sourceTypes, String[] equalKeyExprs, + int[] outputChannels, JoinType joinType, Optional filter, OperatorConfig operatorConfig) { + super(new FactoryContext(sourceTypes, equalKeyExprs, outputChannels, joinType, filter, operatorConfig)); + } + + /** + * Instantiates a new Omni sort merge streamed table factory with default + * operator config. + * + * @param sourceTypes the all input vector types + * @param equalKeyExprs equal condition key expressions + * @param outputChannels output of streamed table + * @param joinType join type + * @param filter condition for not equal expression + */ + public OmniSmjStreamedTableWithExprOperatorFactory(DataType[] sourceTypes, String[] equalKeyExprs, + int[] outputChannels, JoinType joinType, Optional filter) { + this(sourceTypes, equalKeyExprs, outputChannels, joinType, filter, new OperatorConfig()); + } + + private static native long createSmjStreamedTableWithExprOperatorFactory(String sourceTypes, String[] equalKeyExprs, + int[] outputChannels, int joinType, String filter, String operatorConfig); + + @Override + protected long createNativeOperatorFactory(FactoryContext context) { + String filter = context.filter.isPresent() ? context.filter.get() : null; + return createSmjStreamedTableWithExprOperatorFactory(DataTypeSerializer.serialize(context.sourceTypes), + context.equalKeyExprs, context.outputChannels, context.joinType.getValue(), filter, + OperatorConfig.serialize(context.operatorConfig)); + } + + /** + * The type Factory context. + * + * @since 2021-10-30 + */ + public static class FactoryContext extends OmniOperatorFactoryContext { + private final DataType[] sourceTypes; + + private final String[] equalKeyExprs; + + private final int[] outputChannels; + + private final JoinType joinType; + + private final Optional filter; + + private final OperatorConfig operatorConfig; + + /** + * Instantiates a new Context. + * + * @param sourceTypes the all input vector types + * @param equalKeyExprs equal condition key expressions + * @param outputChannels output of streamed table + * @param joinType join type + * @param filter condition for not equal expression + * @param operatorConfig the operator config + */ + public FactoryContext(DataType[] sourceTypes, String[] equalKeyExprs, int[] outputChannels, JoinType joinType, + Optional filter, OperatorConfig operatorConfig) { + this.sourceTypes = requireNonNull(sourceTypes, "sourceTypes"); + this.equalKeyExprs = requireNonNull(equalKeyExprs, "equalKeyExprs"); + this.outputChannels = requireNonNull(outputChannels, "outputChannels"); + this.joinType = requireNonNull(joinType, "joinType"); + this.filter = requireNonNull(filter, "filter"); + this.operatorConfig = operatorConfig; + setNeedCache(false); + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(sourceTypes), Arrays.hashCode(equalKeyExprs), + Arrays.hashCode(outputChannels), joinType, filter, operatorConfig); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + FactoryContext that = (FactoryContext) obj; + return Arrays.equals(sourceTypes, that.sourceTypes) && Arrays.equals(equalKeyExprs, that.equalKeyExprs) + && Arrays.equals(outputChannels, that.outputChannels) && joinType == that.joinType + && filter.equals(that.filter) && operatorConfig.equals(that.operatorConfig); + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniSmjStreamedTableWithExprOperatorFactoryV3.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniSmjStreamedTableWithExprOperatorFactoryV3.java new file mode 100644 index 0000000..a7cad0b --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/join/OmniSmjStreamedTableWithExprOperatorFactoryV3.java @@ -0,0 +1,127 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.join; + +import static java.util.Objects.requireNonNull; + +import nova.hetu.omniruntime.constants.JoinType; +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.OmniOperatorFactoryContext; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DataTypeSerializer; + +import java.util.Arrays; +import java.util.Objects; +import java.util.Optional; + +/** + * The type Omni sort merge streamed table with expression operator factory. + * + * @since 2023-07-30 + */ +public class OmniSmjStreamedTableWithExprOperatorFactoryV3 + extends OmniOperatorFactory { + /** + * Instantiates a new Omni sort merge streamed table factory. + * + * @param sourceTypes the all input vector types + * @param equalKeyExprs equal condition key expressions + * @param outputChannels output of streamed table + * @param joinType join type + * @param filter condition for not equal expression + * @param operatorConfig the operator config + */ + public OmniSmjStreamedTableWithExprOperatorFactoryV3(DataType[] sourceTypes, String[] equalKeyExprs, + int[] outputChannels, JoinType joinType, Optional filter, OperatorConfig operatorConfig) { + super(new FactoryContext(sourceTypes, equalKeyExprs, outputChannels, joinType, filter, operatorConfig)); + } + + /** + * Instantiates a new Omni sort merge streamed table factory with default + * operator config. + * + * @param sourceTypes the all input vector types + * @param equalKeyExprs equal condition key expressions + * @param outputChannels output of streamed table + * @param joinType join type + * @param filter condition for not equal expression + */ + public OmniSmjStreamedTableWithExprOperatorFactoryV3(DataType[] sourceTypes, String[] equalKeyExprs, + int[] outputChannels, JoinType joinType, Optional filter) { + this(sourceTypes, equalKeyExprs, outputChannels, joinType, filter, new OperatorConfig()); + } + + private static native long createSmjStreamedTableWithExprOperatorFactoryV3(String sourceTypes, + String[] equalKeyExprs, int[] outputChannels, int joinType, String filter, String operatorConfig); + + @Override + protected long createNativeOperatorFactory(FactoryContext context) { + String filter = context.filter.orElse(null); + return createSmjStreamedTableWithExprOperatorFactoryV3(DataTypeSerializer.serialize(context.sourceTypes), + context.equalKeyExprs, context.outputChannels, context.joinType.getValue(), filter, + OperatorConfig.serialize(context.operatorConfig)); + } + + /** + * The type Factory context. + * + * @since 2023-07-30 + */ + public static class FactoryContext extends OmniOperatorFactoryContext { + private final DataType[] sourceTypes; + + private final String[] equalKeyExprs; + + private final int[] outputChannels; + + private final JoinType joinType; + + private final Optional filter; + + private final OperatorConfig operatorConfig; + + /** + * Instantiates a new Context. + * + * @param sourceTypes the all input vector types + * @param equalKeyExprs equal condition key expressions + * @param outputChannels output of streamed table + * @param joinType join type + * @param filter condition for not equal expression + * @param operatorConfig the operator config + */ + public FactoryContext(DataType[] sourceTypes, String[] equalKeyExprs, int[] outputChannels, JoinType joinType, + Optional filter, OperatorConfig operatorConfig) { + this.sourceTypes = requireNonNull(sourceTypes, "sourceTypes"); + this.equalKeyExprs = requireNonNull(equalKeyExprs, "equalKeyExprs"); + this.outputChannels = requireNonNull(outputChannels, "outputChannels"); + this.joinType = requireNonNull(joinType, "joinType"); + this.filter = requireNonNull(filter, "filter"); + this.operatorConfig = operatorConfig; + setNeedCache(false); + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(sourceTypes), Arrays.hashCode(equalKeyExprs), + Arrays.hashCode(outputChannels), joinType, filter, operatorConfig); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + FactoryContext that = (FactoryContext) obj; + return Arrays.equals(sourceTypes, that.sourceTypes) && Arrays.equals(equalKeyExprs, that.equalKeyExprs) + && Arrays.equals(outputChannels, that.outputChannels) && joinType == that.joinType + && filter.equals(that.filter) && operatorConfig.equals(that.operatorConfig); + } + } +} \ No newline at end of file diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/limit/OmniDistinctLimitOperatorFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/limit/OmniDistinctLimitOperatorFactory.java new file mode 100644 index 0000000..ffe938a --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/limit/OmniDistinctLimitOperatorFactory.java @@ -0,0 +1,114 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.limit; + +import static java.util.Objects.requireNonNull; + +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.OmniOperatorFactoryContext; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DataTypeSerializer; + +import java.util.Arrays; +import java.util.Objects; + +/** + * The type Omni distinct limit operator factory. + * + * @since 2021-06-30 + */ +public class OmniDistinctLimitOperatorFactory + extends OmniOperatorFactory { + /** + * Instantiates a new Omni distinct limit operator factory. + * + * @param sourceTypes the data types of each column + * @param distinctCols the column index + * @param hashCol col index of precomputed hash values + * @param limit the limit count + * @param operatorConfig the operator config + */ + public OmniDistinctLimitOperatorFactory(DataType[] sourceTypes, int[] distinctCols, int hashCol, long limit, + OperatorConfig operatorConfig) { + super(new FactoryContext(sourceTypes, distinctCols, hashCol, limit, operatorConfig)); + } + + /** + * Instantiates a new Omni distinct limit operator factory with default operator + * config. + * + * @param sourceTypes the data types of each column + * @param distinctCols the column index + * @param hashCol col index of precomputed hash values + * @param limit the limit count + */ + public OmniDistinctLimitOperatorFactory(DataType[] sourceTypes, int[] distinctCols, int hashCol, long limit) { + this(sourceTypes, distinctCols, hashCol, limit, new OperatorConfig()); + } + + private static native long createDistinctLimitOperatorFactory(String sourceTypes, int[] distinctCols, int hashCol, + long limit); + + @Override + protected long createNativeOperatorFactory(FactoryContext context) { + return createDistinctLimitOperatorFactory(DataTypeSerializer.serialize(context.sourceTypes), + context.distinctCols, context.hashCol, context.limit); + } + + /** + * The type Factory context. + * + * @since 2021-06-30 + */ + public static class FactoryContext extends OmniOperatorFactoryContext { + private final DataType[] sourceTypes; + + private final int[] distinctCols; + + private final int hashCol; + + private final long limit; + + private final OperatorConfig operatorConfig; + + /** + * Instantiates a new Context. + * + * @param sourceTypes the data types of each column + * @param distinctCols the column index + * @param hashCol col index of precomputed hash values + * @param limit the limit count + * @param operatorConfig the operator config + */ + public FactoryContext(DataType[] sourceTypes, int[] distinctCols, int hashCol, long limit, + OperatorConfig operatorConfig) { + this.sourceTypes = requireNonNull(sourceTypes, "Source types array is null."); + this.distinctCols = requireNonNull(distinctCols, "Distinct cols array is null."); + this.limit = limit; + this.hashCol = hashCol; + this.operatorConfig = operatorConfig; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + FactoryContext that = (FactoryContext) obj; + return Arrays.equals(sourceTypes, that.sourceTypes) && Arrays.equals(distinctCols, that.distinctCols) + && limit == that.limit && hashCol == that.hashCol && operatorConfig.equals(that.operatorConfig); + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(sourceTypes), Arrays.hashCode(distinctCols), limit, hashCol, + operatorConfig); + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/limit/OmniLimitOperatorFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/limit/OmniLimitOperatorFactory.java new file mode 100644 index 0000000..73b377a --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/limit/OmniLimitOperatorFactory.java @@ -0,0 +1,81 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.limit; + +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.OmniOperatorFactoryContext; + +import java.util.Objects; + +/** + * The type Omni limit operator factory. + * + * @since 2021-06-30 + */ +public class OmniLimitOperatorFactory extends OmniOperatorFactory { + /** + * Instantiates a new Omni limit operator factory. + * + * @param limit the limit count + */ + public OmniLimitOperatorFactory(int limit) { + super(new FactoryContext(limit, 0)); + } + + /** + * Instantiates a new Omni limit operator factory. + * + * @param limit the limit count + * @param offset the offset count + */ + public OmniLimitOperatorFactory(int limit, int offset) { + super(new FactoryContext(limit, offset)); + } + + @Override + protected long createNativeOperatorFactory(FactoryContext context) { + return createLimitOperatorFactory(context.limit, context.offset); + } + + private static native long createLimitOperatorFactory(int limit, int offset); + + /** + * The type Factory context. + * + * @since 2021-06-30 + */ + public static class FactoryContext extends OmniOperatorFactoryContext { + private final int limit; + private final int offset; + + /** + * Instantiates a new Context. + * + * @param limit the limit count + * @param offset the offset count + */ + public FactoryContext(int limit, int offset) { + this.limit = limit; + this.offset = offset; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + FactoryContext context = (FactoryContext) obj; + return limit == context.limit && offset == context.offset; + } + + @Override + public int hashCode() { + return Objects.hash(this.limit, this.offset); + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/partitionedoutput/OmniPartitionedOutPutOperatorFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/partitionedoutput/OmniPartitionedOutPutOperatorFactory.java new file mode 100644 index 0000000..4c86daf --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/partitionedoutput/OmniPartitionedOutPutOperatorFactory.java @@ -0,0 +1,162 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.partitionedoutput; + +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.OmniOperatorFactoryContext; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DataTypeSerializer; + +import java.util.Arrays; +import java.util.Objects; +import java.util.OptionalInt; + +/** + * The type Omni partitionedoutput operator factory. + * + * @since 2021-06-30 + */ +public class OmniPartitionedOutPutOperatorFactory + extends OmniOperatorFactory { + /** + * Instantiates a new Omni partitioned out put operator factory. + * + * @param sourceTypes the source types + * @param isReplicatesAnyRow the replicates any row + * @param nullChannel the null channel + * @param partitionChannels the partition channels + * @param partitionCount the partition count + * @param bucketToPartition the bucket to partition + * @param isHashPrecomputed the is hash precomputed + * @param hashChannelTypes the hash channel types + * @param hashChannels the hash channels + * @param operatorConfig the operator config + */ + public OmniPartitionedOutPutOperatorFactory(DataType[] sourceTypes, boolean isReplicatesAnyRow, + OptionalInt nullChannel, int[] partitionChannels, int partitionCount, int[] bucketToPartition, + boolean isHashPrecomputed, DataType[] hashChannelTypes, int[] hashChannels, OperatorConfig operatorConfig) { + super(new FactoryContext(sourceTypes, isReplicatesAnyRow, nullChannel, partitionChannels, partitionCount, + bucketToPartition, isHashPrecomputed, hashChannelTypes, hashChannels, operatorConfig)); + } + + /** + * Instantiates a new Omni partitioned out put operator factory with default + * operator config. + * + * @param sourceTypes the source types + * @param isReplicatesAnyRow the replicates any row + * @param nullChannel the null channel + * @param partitionChannels the partition channels + * @param partitionCount the partition count + * @param bucketToPartition the bucket to partition + * @param isHashPrecomputed the is hash precomputed + * @param hashChannelTypes the hash channel types + * @param hashChannels the hash channels + */ + public OmniPartitionedOutPutOperatorFactory(DataType[] sourceTypes, boolean isReplicatesAnyRow, + OptionalInt nullChannel, int[] partitionChannels, int partitionCount, int[] bucketToPartition, + boolean isHashPrecomputed, DataType[] hashChannelTypes, int[] hashChannels) { + this(sourceTypes, isReplicatesAnyRow, nullChannel, partitionChannels, partitionCount, bucketToPartition, + isHashPrecomputed, hashChannelTypes, hashChannels, new OperatorConfig()); + } + + private static native long createPartitionedOutputOperatorFactory(String sourceTypes, boolean isReplicatesAnyRow, + int nullChannel, int[] partitionChannels, int partitionCount, int[] bucketToPartition, + boolean isHashPrecomputed, String hashChannelTypes, int[] hashChannels); + + @Override + protected long createNativeOperatorFactory(FactoryContext context) { + int nullChannel = context.nullChannel.isPresent() ? context.nullChannel.getAsInt() : -1; + return createPartitionedOutputOperatorFactory(DataTypeSerializer.serialize(context.sourceTypes), + context.isReplicatesAnyRow, nullChannel, context.partitionChannels, context.partitionCount, + context.bucketToPartition, context.isHashPrecomputed, + DataTypeSerializer.serialize(context.hashChannelTypes), context.hashChannels); + } + + /** + * The type Factory context. + * + * @since 2021-06-30 + */ + public static class FactoryContext extends OmniOperatorFactoryContext { + private final DataType[] sourceTypes; + + private final boolean isReplicatesAnyRow; + + private final OptionalInt nullChannel; + + private final int[] partitionChannels; + + private final int partitionCount; + + private final int[] bucketToPartition; + + private final boolean isHashPrecomputed; + + private final DataType[] hashChannelTypes; + + private final int[] hashChannels; + + private OperatorConfig operatorConfig; + + /** + * Instantiates a new Jit context. + * + * @param sourceTypes the source types + * @param isReplicatesAnyRow the replicates any row + * @param nullChannel the null channel + * @param partitionChannels the partition channels + * @param partitionCount the partition count + * @param bucketToPartition the bucket to partition + * @param isHashPrecomputed the is hash precomputed + * @param hashChannelTypes the hash channel types + * @param hashChannels the hash channels + * @param operatorConfig the operator config + */ + public FactoryContext(DataType[] sourceTypes, boolean isReplicatesAnyRow, OptionalInt nullChannel, + int[] partitionChannels, int partitionCount, int[] bucketToPartition, boolean isHashPrecomputed, + DataType[] hashChannelTypes, int[] hashChannels, OperatorConfig operatorConfig) { + this.sourceTypes = sourceTypes; + this.isReplicatesAnyRow = isReplicatesAnyRow; + this.nullChannel = nullChannel; + this.partitionChannels = partitionChannels; + this.partitionCount = partitionCount; + this.bucketToPartition = bucketToPartition; + this.isHashPrecomputed = isHashPrecomputed; + this.hashChannelTypes = hashChannelTypes; + this.hashChannels = hashChannels; + this.operatorConfig = operatorConfig; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + FactoryContext context = (FactoryContext) obj; + return isReplicatesAnyRow == context.isReplicatesAnyRow && partitionCount == context.partitionCount + && Arrays.equals(sourceTypes, context.sourceTypes) + && Objects.equals(nullChannel, context.nullChannel) + && Arrays.equals(partitionChannels, context.partitionChannels) + && Arrays.equals(bucketToPartition, context.bucketToPartition) + && context.isHashPrecomputed == isHashPrecomputed + && Arrays.equals(hashChannelTypes, context.hashChannelTypes) + && Arrays.equals(hashChannels, context.hashChannels) + && operatorConfig.equals(context.operatorConfig); + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(sourceTypes), isReplicatesAnyRow, nullChannel, + Arrays.hashCode(partitionChannels), partitionCount, Arrays.hashCode(bucketToPartition), + isHashPrecomputed, Arrays.hashCode(hashChannelTypes), Arrays.hashCode(hashChannels), + operatorConfig); + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/project/OmniProjectOperatorFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/project/OmniProjectOperatorFactory.java new file mode 100644 index 0000000..3cd048c --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/project/OmniProjectOperatorFactory.java @@ -0,0 +1,151 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.project; + +import static java.util.Objects.requireNonNull; + +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.OmniOperatorFactoryContext; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DataTypeSerializer; + +import java.util.Arrays; +import java.util.Objects; + +/** + * The type Omni project operator factory. + * + * @since 2021-06-30 + */ +public class OmniProjectOperatorFactory extends OmniOperatorFactory { + private boolean isSupported; + + /** + * Instantiates a new Omni project operator factory. + * + * @param expressions the expressions + * @param inputTypes the input types + * @param operatorConfig the operator config + */ + public OmniProjectOperatorFactory(String[] expressions, DataType[] inputTypes, OperatorConfig operatorConfig) { + super(new FactoryContext(expressions, inputTypes, operatorConfig)); + } + + /** + * Instantiates a new Omni project operator factory with default operator + * config. + * + * @param expressions the expressions + * @param inputTypes the input types + */ + public OmniProjectOperatorFactory(String[] expressions, DataType[] inputTypes) { + this(expressions, inputTypes, new OperatorConfig()); + } + + /** + * Instantiates a new Omni project operator factory with configured expression + * parsing format. + * + * @param expressions the expressions + * @param inputTypes the input types + * @param parseFormat the parse format + * @param operatorConfig the operator config + */ + public OmniProjectOperatorFactory(String[] expressions, DataType[] inputTypes, int parseFormat, + OperatorConfig operatorConfig) { + super(new FactoryContext(expressions, inputTypes, parseFormat, operatorConfig)); + } + + /** + * Instantiates a new Omni project operator factory with configured expression + * parsing format with default operator config. + * + * @param expressions the expressions + * @param inputTypes the input types + * @param parseFormat the parse format + */ + public OmniProjectOperatorFactory(String[] expressions, DataType[] inputTypes, int parseFormat) { + this(expressions, inputTypes, parseFormat, new OperatorConfig()); + } + + private static native long createProjectOperatorFactory(String inputTypes, int inputLength, Object[] expressions, + int expressionsLength, int parseFormat, String operatorConfig); + + @Override + protected long createNativeOperatorFactory(FactoryContext context) { + long factoryAddr = createProjectOperatorFactory(DataTypeSerializer.serialize(context.inputTypes), + context.inputTypes.length, context.expressions, context.expressions.length, context.parseFormat, + OperatorConfig.serialize(context.operatorConfig)); + if (factoryAddr != 0) { + isSupported = true; + } + return factoryAddr; + } + + public boolean isSupported() { + return isSupported; + } + + /** + * The type Factory context. + * + * @since 2021-06-30 + */ + public static class FactoryContext extends OmniOperatorFactoryContext { + private final DataType[] inputTypes; + + private final String[] expressions; + + private final int parseFormat; + + private final OperatorConfig operatorConfig; + + /** + * Instantiates a new Context. + * + * @param expressions the expressions + * @param inputTypes the input types + * @param operatorConfig the operator config + */ + public FactoryContext(String[] expressions, DataType[] inputTypes, OperatorConfig operatorConfig) { + this(expressions, inputTypes, 0, operatorConfig); + } + + /** + * Instantiates a new Context with configured parsing format of the expression. + * + * @param expressions the expressions + * @param inputTypes the input types + * @param parseFormat the parse format + * @param operatorConfig the operator config + */ + public FactoryContext(String[] expressions, DataType[] inputTypes, int parseFormat, + OperatorConfig operatorConfig) { + this.inputTypes = requireNonNull(inputTypes, "Input types array is null."); + this.expressions = requireNonNull(expressions, "Expressions is null."); + this.parseFormat = parseFormat; + this.operatorConfig = operatorConfig; + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(inputTypes), Arrays.hashCode(expressions), parseFormat, operatorConfig); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + FactoryContext that = (FactoryContext) obj; + return Arrays.equals(expressions, that.expressions) && Arrays.equals(inputTypes, that.inputTypes) + && parseFormat == that.parseFormat && operatorConfig.equals(that.operatorConfig); + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/sort/OmniSortOperatorFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/sort/OmniSortOperatorFactory.java new file mode 100644 index 0000000..c2db1f5 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/sort/OmniSortOperatorFactory.java @@ -0,0 +1,123 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.sort; + +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.OmniOperatorFactoryContext; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DataTypeSerializer; + +import java.util.Arrays; +import java.util.Objects; + +/** + * The type Omni sort operator factory. + * + * @since 2021-06-30 + */ +public class OmniSortOperatorFactory extends OmniOperatorFactory { + /** + * Instantiates a new Omni sort operator factory. + * + * @param sourceTypes the source types + * @param outputColumns the output columns + * @param sortColumns the sort columns + * @param sortAscendings the sort ascendings + * @param sortNullFirsts the sort null firsts + * @param operatorConfig the operator config + */ + public OmniSortOperatorFactory(DataType[] sourceTypes, int[] outputColumns, String[] sortColumns, + int[] sortAscendings, int[] sortNullFirsts, OperatorConfig operatorConfig) { + super(new FactoryContext(sourceTypes, outputColumns, sortColumns, sortAscendings, sortNullFirsts, + operatorConfig)); + } + + /** + * Instantiates a new Omni sort operator factory with default operator config. + * + * @param sourceTypes the source types + * @param outputColumns the output columns + * @param sortColumns the sort columns + * @param sortAscendings the sort ascendings + * @param sortNullFirsts the sort null firsts + */ + public OmniSortOperatorFactory(DataType[] sourceTypes, int[] outputColumns, String[] sortColumns, + int[] sortAscendings, int[] sortNullFirsts) { + this(sourceTypes, outputColumns, sortColumns, sortAscendings, sortNullFirsts, new OperatorConfig()); + } + + private static native long createSortOperatorFactory(String sourceTypes, int[] outputCols, String[] sortCols, + int[] ascendings, int[] nullFirsts, String operatorConfig); + + @Override + protected long createNativeOperatorFactory(FactoryContext context) { + return createSortOperatorFactory(DataTypeSerializer.serialize(context.sourceTypes), context.outputColumns, + context.sortColumns, context.sortAscendings, context.sortNullFirsts, + OperatorConfig.serialize(context.operatorConfig)); + } + + /** + * The type Factory context. + * + * @since 2021-06-30 + */ + public static class FactoryContext extends OmniOperatorFactoryContext { + private final DataType[] sourceTypes; + + private final int[] outputColumns; + + private final String[] sortColumns; + + private final int[] sortAscendings; + + private final int[] sortNullFirsts; + + private final OperatorConfig operatorConfig; + + /** + * Instantiates a new Context. + * + * @param sourceTypes the source types + * @param outputColumns the output columns + * @param sortColumns the sort columns + * @param sortAscendings the sort ascendings + * @param sortNullFirsts the sort null firsts + * @param operatorConfig the operator config + */ + public FactoryContext(DataType[] sourceTypes, int[] outputColumns, String[] sortColumns, int[] sortAscendings, + int[] sortNullFirsts, OperatorConfig operatorConfig) { + this.sourceTypes = sourceTypes; + this.outputColumns = outputColumns; + this.sortColumns = sortColumns; + this.sortAscendings = sortAscendings; + this.sortNullFirsts = sortNullFirsts; + this.operatorConfig = operatorConfig; + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(sourceTypes), Arrays.hashCode(outputColumns), + Arrays.hashCode(sortColumns), Arrays.hashCode(sortAscendings), Arrays.hashCode(sortNullFirsts), + operatorConfig); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + FactoryContext that = (FactoryContext) obj; + return Arrays.equals(sourceTypes, that.sourceTypes) && Arrays.equals(outputColumns, that.outputColumns) + && Arrays.equals(sortColumns, that.sortColumns) + && Arrays.equals(sortAscendings, that.sortAscendings) + && Arrays.equals(sortNullFirsts, that.sortNullFirsts) + && Objects.equals(operatorConfig, that.operatorConfig); + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/sort/OmniSortWithExprOperatorFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/sort/OmniSortWithExprOperatorFactory.java new file mode 100644 index 0000000..2ecb95e --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/sort/OmniSortWithExprOperatorFactory.java @@ -0,0 +1,123 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.sort; + +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.OmniOperatorFactoryContext; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DataTypeSerializer; + +import java.util.Arrays; +import java.util.Objects; + +/** + * The Omni sort with expression operator factory. + * + * @since 2021-10-16 + */ +public class OmniSortWithExprOperatorFactory + extends OmniOperatorFactory { + /** + * Instantiates a new Omni sort with expression operator factory. + * + * @param sourceTypes the source types + * @param outputColumns the output columns + * @param sortKeys the sort keys + * @param sortAscendings the sort ascendings + * @param sortNullFirsts the sort null firsts + */ + public OmniSortWithExprOperatorFactory(DataType[] sourceTypes, int[] outputColumns, String[] sortKeys, + int[] sortAscendings, int[] sortNullFirsts) { + super(new FactoryContext(sourceTypes, outputColumns, sortKeys, sortAscendings, sortNullFirsts, + new OperatorConfig())); + } + + /** + * Instantiates a new Omni sort with expression operator factory with default + * operator config. + * + * @param sourceTypes the source types + * @param outputColumns the output columns + * @param sortKeys the sort keys + * @param sortAscendings the sort ascendings + * @param sortNullFirsts the sort null firsts + * @param operatorConfig the operator config + */ + public OmniSortWithExprOperatorFactory(DataType[] sourceTypes, int[] outputColumns, String[] sortKeys, + int[] sortAscendings, int[] sortNullFirsts, OperatorConfig operatorConfig) { + super(new FactoryContext(sourceTypes, outputColumns, sortKeys, sortAscendings, sortNullFirsts, operatorConfig)); + } + + private static native long createSortWithExprOperatorFactory(String sourceTypes, int[] outputCols, + String[] sortKeys, int[] ascendings, int[] nullFirsts, String operatorConfig); + + @Override + protected long createNativeOperatorFactory(FactoryContext context) { + return createSortWithExprOperatorFactory(DataTypeSerializer.serialize(context.sourceTypes), + context.outputColumns, context.sortKeys, context.sortAscendings, context.sortNullFirsts, + OperatorConfig.serialize(context.operatorConfig)); + } + + /** + * The Factory context. + * + * @since 2021-10-16 + */ + public static class FactoryContext extends OmniOperatorFactoryContext { + private final DataType[] sourceTypes; + + private final int[] outputColumns; + + private final String[] sortKeys; + + private final int[] sortAscendings; + + private final int[] sortNullFirsts; + + private final OperatorConfig operatorConfig; + + /** + * Instantiates a new Context. + * + * @param sourceTypes the source types + * @param outputColumns the output columns + * @param sortKeys the sort keys + * @param sortAscendings the sort ascendings + * @param sortNullFirsts the sort null firsts + * @param operatorConfig the operator config + */ + public FactoryContext(DataType[] sourceTypes, int[] outputColumns, String[] sortKeys, int[] sortAscendings, + int[] sortNullFirsts, OperatorConfig operatorConfig) { + this.sourceTypes = sourceTypes; + this.outputColumns = outputColumns; + this.sortKeys = sortKeys; + this.sortAscendings = sortAscendings; + this.sortNullFirsts = sortNullFirsts; + this.operatorConfig = operatorConfig; + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(sourceTypes), Arrays.hashCode(outputColumns), Arrays.hashCode(sortKeys), + Arrays.hashCode(sortAscendings), Arrays.hashCode(sortNullFirsts), operatorConfig); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + FactoryContext that = (FactoryContext) obj; + return Arrays.equals(sourceTypes, that.sourceTypes) && Arrays.equals(outputColumns, that.outputColumns) + && Arrays.equals(sortKeys, that.sortKeys) && Arrays.equals(sortAscendings, that.sortAscendings) + && Arrays.equals(sortNullFirsts, that.sortNullFirsts) + && Objects.equals(operatorConfig, that.operatorConfig); + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/topn/OmniTopNOperatorFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/topn/OmniTopNOperatorFactory.java new file mode 100644 index 0000000..3a6c04b --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/topn/OmniTopNOperatorFactory.java @@ -0,0 +1,156 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.topn; + +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.OmniOperatorFactoryContext; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DataTypeSerializer; + +import java.util.Arrays; +import java.util.Objects; + +/** + * The type Omni top n operator factory. + * + * @since 2021-06-30 + */ +public class OmniTopNOperatorFactory extends OmniOperatorFactory { + /** + * Instantiates a new Omni top n operator factory. + * + * @param sourceTypes the source types + * @param limitN the limit n + * @param sortCols the sort cols + * @param sortAssendings the sort assendings + * @param sortNullFirsts the sort null firsts + * @param operatorConfig the operator config + */ + public OmniTopNOperatorFactory(DataType[] sourceTypes, int limitN, String[] sortCols, int[] sortAssendings, + int[] sortNullFirsts, OperatorConfig operatorConfig) { + super(new FactoryContext(sourceTypes, limitN, 0, sortCols, sortAssendings, sortNullFirsts, operatorConfig)); + } + + /** + * Instantiates a new Omni top n operator factory. + * + * @param sourceTypes the source types + * @param limitN the limit n + * @param offsetN the offset n + * @param sortCols the sort cols + * @param sortAssendings the sort assendings + * @param sortNullFirsts the sort null firsts + * @param operatorConfig the operator config + */ + public OmniTopNOperatorFactory(DataType[] sourceTypes, int limitN, int offsetN, String[] sortCols, + int[] sortAssendings, int[] sortNullFirsts, OperatorConfig operatorConfig) { + super(new FactoryContext(sourceTypes, limitN, offsetN, sortCols, sortAssendings, sortNullFirsts, + operatorConfig)); + } + + /** + * Instantiates a new Omni top n operator factory with default operator config. + * + * @param sourceTypes the source types + * @param limitN the limit n + * @param sortCols the sort cols + * @param sortAssendings the sort assendings + * @param sortNullFirsts the sort null firsts + */ + public OmniTopNOperatorFactory(DataType[] sourceTypes, int limitN, String[] sortCols, int[] sortAssendings, + int[] sortNullFirsts) { + this(sourceTypes, limitN, 0, sortCols, sortAssendings, sortNullFirsts, new OperatorConfig()); + } + + /** + * Instantiates a new Omni top n operator factory with default operator config. + * + * @param sourceTypes the source types + * @param limitN the limit n + * @param offsetN the offset n + * @param sortCols the sort cols + * @param sortAssendings the sort assendings + * @param sortNullFirsts the sort null firsts + */ + public OmniTopNOperatorFactory(DataType[] sourceTypes, int limitN, int offsetN, String[] sortCols, + int[] sortAssendings, int[] sortNullFirsts) { + this(sourceTypes, limitN, offsetN, sortCols, sortAssendings, sortNullFirsts, new OperatorConfig()); + } + + private static native long createTopNOperatorFactory(String sourceTypes, int limitN, int offsetN, String[] sortCols, + int[] sortAssendings, int[] sortNullFirsts); + + @Override + protected long createNativeOperatorFactory(FactoryContext context) { + return createTopNOperatorFactory(DataTypeSerializer.serialize(context.sourceTypes), context.limitN, + context.offsetN, context.sortCols, context.sortAssendings, context.sortNullFirsts); + } + + /** + * The type Context. + * + * @since 2021-06-30 + */ + public static class FactoryContext extends OmniOperatorFactoryContext { + private final DataType[] sourceTypes; + + private final int limitN; + + private final int offsetN; + + private final String[] sortCols; + + private final int[] sortAssendings; + + private final int[] sortNullFirsts; + + private final OperatorConfig operatorConfig; + + /** + * Instantiates a new Context. + * + * @param sourceTypes the source types + * @param limitN the limit n + * @param offsetN the offset n + * @param sortCols the sort cols + * @param sortAssendings the sort assendings + * @param sortNullFirsts the sort null firsts + * @param operatorConfig the operator config + */ + public FactoryContext(DataType[] sourceTypes, int limitN, int offsetN, String[] sortCols, int[] sortAssendings, + int[] sortNullFirsts, OperatorConfig operatorConfig) { + this.sourceTypes = sourceTypes; + this.limitN = limitN; + this.offsetN = offsetN; + this.sortCols = sortCols; + this.sortAssendings = sortAssendings; + this.sortNullFirsts = sortNullFirsts; + this.operatorConfig = operatorConfig; + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(sourceTypes), limitN, offsetN, Arrays.hashCode(sortCols), + Arrays.hashCode(sortAssendings), Arrays.hashCode(sortNullFirsts), operatorConfig); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + FactoryContext context = (FactoryContext) obj; + return limitN == context.limitN && offsetN == context.offsetN + && Arrays.equals(sourceTypes, context.sourceTypes) && Arrays.equals(sortCols, context.sortCols) + && Arrays.equals(sortAssendings, context.sortAssendings) + && Arrays.equals(sortNullFirsts, context.sortNullFirsts) + && operatorConfig.equals(context.operatorConfig); + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/topn/OmniTopNWithExprOperatorFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/topn/OmniTopNWithExprOperatorFactory.java new file mode 100644 index 0000000..eb38c45 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/topn/OmniTopNWithExprOperatorFactory.java @@ -0,0 +1,160 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.topn; + +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.OmniOperatorFactoryContext; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DataTypeSerializer; + +import java.util.Arrays; +import java.util.Objects; + +/** + * The type Omni top n with expression operator factory. + * + * @since 2021-10-26 + */ +public class OmniTopNWithExprOperatorFactory + extends OmniOperatorFactory { + /** + * Instantiates a new Omni top n with expression operator factory. + * + * @param sourceTypes the source types + * @param limitN the limit n + * @param sortKeys the sort keys + * @param sortAssendings the sort assendings + * @param sortNullFirsts the sort null firsts + * @param operatorConfig the operator config + */ + public OmniTopNWithExprOperatorFactory(DataType[] sourceTypes, int limitN, String[] sortKeys, int[] sortAssendings, + int[] sortNullFirsts, OperatorConfig operatorConfig) { + super(new FactoryContext(sourceTypes, limitN, 0, sortKeys, sortAssendings, sortNullFirsts, operatorConfig)); + } + + /** + * Instantiates a new Omni top n with expression operator factory. + * + * @param sourceTypes the source types + * @param limitN the limit n + * @param offsetN the offset n + * @param sortKeys the sort keys + * @param sortAssendings the sort assendings + * @param sortNullFirsts the sort null firsts + * @param operatorConfig the operator config + */ + public OmniTopNWithExprOperatorFactory(DataType[] sourceTypes, int limitN, int offsetN, String[] sortKeys, + int[] sortAssendings, int[] sortNullFirsts, OperatorConfig operatorConfig) { + super(new FactoryContext(sourceTypes, limitN, offsetN, sortKeys, sortAssendings, sortNullFirsts, + operatorConfig)); + } + + /** + * Instantiates a new Omni top n with expression operator factory with default + * operator config. + * + * @param sourceTypes the source types + * @param limitN the limit n + * @param sortKeys the sort keys + * @param sortAssendings the sort assendings + * @param sortNullFirsts the sort null firsts + */ + public OmniTopNWithExprOperatorFactory(DataType[] sourceTypes, int limitN, String[] sortKeys, int[] sortAssendings, + int[] sortNullFirsts) { + this(sourceTypes, limitN, 0, sortKeys, sortAssendings, sortNullFirsts, new OperatorConfig()); + } + + /** + * Instantiates a new Omni top n with expression operator factory with default + * operator config. + * + * @param sourceTypes the source types + * @param limitN the limit n + * @param offsetN the offset n + * @param sortKeys the sort keys + * @param sortAssendings the sort assendings + * @param sortNullFirsts the sort null firsts + */ + public OmniTopNWithExprOperatorFactory(DataType[] sourceTypes, int limitN, int offsetN, String[] sortKeys, + int[] sortAssendings, int[] sortNullFirsts) { + this(sourceTypes, limitN, offsetN, sortKeys, sortAssendings, sortNullFirsts, new OperatorConfig()); + } + + private static native long createTopNWithExprOperatorFactory(String sourceTypes, int limitN, int offsetN, + String[] sortKeys, int[] sortAssendings, int[] sortNullFirsts, String operatorConfig); + + @Override + protected long createNativeOperatorFactory(FactoryContext context) { + return createTopNWithExprOperatorFactory(DataTypeSerializer.serialize(context.sourceTypes), context.limitN, + context.offsetN, context.sortKeys, context.sortAssendings, context.sortNullFirsts, + OperatorConfig.serialize(context.operatorConfig)); + } + + /** + * The type Context. + * + * @since 2021-10-26 + */ + public static class FactoryContext extends OmniOperatorFactoryContext { + private final DataType[] sourceTypes; + + private final int limitN; + + private final int offsetN; + + private final String[] sortKeys; + + private final int[] sortAssendings; + + private final int[] sortNullFirsts; + + private final OperatorConfig operatorConfig; + + /** + * Instantiates a new Context. + * + * @param sourceTypes the source types + * @param limitN the limit n + * @param offsetN the offset n + * @param sortKeys the sort cols + * @param sortAssendings the sort assendings + * @param sortNullFirsts the sort null firsts + * @param operatorConfig the operator config + */ + public FactoryContext(DataType[] sourceTypes, int limitN, int offsetN, String[] sortKeys, int[] sortAssendings, + int[] sortNullFirsts, OperatorConfig operatorConfig) { + this.sourceTypes = sourceTypes; + this.limitN = limitN; + this.offsetN = offsetN; + this.sortKeys = sortKeys; + this.sortAssendings = sortAssendings; + this.sortNullFirsts = sortNullFirsts; + this.operatorConfig = operatorConfig; + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(sourceTypes), limitN, offsetN, Arrays.hashCode(sortKeys), + Arrays.hashCode(sortAssendings), Arrays.hashCode(sortNullFirsts), operatorConfig); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + FactoryContext context = (FactoryContext) obj; + return limitN == context.limitN && offsetN == context.offsetN + && Arrays.equals(sourceTypes, context.sourceTypes) && Arrays.equals(sortKeys, context.sortKeys) + && Arrays.equals(sortAssendings, context.sortAssendings) + && Arrays.equals(sortNullFirsts, context.sortNullFirsts) + && operatorConfig.equals(context.operatorConfig); + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/topnsort/OmniTopNSortWithExprOperatorFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/topnsort/OmniTopNSortWithExprOperatorFactory.java new file mode 100644 index 0000000..0f87278 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/topnsort/OmniTopNSortWithExprOperatorFactory.java @@ -0,0 +1,141 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.topnsort; + +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.OmniOperatorFactoryContext; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DataTypeSerializer; + +import java.util.Arrays; +import java.util.Objects; + +/** + * The type omni top n sort with expression operator factory. + * + * @since 2023-5-26 + */ +public class OmniTopNSortWithExprOperatorFactory + extends OmniOperatorFactory { + /** + * Instantiates a new omni top n sort with expression operator factory. + * + * @param sourceTypes the source types + * @param limitN the limit n + * @param isStrictTopN true for window RowNumber, false for window Rank + * @param partitionKeys the partition keys + * @param sortKeys the sort keys + * @param sortAssendings the sort assendings + * @param sortNullFirsts the sort null firsts + * @param operatorConfig the operator config + */ + public OmniTopNSortWithExprOperatorFactory(DataType[] sourceTypes, int limitN, boolean isStrictTopN, + String[] partitionKeys, String[] sortKeys, int[] sortAssendings, int[] sortNullFirsts, + OperatorConfig operatorConfig) { + super(new FactoryContext(sourceTypes, limitN, isStrictTopN, partitionKeys, sortKeys, sortAssendings, + sortNullFirsts, operatorConfig)); + } + + /** + * Instantiates a new omni top n sort with expression operator factory with default + * operator config. + * + * @param sourceTypes the source types + * @param limitN the limit n + * @param isStrictTopN true for window RowNumber, false for window Rank + * @param partitionKeys the partition keys + * @param sortKeys the sort keys + * @param sortAssendings the sort assendings + * @param sortNullFirsts the sort null firsts + */ + public OmniTopNSortWithExprOperatorFactory(DataType[] sourceTypes, int limitN, boolean isStrictTopN, + String[] partitionKeys, String[] sortKeys, int[] sortAssendings, int[] sortNullFirsts) { + this(sourceTypes, limitN, isStrictTopN, partitionKeys, sortKeys, sortAssendings, sortNullFirsts, + new OperatorConfig()); + } + + private static native long createTopNSortWithExprOperatorFactory(String sourceTypes, int limitN, + boolean isStrictTopN, String[] partitionKeys, String[] sortKeys, int[] sortAssendings, int[] sortNullFirsts, + String operatorConfig); + + @Override + protected long createNativeOperatorFactory(FactoryContext context) { + return createTopNSortWithExprOperatorFactory(DataTypeSerializer.serialize(context.sourceTypes), context.limitN, + context.isStrictTopN, context.partitionKeys, context.sortKeys, context.sortAssendings, + context.sortNullFirsts, OperatorConfig.serialize(context.operatorConfig)); + } + + /** + * The type Context. + * + * @since 2023-5-26 + */ + public static class FactoryContext extends OmniOperatorFactoryContext { + private final DataType[] sourceTypes; + + private final int limitN; + + private final boolean isStrictTopN; + + private final String[] partitionKeys; + + private final String[] sortKeys; + + private final int[] sortAssendings; + + private final int[] sortNullFirsts; + + private final OperatorConfig operatorConfig; + + /** + * Instantiates a new Context. + * + * @param sourceTypes the source types + * @param limitN the limit n + * @param isStrictTopN true for window RowNumber, false for window Rank + * @param partitionKeys the partition keys + * @param sortKeys the sort keys + * @param sortAssendings the sort assendings + * @param sortNullFirsts the sort null firsts + * @param operatorConfig the operator config + */ + public FactoryContext(DataType[] sourceTypes, int limitN, boolean isStrictTopN, String[] partitionKeys, + String[] sortKeys, int[] sortAssendings, int[] sortNullFirsts, OperatorConfig operatorConfig) { + this.sourceTypes = sourceTypes; + this.limitN = limitN; + this.isStrictTopN = isStrictTopN; + this.partitionKeys = partitionKeys; + this.sortKeys = sortKeys; + this.sortAssendings = sortAssendings; + this.sortNullFirsts = sortNullFirsts; + this.operatorConfig = operatorConfig; + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(sourceTypes), limitN, isStrictTopN, Arrays.hashCode(partitionKeys), + Arrays.hashCode(sortKeys), Arrays.hashCode(sortAssendings), Arrays.hashCode(sortNullFirsts), + operatorConfig); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + FactoryContext context = (FactoryContext) obj; + return limitN == context.limitN && Arrays.equals(sourceTypes, context.sourceTypes) + && isStrictTopN == context.isStrictTopN && Arrays.equals(partitionKeys, context.partitionKeys) + && Arrays.equals(sortKeys, context.sortKeys) + && Arrays.equals(sortAssendings, context.sortAssendings) + && Arrays.equals(sortNullFirsts, context.sortNullFirsts) + && operatorConfig.equals(context.operatorConfig); + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/union/OmniUnionOperatorFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/union/OmniUnionOperatorFactory.java new file mode 100644 index 0000000..9ed4fa5 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/union/OmniUnionOperatorFactory.java @@ -0,0 +1,93 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.union; + +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.OmniOperatorFactoryContext; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DataTypeSerializer; + +import java.util.Arrays; +import java.util.Objects; + +/** + * The type Omni union operator factory. + * + * @since 2021-06-30 + */ +public class OmniUnionOperatorFactory extends OmniOperatorFactory { + /** + * Instantiates a new Omni union operator factory. + * + * @param sourceTypes the source type + * @param isDistinct mark union or union all + * @param operatorConfig the operator config + */ + public OmniUnionOperatorFactory(DataType[] sourceTypes, boolean isDistinct, OperatorConfig operatorConfig) { + super(new FactoryContext(sourceTypes, isDistinct, operatorConfig)); + } + + /** + * Instantiates a new Omni union operator factory with default operator config. + * + * @param sourceTypes the source type + * @param isDistinct mark union or union all + */ + public OmniUnionOperatorFactory(DataType[] sourceTypes, boolean isDistinct) { + this(sourceTypes, isDistinct, new OperatorConfig()); + } + + @Override + protected long createNativeOperatorFactory(FactoryContext context) { + return createUnionOperatorFactory(DataTypeSerializer.serialize(context.sourceTypes), context.isDistinct); + } + + private static native long createUnionOperatorFactory(String sourceTypes, boolean isDistinct); + + /** + * The type Factory context. + * + * @since 2021-06-30 + */ + public static class FactoryContext extends OmniOperatorFactoryContext { + private final DataType[] sourceTypes; + + private final boolean isDistinct; + + private final OperatorConfig operatorConfig; + + /** + * Instantiates a new Jit context. + * + * @param sourceTypes the source types + * @param isDistinct the is distinct + * @param operatorConfig the operator config + */ + public FactoryContext(DataType[] sourceTypes, boolean isDistinct, OperatorConfig operatorConfig) { + this.sourceTypes = sourceTypes; + this.isDistinct = isDistinct; + this.operatorConfig = operatorConfig; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + FactoryContext context = (FactoryContext) obj; + return isDistinct == context.isDistinct && Arrays.equals(sourceTypes, context.sourceTypes) + && operatorConfig.equals(context.operatorConfig); + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(sourceTypes), isDistinct, operatorConfig); + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/window/OmniWindowGroupLimitWithExprOperatorFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/window/OmniWindowGroupLimitWithExprOperatorFactory.java new file mode 100644 index 0000000..77e01c5 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/window/OmniWindowGroupLimitWithExprOperatorFactory.java @@ -0,0 +1,140 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.window; + +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.OmniOperatorFactoryContext; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DataTypeSerializer; + +import java.util.Arrays; +import java.util.Objects; + +/** + * The type omni window group limit n with expression operator factory. + * + * @since 2025-01-15 + */ +public class OmniWindowGroupLimitWithExprOperatorFactory + extends OmniOperatorFactory { + /** + * Instantiates a new omni top n sort with expression operator factory. + * + * @param sourceTypes the source types + * @param n means limit n + * @param funcName row_number/rank/dense_rank + * @param partitionKeys the partition keys + * @param sortKeys the sort keys + * @param sortAssendings the sort assendings + * @param sortNullFirsts the sort null firsts + * @param operatorConfig the operator config + */ + public OmniWindowGroupLimitWithExprOperatorFactory(DataType[] sourceTypes, int n, String funcName, + String[] partitionKeys, String[] sortKeys, int[] sortAssendings, int[] sortNullFirsts, + OperatorConfig operatorConfig) { + super(new FactoryContext(sourceTypes, n, funcName, partitionKeys, sortKeys, sortAssendings, sortNullFirsts, + operatorConfig)); + } + + /** + * Instantiates a new omni window group limit with expression operator factory + * with default + * operator config. + * + * @param sourceTypes the source types + * @param n means limit n + * @param funcName row_number/rank/dense_rank + * @param partitionKeys the partition keys + * @param sortKeys the sort keys + * @param sortAssendings the sort assendings + * @param sortNullFirsts the sort null firsts + */ + public OmniWindowGroupLimitWithExprOperatorFactory(DataType[] sourceTypes, int n, String funcName, + String[] partitionKeys, String[] sortKeys, int[] sortAssendings, int[] sortNullFirsts) { + this(sourceTypes, n, funcName, partitionKeys, sortKeys, sortAssendings, sortNullFirsts, new OperatorConfig()); + } + + private static native long createWindowGroupLimitWithExprOperatorFactory(String sourceTypes, int n, String funcName, + String[] partitionKeys, String[] sortKeys, int[] sortAssendings, int[] sortNullFirsts, + String operatorConfig); + + @Override + protected long createNativeOperatorFactory(FactoryContext context) { + return createWindowGroupLimitWithExprOperatorFactory(DataTypeSerializer.serialize(context.sourceTypes), + context.n, context.funcName, context.partitionKeys, context.sortKeys, context.sortAssendings, + context.sortNullFirsts, OperatorConfig.serialize(context.operatorConfig)); + } + + /** + * The type Context. + * + * @since 2025-1-15 + */ + public static class FactoryContext extends OmniOperatorFactoryContext { + private final DataType[] sourceTypes; + + private final int n; + + private final String funcName; + + private final String[] partitionKeys; + + private final String[] sortKeys; + + private final int[] sortAssendings; + + private final int[] sortNullFirsts; + + private final OperatorConfig operatorConfig; + + /** + * Instantiates a new Context. + * + * @param sourceTypes the source types + * @param n means limit n + * @param funcName row_number/rank/dense_rank + * @param partitionKeys the partition keys + * @param sortKeys the sort keys + * @param sortAssendings the sort assendings + * @param sortNullFirsts the sort null firsts + * @param operatorConfig the operator config + */ + public FactoryContext(DataType[] sourceTypes, int n, String funcName, String[] partitionKeys, String[] sortKeys, + int[] sortAssendings, int[] sortNullFirsts, OperatorConfig operatorConfig) { + this.sourceTypes = sourceTypes; + this.n = n; + this.funcName = funcName; + this.partitionKeys = partitionKeys; + this.sortKeys = sortKeys; + this.sortAssendings = sortAssendings; + this.sortNullFirsts = sortNullFirsts; + this.operatorConfig = operatorConfig; + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(sourceTypes), n, funcName, Arrays.hashCode(partitionKeys), + Arrays.hashCode(sortKeys), Arrays.hashCode(sortAssendings), Arrays.hashCode(sortNullFirsts), + operatorConfig); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + FactoryContext context = (FactoryContext) obj; + return n == context.n && Arrays.equals(sourceTypes, context.sourceTypes) && funcName == context.funcName + && Arrays.equals(partitionKeys, context.partitionKeys) && Arrays.equals(sortKeys, context.sortKeys) + && Arrays.equals(sortAssendings, context.sortAssendings) + && Arrays.equals(sortNullFirsts, context.sortNullFirsts) + && operatorConfig.equals(context.operatorConfig); + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/window/OmniWindowOperatorFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/window/OmniWindowOperatorFactory.java new file mode 100644 index 0000000..25688a3 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/window/OmniWindowOperatorFactory.java @@ -0,0 +1,243 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.window; + +import static nova.hetu.omniruntime.constants.ConstantHelper.toNativeConstants; + +import nova.hetu.omniruntime.constants.FunctionType; +import nova.hetu.omniruntime.constants.OmniWindowFrameBoundType; +import nova.hetu.omniruntime.constants.OmniWindowFrameType; +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.OmniOperatorFactoryContext; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DataTypeSerializer; + +import java.util.Arrays; +import java.util.Objects; + +/** + * The type Omni window operator factory. + * + * @since 20210630 + */ +public class OmniWindowOperatorFactory extends OmniOperatorFactory { + /** + * Instantiates a new Omni window operator factory. + * + * @param sourceTypes the source types + * @param outputChannels the output channels + * @param windowFunction the window function + * @param partitionChannels the partition channels + * @param preGroupedChannels the pre grouped channels + * @param sortChannels the sort channels + * @param sortOrder the sort order + * @param sortNullFirsts the sort null firsts + * @param preSortedChannelPrefix the pre sorted channel prefix + * @param expectedPositions the expected positions + * @param argumentChannels the argument channels + * @param windowFunctionReturnType the window function return type + * @param windowFrameTypes frame types of the window + * @param windowFrameStartTypes start types of frame in window + * @param winddowFrameStartChannels channels value of frame start value + * @param windowFrameEndTypes end types of frame in window + * @param winddowFrameEndChannels channel values of frame end value + * @param operatorConfig the operator config + */ + public OmniWindowOperatorFactory(DataType[] sourceTypes, int[] outputChannels, FunctionType[] windowFunction, + int[] partitionChannels, int[] preGroupedChannels, int[] sortChannels, int[] sortOrder, + int[] sortNullFirsts, int preSortedChannelPrefix, int expectedPositions, int[] argumentChannels, + DataType[] windowFunctionReturnType, OmniWindowFrameType[] windowFrameTypes, + OmniWindowFrameBoundType[] windowFrameStartTypes, int[] winddowFrameStartChannels, + OmniWindowFrameBoundType[] windowFrameEndTypes, int[] winddowFrameEndChannels, + OperatorConfig operatorConfig) { + super(new FactoryContext(sourceTypes, outputChannels, windowFunction, partitionChannels, preGroupedChannels, + sortChannels, sortOrder, sortNullFirsts, preSortedChannelPrefix, expectedPositions, argumentChannels, + windowFunctionReturnType, windowFrameTypes, windowFrameStartTypes, winddowFrameStartChannels, + windowFrameEndTypes, winddowFrameEndChannels, operatorConfig)); + } + + /** + * Instantiates a new Omni window operator factory with default operator config. + * + * @param sourceTypes the source types + * @param outputChannels the output channels + * @param windowFunction the window function + * @param partitionChannels the partition channels + * @param preGroupedChannels the pre grouped channels + * @param sortChannels the sort channels + * @param sortOrder the sort order + * @param sortNullFirsts the sort null firsts + * @param preSortedChannelPrefix the pre sorted channel prefix + * @param expectedPositions the expected positions + * @param argumentChannels the argument channels + * @param windowFunctionReturnType the window function return type + * @param windowFrameTypes frame types of the window + * @param windowFrameStartTypes start types of frame in window + * @param winddowFrameStartChannels channels value of frame start value + * @param windowFrameEndTypes end types of frame in window + * @param winddowFrameEndChannels channel values of frame end value + */ + public OmniWindowOperatorFactory(DataType[] sourceTypes, int[] outputChannels, FunctionType[] windowFunction, + int[] partitionChannels, int[] preGroupedChannels, int[] sortChannels, int[] sortOrder, + int[] sortNullFirsts, int preSortedChannelPrefix, int expectedPositions, int[] argumentChannels, + DataType[] windowFunctionReturnType, OmniWindowFrameType[] windowFrameTypes, + OmniWindowFrameBoundType[] windowFrameStartTypes, int[] winddowFrameStartChannels, + OmniWindowFrameBoundType[] windowFrameEndTypes, int[] winddowFrameEndChannels) { + this(sourceTypes, outputChannels, windowFunction, partitionChannels, preGroupedChannels, sortChannels, + sortOrder, sortNullFirsts, preSortedChannelPrefix, expectedPositions, argumentChannels, + windowFunctionReturnType, windowFrameTypes, windowFrameStartTypes, winddowFrameStartChannels, + windowFrameEndTypes, winddowFrameEndChannels, new OperatorConfig()); + } + + private static native long createWindowOperatorFactory(String sourceTypes, int[] outputChannels, int[] windFunction, + int[] partitionChannels, int[] preGroupedChannels, int[] sortChannels, int[] sortOrder, + int[] sortNullFirsts, int preSortedChannelPrefix, int expectedPositions, int[] argumentChannels, + String windowFunctionReturnType, int[] windowFrameTypes, int[] windowFrameStartTypes, + int[] windowFrameStartChannels, int[] windowFrameEndTypes, int[] windowFrameEndChannels, + String operatorConfig); + + @Override + protected long createNativeOperatorFactory(FactoryContext context) { + return createWindowOperatorFactory(DataTypeSerializer.serialize(context.sourceTypes), context.outputChannels, + toNativeConstants(context.windFunction), context.partitionChannels, context.preGroupedChannels, + context.sortChannels, context.sortOrder, context.sortNullFirsts, context.preSortedChannelPrefix, + context.expectedPositions, context.argumentChannels, + DataTypeSerializer.serialize(context.windowFunctionReturnType), toNativeConstants(context.frameTypes), + toNativeConstants(context.frameStartTypes), context.frameStartChannels, + toNativeConstants(context.frameEndTypes), context.frameEndChannels, + OperatorConfig.serialize(context.operatorConfig)); + } + + /** + * The type Factory context. + * + * @since 20210630 + */ + public static class FactoryContext extends OmniOperatorFactoryContext { + private final DataType[] sourceTypes; + + private final int[] outputChannels; + + private final FunctionType[] windFunction; + + private final int[] partitionChannels; + + private final int[] sortChannels; + + private final int[] sortOrder; + + private final int[] preGroupedChannels; + + private final int[] sortNullFirsts; + + private final int preSortedChannelPrefix; + + private final int expectedPositions; + + private final int[] argumentChannels; + + private final DataType[] windowFunctionReturnType; + + private final OmniWindowFrameType[] frameTypes; + + private final OmniWindowFrameBoundType[] frameStartTypes; + + private final int[] frameStartChannels; + + private final OmniWindowFrameBoundType[] frameEndTypes; + + private final int[] frameEndChannels; + + private final OperatorConfig operatorConfig; + + /** + * Instantiates a new Context. + * + * @param sourceTypes the source types + * @param outputChannels the output channels + * @param windowFunction the window function + * @param partitionChannels the partition channels + * @param preGroupedChannels the pre grouped channels + * @param sortChannels the sort channels + * @param sortOrder the sort order + * @param sortNullFirsts the sort null firsts + * @param preSortedChannelPrefix the pre sorted channel prefix + * @param expectedPositions the expected positions + * @param argumentChannels the argument channels + * @param windowFunctionReturnType the window function return type + * @param windowFrameTypes frame types of the window + * @param windowFrameStartTypes start types of frame in window + * @param winddowFrameStartChannels channels value of frame start value + * @param windowFrameEndTypes end types of frame in window + * @param winddowFrameEndChannels channel values of frame end value + * @param operatorConfig the operator config + */ + public FactoryContext(DataType[] sourceTypes, int[] outputChannels, FunctionType[] windowFunction, + int[] partitionChannels, int[] preGroupedChannels, int[] sortChannels, int[] sortOrder, + int[] sortNullFirsts, int preSortedChannelPrefix, int expectedPositions, int[] argumentChannels, + DataType[] windowFunctionReturnType, OmniWindowFrameType[] windowFrameTypes, + OmniWindowFrameBoundType[] windowFrameStartTypes, int[] winddowFrameStartChannels, + OmniWindowFrameBoundType[] windowFrameEndTypes, int[] winddowFrameEndChannels, + OperatorConfig operatorConfig) { + this.sourceTypes = sourceTypes; + this.outputChannels = outputChannels; + this.windFunction = windowFunction; + this.partitionChannels = partitionChannels; + this.preGroupedChannels = preGroupedChannels; + this.sortChannels = sortChannels; + this.sortOrder = sortOrder; + this.sortNullFirsts = sortNullFirsts; + this.preSortedChannelPrefix = preSortedChannelPrefix; + this.expectedPositions = expectedPositions; + this.argumentChannels = argumentChannels; + this.windowFunctionReturnType = windowFunctionReturnType; + this.frameTypes = windowFrameTypes; + this.frameStartTypes = windowFrameStartTypes; + this.frameStartChannels = winddowFrameStartChannels; + this.frameEndTypes = windowFrameEndTypes; + this.frameEndChannels = winddowFrameEndChannels; + this.operatorConfig = operatorConfig; + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(sourceTypes), Arrays.hashCode(outputChannels), + Arrays.hashCode(windFunction), Arrays.hashCode(partitionChannels), + Arrays.hashCode(preGroupedChannels), Arrays.hashCode(sortChannels), Arrays.hashCode(sortOrder), + Arrays.hashCode(sortNullFirsts), preSortedChannelPrefix, expectedPositions, + Arrays.hashCode(argumentChannels), Arrays.hashCode(windowFunctionReturnType), + Arrays.hashCode(frameTypes), Arrays.hashCode(frameStartTypes), Arrays.hashCode(frameStartChannels), + Arrays.hashCode(frameEndTypes), Arrays.hashCode(frameEndChannels), operatorConfig); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + FactoryContext context = (FactoryContext) obj; + return preSortedChannelPrefix == context.preSortedChannelPrefix + && expectedPositions == context.expectedPositions && Arrays.equals(sourceTypes, context.sourceTypes) + && Arrays.equals(outputChannels, context.outputChannels) + && Arrays.equals(windFunction, context.windFunction) + && Arrays.equals(partitionChannels, context.partitionChannels) + && Arrays.equals(preGroupedChannels, context.preGroupedChannels) + && Arrays.equals(sortChannels, context.sortChannels) && Arrays.equals(sortOrder, context.sortOrder) + && Arrays.equals(sortNullFirsts, context.sortNullFirsts) + && Arrays.equals(argumentChannels, context.argumentChannels) + && Arrays.equals(windowFunctionReturnType, context.windowFunctionReturnType) + && Arrays.equals(frameTypes, context.frameTypes) + && Arrays.equals(frameStartTypes, context.frameStartTypes) + && Arrays.equals(frameStartChannels, context.frameStartChannels) + && Arrays.equals(frameEndTypes, context.frameEndTypes) + && Arrays.equals(frameEndChannels, context.frameEndChannels) + && operatorConfig.equals(context.operatorConfig); + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/operator/window/OmniWindowWithExprOperatorFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/window/OmniWindowWithExprOperatorFactory.java new file mode 100644 index 0000000..6f1d538 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/operator/window/OmniWindowWithExprOperatorFactory.java @@ -0,0 +1,244 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator.window; + +import static nova.hetu.omniruntime.constants.ConstantHelper.toNativeConstants; + +import nova.hetu.omniruntime.constants.FunctionType; +import nova.hetu.omniruntime.constants.OmniWindowFrameBoundType; +import nova.hetu.omniruntime.constants.OmniWindowFrameType; +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.OmniOperatorFactoryContext; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DataTypeSerializer; + +import java.util.Arrays; +import java.util.Objects; + +/** + * The type Omni window operator factory. + * + * @since 20210630 + */ +public class OmniWindowWithExprOperatorFactory + extends OmniOperatorFactory { + /** + * Instantiates a new Omni window operator factory. + * + * @param sourceTypes the source types + * @param outputChannels the output channels + * @param windowFunction the window function + * @param partitionChannels the partition channels + * @param preGroupedChannels the pre grouped channels + * @param sortChannels the sort channels + * @param sortOrder the sort order + * @param sortNullFirsts the sort null firsts + * @param preSortedChannelPrefix the pre sorted channel prefix + * @param expectedPositions the expected positions + * @param argumentKeys the argument keys + * @param windowFunctionReturnType the window function return type + * @param windowFrameTypes frame types of the window + * @param windowFrameStartTypes start types of frame in window + * @param winddowFrameStartChannels channels value of frame start value + * @param windowFrameEndTypes end types of frame in window + * @param winddowFrameEndChannels channel values of frame end value + * @param operatorConfig the operator config + */ + public OmniWindowWithExprOperatorFactory(DataType[] sourceTypes, int[] outputChannels, + FunctionType[] windowFunction, int[] partitionChannels, int[] preGroupedChannels, int[] sortChannels, + int[] sortOrder, int[] sortNullFirsts, int preSortedChannelPrefix, int expectedPositions, + String[] argumentKeys, DataType[] windowFunctionReturnType, OmniWindowFrameType[] windowFrameTypes, + OmniWindowFrameBoundType[] windowFrameStartTypes, int[] winddowFrameStartChannels, + OmniWindowFrameBoundType[] windowFrameEndTypes, int[] winddowFrameEndChannels, + OperatorConfig operatorConfig) { + super(new FactoryContext(sourceTypes, outputChannels, windowFunction, partitionChannels, preGroupedChannels, + sortChannels, sortOrder, sortNullFirsts, preSortedChannelPrefix, expectedPositions, argumentKeys, + windowFunctionReturnType, windowFrameTypes, windowFrameStartTypes, winddowFrameStartChannels, + windowFrameEndTypes, winddowFrameEndChannels, operatorConfig)); + } + + /** + * Instantiates a new Omni window operator factory with default operator config. + * + * @param sourceTypes the source types + * @param outputChannels the output channels + * @param windowFunction the window function + * @param partitionChannels the partition channels + * @param preGroupedChannels the pre grouped channels + * @param sortChannels the sort channels + * @param sortOrder the sort order + * @param sortNullFirsts the sort null firsts + * @param preSortedChannelPrefix the pre sorted channel prefix + * @param expectedPositions the expected positions + * @param argumentKeys the argument keys + * @param windowFunctionReturnType the window function return type + * @param windowFrameTypes frame types of the window + * @param windowFrameStartTypes start types of frame in window + * @param winddowFrameStartChannels channels value of frame start value + * @param windowFrameEndTypes end types of frame in window + * @param winddowFrameEndChannels channel values of frame end value + */ + public OmniWindowWithExprOperatorFactory(DataType[] sourceTypes, int[] outputChannels, + FunctionType[] windowFunction, int[] partitionChannels, int[] preGroupedChannels, int[] sortChannels, + int[] sortOrder, int[] sortNullFirsts, int preSortedChannelPrefix, int expectedPositions, + String[] argumentKeys, DataType[] windowFunctionReturnType, OmniWindowFrameType[] windowFrameTypes, + OmniWindowFrameBoundType[] windowFrameStartTypes, int[] winddowFrameStartChannels, + OmniWindowFrameBoundType[] windowFrameEndTypes, int[] winddowFrameEndChannels) { + this(sourceTypes, outputChannels, windowFunction, partitionChannels, preGroupedChannels, sortChannels, + sortOrder, sortNullFirsts, preSortedChannelPrefix, expectedPositions, argumentKeys, + windowFunctionReturnType, windowFrameTypes, windowFrameStartTypes, winddowFrameStartChannels, + windowFrameEndTypes, winddowFrameEndChannels, new OperatorConfig()); + } + + private static native long createWindowWithExprOperatorFactory(String sourceTypes, int[] outputChannels, + int[] windFunction, int[] partitionChannels, int[] preGroupedChannels, int[] sortChannels, int[] sortOrder, + int[] sortNullFirsts, int preSortedChannelPrefix, int expectedPositions, String[] argumentKeys, + String windowFunctionReturnType, int[] windowFrameTypes, int[] windowFrameStartTypes, + int[] windowFrameStartChannels, int[] windowFrameEndTypes, int[] windowFrameEndChannels, + String operatorConfig); + + @Override + protected long createNativeOperatorFactory(FactoryContext context) { + return createWindowWithExprOperatorFactory(DataTypeSerializer.serialize(context.sourceTypes), + context.outputChannels, toNativeConstants(context.windFunction), context.partitionChannels, + context.preGroupedChannels, context.sortChannels, context.sortOrder, context.sortNullFirsts, + context.preSortedChannelPrefix, context.expectedPositions, context.argumentKeys, + DataTypeSerializer.serialize(context.windowFunctionReturnType), toNativeConstants(context.frameTypes), + toNativeConstants(context.frameStartTypes), context.frameStartChannels, + toNativeConstants(context.frameEndTypes), context.frameEndChannels, + OperatorConfig.serialize(context.operatorConfig)); + } + + /** + * The type Factory context. + * + * @since 20210630 + */ + public static class FactoryContext extends OmniOperatorFactoryContext { + private final DataType[] sourceTypes; + + private final int[] outputChannels; + + private final FunctionType[] windFunction; + + private final int[] partitionChannels; + + private final int[] sortChannels; + + private final int[] sortOrder; + + private final int[] preGroupedChannels; + + private final int[] sortNullFirsts; + + private final int preSortedChannelPrefix; + + private final int expectedPositions; + + private final String[] argumentKeys; + + private final DataType[] windowFunctionReturnType; + + private final OmniWindowFrameType[] frameTypes; + + private final OmniWindowFrameBoundType[] frameStartTypes; + + private final int[] frameStartChannels; + + private final OmniWindowFrameBoundType[] frameEndTypes; + + private final int[] frameEndChannels; + + private final OperatorConfig operatorConfig; + + /** + * Instantiates a new Context. + * + * @param sourceTypes the source types + * @param outputChannels the output channels + * @param windowFunction the window function + * @param partitionChannels the partition channels + * @param preGroupedChannels the pre grouped channels + * @param sortChannels the sort channels + * @param sortOrder the sort order + * @param sortNullFirsts the sort null firsts + * @param preSortedChannelPrefix the pre sorted channel prefix + * @param expectedPositions the expected positions + * @param argumentKeys the argument channels + * @param windowFunctionReturnType the window function return type + * @param windowFrameTypes frame types of the window + * @param windowFrameStartTypes start types of frame in window + * @param winddowFrameStartChannels channels value of frame start value + * @param windowFrameEndTypes end types of frame in window + * @param winddowFrameEndChannels channel values of frame end value + * @param operatorConfig the operator config + */ + public FactoryContext(DataType[] sourceTypes, int[] outputChannels, FunctionType[] windowFunction, + int[] partitionChannels, int[] preGroupedChannels, int[] sortChannels, int[] sortOrder, + int[] sortNullFirsts, int preSortedChannelPrefix, int expectedPositions, String[] argumentKeys, + DataType[] windowFunctionReturnType, OmniWindowFrameType[] windowFrameTypes, + OmniWindowFrameBoundType[] windowFrameStartTypes, int[] winddowFrameStartChannels, + OmniWindowFrameBoundType[] windowFrameEndTypes, int[] winddowFrameEndChannels, + OperatorConfig operatorConfig) { + this.sourceTypes = sourceTypes; + this.outputChannels = outputChannels; + this.windFunction = windowFunction; + this.partitionChannels = partitionChannels; + this.preGroupedChannels = preGroupedChannels; + this.sortChannels = sortChannels; + this.sortOrder = sortOrder; + this.sortNullFirsts = sortNullFirsts; + this.preSortedChannelPrefix = preSortedChannelPrefix; + this.expectedPositions = expectedPositions; + this.argumentKeys = argumentKeys; + this.windowFunctionReturnType = windowFunctionReturnType; + this.frameTypes = windowFrameTypes; + this.frameStartTypes = windowFrameStartTypes; + this.frameStartChannels = winddowFrameStartChannels; + this.frameEndTypes = windowFrameEndTypes; + this.frameEndChannels = winddowFrameEndChannels; + this.operatorConfig = operatorConfig; + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(sourceTypes), Arrays.hashCode(outputChannels), + Arrays.hashCode(windFunction), Arrays.hashCode(partitionChannels), + Arrays.hashCode(preGroupedChannels), Arrays.hashCode(sortChannels), Arrays.hashCode(sortOrder), + Arrays.hashCode(sortNullFirsts), preSortedChannelPrefix, expectedPositions, + Arrays.hashCode(argumentKeys), Arrays.hashCode(windowFunctionReturnType), + Arrays.hashCode(argumentKeys), Arrays.hashCode(windowFunctionReturnType), + Arrays.hashCode(frameStartChannels), Arrays.hashCode(frameEndTypes), + Arrays.hashCode(frameEndChannels), operatorConfig); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + FactoryContext context = (FactoryContext) obj; + return preSortedChannelPrefix == context.preSortedChannelPrefix + && expectedPositions == context.expectedPositions && Arrays.equals(sourceTypes, context.sourceTypes) + && Arrays.equals(outputChannels, context.outputChannels) + && Arrays.equals(windFunction, context.windFunction) + && Arrays.equals(partitionChannels, context.partitionChannels) + && Arrays.equals(preGroupedChannels, context.preGroupedChannels) + && Arrays.equals(sortChannels, context.sortChannels) && Arrays.equals(sortOrder, context.sortOrder) + && Arrays.equals(sortNullFirsts, context.sortNullFirsts) + && Arrays.equals(argumentKeys, context.argumentKeys) + && Arrays.equals(windowFunctionReturnType, context.windowFunctionReturnType) + && operatorConfig.equals(context.operatorConfig) && Arrays.equals(frameTypes, context.frameTypes) + && Arrays.equals(frameStartTypes, context.frameStartTypes) + && Arrays.equals(frameStartChannels, context.frameStartChannels) + && Arrays.equals(frameEndTypes, context.frameEndTypes) + && Arrays.equals(frameEndChannels, context.frameEndChannels); + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/type/BooleanDataType.java b/bindings/java/src/main/java/nova/hetu/omniruntime/type/BooleanDataType.java new file mode 100644 index 0000000..fdbbd6f --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/type/BooleanDataType.java @@ -0,0 +1,26 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.type; + +/** + * boolean data type. + * + * @since 2021-08-05 + */ +public class BooleanDataType extends DataType { + /** + * Boolean singleton. + */ + public static final BooleanDataType BOOLEAN = new BooleanDataType(); + + private static final long serialVersionUID = 8981310620537140335L; + + /** + * Boolean construct. + */ + public BooleanDataType() { + super(DataTypeId.OMNI_BOOLEAN); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/type/ByteDataType.java b/bindings/java/src/main/java/nova/hetu/omniruntime/type/ByteDataType.java new file mode 100644 index 0000000..41bd18f --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/type/ByteDataType.java @@ -0,0 +1,26 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ + +package nova.hetu.omniruntime.type; + +/** + * byte data type. + * + * @since 2025-08-05 + */ +public class ByteDataType extends DataType { + /** + * Byte singleton. + */ + public static final ByteDataType BYTE = new ByteDataType(); + + private static final long serialVersionUID = -1145142315179689320L; + + /** + * The construct. + */ + public ByteDataType() { + super(DataTypeId.OMNI_BYTE); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/type/CharDataType.java b/bindings/java/src/main/java/nova/hetu/omniruntime/type/CharDataType.java new file mode 100644 index 0000000..8e1ea21 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/type/CharDataType.java @@ -0,0 +1,54 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.type; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Objects; + +/** + * char data type. + * + * @since 2021-11-30 + */ +public class CharDataType extends VarcharDataType { + /** + * max width for char data type. + */ + public static final int MAX_WIDTH = 65_536; + + /** + * char singleton. + */ + public static final CharDataType CHAR = new CharDataType(MAX_WIDTH); + + private static final long serialVersionUID = -8306919387371983633L; + + /** + * The construct of char data type. + * + * @param width the width of char + */ + public CharDataType(@JsonProperty("width") int width) { + super(width, DataTypeId.OMNI_CHAR); + } + + @Override + public int hashCode() { + return Objects.hash(width, super.getId()); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + CharDataType other = (CharDataType) obj; + return Objects.equals(width, other.getWidth()) && Objects.equals(super.getId(), other.getId()); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/type/ContainerDataType.java b/bindings/java/src/main/java/nova/hetu/omniruntime/type/ContainerDataType.java new file mode 100644 index 0000000..c22be56 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/type/ContainerDataType.java @@ -0,0 +1,56 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.type; + +/** + * container data type. + * + * @since 2021-07-17 + */ +public class ContainerDataType extends DataType { + /** + * Container singleton. + */ + public static final ContainerDataType CONTAINER = new ContainerDataType(); + + private static final long serialVersionUID = 7653293048781110462L; + + private DataType[] fieldTypes; + + /** + * The construct of container data type. + * + * @param fieldTypes the types of data + */ + public ContainerDataType(DataType[] fieldTypes) { + super(DataTypeId.OMNI_CONTAINER); + this.fieldTypes = fieldTypes; + } + + /** + * Container construct. + */ + public ContainerDataType() { + super(DataTypeId.OMNI_CONTAINER); + } + + /** + * get number of filed types. + * + * @return the number of filedTypes + */ + public int size() { + return fieldTypes.length; + } + + /** + * get field types. + * + * @return field types + */ + public DataType[] getFieldTypes() { + return fieldTypes; + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/type/DataType.java b/bindings/java/src/main/java/nova/hetu/omniruntime/type/DataType.java new file mode 100644 index 0000000..17985d1 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/type/DataType.java @@ -0,0 +1,154 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.type; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import com.fasterxml.jackson.annotation.JsonValue; +import com.fasterxml.jackson.databind.annotation.JsonTypeIdResolver; + +import java.io.Serializable; +import java.util.Objects; + +/** + * data type. + * + * @since 2021-08-05 + */ +@JsonTypeInfo(use = JsonTypeInfo.Id.CUSTOM, include = JsonTypeInfo.As.EXISTING_PROPERTY, property = "id") +@JsonTypeIdResolver(DataTypeSerializer.DataTypeResolver.class) +public class DataType implements Serializable { + private static final long serialVersionUID = 2589766491688675794L; + + @JsonProperty + private final DataTypeId id; + + public DataType(@JsonProperty("id") DataTypeId id) { + this.id = id; + } + + public DataTypeId getId() { + return id; + } + + /** + * Create a data type object. + * + * @param typeId create data type by data type id + * @return data type + */ + public static DataType create(int typeId) { + return new DataType(DataTypeId.values()[typeId]); + } + + /** + * The data type id. + */ + public enum DataTypeId { + OMNI_NONE(0), + OMNI_INT(1), + OMNI_LONG(2), + OMNI_DOUBLE(3), + OMNI_BOOLEAN(4), + OMNI_SHORT(5), + OMNI_DECIMAL64(6), + OMNI_DECIMAL128(7), + OMNI_DATE32(8), + OMNI_DATE64(9), + OMNI_TIME32(10), + OMNI_TIME64(11), + OMNI_TIMESTAMP(12), + OMNI_INTERVAL_MONTHS(13), + OMNI_INTERVAL_DAY_TIME(14), + OMNI_VARCHAR(15), + OMNI_CHAR(16), + OMNI_CONTAINER(17), + OMNI_BYTE(18), + OMNI_INVALID(19); + + private final int value; + + DataTypeId(int value) { + this.value = value; + } + + /** + * Serialize the ordinal of enum. + * + * @return the ordinal. + */ + @JsonValue + public int toValue() { + return this.value; + } + } + + /** + * The unit of date. + */ + public enum DateUnit { + DAY(0), + MILLI(1); + + private final int value; + + DateUnit(int value) { + this.value = value; + } + + /** + * Serialize the value of enum. + * + * @return the value. + */ + @JsonValue + public int toValue() { + return this.value; + } + } + + /** + * The unit of time. + */ + public enum TimeUnit { + SEC(0), + MILLISEC(1), + MICROSEC(2), + NANOSEC(3); + + private final int value; + + TimeUnit(int value) { + this.value = value; + } + + /** + * Serialize the value of enum. + * + * @return the value. + */ + @JsonValue + public int toValue() { + return this.value; + } + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + DataType dataType = (DataType) obj; + return id == dataType.id; + } + + @Override + public int hashCode() { + return Objects.hash(id); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/type/DataTypeSerializer.java b/bindings/java/src/main/java/nova/hetu/omniruntime/type/DataTypeSerializer.java new file mode 100644 index 0000000..a794b40 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/type/DataTypeSerializer.java @@ -0,0 +1,184 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.type; + +import static nova.hetu.omniruntime.utils.OmniErrorType.OMNI_INNER_ERROR; +import static nova.hetu.omniruntime.utils.OmniErrorType.OMNI_NOSUPPORT; + +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.DatabindContext; +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.jsontype.impl.TypeIdResolverBase; + +import nova.hetu.omniruntime.utils.OmniRuntimeException; + +import java.io.IOException; + +/** + * Data type serializer, used for serialize and deserialize the data type. + * + * @since 2021-08-05 + */ +public class DataTypeSerializer { + /** + * Object mapper singleton. + */ + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + /** + * Serialize a single data type. + * + * @param dataType serialize a single data type + * @return return the string after serialization + */ + public static String serializeSingle(DataType dataType) { + try { + return OBJECT_MAPPER.writeValueAsString(dataType); + } catch (JsonProcessingException e) { + throw new OmniRuntimeException(OMNI_INNER_ERROR, "Serialization failed.", e); + } + } + + /** + * Serialize data types. + * + * @param dataTypes serialize data types + * @return return the string after serialization + */ + public static String serialize(DataType[] dataTypes) { + try { + return OBJECT_MAPPER.writeValueAsString(dataTypes); + } catch (JsonProcessingException e) { + throw new OmniRuntimeException(OMNI_INNER_ERROR, "Serialization failed.", e); + } + } + + /** + * Serialize data types. + * + * @param dataTypes serialize data types[][] + * @return return the string[] after serialization + */ + public static String[] serialize(DataType[][] dataTypes) { + String[] strings = new String[dataTypes.length]; + for (int i = 0; i < dataTypes.length; i++) { + strings[i] = serialize(dataTypes[i]); + } + return strings; + } + + /** + * Deserialize a single data type. + * + * @param type the string need to be deserialization + * @return return the vector type + */ + public static DataType deserializeSingle(String type) { + try { + return OBJECT_MAPPER.readerFor(DataType.class).readValue(type); + } catch (JsonProcessingException e) { + throw new OmniRuntimeException(OMNI_INNER_ERROR, "Deserialization failed.", e); + } + } + + /** + * Deserialize data types. + * + * @param types the string need to be deserialization + * @return return the vector types + */ + public static DataType[] deserialize(String types) { + try { + return OBJECT_MAPPER.readerFor(DataType[].class).readValue(types); + } catch (JsonProcessingException e) { + throw new OmniRuntimeException(OMNI_INNER_ERROR, "Deserialization failed.", e); + } + } + + static class DataTypeResolver extends TypeIdResolverBase { + private JavaType superType; + + @Override + public JavaType typeFromId(DatabindContext context, String id) throws IOException { + Class subType = null; + DataType.DataTypeId dataTypeId = DataType.DataTypeId.values()[Integer.parseInt(id)]; + switch (dataTypeId) { + case OMNI_INT: + subType = IntDataType.class; + break; + case OMNI_LONG: + subType = LongDataType.class; + break; + case OMNI_DOUBLE: + subType = DoubleDataType.class; + break; + case OMNI_TIMESTAMP: + subType = TimestampDataType.class; + break; + case OMNI_BOOLEAN: + subType = BooleanDataType.class; + break; + case OMNI_SHORT: + subType = ShortDataType.class; + break; + case OMNI_BYTE: + subType = ByteDataType.class; + break; + case OMNI_CONTAINER: + subType = ContainerDataType.class; + break; + case OMNI_NONE: + subType = NoneDataType.class; + break; + case OMNI_INVALID: + subType = InvalidDataType.class; + break; + case OMNI_VARCHAR: + subType = VarcharDataType.class; + break; + case OMNI_CHAR: + subType = CharDataType.class; + break; + case OMNI_DECIMAL64: + subType = Decimal64DataType.class; + break; + case OMNI_DECIMAL128: + subType = Decimal128DataType.class; + break; + case OMNI_DATE32: + subType = Date32DataType.class; + break; + case OMNI_DATE64: + subType = Date64DataType.class; + break; + default: + throw new OmniRuntimeException(OMNI_NOSUPPORT, "Unsupported data type : " + id); + } + return context.constructSpecializedType(superType, subType); + } + + @Override + public JsonTypeInfo.Id getMechanism() { + return JsonTypeInfo.Id.CUSTOM; + } + + @Override + public String idFromValue(Object value) { + return value.getClass().toString(); + } + + @Override + public String idFromValueAndType(Object value, Class suggestedType) { + return null; + } + + @Override + public void init(JavaType baseType) { + superType = baseType; + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/type/Date32DataType.java b/bindings/java/src/main/java/nova/hetu/omniruntime/type/Date32DataType.java new file mode 100644 index 0000000..c56026e --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/type/Date32DataType.java @@ -0,0 +1,57 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.type; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Objects; + +/** + * date32 data type. + * + * @since 2021-08-05 + */ +public class Date32DataType extends DataType { + /** + * Date32 singleton. + */ + public static final Date32DataType DATE32 = new Date32DataType(DateUnit.DAY); + + private static final long serialVersionUID = 8120887624931817382L; + + @JsonProperty + private final DateUnit dateUnit; + + /** + * Date32 construct. + * + * @param dateUnit the unit of date + */ + public Date32DataType(@JsonProperty("dateUnit") DateUnit dateUnit) { + super(DataTypeId.OMNI_DATE32); + this.dateUnit = dateUnit; + } + + public DateUnit getDateUnit() { + return dateUnit; + } + + @Override + public int hashCode() { + return Objects.hash(dateUnit, super.getId()); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + Date32DataType other = (Date32DataType) obj; + return (Objects.equals(dateUnit, other.getDateUnit()) && Objects.equals(super.getId(), other.getId())); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/type/Date64DataType.java b/bindings/java/src/main/java/nova/hetu/omniruntime/type/Date64DataType.java new file mode 100644 index 0000000..a678c05 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/type/Date64DataType.java @@ -0,0 +1,57 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.type; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Objects; + +/** + * date64 data type. + * + * @since 2021-08-05 + */ +public class Date64DataType extends DataType { + /** + * Date64 singleton. + */ + public static final Date64DataType DATE64 = new Date64DataType(DateUnit.DAY); + + private static final long serialVersionUID = -6927167052418618260L; + + @JsonProperty + private final DataType.DateUnit dateUnit; + + /** + * date 64 construct. + * + * @param dateUnit the unit of date + */ + public Date64DataType(@JsonProperty("dateUnit") DataType.DateUnit dateUnit) { + super(DataTypeId.OMNI_DATE64); + this.dateUnit = dateUnit; + } + + public DateUnit getDateUnit() { + return dateUnit; + } + + @Override + public int hashCode() { + return Objects.hash(dateUnit, super.getId()); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + Date64DataType other = (Date64DataType) obj; + return (Objects.equals(dateUnit, other.getDateUnit()) && Objects.equals(super.getId(), other.getId())); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/type/Decimal128DataType.java b/bindings/java/src/main/java/nova/hetu/omniruntime/type/Decimal128DataType.java new file mode 100644 index 0000000..c4e8025 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/type/Decimal128DataType.java @@ -0,0 +1,44 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.type; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Decimal128 data type. + * + * @since 2021-08-05 + */ +public class Decimal128DataType extends DecimalDataType { + /** + * Default precision value of decimal. + */ + public static final int DEFAULT_PRECISION = 38; + + /** + * Default scale value of decimal. + */ + public static final int DEFAULT_SCALE = 0; + + /** + * Decimal128 singleton. + */ + public static final Decimal128DataType DECIMAL128 = new Decimal128DataType(DEFAULT_PRECISION, DEFAULT_SCALE); + + /** + * Construct of decimal128 data type. + * + * @param precision the precision of decimal + * @param scale the scale of decimal + */ + + private static final long serialVersionUID = 7504240180082236146L; + + @JsonCreator + public Decimal128DataType(@JsonProperty("precision") int precision, @JsonProperty("scale") int scale) { + super(precision, scale, DataTypeId.OMNI_DECIMAL128); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/type/Decimal64DataType.java b/bindings/java/src/main/java/nova/hetu/omniruntime/type/Decimal64DataType.java new file mode 100644 index 0000000..4e9e9bd --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/type/Decimal64DataType.java @@ -0,0 +1,41 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.type; + +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Decimal64 data type. + * + * @since 2021-08-05 + */ +public class Decimal64DataType extends DecimalDataType { + /** + * Default precision value of decimal. + */ + public static final int DEFAULT_PRECISION = 18; + + /** + * Default scale value of decimal. + */ + public static final int DEFAULT_SCALE = 0; + + /** + * Decimal64 singleton. + */ + public static final Decimal64DataType DECIMAL64 = new Decimal64DataType(DEFAULT_PRECISION, DEFAULT_SCALE); + + private static final long serialVersionUID = -1858555622202917305L; + + /** + * Construct of decimal64 data type. + * + * @param precision the precision of decimal + * @param scale the scale of decimal + */ + public Decimal64DataType(@JsonProperty("precision") int precision, @JsonProperty("scale") int scale) { + super(precision, scale, DataTypeId.OMNI_DECIMAL64); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/type/DecimalDataType.java b/bindings/java/src/main/java/nova/hetu/omniruntime/type/DecimalDataType.java new file mode 100644 index 0000000..cfd5ab5 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/type/DecimalDataType.java @@ -0,0 +1,64 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.type; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Objects; + +/** + * Decimal data type. + * + * @since 2022-03-10 + */ +public abstract class DecimalDataType extends DataType { + private static final long serialVersionUID = -3389964658615782592L; + + @JsonProperty + private final int precision; + + @JsonProperty + private final int scale; + + /** + * Construct of decimal data type. + * + * @param precision the precision of decimal + * @param scale the scale of decimal + * @param typeId the data typeId + */ + public DecimalDataType(@JsonProperty("precision") int precision, @JsonProperty("scale") int scale, + @JsonProperty("id") DataTypeId typeId) { + super(typeId); + this.precision = precision; + this.scale = scale; + } + + public int getPrecision() { + return precision; + } + + public int getScale() { + return scale; + } + + @Override + public int hashCode() { + return Objects.hash(precision, scale, super.getId()); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + DecimalDataType other = (DecimalDataType) obj; + return (Objects.equals(precision, other.getPrecision()) && Objects.equals(scale, other.getScale()) + && Objects.equals(super.getId(), other.getId())); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/type/DoubleDataType.java b/bindings/java/src/main/java/nova/hetu/omniruntime/type/DoubleDataType.java new file mode 100644 index 0000000..0af1a55 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/type/DoubleDataType.java @@ -0,0 +1,26 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.type; + +/** + * double data type. + * + * @since 2021-08-05 + */ +public class DoubleDataType extends DataType { + /** + * Double singleton. + */ + public static final DoubleDataType DOUBLE = new DoubleDataType(); + + private static final long serialVersionUID = -5517157056853810138L; + + /** + * The construct. + */ + public DoubleDataType() { + super(DataTypeId.OMNI_DOUBLE); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/type/IntDataType.java b/bindings/java/src/main/java/nova/hetu/omniruntime/type/IntDataType.java new file mode 100644 index 0000000..f3ac887 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/type/IntDataType.java @@ -0,0 +1,26 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.type; + +/** + * int data type. + * + * @since 2021-08-05 + */ +public class IntDataType extends DataType { + /** + * Integer singleton. + */ + public static final IntDataType INTEGER = new IntDataType(); + + private static final long serialVersionUID = -5622723228847479686L; + + /** + * The construct. + */ + public IntDataType() { + super(DataTypeId.OMNI_INT); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/type/InvalidDataType.java b/bindings/java/src/main/java/nova/hetu/omniruntime/type/InvalidDataType.java new file mode 100644 index 0000000..3513271 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/type/InvalidDataType.java @@ -0,0 +1,24 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.type; + +/** + * invalid data type. The data type of unsupported/invalid data. + * + * @since 2022-04-01 + */ +public class InvalidDataType extends DataType { + /** + * Invalid singleton. + */ + public static final InvalidDataType INVALID = new InvalidDataType(); + + /** + * The construct. + */ + public InvalidDataType() { + super(DataTypeId.OMNI_INVALID); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/type/LongDataType.java b/bindings/java/src/main/java/nova/hetu/omniruntime/type/LongDataType.java new file mode 100644 index 0000000..cd531f4 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/type/LongDataType.java @@ -0,0 +1,26 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.type; + +/** + * long data type. + * + * @since 2021-08-05 + */ +public class LongDataType extends DataType { + /** + * Long singleton. + */ + public static final LongDataType LONG = new LongDataType(); + + private static final long serialVersionUID = -1589352305079680921L; + + /** + * The construct. + */ + public LongDataType() { + super(DataTypeId.OMNI_LONG); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/type/NoneDataType.java b/bindings/java/src/main/java/nova/hetu/omniruntime/type/NoneDataType.java new file mode 100644 index 0000000..000d058 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/type/NoneDataType.java @@ -0,0 +1,24 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.type; + +/** + * None data type. The data type of NULL data. + * + * @since 2022-04-01 + */ +public class NoneDataType extends DataType { + /** + * None singleton. + */ + public static final NoneDataType NONE = new NoneDataType(); + + /** + * The construct. + */ + public NoneDataType() { + super(DataTypeId.OMNI_NONE); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/type/ShortDataType.java b/bindings/java/src/main/java/nova/hetu/omniruntime/type/ShortDataType.java new file mode 100644 index 0000000..8761422 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/type/ShortDataType.java @@ -0,0 +1,26 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.type; + +/** + * short data type. + * + * @since 2021-08-05 + */ +public class ShortDataType extends DataType { + /** + * Short singleton. + */ + public static final ShortDataType SHORT = new ShortDataType(); + + private static final long serialVersionUID = -1938040225939461L; + + /** + * The construct. + */ + public ShortDataType() { + super(DataTypeId.OMNI_SHORT); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/type/TimestampDataType.java b/bindings/java/src/main/java/nova/hetu/omniruntime/type/TimestampDataType.java new file mode 100644 index 0000000..116365a --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/type/TimestampDataType.java @@ -0,0 +1,26 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. + */ + +package nova.hetu.omniruntime.type; + +/** + * timestamp data type. + * + * @since 2024-09-09 + */ +public class TimestampDataType extends DataType { + /** + * timestamp singleton. + */ + public static final TimestampDataType TIMESTAMP = new TimestampDataType(); + + private static final long serialVersionUID = -165184964293631557L; + + /** + * The construct. + */ + public TimestampDataType() { + super(DataTypeId.OMNI_TIMESTAMP); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/type/VarcharDataType.java b/bindings/java/src/main/java/nova/hetu/omniruntime/type/VarcharDataType.java new file mode 100644 index 0000000..22a818d --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/type/VarcharDataType.java @@ -0,0 +1,76 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.type; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Objects; + +/** + * varchar data type. + * + * @since 2021-08-05 + */ +public class VarcharDataType extends DataType { + /** + * max width for varchar data type. + */ + public static final int MAX_WIDTH = 1024 * 1024; + + /** + * Varchar singleton. + */ + public static final VarcharDataType VARCHAR = new VarcharDataType(MAX_WIDTH); + + private static final long serialVersionUID = -4778484134512020833L; + + /** + * average length of a varchar. + */ + @JsonProperty + protected final int width; + + /** + * The construct of varchar data type. + * + * @param width the width of varchar + */ + public VarcharDataType(@JsonProperty("width") int width) { + super(DataTypeId.OMNI_VARCHAR); + this.width = Math.min(MAX_WIDTH, width); + } + + /** + * The construct of varchar data type. + * + * @param width the width of varchar + * @param dataTypeId the types of data + */ + protected VarcharDataType(int width, DataTypeId dataTypeId) { + super(dataTypeId); + this.width = width; + } + + public int getWidth() { + return width; + } + + @Override + public int hashCode() { + return Objects.hash(width, super.getId()); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + VarcharDataType other = (VarcharDataType) obj; + return (Objects.equals(width, other.getWidth()) && Objects.equals(super.getId(), other.getId())); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/utils/JsonUtils.java b/bindings/java/src/main/java/nova/hetu/omniruntime/utils/JsonUtils.java new file mode 100644 index 0000000..f7ea792 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/utils/JsonUtils.java @@ -0,0 +1,83 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.utils; + +import static nova.hetu.omniruntime.utils.OmniErrorType.OMNI_INNER_ERROR; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; + +import java.util.ArrayList; +import java.util.List; + +/** + * json serialize/deserialize util. + * + * @since 2022-9-22 + */ +public class JsonUtils { + /** + * Object mapper singleton. + */ + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + /** + * transform String array to json + * + * @param array String[] + * @return String + * @throws OmniRuntimeException omniRuntimeException + */ + public static String jsonStringArray(String[] array) { + try { + return OBJECT_MAPPER.writeValueAsString(array); + } catch (JsonProcessingException e) { + throw new OmniRuntimeException(OMNI_INNER_ERROR, "Serialization failed.", e); + } + } + + /** + * transform each row in String[][] to json + * + * @param array String[][] + * @return String[] + */ + public static String[] jsonStringArray(String[][] array) { + List stringList = new ArrayList<>(); + for (String[] arr : array) { + stringList.add(jsonStringArray(arr)); + } + return stringList.toArray(new String[stringList.size()]); + } + + /** + * Deserialize a single json. + * + * @param json the string need to be deserialization + * @return String[] + * @throws OmniRuntimeException omniRuntimeException + */ + public static String[] deserializeJson(String json) { + try { + return OBJECT_MAPPER.readerFor(String[].class).readValue(json); + } catch (JsonProcessingException e) { + throw new OmniRuntimeException(OMNI_INNER_ERROR, "Deserialization failed.", e); + } + } + + /** + * Deserialize a single json. + * + * @param jsons the string[] need to be deserialization + * @return String[][] + */ + public static String[][] deserializeJson(String[] jsons) { + String[][] strings = new String[jsons.length][]; + for (int i = 0; i < jsons.length; i++) { + strings[i] = deserializeJson(jsons[i]); + } + return strings; + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/utils/NativeLog.java b/bindings/java/src/main/java/nova/hetu/omniruntime/utils/NativeLog.java new file mode 100644 index 0000000..5c346de --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/utils/NativeLog.java @@ -0,0 +1,50 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.utils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * The global log init. + * + * @since 20220320 + */ + +public class NativeLog { + private static volatile NativeLog instance; + + private static final Logger logger = LoggerFactory.getLogger(NativeLog.class); + + /** + * Native Log init. + * + * @since 20220320 + */ + private NativeLog() { + initLog(); + } + + /** + * Native getInstance. + * + * @return new NativeLog + */ + public static NativeLog getInstance() { + if (instance == null) { + synchronized (NativeLog.class) { + if (instance == null) { + instance = new NativeLog(); + } + } + } + return instance; + } + + /** + * Init global logger. + */ + public static native void initLog(); +} \ No newline at end of file diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/utils/NullsBufHelper.java b/bindings/java/src/main/java/nova/hetu/omniruntime/utils/NullsBufHelper.java new file mode 100644 index 0000000..d22f7c7 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/utils/NullsBufHelper.java @@ -0,0 +1,170 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ + +package nova.hetu.omniruntime.utils; + +import nova.hetu.omniruntime.vector.OmniBuffer; + +/** + * nulls buffer util. + * + * @since 2025-02-15 + */ +public class NullsBufHelper { + /** + * Byte index obtained from the bit field index position + * + * @param absoluteBitIndex - Bit field index position + * @return byte index + */ + public static int byteIndex(int absoluteBitIndex) { + return absoluteBitIndex >> 3; + } + + /** + * Obtain the bit field where the actual byte position is located based on the bit field index position. + * + * @param absoluteBitIndex Bit field index position + * @return Bit field in which the actual byte position is located + */ + public static int bitIndex(int absoluteBitIndex) { + return absoluteBitIndex & 7; + } + + /** + * Calculate the number of bytes occupied by the number of bits. + * + * @param bits Number of bit fields + * @return number of bytes + */ + public static int nBytes(int bits) { + return (bits + (8 - 1)) / 8; + } + + /** + * Indicates whether the position specified by the bit field is marked. + * + * @param nullsBuf bit field memory + * @param index Indexes + * @return flag + */ + public static int isSet(OmniBuffer nullsBuf, int index) { + int byteIndex = byteIndex(index); + int bitIndex = bitIndex(index); + int currentByte = nullsBuf.getByte(byteIndex); + return currentByte >> bitIndex & 1; + } + + /** + * Indicates whether the position specified by the bit field is marked. + * + * @param isNulls bit field memory + * @param index Indexes + * @return flag + */ + public static int isSet(byte[] isNulls, int index) { + return isNulls[byteIndex(index)] >> bitIndex(index) & 1; + } + + /** + * Indicates the flag of the specified position of the clear bit field. + * + * @param nullsBuf bit field memory + * @param index Indexes + */ + public static void unsetBit(OmniBuffer nullsBuf, int index) { + int byteIndex = byteIndex(index); + int bitIndex = bitIndex(index); + int currentByte = nullsBuf.getByte(byteIndex); + int bitMask = 1 << bitIndex; + currentByte &= ~bitMask; + nullsBuf.setByte(byteIndex, (byte) currentByte); + } + + /** + * Specifies the position of the flag bit field. + * + * @param nullsBuf bit field memory + * @param index Indexes + */ + public static void setBit(OmniBuffer nullsBuf, int index) { + int byteIndex = byteIndex(index); + int bitIndex = bitIndex(index); + int currentByte = nullsBuf.getByte(byteIndex); + int bitMask = 1 << bitIndex; + currentByte |= bitMask; + nullsBuf.setByte(byteIndex, (byte) currentByte); + } + + /** + * Bulk Mark Bit Field Memory + * + * @param nullsBuf bit field memory + * @param index Mark Start Position + * @param isNulls Tag Value array + * @param srcStart Start position of the tag value array + * @param length Tag Value array length + */ + public static void setBit(OmniBuffer nullsBuf, int index, byte[] isNulls, int srcStart, int length) { + int i = 0; + while (i < length) { + setValidityBit(nullsBuf, index + i, isNulls[srcStart + i]); + i++; + } + } + + /** + * Bulk Mark Bit Field Memory + * + * @param nullsBuf bit field memory + * @param index Mark Start Position + * @param isBitNulls Tag Value array (Bit) + * @param srcStart Start position of the tag value array (Bit) + * @param length Tag Value array length (Bit) + */ + public static void setBitByBits(OmniBuffer nullsBuf, int index, byte[] isBitNulls, int srcStart, int length) { + int i = 0; + while (i < length) { + setValidityBit(nullsBuf, index + i, isSet(isBitNulls, srcStart + i)); + i++; + } + } + + /** + * Obtains bit field tags in batches and adds them to an array. + * + * @param nullsBuf bit field memory + * @param index Mark Start Position + * @param nullsArray Tag Value array + * @param targetStart Start position of the tag value array + * @param length Tag Value array length + */ + public static void getBytes(OmniBuffer nullsBuf, int index, byte[] nullsArray, int targetStart, int length) { + int i = 0; + while (i < length) { + nullsArray[targetStart + i] = (byte) isSet(nullsBuf, index + i); + i++; + } + } + + /** + * Marks the specified position in the bit field memory with a specified value. + * + * @param nullsBuf bit field memory + * @param index Indexes + * @param value value mark + */ + public static void setValidityBit(OmniBuffer nullsBuf, int index, int value) { + int byteIndex = byteIndex(index); + int bitIndex = bitIndex(index); + int currentByte = nullsBuf.getByte(byteIndex); + int bitMask = 1 << bitIndex; + if (value != 0) { + currentByte |= bitMask; + } else { + currentByte &= ~bitMask; + } + nullsBuf.setByte(byteIndex, (byte) currentByte); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/utils/OmniErrorType.java b/bindings/java/src/main/java/nova/hetu/omniruntime/utils/OmniErrorType.java new file mode 100644 index 0000000..2ef0679 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/utils/OmniErrorType.java @@ -0,0 +1,60 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.utils; + +/** + * The enum Omni error type. + * + * @since 2021-06-30 + */ +public enum OmniErrorType { + /** + * Omni undefined omni error type. + */ + OMNI_UNDEFINED(1), + /** + * Omni nosupport omni error type. + */ + OMNI_NOSUPPORT(2), + /** + * Omni native error omni error type. + */ + OMNI_NATIVE_ERROR(3), + + /** + * Omni runtime param error. + */ + OMNI_PARAM_ERROR(4), + + /** + * Omni inner error. + */ + OMNI_INNER_ERROR(5), + + /** + * Omni vec or vectbatch double free + */ + OMNI_DOUBLE_FREE(6), + + /** + * Omni java udf error + */ + OMNI_JAVA_UDF_ERROR(7); + + private final int value; + + OmniErrorType(int value) { + this.value = value; + } + + /** + * Gets value. + * + * @return the value + */ + public int getValue() { + return value; + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/utils/OmniRuntimeException.java b/bindings/java/src/main/java/nova/hetu/omniruntime/utils/OmniRuntimeException.java new file mode 100644 index 0000000..96e5d70 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/utils/OmniRuntimeException.java @@ -0,0 +1,62 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.utils; + +import static nova.hetu.omniruntime.utils.OmniErrorType.OMNI_NATIVE_ERROR; + +/** + * The type Omni runtime exception. + * + * @since 2021-06-30 + */ +public class OmniRuntimeException extends RuntimeException { + private static final long serialVersionUID = -4352889723335051173L; + + private final OmniErrorType errorType; + + /** + * this method for jni method call + * + * @param msg error message + */ + public OmniRuntimeException(String msg) { + super(msg); + this.errorType = OMNI_NATIVE_ERROR; + } + + /** + * Instantiates a new Omni runtime exception. + * + * @param errorType the error type + * @param msg the msg + */ + public OmniRuntimeException(OmniErrorType errorType, String msg) { + super(msg); + this.errorType = errorType; + } + + /** + * Instantiates a new Omni runtime exception. + * + * @param errorType the error type + * @param string the string + * @param throwable the throwable + */ + public OmniRuntimeException(OmniErrorType errorType, String string, Throwable throwable) { + super(string, throwable); + this.errorType = errorType; + } + + /** + * Instantiates a new Omni runtime exception. + * + * @param errorType the error type + * @param throwable the throwable + */ + public OmniRuntimeException(OmniErrorType errorType, Throwable throwable) { + super(throwable); + this.errorType = errorType; + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/utils/ParseUtil.java b/bindings/java/src/main/java/nova/hetu/omniruntime/utils/ParseUtil.java new file mode 100644 index 0000000..ebca6c7 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/utils/ParseUtil.java @@ -0,0 +1,88 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.utils; + +import static nova.hetu.omniruntime.memory.MemoryManager.UNLIMITED; + +import com.sun.management.OperatingSystemMXBean; + +import java.lang.management.ManagementFactory; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * parse memory size + * + * @since 2022-04-02 + */ +public class ParseUtil { + private static final Pattern PATTERN = Pattern.compile("^\\s*(\\d+(?:\\.\\d+)?)\\s*([a-zA-Z]+)\\s*$"); + + private ParseUtil() { + } + + /** + * parse memory size to byte, like 1B, 1KB, 1MB, 1GB. + * + * @param size capacity size with unit + * @return size in bytes + */ + public static long parserMemoryParameters(String size) { + Matcher matcher = PATTERN.matcher(size); + if (!matcher.matches()) { + throw new OmniRuntimeException(OmniErrorType.OMNI_PARAM_ERROR, + "size is not a valid data size string" + size); + } + + long value = Long.parseLong(matcher.group(1)); + String unitString = matcher.group(2); + + for (Unit unit : Unit.values()) { + if (unit.getUnitString().equals(unitString)) { + long limit = value * unit.getFactor(); + long systemFreeMemory = getOperatorSystemFreeMemorySize(); + if (limit >= systemFreeMemory || limit < UNLIMITED) { + throw new OmniRuntimeException(OmniErrorType.OMNI_PARAM_ERROR, + "OMNI_OFFHEAP_MEMORY_SIZE exceeds system free memorySize:" + systemFreeMemory); + } + return limit; + } + } + throw new OmniRuntimeException(OmniErrorType.OMNI_PARAM_ERROR, "Unknown unit:" + unitString); + } + + private static long getOperatorSystemFreeMemorySize() { + java.lang.management.OperatingSystemMXBean langOSMXBean = ManagementFactory.getOperatingSystemMXBean(); + if (langOSMXBean instanceof OperatingSystemMXBean) { + OperatingSystemMXBean osmxb = (OperatingSystemMXBean) langOSMXBean; + return osmxb.getFreePhysicalMemorySize(); + } + throw new OmniRuntimeException(OmniErrorType.OMNI_UNDEFINED, "Cannot get system freeMemorySize"); + } + + enum Unit { + BYTE(1L, "B"), + KILOBYTE(1L << 10, "KB"), + MEGABYTE(1L << 20, "MB"), + GIGABYTE(1L << 30, "GB"), + TERABYTE(1L << 40, "TB"); + + private final long factor; + private final String unitString; + + Unit(long factor, String unitString) { + this.factor = factor; + this.unitString = unitString; + } + + long getFactor() { + return factor; + } + + String getUnitString() { + return unitString; + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/utils/ShuffleHashHelper.java b/bindings/java/src/main/java/nova/hetu/omniruntime/utils/ShuffleHashHelper.java new file mode 100644 index 0000000..323b511 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/utils/ShuffleHashHelper.java @@ -0,0 +1,22 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ + +package nova.hetu.omniruntime.utils; + +/** + * shuffle hash + * + * @since 2025-03-21 + */ +public class ShuffleHashHelper { + /** + * use to compute shuffle hash partitionIds + * + * @param vecAddrArray the array of nativeVec + * @param partitionNum the num of partition + * @param rowCount the num of row + * @return the partitionIds of vec + */ + public static native long computePartitionIds(long[] vecAddrArray, int partitionNum, int rowCount); +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/utils/TraceUtil.java b/bindings/java/src/main/java/nova/hetu/omniruntime/utils/TraceUtil.java new file mode 100644 index 0000000..c04e57d --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/utils/TraceUtil.java @@ -0,0 +1,32 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2020. All rights reserved. + */ + +package nova.hetu.omniruntime.utils; + +import static java.lang.Math.min; +import static java.lang.System.lineSeparator; + +import java.util.Arrays; +import java.util.StringJoiner; + +/** + * trace util. + * + * @since 2021-10-21 + */ +public class TraceUtil { + /** + * used for get stack trace of current thread. + * + * @return the stack trace of current thread + */ + public static String stack() { + StackTraceElement[] elements = Thread.currentThread().getStackTrace(); + StringJoiner stack = new StringJoiner(lineSeparator() + "\t"); + Arrays.stream(elements).skip(2).limit(min(25, elements.length - 1)).forEach(item -> { + stack.add(item.toString()); + }); + return stack.toString(); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/vector/BooleanVec.java b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/BooleanVec.java new file mode 100644 index 0000000..9e25a91 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/BooleanVec.java @@ -0,0 +1,114 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import nova.hetu.omniruntime.type.BooleanDataType; + +/** + * boolean vec. + * + * @since 2021-07-17 + */ +public class BooleanVec extends FixedWidthVec { + private static final int BYTES = 1; + + public BooleanVec(int size) { + super(size * BYTES, size, VecEncoding.OMNI_VEC_ENCODING_FLAT, BooleanDataType.BOOLEAN); + } + + public BooleanVec(long nativeVector) { + super(nativeVector, BooleanDataType.BOOLEAN, BYTES); + } + + public BooleanVec(long nativeVector, long nativeValueBufAddress, long nativeVectorNullBufAddress, int size) { + super(nativeVector, nativeValueBufAddress, nativeVectorNullBufAddress, BYTES * size, size, + BooleanDataType.BOOLEAN); + } + + private BooleanVec(BooleanVec vector, int offset, int length) { + super(vector, offset, length, length * BYTES); + } + + private BooleanVec(BooleanVec vector, int[] positions, int offset, int length) { + super(vector, positions, offset, length, length * BYTES); + } + + /** + * Sets the specified boolean at the specified absolute. + * + * @param index the element offset in vec + * @param value the value of vec + */ + public void set(int index, boolean value) { + valuesBuf.setByte(index, value ? (byte) 1 : (byte) 0); + } + + /** + * get the specified boolean at the specified absolute. + * + * @param index the element offset in vec + * @return if the value of 1 returns true, otherwise it returns false + */ + public boolean get(int index) { + return valuesBuf.getByte(index * BYTES) == 1; + } + + /** + * get boolean values from the specified position. + * + * @param index the position of element + * @param length the number of element + * @return boolean value array + */ + public boolean[] get(int index, int length) { + byte[] target = valuesBuf.getBytes(index * BYTES, length); + return transformByteToBoolean(target, 0, length); + } + + /** + * Batch sets the specified boolean at the specified absolute. + * + * @param values the value of the element to be written + * @param index the element index in vec + * @param start the element index in values + * @param length the number of elements that need to written + */ + public void put(boolean[] values, int index, int start, int length) { + byte[] data = transformBooleanToByte(values, start, length); + valuesBuf.setBytes(index, data, 0, length); + } + + /** + * Batch sets the specified byte at the specified absolute. + * + * @param values the value of the element to be written + * @param index the element index in vec + * @param start the element index in values + * @param length the number of elements that need to written + */ + public void put(byte[] values, int index, int start, int length) { + valuesBuf.setBytes(index * BYTES, values, start * BYTES, length * BYTES); + } + + @Override + public BooleanVec slice(int startIdx, int length) { + return new BooleanVec(this, startIdx, length); + } + + @Override + public BooleanVec copyPositions(int[] positions, int offset, int length) { + return new BooleanVec(this, positions, offset, length); + } + + @Override + public int getRealValueBufCapacityInBytes() { + return size * BYTES; + } + + @Override + public int getCapacityInBytes() { + return size * BYTES; + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/vector/ByteVec.java b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/ByteVec.java new file mode 100644 index 0000000..e38e6d5 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/ByteVec.java @@ -0,0 +1,101 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import nova.hetu.omniruntime.type.ByteDataType; + +/** + * byte vec. + * + * @since 2025-08-05 + */ +public class ByteVec extends FixedWidthVec { + private static final int BYTES = Byte.BYTES; + + public ByteVec(int size) { + super(size * BYTES, size, VecEncoding.OMNI_VEC_ENCODING_FLAT, ByteDataType.BYTE); + } + + public ByteVec(long nativeVector) { + super(nativeVector, ByteDataType.BYTE, BYTES); + } + + public ByteVec(long nativeVector, long nativeValueBufAddress, long nativeVectorNullBufAddress, int size) { + super(nativeVector, nativeValueBufAddress, nativeVectorNullBufAddress, BYTES * size, size, ByteDataType.BYTE); + } + + private ByteVec(ByteVec vector, int offset, int length) { + super(vector, offset, length, length * BYTES); + } + + private ByteVec(ByteVec vector, int[] positions, int offset, int length) { + super(vector, positions, offset, length, length * BYTES); + } + + /** + * get the specified byte at the specified absolute. + * + * @param index the element offset in vec + * @return byte value + */ + public byte get(int index) { + return valuesBuf.getByte(index * BYTES); + } + + /** + * get byte values from the specified position. + * + * @param index the position of element + * @param length the number of element + * @return byte value array + */ + public byte[] get(int index, int length) { + byte[] target = new byte[length]; + valuesBuf.getBytes(index * BYTES, target, 0, length * BYTES); + return target; + } + + /** + * Sets the specified byte at the specified absolute. + * + * @param index the element offset in vec + * @param value the value of vec + */ + public void set(int index, byte value) { + valuesBuf.setByte(index * BYTES, value); + } + + /** + * Batch sets the specified byte at the specified absolute. + * + * @param values the value of the element to be written + * @param offset the element offset in vec + * @param start the element index in values + * @param length the number of elements that need to written + */ + public void put(byte[] values, int offset, int start, int length) { + valuesBuf.setBytes(offset * BYTES, values, start * BYTES, length * BYTES); + } + + @Override + public ByteVec slice(int start, int length) { + return new ByteVec(this, start, length); + } + + @Override + public ByteVec copyPositions(int[] positions, int offset, int length) { + return new ByteVec(this, positions, offset, length); + } + + @Override + public int getRealValueBufCapacityInBytes() { + return size * BYTES; + } + + @Override + public int getCapacityInBytes() { + return size * BYTES; + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/vector/ContainerVec.java b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/ContainerVec.java new file mode 100644 index 0000000..685343e --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/ContainerVec.java @@ -0,0 +1,201 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import static nova.hetu.omniruntime.vector.VecEncoding.OMNI_VEC_ENCODING_CONTAINER; + +import nova.hetu.omniruntime.type.ContainerDataType; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DataTypeSerializer; + +import java.nio.ByteBuffer; + +/** + * container vec. + * + * @since 2021-08-05 + */ +public class ContainerVec extends FixedWidthVec { + private static final int BYTES = Long.BYTES; + + private int positionCount; + + private DataType[] dataTypes; + + /** + * The routine will use the specialized vector allocator to allocate new vector. + * + * @param vectorCount the number of vector + * @param positionCount the actual number of value of vector + * @param vectorAddresses the address of vector + * @param dataTypes the data type of this vector + */ + public ContainerVec(int vectorCount, int positionCount, long[] vectorAddresses, DataType[] dataTypes) { + super(vectorCount * BYTES, positionCount, OMNI_VEC_ENCODING_CONTAINER, ContainerDataType.CONTAINER); + this.positionCount = positionCount; + this.dataTypes = dataTypes; + setDataTypesNative(getNativeVector(), DataTypeSerializer.serialize(dataTypes)); + put(vectorAddresses, 0, 0, vectorAddresses.length); + } + + /** + * this constructor is for JNI to initiate. The 'positionCount' is the number of + * row of all vectors in this ContainerVec. The number of element in this + * ContainerVec is the 'size' attribute of Vec. + * + * @param nativeVector native vector address + */ + public ContainerVec(long nativeVector) { + super(nativeVector, ContainerDataType.CONTAINER, BYTES); + // get other attributes from native + this.positionCount = getPositionNative(nativeVector); + this.dataTypes = DataTypeSerializer.deserialize(getDataTypesNative(nativeVector)); + } + + /** + * The routine will use native vector to initialize a new vector. + * + * @param nativeVector native vector address + * @param nativeValueBufAddress valueBuf address of native vector + * @param nativeVectorNullBufAddress nullBuf address of native vector + * @param size the actual number of value of vector + */ + public ContainerVec(long nativeVector, long nativeValueBufAddress, long nativeVectorNullBufAddress, int size) { + super(nativeVector, nativeValueBufAddress, nativeVectorNullBufAddress, size * BYTES, size, + ContainerDataType.CONTAINER); + // get other attributes from native + this.positionCount = getPositionNative(nativeVector); + this.dataTypes = DataTypeSerializer.deserialize(getDataTypesNative(nativeVector)); + } + + /** + * This constructor of vector is just for shuffle compilation to pass, it will + * be removed later. + * + * @param data data of vector + * @param capacityInBytes size in bytes of data + */ + @Deprecated + public ContainerVec(ByteBuffer data, int capacityInBytes) { + super(capacityInBytes, data.limit(), OMNI_VEC_ENCODING_CONTAINER, ContainerDataType.CONTAINER); + } + + private ContainerVec(ContainerVec containerVec, int start, int length, DataType[] dataTypes) { + super(containerVec, start, length, dataTypes.length * BYTES); + this.positionCount = length; + this.dataTypes = dataTypes; + } + + private ContainerVec(ContainerVec vector, int[] positions, int offset, int length, DataType[] dataTypes) { + super(vector, positions, offset, length, dataTypes.length * BYTES); + this.positionCount = length; + this.dataTypes = dataTypes; + } + + private static native int getPositionNative(long nativeVector); + + private static native String getDataTypesNative(long nativeVector); + + private static native void setDataTypesNative(long nativeVector, String dataTypes); + + /** + * get the specified long at the specified absolute. + * + * @param index the element offset in vec + * @return the value of long + */ + public long get(int index) { + return valuesBuf.getLong(index * BYTES); + } + + /** + * Sets the specified long at the specified absolute. + * + * @param index the element offset in vec + * @param value the value of vec + */ + public void set(int index, long value) { + valuesBuf.setLong((index) * BYTES, value); + } + + /** + * Batch sets the specified long at the specified absolute. + * + * @param values the value of the element to be written + * @param offset the element offset in vec + * @param start the element index in values + * @param length the number of elements that need to written + */ + public void put(long[] values, int offset, int start, int length) { + valuesBuf.setLongArray(offset, values, start, length * BYTES); + } + + /** + * get position count. + * + * @return positionCount + */ + public int getPositionCount() { + return this.positionCount; + } + + /** + * get data types. + * + * @return dataTypes + */ + public DataType[] getDataTypes() { + return this.dataTypes; + } + + /** + * get the specified long at the specified absolute. + * + * @param index the element offset in vec + * @return get(index) + */ + public long getVector(int index) { + return get(index); + } + + /** + * get position count. + * + * @return positionCount + */ + public int getSize() { + return positionCount; + } + + @Override + public ContainerVec slice(int start, int length) { + return new ContainerVec(this, start, length, dataTypes); + } + + @Override + public ContainerVec copyPositions(int[] positions, int offset, int length) { + return new ContainerVec(this, positions, offset, length, dataTypes); + } + + @Override + public int getRealValueBufCapacityInBytes() { + return getCapacityInBytes(); + } + + @Override + public VecEncoding getEncoding() { + return OMNI_VEC_ENCODING_CONTAINER; + } + + /** + * get the encoding of vector at index position. + * + * @param index the element offset in vec + * @return encoding of vector at index position + */ + public VecEncoding getVecEncoding(int index) { + return VecEncoding.values()[getVecEncodingNative(getVector(index))]; + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/vector/Decimal128Vec.java b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/Decimal128Vec.java new file mode 100644 index 0000000..34979bb --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/Decimal128Vec.java @@ -0,0 +1,189 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2020. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import nova.hetu.omniruntime.type.Decimal128DataType; + +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +/** + * 128-bit decimal vec. + * + * @since 2021-07-17 + */ +public class Decimal128Vec extends DecimalVec { + private static final int BYTES = Long.BYTES * 2; + + public Decimal128Vec(int size) { + super(size, BYTES, Decimal128DataType.DECIMAL128); + } + + public Decimal128Vec(long nativeVector) { + super(nativeVector, BYTES, Decimal128DataType.DECIMAL128); + } + + public Decimal128Vec(long nativeVector, long nativeValueBufAddress, long nativeVectorNullBufAddress, int size) { + super(nativeVector, nativeValueBufAddress, nativeVectorNullBufAddress, size * BYTES, size, BYTES, + Decimal128DataType.DECIMAL128); + } + + private Decimal128Vec(Decimal128Vec vector, int offset, int length) { + super(vector, offset, length); + } + + private Decimal128Vec(Decimal128Vec vector, int[] positions, int offset, int length) { + super(vector, positions, offset, length); + } + + /** + * split a vec into two vec according to the specified index and length. + * + * @param start starting index + * @param length slice length + * @return new vec + */ + @Override + public Decimal128Vec slice(int start, int length) { + return new Decimal128Vec(this, start, length); + } + + /** + * copy a new vec based on the positions. + * + * @param positions all positions in vec + * @param offset position offset + * @param length the number of elements to be copied + * @return new vec + */ + @Override + public Decimal128Vec copyPositions(int[] positions, int offset, int length) { + return new Decimal128Vec(this, positions, offset, length); + } + + /** + * transfer long to bytes + * + * @param input input long + * @return new bytes + */ + public static byte[] longToBytes(long input) { + ByteBuffer buffer = ByteBuffer.allocate(Long.BYTES); + buffer.putLong(input); + return buffer.array(); + } + + /** + * transfer bytes to long + * + * @param bytes input bytes + * @return new long + */ + public static long bytesToLong(byte[] bytes) { + ByteBuffer buffer = ByteBuffer.allocate(Long.BYTES); + buffer.put(bytes); + buffer.flip(); // need flip + return buffer.getLong(); + } + + /** + * get long array from 128-bit BigInteger + * + * @param bigInteger input + * @return new long array + */ + public static long[] putDecimal(BigInteger bigInteger) { + return putDecimal(bigInteger.toByteArray(), bigInteger.compareTo(BigInteger.ZERO) == -1); + } + + /** + * get long array from 128-bit BigInteger bytes + * + * @param bytes BigInteger bytes + * @param isNegative isNegative + * @return new long array + */ + public static long[] putDecimal(byte[] bytes, boolean isNegative) { + ByteBuffer d128Buffer = ByteBuffer.allocate(Long.BYTES * 2); + int byteArrayLength = bytes.length; + for (int i = 0; i < byteArrayLength; i++) { + d128Buffer.put(bytes[byteArrayLength - i - 1]); + } + if (isNegative) { + for (int i = byteArrayLength; i < 2 * Long.BYTES; i++) { + d128Buffer.put((byte) -1); + } + } + d128Buffer.clear(); + d128Buffer.order(ByteOrder.LITTLE_ENDIAN); + long[] result = new long[2]; + result[0] = d128Buffer.getLong(); + result[1] = d128Buffer.getLong(); + return result; + } + + /** + * get 128-bit BigInteger from long array + * + * @param longs input + * @return new BigInteger + */ + public static BigInteger getDecimal(long[] longs) { + byte[] bytes = new byte[Long.BYTES * 2]; + byte[] highBytes = longToBytes(longs[1]); + byte[] lowBytes = longToBytes(longs[0]); + System.arraycopy(highBytes, 0, bytes, 0, Long.BYTES); + System.arraycopy(lowBytes, 0, bytes, 8, Long.BYTES); + return new BigInteger(bytes); + } + + /** + * please use this method to set jdk BigInteger to Decimal128Vec + * + * @param index row index + * @param decimal input value + */ + public void setBigInteger(int index, BigInteger decimal) { + super.set(index, putDecimal(decimal)); + } + + /** + * set BigInteger bytes to Decimal128Vec + * + * @param index row index + * @param bigIntegerBytes BigInteger bytes + * @param isNegative isNegative + */ + public void setBigInteger(int index, byte[] bigIntegerBytes, boolean isNegative) { + super.set(index, putDecimal(bigIntegerBytes, isNegative)); + } + + /** + * please use this method to get jdk BigInteger from Decimal128Vec + * + * @param index row index + * @return new BigInteger + */ + public BigInteger getBigInteger(int index) { + return getDecimal(super.get(index)); + } + + /** + * use this method to get jdk BigInteger bytes and isNegative from Decimal128Vec + * + * @param index row index + * @return isNegative and BigInteger bytes + */ + public byte[] getBytes(int index) { + long[] longs = super.get(index); + byte[] bytes = new byte[Long.BYTES * 2]; + byte[] highBytes = longToBytes(longs[1]); + byte[] lowBytes = longToBytes(longs[0]); + System.arraycopy(highBytes, 0, bytes, 0, Long.BYTES); + System.arraycopy(lowBytes, 0, bytes, 8, Long.BYTES); + return bytes; + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/vector/DecimalVec.java b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/DecimalVec.java new file mode 100644 index 0000000..27b4abf --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/DecimalVec.java @@ -0,0 +1,158 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import static nova.hetu.omniruntime.utils.OmniErrorType.OMNI_PARAM_ERROR; + +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.utils.OmniRuntimeException; + +/** + * base class of decimal vec. + * + * @since 2021-07-17 + */ +public abstract class DecimalVec extends FixedWidthVec { + private final int typeWidth; + + /** + * Ihe routine will use GLOBAL memory pool when there is no specialized vector + * allocator. + * + * @param size the actual number of value of vector + * @param typeLength the length of this data type + * @param type the data type of this vector + */ + public DecimalVec(int size, int typeLength, DataType type) { + super(size * typeLength, size, VecEncoding.OMNI_VEC_ENCODING_FLAT, type); + this.typeWidth = getTypeWidth(typeLength); + } + + /** + * The routine will use native vector to initialize a new vector. + * + * @param nativeVector native vector address + * @param typeLength the length of this data type + * @param type the type of this vector + */ + public DecimalVec(long nativeVector, int typeLength, DataType type) { + super(nativeVector, type, typeLength); + this.typeWidth = getTypeWidth(typeLength); + } + + /** + * The routine will use native vector to initialize a new vector. + * + * @param nativeVector native vector address + * @param nativeValueBufAddress valueBuf address of native vector + * @param nativeVectorNullBufAddress nullBuf address of native vector + * @param capacityInBytes capacity in bytes of vector + * @param size the actual number of value of vector + * @param typeLength the length of this data type + * @param type the type of this vector + */ + public DecimalVec(long nativeVector, long nativeValueBufAddress, long nativeVectorNullBufAddress, + int capacityInBytes, int size, int typeLength, DataType type) { + super(nativeVector, nativeValueBufAddress, nativeVectorNullBufAddress, capacityInBytes, size, type); + this.typeWidth = getTypeWidth(typeLength); + } + + /** + * The routine is just for slicing and copyRegion vector operator. + * + * @param vector the vector need to be sliced or copyRegion + * @param offset When a vector has been sliced or copyRegion, this value will + * point to where is the new slice {@link Vec} start + * @param length the number of value + */ + protected DecimalVec(DecimalVec vector, int offset, int length) { + super(vector, offset, length, length * vector.typeWidth * Long.BYTES); + this.typeWidth = vector.typeWidth; + } + + /** + * The routine is just for copyPosition vector operator. + * + * @param vector the vector need to be copy + * @param positions the original vector positions + * @param offset offset of positions in the input parameter + * @param length number of elements copied + */ + protected DecimalVec(DecimalVec vector, int[] positions, int offset, int length) { + super(vector, positions, offset, length, length * vector.typeWidth * Long.BYTES); + this.typeWidth = vector.typeWidth; + } + + private int getTypeWidth(int typeLength) { + return typeLength / Long.BYTES; + } + + /** + * get the specified decimal at the specified absolute + * + * @param index the element offset in vec + * @return decimal value, from high to low + */ + public long[] get(int index) { + long[] value = new long[this.typeWidth]; + int offset = index * this.typeWidth; + for (int i = 0; i < this.typeWidth; i++) { + value[i] = valuesBuf.getLong((offset + i) * Long.BYTES); + } + return value; + } + + /** + * get long values from the specified position with element number + * + * @param index the position of element + * @param length the number of element + * @return long value array + */ + public long[] get(int index, int length) { + long[] value = new long[this.typeWidth * length]; + int offset = index * this.typeWidth; + valuesBuf.getLongArray(offset * Long.BYTES, value, 0, value.length * Long.BYTES); + return value; + } + + /** + * Sets the specified decimal value at the specified absolute + * + * @param index the element offset in vec + * @param value the value of the element to be written, from high to low + */ + public void set(int index, long[] value) { + int offset = index * this.typeWidth; + for (int i = 0; i < this.typeWidth; i++) { + valuesBuf.setLong((offset + i) * Long.BYTES, value[i]); + } + } + + /** + * Batch sets the specified long at the specified absolute + * + * @param values the value of the element to be written + * @param offset the element offset in vec + * @param start the element index in values + * @param length the number of elements that need to written + */ + public void put(long[] values, int offset, int start, int length) { + if (length % this.typeWidth != 0) { + throw new OmniRuntimeException(OMNI_PARAM_ERROR, "length " + length + "is error."); + } + valuesBuf.setLongArray(offset * typeWidth * Long.BYTES, values, start * Long.BYTES, length * Long.BYTES); + } + + @Override + public int getRealValueBufCapacityInBytes() { + return size * typeWidth * Long.BYTES; + } + + @Override + public int getCapacityInBytes() { + return size * typeWidth * Long.BYTES; + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/vector/DictionaryVec.java b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/DictionaryVec.java new file mode 100644 index 0000000..e5e562a --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/DictionaryVec.java @@ -0,0 +1,351 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2020. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import static nova.hetu.omniruntime.vector.VariableWidthVec.getValueOffsetsNative; + +import com.google.common.annotations.VisibleForTesting; + +import nova.hetu.omniruntime.type.DataType; +import sun.misc.Unsafe; + +/** + * dictionary vec. + * + * @since 2021-07-17 + */ +public class DictionaryVec extends FixedWidthVec { + private static final int BYTES = Integer.BYTES; + + private Vec dictionary; + + private long dataAddress; + private long offsetsAddress; + + /** + * The routine will use native vector to initialize a new dictionary vector. + * + * @param nativeVector native dictionary vector address + * @param dataType vector datatype + */ + public DictionaryVec(long nativeVector, DataType dataType) { + super(nativeVector, dataType, BYTES); + } + + /** + * The routine will use native vector to initialize a new dictionary vector. + * + * @param nativeVector native vector address + * @param nativeValueBufAddress valueBuf address of native vector + * @param nativeVectorNullBufAddress nullBuf address of native vector + * @param size the actual number of value of vector(ids) + * @param dataType the dataType of native vector + */ + public DictionaryVec(long nativeVector, long nativeValueBufAddress, long nativeVectorNullBufAddress, int size, + DataType dataType) { + super(nativeVector, nativeValueBufAddress, nativeVectorNullBufAddress, size * BYTES, size, dataType); + dataAddress = getDictionaryNative(nativeVector, getType().getId().toValue()); + } + + /** + * The routine will use the specialized vector allocator to allocate new vector. + * + * @param dictionary the specialized vector + * @param ids the int array + */ + public DictionaryVec(Vec dictionary, int[] ids) { + super(dictionary, ids, ids.length * BYTES, dictionary.getType()); + dataAddress = getDictionaryNative(getNativeVector(), getType().getId().toValue()); + } + + private DictionaryVec(DictionaryVec vector, int offset, int length) { + super(vector, offset, length, length * BYTES); + dataAddress = getDictionaryNative(vector.getNativeVector(), vector.getType().getId().toValue()); + } + + private DictionaryVec(DictionaryVec vector, int[] positions, int offset, int length) { + super(vector, positions, offset, length, length * BYTES); + dataAddress = getDictionaryNative(vector.getNativeVector(), vector.getType().getId().toValue()); + } + + private static native long getDictionaryNative(long nativeVector, int dataTypeId); + + public Vec getDictionary() { + return dictionary; + } + + /** + * for the UT of getDictionaryNative with empty strings. + * + * @return the address of dictionary container + */ + @VisibleForTesting + public long getDataAddress() { + return dataAddress; + } + + /** + * v2 need expand dictionary + * @return expanded vector + */ + public Vec expandDictionary() { + int size = getSize(); + DataType dataType = getType(); + Vec vector = VecFactory.createFlatVec(size, dataType); + vector.setNulls(0, getValuesNulls(0, size), 0, size); + switch (dataType.getId()) { + case OMNI_INT: + case OMNI_DATE32: + setValue(size, (IntVec) vector); + break; + case OMNI_LONG: + case OMNI_TIMESTAMP: + case OMNI_DATE64: + case OMNI_DECIMAL64: + setValue(size, (LongVec) vector); + break; + case OMNI_DOUBLE: + setValue(size, (DoubleVec) vector); + break; + case OMNI_SHORT: + setValue(size, (ShortVec) vector); + break; + case OMNI_BYTE: + setValue(size, (ByteVec) vector); + break; + case OMNI_BOOLEAN: + setValue(size, (BooleanVec) vector); + break; + case OMNI_VARCHAR: + case OMNI_CHAR: + setValue(size, (VarcharVec) vector); + break; + case OMNI_DECIMAL128: + setValue(size, (Decimal128Vec) vector); + break; + default: + throw new IllegalArgumentException("Not Support Data Type " + dataType.getId()); + } + return vector; + } + + private void setValue(int size, Decimal128Vec vector) { + for (int i = 0; i < size; i++) { + if (!vector.isNull(i)) { + vector.set(i, getDecimal128(i)); + } + } + } + + private void setValue(int size, VarcharVec vector) { + for (int i = 0; i < size; i++) { + if (!vector.isNull(i)) { + vector.set(i, getBytes(i)); + } else { + vector.setNull(i); + } + } + } + + private void setValue(int size, BooleanVec vector) { + for (int i = 0; i < size; i++) { + if (!vector.isNull(i)) { + vector.set(i, getBoolean(i)); + } + } + } + + private void setValue(int size, ShortVec vector) { + for (int i = 0; i < size; i++) { + if (!vector.isNull(i)) { + vector.set(i, getShort(i)); + } + } + } + + private void setValue(int size, ByteVec vector) { + for (int i = 0; i < size; i++) { + if (!vector.isNull(i)) { + vector.set(i, getByte(i)); + } + } + } + + private void setValue(int size, DoubleVec vector) { + for (int i = 0; i < size; i++) { + if (!vector.isNull(i)) { + vector.set(i, getDouble(i)); + } + } + } + + private void setValue(int size, LongVec vector) { + for (int i = 0; i < size; i++) { + if (!vector.isNull(i)) { + vector.set(i, getLong(i)); + } + } + } + + private void setValue(int size, IntVec vector) { + for (int i = 0; i < size; i++) { + if (!vector.isNull(i)) { + vector.set(i, getInt(i)); + } + } + } + + /** + * get ids from valuesBuf + * + * @return ids array + * */ + public int[] getIds() { + int[] ids = new int[size]; + valuesBuf.getIntArray(0, ids, 0, size * BYTES); + return ids; + } + + /** + * get the specified integer at the specified absolute. + * + * @param index the element offset in vec + * @return int value + */ + public int getId(int index) { + return valuesBuf.getInt(index * BYTES); + } + + /** + * get the specified short at the specified absolute. + * + * @param index the element offset in vec + * @return short value + */ + public short getShort(int index) { + int originIndex = getId(index); + return JvmUtils.UNSAFE.getShort(dataAddress + originIndex * Short.BYTES); + } + + /** + * get the specified byte at the specified absolute. + * + * @param index the element offset in vec + * @return byte value + */ + public byte getByte(int index) { + int originIndex = getId(index); + return JvmUtils.UNSAFE.getByte(dataAddress + originIndex * Byte.BYTES); + } + + /** + * get the specified integer at the specified absolute. + * + * @param index the element offset in vec + * @return integer value + */ + public int getInt(int index) { + int originIndex = getId(index); + return JvmUtils.UNSAFE.getInt(dataAddress + originIndex * Integer.BYTES); + } + + /** + * get the specified long at the specified absolute. + * + * @param index the element offset in vec + * @return long value + */ + public long getLong(int index) { + int originIndex = getId(index); + return JvmUtils.UNSAFE.getLong(dataAddress + originIndex * Long.BYTES); + } + + /** + * get the specified double at the specified absolute. + * + * @param index the element offset in vec + * @return double value + */ + public double getDouble(int index) { + int originIndex = getId(index); + return JvmUtils.UNSAFE.getDouble(dataAddress + originIndex * Double.BYTES); + } + + /** + * get the specified boolean at the specified absolute. + * + * @param index the element offset in vec + * @return boolean value + */ + public boolean getBoolean(int index) { + int originIndex = getId(index); + return JvmUtils.UNSAFE.getByte(dataAddress + originIndex) == 1; + } + + /** + * get the offset value of the specified position. + * + * @param index the element offset in vec + * @return offset value + */ + public int getValueOffset(int index) { + return JvmUtils.UNSAFE.getInt(offsetsAddress + index * Integer.BYTES); + } + + /** + * get the specified bytes at the specified absolute. + * + * @param index the element offset in vec + * @return byte array + */ + public byte[] getBytes(int index) { + if (offsetsAddress == 0) { + offsetsAddress = getValueOffsetsNative(nativeVector); + } + + int originIndex = getId(index); + final int stringLen = getValueOffset(originIndex + 1) - getValueOffset(originIndex); + final long stringAddr = dataAddress + getValueOffset(originIndex); + byte[] target = new byte[stringLen]; + JvmUtils.UNSAFE.copyMemory(null, stringAddr, target, Unsafe.ARRAY_BYTE_BASE_OFFSET, stringLen); + return target; + } + + /** + * get the specified decimal at the specified absolute. + * + * @param index the element offset in vec + * @return long array + */ + public long[] getDecimal128(int index) { + int originIndex = getId(index); + long[] value = new long[2]; + int valueIndex = originIndex * 2; + for (int i = 0; i < 2; i++) { + value[i] = JvmUtils.UNSAFE.getLong(dataAddress + (valueIndex + i) * Long.BYTES); + } + return value; + } + + @Override + public DictionaryVec slice(int start, int length) { + return new DictionaryVec(this, start, length); + } + + @Override + public DictionaryVec copyPositions(int[] positions, int offset, int length) { + return new DictionaryVec(this, positions, offset, length); + } + + @Override + public int getRealValueBufCapacityInBytes() { + return size * BYTES; + } + + @Override + public VecEncoding getEncoding() { + return VecEncoding.OMNI_VEC_ENCODING_DICTIONARY; + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/vector/DoubleVec.java b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/DoubleVec.java new file mode 100644 index 0000000..594844c --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/DoubleVec.java @@ -0,0 +1,102 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import nova.hetu.omniruntime.type.DoubleDataType; + +/** + * double vec. + * + * @since 2021-07-17 + */ +public class DoubleVec extends FixedWidthVec { + private static final int BYTES = Double.BYTES; + + public DoubleVec(int size) { + super(size * BYTES, size, VecEncoding.OMNI_VEC_ENCODING_FLAT, DoubleDataType.DOUBLE); + } + + public DoubleVec(long nativeVector) { + super(nativeVector, DoubleDataType.DOUBLE, BYTES); + } + + public DoubleVec(long nativeVector, long nativeValueBufAddress, long nativeVectorNullBufAddress, int size) { + super(nativeVector, nativeValueBufAddress, nativeVectorNullBufAddress, BYTES * size, size, + DoubleDataType.DOUBLE); + } + + private DoubleVec(DoubleVec vector, int offset, int length) { + super(vector, offset, length, length * BYTES); + } + + private DoubleVec(DoubleVec vector, int[] positions, int offset, int length) { + super(vector, positions, offset, length, length * BYTES); + } + + /** + * get the specified double at the specified absolute. + * + * @param index the element offset in vec + * @return double value + */ + public double get(int index) { + return valuesBuf.getDouble(index * BYTES); + } + + /** + * get double values from the specified position. + * + * @param index the position of element + * @param length the number of element + * @return double value array + */ + public double[] get(int index, int length) { + double[] target = new double[length]; + valuesBuf.getDoubleArray(index * BYTES, target, 0, length * BYTES); + return target; + } + + /** + * Sets the specified double at the specified absolute. + * + * @param index the element offset in vec + * @param value the value of vec + */ + public void set(int index, double value) { + valuesBuf.setDouble(index * BYTES, value); + } + + /** + * Batch sets the specified double at the specified absolute. + * + * @param values the value of the element to be written + * @param offset the element offset in vec + * @param start the element index in values + * @param length the number of elements that need to written + */ + public void put(double[] values, int offset, int start, int length) { + valuesBuf.setDoubleArray(offset * BYTES, values, start * BYTES, length * BYTES); + } + + @Override + public DoubleVec slice(int start, int length) { + return new DoubleVec(this, start, length); + } + + @Override + public DoubleVec copyPositions(int[] positions, int offset, int length) { + return new DoubleVec(this, positions, offset, length); + } + + @Override + public int getRealValueBufCapacityInBytes() { + return size * BYTES; + } + + @Override + public int getCapacityInBytes() { + return BYTES * size; + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/vector/FixedWidthVec.java b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/FixedWidthVec.java new file mode 100644 index 0000000..1cde919 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/FixedWidthVec.java @@ -0,0 +1,39 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2020. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import nova.hetu.omniruntime.type.DataType; + +/** + * base class of fixed width vec. + * + * @since 2021-07-17 + */ +public abstract class FixedWidthVec extends Vec { + public FixedWidthVec(int capacityInBytes, int size, VecEncoding encoding, DataType type) { + super(capacityInBytes, size, encoding, type); + } + + public FixedWidthVec(Vec dictionary, int[] ids, int capacityInBytes, DataType type) { + super(dictionary, ids, capacityInBytes, type); + } + + public FixedWidthVec(FixedWidthVec vector, int offset, int length, int capacityInBytes) { + super(vector, offset, length, capacityInBytes); + } + + public FixedWidthVec(FixedWidthVec vector, int[] positions, int offset, int length, int capacityInBytes) { + super(vector, positions, offset, length, capacityInBytes); + } + + public FixedWidthVec(long nativeVector, DataType type, int typeLength) { + super(nativeVector, type, typeLength); + } + + public FixedWidthVec(long nativeVector, long nativeVectorValueBufAddress, long nativeVectorNullBufAddress, + int capacityInBytes, int size, DataType type) { + super(nativeVector, nativeVectorValueBufAddress, nativeVectorNullBufAddress, capacityInBytes, size, type); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/vector/IntVec.java b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/IntVec.java new file mode 100644 index 0000000..eed095d --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/IntVec.java @@ -0,0 +1,101 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import nova.hetu.omniruntime.type.IntDataType; + +/** + * int vec. + * + * @since 2021-07-17 + */ +public class IntVec extends FixedWidthVec { + private static final int BYTES = Integer.BYTES; + + public IntVec(int size) { + super(size * BYTES, size, VecEncoding.OMNI_VEC_ENCODING_FLAT, IntDataType.INTEGER); + } + + public IntVec(long nativeVector) { + super(nativeVector, IntDataType.INTEGER, BYTES); + } + + public IntVec(long nativeVector, long nativeValueBufAddress, long nativeVectorNullBufAddress, int size) { + super(nativeVector, nativeValueBufAddress, nativeVectorNullBufAddress, BYTES * size, size, IntDataType.INTEGER); + } + + private IntVec(IntVec vector, int offset, int length) { + super(vector, offset, length, length * BYTES); + } + + private IntVec(IntVec vector, int[] positions, int offset, int length) { + super(vector, positions, offset, length, length * BYTES); + } + + /** + * get the specified integer at the specified absolute. + * + * @param index the element offset in vec + * @return int value + */ + public int get(int index) { + return valuesBuf.getInt(index * BYTES); + } + + /** + * get int values from the specified position. + * + * @param index the position of element + * @param length the number of element + * @return int value array + */ + public int[] get(int index, int length) { + int[] target = new int[length]; + valuesBuf.getIntArray(index * BYTES, target, 0, length * BYTES); + return target; + } + + /** + * Sets the specified integer at the specified absolute. + * + * @param index the element offset in vec + * @param value the value of vec + */ + public void set(int index, int value) { + valuesBuf.setInt(index * BYTES, value); + } + + /** + * Batch sets the specified integer at the specified absolute. + * + * @param values the value of the element to be written + * @param offset the element offset in vec + * @param start the element index in values + * @param length the number of elements that need to written + */ + public void put(int[] values, int offset, int start, int length) { + valuesBuf.setIntArray(offset * BYTES, values, start * BYTES, length * BYTES); + } + + @Override + public IntVec slice(int start, int length) { + return new IntVec(this, start, length); + } + + @Override + public IntVec copyPositions(int[] positions, int offset, int length) { + return new IntVec(this, positions, offset, length); + } + + @Override + public int getRealValueBufCapacityInBytes() { + return size * BYTES; + } + + @Override + public int getCapacityInBytes() { + return BYTES * size; + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/vector/JvmUtils.java b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/JvmUtils.java new file mode 100644 index 0000000..2919635 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/JvmUtils.java @@ -0,0 +1,120 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import nova.hetu.omniruntime.utils.OmniErrorType; +import nova.hetu.omniruntime.utils.OmniRuntimeException; +import sun.misc.Unsafe; + +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.security.AccessController; +import java.security.PrivilegedAction; + +/** + * jvm utils. + * + * @since 2021-08-05 + */ +public final class JvmUtils { + /** + * jvm unsafe. + */ + public static final Unsafe UNSAFE; + + private static final Constructor DIRECT_BUFFER_CONSTRUCTOR; + + static { + try { + Field field = Unsafe.class.getDeclaredField("theUnsafe"); + field.setAccessible(true); + Object obj = field.get(null); + UNSAFE = obj instanceof Unsafe ? (Unsafe) obj : null; + if (UNSAFE == null) { + throw new OmniRuntimeException(OmniErrorType.OMNI_NATIVE_ERROR, "Unsafe access not available"); + } + + assertArrayIndexScale("Boolean", Unsafe.ARRAY_BOOLEAN_INDEX_SCALE, 1); + assertArrayIndexScale("Byte", Unsafe.ARRAY_BYTE_INDEX_SCALE, 1); + assertArrayIndexScale("Short", Unsafe.ARRAY_SHORT_INDEX_SCALE, 2); + assertArrayIndexScale("Int", Unsafe.ARRAY_INT_INDEX_SCALE, 4); + assertArrayIndexScale("Long", Unsafe.ARRAY_LONG_INDEX_SCALE, 8); + assertArrayIndexScale("Float", Unsafe.ARRAY_FLOAT_INDEX_SCALE, 4); + assertArrayIndexScale("Double", Unsafe.ARRAY_DOUBLE_INDEX_SCALE, 8); + + long address = -1; + final ByteBuffer direct = ByteBuffer.allocateDirect(1); + try { + final Object directBufferConstructor = AccessController.doPrivileged((PrivilegedAction) () -> { + final Constructor constructor; + try { + constructor = direct.getClass().getDeclaredConstructor(long.class, int.class); + constructor.setAccessible(true); + return constructor; + } catch (NoSuchMethodException e) { + throw new OmniRuntimeException(OmniErrorType.OMNI_NOSUPPORT, e); + } + }); + + if (directBufferConstructor instanceof Constructor) { + address = UNSAFE.allocateMemory(1); + // try to use the constructor + try { + ((Constructor) directBufferConstructor).newInstance(address, 1); + DIRECT_BUFFER_CONSTRUCTOR = (Constructor) directBufferConstructor; + } catch (InstantiationException | IllegalAccessException | InvocationTargetException e) { + throw new OmniRuntimeException(OmniErrorType.OMNI_NOSUPPORT, e); + } + } else { + throw new OmniRuntimeException(OmniErrorType.OMNI_NOSUPPORT, + "get the director byte buffer constructor failed."); + } + } finally { + if (address != -1) { + UNSAFE.freeMemory(address); + } + } + } catch (ReflectiveOperationException var1) { + throw new OmniRuntimeException(OmniErrorType.OMNI_NOSUPPORT, var1); + } + } + + private JvmUtils() { + } + + private static void assertArrayIndexScale(String name, int actualIndexScale, int expectedIndexScale) { + if (actualIndexScale != expectedIndexScale) { + throw new IllegalStateException( + name + " array index scale must be " + expectedIndexScale + ", but is " + actualIndexScale); + } + } + + /** + * construct a director byte buffer by address and capacity. + * + * @param omniBuffer the address of byte buffer + * @return director byte buffer + */ + public static ByteBuffer directBuffer(OmniBuffer omniBuffer) { + if (omniBuffer.getCapacity() < 0) { + throw new OmniRuntimeException(OmniErrorType.OMNI_PARAM_ERROR, + "Capacity is negative, has to be positive or 0"); + } + + if (DIRECT_BUFFER_CONSTRUCTOR == null) { + throw new OmniRuntimeException(OmniErrorType.OMNI_NOSUPPORT, + "DirectByteBuffer.(long, int) not available"); + } + try { + return ((ByteBuffer) DIRECT_BUFFER_CONSTRUCTOR.newInstance(omniBuffer.getAddress(), + omniBuffer.getCapacity())).order(ByteOrder.LITTLE_ENDIAN); + } catch (InstantiationException | IllegalAccessException | InvocationTargetException e) { + throw new OmniRuntimeException(OmniErrorType.OMNI_NOSUPPORT, e); + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/vector/LongVec.java b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/LongVec.java new file mode 100644 index 0000000..0911241 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/LongVec.java @@ -0,0 +1,115 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import nova.hetu.omniruntime.type.LongDataType; + +import java.nio.ByteBuffer; + +/** + * long vec. + * + * @since 2021-07-17 + */ +public class LongVec extends FixedWidthVec { + private static final int BYTES = Long.BYTES; + + public LongVec(int size) { + super(size * BYTES, size, VecEncoding.OMNI_VEC_ENCODING_FLAT, LongDataType.LONG); + } + + public LongVec(long nativeVector) { + super(nativeVector, LongDataType.LONG, BYTES); + } + + public LongVec(long nativeVector, long nativeValueBufAddress, long nativeVectorNullBufAddress, int size) { + super(nativeVector, nativeValueBufAddress, nativeVectorNullBufAddress, BYTES * size, size, LongDataType.LONG); + } + + /** + * This constructor of vector is just for shuffle compilation to pass, it will + * be removed later. + * + * @param data data of vector + * @param capacityInBytes size in bytes of data + */ + @Deprecated + public LongVec(ByteBuffer data, int capacityInBytes) { + super(capacityInBytes, data.limit(), VecEncoding.OMNI_VEC_ENCODING_FLAT, LongDataType.LONG); + } + + private LongVec(LongVec vector, int offset, int length) { + super(vector, offset, length, length * BYTES); + } + + private LongVec(LongVec vector, int[] positions, int offset, int length) { + super(vector, positions, offset, length, length * BYTES); + } + + /** + * get the specified long at the specified absolute. + * + * @param index the element offset in vec + * @return long value + */ + public long get(int index) { + return valuesBuf.getLong(index * BYTES); + } + + /** + * get long values from the specified position. + * + * @param index the position of element + * @param length the number of element + * @return long value array + */ + public long[] get(int index, int length) { + long[] target = new long[length]; + valuesBuf.getLongArray(index * BYTES, target, 0, length * BYTES); + return target; + } + + /** + * Sets the specified long at the specified absolute. + * + * @param index the element offset in vec + * @param value the value of vec + */ + public void set(int index, long value) { + valuesBuf.setLong(index * BYTES, value); + } + + /** + * Batch sets the specified long at the specified absolute. + * + * @param values the value of the element to be written + * @param offset the element offset in vec + * @param start the element index in values + * @param length the number of elements that need to written + */ + public void put(long[] values, int offset, int start, int length) { + valuesBuf.setLongArray(offset * BYTES, values, start * BYTES, length * BYTES); + } + + @Override + public LongVec slice(int start, int length) { + return new LongVec(this, start, length); + } + + @Override + public LongVec copyPositions(int[] positions, int offset, int length) { + return new LongVec(this, positions, offset, length); + } + + @Override + public int getRealValueBufCapacityInBytes() { + return size * BYTES; + } + + @Override + public int getCapacityInBytes() { + return BYTES * size; + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/vector/OmniBuffer.java b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/OmniBuffer.java new file mode 100644 index 0000000..36d0896 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/OmniBuffer.java @@ -0,0 +1,215 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +/** + * encapsulate the ByteBuffer data interface and provide them to all vec. + * + * @since 2021-08-02 + */ +public interface OmniBuffer { + /** + * set byte. + * + * @param index the byte offset of the element + * @param value the value of the element + */ + void setByte(int index, byte value); + + /** + * get byte. + * + * @param index the byte offset of the element + * @return the value of element + */ + byte getByte(int index); + + /** + * Batch setting bytes. + * + * @param index the byte offset of the element + * @param src byte array + * @param srcStart array start index + * @param length byte size + */ + void setBytes(int index, byte[] src, int srcStart, int length); + + /** + * get bytes in batch. + * + * @param index the byte offset of the element + * @param length byte size + * @return byte array + */ + byte[] getBytes(int index, int length); + + /** + * get bytes in batch. + * + * @param index the byte offset of the element + * @param target target byte array + * @param targetIndex the offset of the target byte array + * @param length byte size + */ + void getBytes(int index, byte[] target, int targetIndex, int length); + + /** + * set short value. + * + * @param index the byte offset of the element + * @param value value of short + */ + void setShort(int index, short value); + + /** + * set short array. + * + * @param index the byte offset of the element + * @param src short array + * @param srcIndex the starting byte offset of the array + * @param length byte size + */ + void setShortArray(int index, short[] src, int srcIndex, int length); + + /** + * get short value. + * + * @param index the byte offset of the element + * @return short value + */ + short getShort(int index); + + /** + * get short array. + * + * @param index the byte offset of the element + * @param target target short array + * @param targetIndex the starting byte offset of the array + * @param length byte size + */ + void getShortArray(int index, short[] target, int targetIndex, int length); + + /** + * set int value. + * + * @param index the byte offset of the element + * @param value int value + */ + void setInt(int index, int value); + + /** + * get int value. + * + * @param index the byte offset of the element + * @return int value + */ + int getInt(int index); + + /** + * set int array. + * + * @param index the byte offset of the element + * @param src int array + * @param srcIndex the starting byte offset of the array + * @param length byte size + */ + void setIntArray(int index, int[] src, int srcIndex, int length); + + /** + * get int array. + * + * @param index the byte offset of the element + * @param target target int array + * @param targetIndex the starting byte offset of the array + * @param length byte size + */ + void getIntArray(int index, int[] target, int targetIndex, int length); + + /** + * set long value. + * + * @param index the byte offset of the element + * @param value long value + */ + void setLong(int index, long value); + + /** + * get long value. + * + * @param index the byte offset of the element + * @return long value + */ + long getLong(int index); + + /** + * set long array. + * + * @param index the byte offset of the element + * @param src long array + * @param srcIndex the starting byte offset of the array + * @param length byte size + */ + void setLongArray(int index, long[] src, int srcIndex, int length); + + /** + * get long array. + * + * @param index the byte offset of the element + * @param target target long array + * @param targetIndex the starting byte offset of the array + * @param length byte size + */ + void getLongArray(int index, long[] target, int targetIndex, int length); + + /** + * set double value. + * + * @param index the byte offset of the element + * @param value double value + */ + void setDouble(int index, double value); + + /** + * get double value. + * + * @param index the byte offset of the element + * @return double value + */ + double getDouble(int index); + + /** + * set double array. + * + * @param index the byte offset of the element + * @param src source double array + * @param srcIndex the starting byte offset of the array + * @param length byte size + */ + void setDoubleArray(int index, double[] src, int srcIndex, int length); + + /** + * get double array. + * + * @param index the byte offset of the element + * @param target target double array + * @param targetIndex the starting byte offset of the array + * @param length byte size + */ + void getDoubleArray(int index, double[] target, int targetIndex, int length); + + /** + * get data capacity from omnibuf. + * + * @return capacity of omnibuf + */ + int getCapacity(); + + /** + * get data address from omnibuf. + * + * @return data address of omnibuf + */ + long getAddress(); +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/vector/OmniBufferFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/OmniBufferFactory.java new file mode 100644 index 0000000..432f1bc --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/OmniBufferFactory.java @@ -0,0 +1,26 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +/** + * Responsible for creating different type of omniBuffer. + * + * @since 2021-08-10 + */ +public class OmniBufferFactory { + private OmniBufferFactory() { + } + + /** + * create a new omnibuffer object. + * + * @param address the address of buffer object + * @param capacity the capacity of buffer object + * @return omnibuffer object + */ + public static OmniBuffer create(long address, int capacity) { + return new OmniBufferUnsafeV8(address, capacity); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/vector/OmniBufferUnsafeV8.java b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/OmniBufferUnsafeV8.java new file mode 100644 index 0000000..3c5dcab --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/OmniBufferUnsafeV8.java @@ -0,0 +1,145 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import sun.misc.Unsafe; + +/** + * jdk8 unsafe interface implementation. + * + * @since 2021-08-10 + */ +public class OmniBufferUnsafeV8 implements OmniBuffer { + private final long address; + + private final int capacity; + + public OmniBufferUnsafeV8(long address, int capacity) { + this.address = address; + this.capacity = capacity; + } + + @Override + public void setByte(int index, byte value) { + JvmUtils.UNSAFE.putByte(addr(index), value); + } + + @Override + public byte getByte(int index) { + return JvmUtils.UNSAFE.getByte(addr(index)); + } + + @Override + public void setBytes(int index, byte[] src, int srcStart, int length) { + JvmUtils.UNSAFE.copyMemory(src, Unsafe.ARRAY_BYTE_BASE_OFFSET + srcStart, null, addr(index), length); + } + + @Override + public byte[] getBytes(int index, int length) { + byte[] target = new byte[length]; + getBytes(index, target, 0, length); + return target; + } + + @Override + public void getBytes(int index, byte[] target, int targetIndex, int length) { + JvmUtils.UNSAFE.copyMemory(null, addr(index), target, Unsafe.ARRAY_BYTE_BASE_OFFSET + targetIndex, length); + } + + @Override + public void setShort(int index, short value) { + JvmUtils.UNSAFE.putShort(addr(index), value); + } + + @Override + public void setShortArray(int index, short[] src, int srcIndex, int length) { + JvmUtils.UNSAFE.copyMemory(src, Unsafe.ARRAY_LONG_BASE_OFFSET + srcIndex, null, addr(index), length); + } + + @Override + public short getShort(int index) { + return JvmUtils.UNSAFE.getShort(addr(index)); + } + + @Override + public void getShortArray(int index, short[] target, int targetIndex, int length) { + JvmUtils.UNSAFE.copyMemory(null, addr(index), target, Unsafe.ARRAY_SHORT_BASE_OFFSET + targetIndex, length); + } + + @Override + public void setInt(int index, int value) { + JvmUtils.UNSAFE.putInt(addr(index), value); + } + + @Override + public int getInt(int index) { + return JvmUtils.UNSAFE.getInt(addr(index)); + } + + @Override + public void setIntArray(int index, int[] src, int srcIndex, int length) { + JvmUtils.UNSAFE.copyMemory(src, Unsafe.ARRAY_INT_BASE_OFFSET + srcIndex, null, addr(index), length); + } + + @Override + public void getIntArray(int index, int[] target, int targetIndex, int length) { + JvmUtils.UNSAFE.copyMemory(null, addr((long) index), target, Unsafe.ARRAY_INT_BASE_OFFSET + targetIndex, + length); + } + + @Override + public void setLong(int index, long value) { + JvmUtils.UNSAFE.putLong(addr(index), value); + } + + @Override + public long getLong(int index) { + return JvmUtils.UNSAFE.getLong(addr(index)); + } + + @Override + public void setLongArray(int index, long[] src, int srcIndex, int length) { + JvmUtils.UNSAFE.copyMemory(src, Unsafe.ARRAY_LONG_BASE_OFFSET + srcIndex, null, addr(index), length); + } + + @Override + public void getLongArray(int index, long[] target, int targetIndex, int length) { + JvmUtils.UNSAFE.copyMemory(null, addr(index), target, Unsafe.ARRAY_LONG_BASE_OFFSET + targetIndex, length); + } + + @Override + public void setDouble(int index, double value) { + JvmUtils.UNSAFE.putDouble(addr(index), value); + } + + @Override + public double getDouble(int index) { + return JvmUtils.UNSAFE.getDouble(addr(index)); + } + + @Override + public void setDoubleArray(int index, double[] src, int srcIndex, int length) { + JvmUtils.UNSAFE.copyMemory(src, Unsafe.ARRAY_DOUBLE_BASE_OFFSET + srcIndex, null, addr(index), length); + } + + @Override + public void getDoubleArray(int index, double[] target, int targetIndex, int length) { + JvmUtils.UNSAFE.copyMemory(null, addr(index), target, Unsafe.ARRAY_LONG_BASE_OFFSET + targetIndex, length); + } + + @Override + public int getCapacity() { + return capacity; + } + + @Override + public long getAddress() { + return address; + } + + private long addr(long offsetInBytes) { + return address + offsetInBytes; + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/vector/Row.java b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/Row.java new file mode 100644 index 0000000..97b15ed --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/Row.java @@ -0,0 +1,79 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import java.io.Closeable; + +/** + * row vec. + * + * @since 2024-05-16 + */ +public class Row implements Closeable { + /** + * indicates native buffer address in cpp. + */ + protected final long nativeRow; + + private int length; + + private OmniBuffer rowBuf; + + /** + * new one row in c++ + * + * @param dataAddr native address in cpp + * @param len number of row's bytes + */ + public Row(long dataAddr, int len) { + nativeRow = dataAddr; + length = len; + this.rowBuf = OmniBufferFactory.create(dataAddr, len); + } + + /** + * return one row's capacity + * + * @return buf's capacity + */ + public int getCapacity() { + return rowBuf.getCapacity(); + } + + /** + * return one row's real length + * + * @return buf's length + */ + public int getLength() { + return length; + } + + /** + * return one row's native address + * + * @return row's native address + */ + public long getNativeRow() { + return nativeRow; + } + + /** + * it may close raw buffer 's memory, now memory is deleted by adaptor + */ + @Override + public void close() { + + } + + /** + * return heap bytes object + * + * @return bytes object + */ + public byte[] getBytes() { + return rowBuf.getBytes(0, length); + } +} \ No newline at end of file diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/vector/RowBatch.java b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/RowBatch.java new file mode 100644 index 0000000..eb1297f --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/RowBatch.java @@ -0,0 +1,91 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import nova.hetu.omniruntime.utils.OmniErrorType; +import nova.hetu.omniruntime.utils.OmniRuntimeException; + +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * row batch : to collect all row. + * + * @since 2024-05-16 + */ +public class RowBatch implements AutoCloseable { + /** + * native address of row batch. + */ + protected final long nativeRowBatch; + + private Row[] rows; + + private int rowCount; + + private AtomicBoolean isClosed = new AtomicBoolean(false); + + /** + * construct row batch + * + * @param nativeAddress address of row batch + * @param rows actual rows + * @param rowCount total row count of batch + */ + public RowBatch(long nativeAddress, Row[] rows, int rowCount) { + this.rows = rows; + this.rowCount = rowCount; + this.nativeRowBatch = nativeAddress; + } + + /** + * construct row batch from vector batch + * + * @param vb vector batch + */ + public RowBatch(VecBatch vb) { + long rb = transFromVectorBatch(vb.getNativeVectorBatch()); + this.rowCount = vb.getRowCount(); + this.nativeRowBatch = rb; + } + + /** + * construct row batch from rows + * + * @param rows all rows of batch + * @param rowCount row count of row batch + */ + public RowBatch(Row[] rows, int rowCount) { + this(newRowBatchNative(rows, rowCount), rows, rowCount); + } + + // only release rowBatch + private static native void freeRowBatchNative(long nativeVectorBatch); + + private static native long newRowBatchNative(Row[] rows, int rowCount); + + private static native long transFromVectorBatch(long vbAddress); + + public long getNativeRowBatch() { + return nativeRowBatch; + } + + public Row[] getRows() { + return rows; + } + + public int getRowCount() { + return rowCount; + } + + @Override + public void close() { + if (isClosed.compareAndSet(false, true)) { + freeRowBatchNative(getNativeRowBatch()); + } else { + throw new OmniRuntimeException(OmniErrorType.OMNI_DOUBLE_FREE, "row batch has been closed:" + this + + ",threadName:" + Thread.currentThread().getName() + ",native:" + nativeRowBatch); + } + } +} \ No newline at end of file diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/vector/ShortVec.java b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/ShortVec.java new file mode 100644 index 0000000..5a5dc84 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/ShortVec.java @@ -0,0 +1,101 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import nova.hetu.omniruntime.type.ShortDataType; + +/** + * short vec. + * + * @since 2021-07-17 + */ +public class ShortVec extends FixedWidthVec { + private static final int BYTES = Short.BYTES; + + public ShortVec(int size) { + super(size * BYTES, size, VecEncoding.OMNI_VEC_ENCODING_FLAT, ShortDataType.SHORT); + } + + public ShortVec(long nativeVector) { + super(nativeVector, ShortDataType.SHORT, BYTES); + } + + public ShortVec(long nativeVector, long nativeValueBufAddress, long nativeVectorNullBufAddress, int size) { + super(nativeVector, nativeValueBufAddress, nativeVectorNullBufAddress, BYTES * size, size, ShortDataType.SHORT); + } + + private ShortVec(ShortVec vector, int offset, int length) { + super(vector, offset, length, length * BYTES); + } + + private ShortVec(ShortVec vector, int[] positions, int offset, int length) { + super(vector, positions, offset, length, length * BYTES); + } + + /** + * get the specified short at the specified absolute. + * + * @param index the element offset in vec + * @return short value + */ + public short get(int index) { + return valuesBuf.getShort(index * BYTES); + } + + /** + * get short values from the specified position. + * + * @param index the position of element + * @param length the number of element + * @return short value array + */ + public short[] get(int index, int length) { + short[] target = new short[length]; + valuesBuf.getShortArray(index * BYTES, target, 0, length * BYTES); + return target; + } + + /** + * Sets the specified short at the specified absolute. + * + * @param index the element offset in vec + * @param value the value of vec + */ + public void set(int index, short value) { + valuesBuf.setShort(index * BYTES, value); + } + + /** + * Batch sets the specified short at the specified absolute. + * + * @param values the value of the element to be written + * @param offset the element offset in vec + * @param start the element index in values + * @param length the number of elements that need to written + */ + public void put(short[] values, int offset, int start, int length) { + valuesBuf.setShortArray(offset * BYTES, values, start * BYTES, length * BYTES); + } + + @Override + public ShortVec slice(int start, int length) { + return new ShortVec(this, start, length); + } + + @Override + public ShortVec copyPositions(int[] positions, int offset, int length) { + return new ShortVec(this, positions, offset, length); + } + + @Override + public int getRealValueBufCapacityInBytes() { + return size * BYTES; + } + + @Override + public int getCapacityInBytes() { + return size * BYTES; + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/vector/VarcharVec.java b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/VarcharVec.java new file mode 100644 index 0000000..8417955 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/VarcharVec.java @@ -0,0 +1,196 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import nova.hetu.omniruntime.type.VarcharDataType; +import nova.hetu.omniruntime.utils.NullsBufHelper; + +/** + * varchar vec. + * + * @since 2021-07-17 + */ +public class VarcharVec extends VariableWidthVec { + /** + * The default capacity in bytes. + */ + public static final int INIT_CAPACITY_IN_BYTES = 32 * 1024; // 32K + private static final int EXPAND_FACTOR = 2; + + /** + * Warning: If this constructor is called, consider the capacity expansion scenario and update the vector info. + * + * @param size row count + */ + public VarcharVec(int size) { + super(4 * 1024, size, VarcharDataType.VARCHAR); + } + + public VarcharVec(int capacityInBytes, int size) { + super(capacityInBytes, size, VarcharDataType.VARCHAR); + } + + public VarcharVec(long nativeVector) { + super(nativeVector, VarcharDataType.VARCHAR); + } + + public VarcharVec(long nativeVector, long nativeValueBufAddress, long nativeVectorNullBufAddress, + long nativeVectorOffsetBufAddress, int size) { + super(nativeVector, nativeValueBufAddress, nativeVectorNullBufAddress, nativeVectorOffsetBufAddress, + getCapacityInBytesNative(nativeVector), size, VarcharDataType.VARCHAR); + } + + private VarcharVec(VarcharVec vector, int offset, int length) { + super(vector, offset, length); + } + + private VarcharVec(VarcharVec vector, int[] positions, int offset, int length) { + super(vector, positions, offset, length); + } + + /** + * get the specified bytes at the specified absolute. + * + * @param index the element offset in vec + * @return byte array + */ + public byte[] get(int index) { + final int startOffset = getValueOffset(index); + final int dataLen = getValueOffset(index + 1) - startOffset; + return getData(startOffset, dataLen); + } + + /** + * Batch gets the specified bytes at the specified absolute. + * + * @param index the element offset in vec + * @param length the number of element + * @return byte array + */ + public byte[] get(int index, int length) { + final int startOffset = getValueOffset(index); + final int dataLen = getValueOffset(index + length) - startOffset; + return getData(startOffset, dataLen); + } + + /** + * according to the specified offset and length, read data from the buffer. + * + * @param offsetInBytes offset bytes in buffer + * @param length the length of the data to be read + * @return byte array + */ + public byte[] getData(int offsetInBytes, int length) { + return valuesBuf.getBytes(offsetInBytes, length); + } + + /** + * Sets the specified bytes at the specified absolute. + * + * @param index the element offset in vec + * @param value byte array + */ + public void set(int index, byte[] value) { + fillSlots(index); + final int startOffset = getValueOffset(index); + setValueOffset(index + 1, startOffset + value.length); + setData(startOffset, value, 0, value.length); + lastOffsetPosition = index; + } + + private void checkCapacity(int needCapacityInBytes) { + if (needCapacityInBytes < 0) { + return; + } + int toCapacityInBytes = (capacityInBytes > 0) ? capacityInBytes : INIT_CAPACITY_IN_BYTES; + // the capacity is doubled for each calculation + while (toCapacityInBytes < needCapacityInBytes) { + toCapacityInBytes *= EXPAND_FACTOR; + } + if (toCapacityInBytes != capacityInBytes) { + // expand Data Capacity + long newValuesAddress = expandDataCapacity(getNativeVector(), toCapacityInBytes); + capacityInBytes = toCapacityInBytes; + // update data address + valuesBuf = OmniBufferFactory.create(newValuesAddress, capacityInBytes); + } + } + + /** + * Batch sets the specified bytes at the specified absolute. + * + * @param index the value of the element to be written + * @param values the bytes array + * @param offsetInArray the element offset in bytes array + * @param offsets offsets array + * @param offsetsIndex the offset of the offsets array + * @param length the number of elements + */ + public void put(int index, byte[] values, int offsetInArray, int[] offsets, int offsetsIndex, int length) { + if (values == null || length == 0) { + return; + } + fillSlots(index); + int startOffset = getValueOffset(index); + int[] newOffsets = compactOffsets(startOffset, offsets, offsetsIndex, length); + int dataLength = offsets[offsetsIndex + length] - offsets[offsetsIndex]; + // set offsets + offsetsBuf.setIntArray((index + 1) * Integer.BYTES, newOffsets, Integer.BYTES, + (newOffsets.length - 1) * Integer.BYTES); + // set data + setData(startOffset, values, offsetInArray + offsets[offsetsIndex], dataLength); + lastOffsetPosition = index + length - 1; + } + + private int[] compactOffsets(int startOffset, int[] srcOffsets, int offsetIndex, int length) { + int[] newOffsets = new int[length + 1]; + for (int i = 1; i <= length; i++) { + newOffsets[i] = srcOffsets[offsetIndex + i] - srcOffsets[offsetIndex] + startOffset; + } + return newOffsets; + } + + private void setData(int offsetInBytes, byte[] data, int start, int length) { + checkCapacity(offsetInBytes + length); + valuesBuf.setBytes(offsetInBytes, data, start, length); + } + + @Override + public VarcharVec slice(int start, int length) { + return new VarcharVec(this, start, length); + } + + @Override + public VarcharVec copyPositions(int[] positions, int offset, int length) { + return new VarcharVec(this, positions, offset, length); + } + + @Override + public void append(Vec other, int offset, int length) { + super.append(other, offset, length); + int newCapacityInBytes = getCapacityInBytesNative(nativeVector); + // check expand, update initial value if expansion happened. + if (newCapacityInBytes != capacityInBytes) { + capacityInBytes = newCapacityInBytes; + size = getSizeNative(nativeVector); + nullsBuf = OmniBufferFactory.create(getValueNullsNative(nativeVector), NullsBufHelper.nBytes(size)); + valuesBuf = OmniBufferFactory.create(getValuesNative(nativeVector), capacityInBytes); + offsetsBuf = OmniBufferFactory.create(getValueOffsetsNative(nativeVector), (size + 1) * Integer.BYTES); + } + } + + /** + * set offsets buffer. + * + * @param offsets offset of buf + * @param length the number of element + */ + public void setOffsetBuf(int[] offsets, int length) { + offsetsBuf.setIntArray(Integer.BYTES, offsets, Integer.BYTES, length * Integer.BYTES); + lastOffsetPosition = length - 1; + } + + private static native long expandDataCapacity(long nativeVector, int toCapacityInBytes); +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/vector/VariableWidthVec.java b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/VariableWidthVec.java new file mode 100644 index 0000000..9e92f4a --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/VariableWidthVec.java @@ -0,0 +1,218 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import static nova.hetu.omniruntime.vector.VecEncoding.OMNI_VEC_ENCODING_FLAT; + +import nova.hetu.omniruntime.type.DataType; + +/** + * base class of variable width vec. + * + * @since 2021-07-17 + */ +public abstract class VariableWidthVec extends Vec { + /** + * offsets buffer. + */ + protected OmniBuffer offsetsBuf; + + /** + * last set index. + */ + protected int lastOffsetPosition = -1; + + /** + * Ihe routine will use GLOBAL memory pool when there is no specialized vector + * allocator. + * + * @param capacityInBytes the number of value of vector + * @param size the actual number of value of vector + * @param type the data type of this vector + */ + public VariableWidthVec(int capacityInBytes, int size, DataType type) { + super(capacityInBytes, size, OMNI_VEC_ENCODING_FLAT, type); + this.offsetsBuf = OmniBufferFactory.create(getValueOffsetsNative(getNativeVector()), + (size + 1) * Integer.BYTES); + } + + /** + * The routine is just for slicing and copyRegion vector operator. + * + * @param vec the vector need to be sliced or copyRegion + * @param offset When a vector has been sliced or copyRegion, this value will + * point to where is the new slice {@link Vec} start + * @param length the number of value + */ + protected VariableWidthVec(Vec vec, int offset, int length) { + super(vec, offset, length, vec.getCapacityInBytes()); + this.offsetsBuf = OmniBufferFactory.create(getValueOffsetsNative(getNativeVector()), + (length + 1) * Integer.BYTES); + } + + /** + * The routine is just for copyPosition vector operator. + * + * @param vec the vector need to be copy + * @param positions the original vector positions + * @param offset offset of positions in the input parameter + * @param length number of elements copied + */ + protected VariableWidthVec(Vec vec, int[] positions, int offset, int length) { + super(vec, positions, offset, length, vec.getCapacityInBytes()); + this.offsetsBuf = OmniBufferFactory.create(getValueOffsetsNative(getNativeVector()), + (length + 1) * Integer.BYTES); + this.capacityInBytes = getCapacityInBytesNative(nativeVector); + } + + /** + * The routine will use native vector to initialize a new vector. + * + * @param nativeVector native vector address + * @param dataType the type of this vector + */ + protected VariableWidthVec(long nativeVector, DataType dataType) { + super(nativeVector, dataType); + this.offsetsBuf = OmniBufferFactory.create(getValueOffsetsNative(getNativeVector()), + (size + 1) * Integer.BYTES); + } + + /** + * The routine will use native vector to initialize a new vector. + * + * @param nativeVector native vector address + * @param nativeValueBufAddress valueBuf address of native vector + * @param nativeVectorNullBufAddress nullBuf address of native vector + * @param nativeVectorOffsetBufAddress offsetBuf address of native vector + * @param capacityInBytes capacity in bytes of vector + * @param size the actual number of value of vector + * @param dataType the type of this vector + */ + protected VariableWidthVec(long nativeVector, long nativeValueBufAddress, long nativeVectorNullBufAddress, + long nativeVectorOffsetBufAddress, int capacityInBytes, int size, DataType dataType) { + super(nativeVector, nativeValueBufAddress, nativeVectorNullBufAddress, capacityInBytes, size, dataType); + this.offsetsBuf = OmniBufferFactory.create(nativeVectorOffsetBufAddress, (size + 1) * Integer.BYTES); + } + + /** + * get value offset buffer from native vector. + * + * @param nativeVector native vector address + * @return value offset buffer + */ + protected static native long getValueOffsetsNative(long nativeVector); + + /** + * get the offset value of the specified position. + * + * @param index the element offset in vec + * @return offset value + */ + public int getValueOffset(int index) { + return offsetsBuf.getInt(index * Integer.BYTES); + } + + /** + * set the offset value of the specified position. + * + * @param index the element offset in vec + * @param offset offset value + */ + public void setValueOffset(int index, int offset) { + offsetsBuf.setInt(index * Integer.BYTES, offset); + } + + /** + * get the data length of the specified position. + * + * @param index the element offset in vec + * @return data length in bytes + */ + public int getDataLength(int index) { + return getValueOffset(index + 1) - getValueOffset(index); + } + + /** + * return null value array from 0 to size + offset length. + * + * @return raw value of offsets + */ + public int[] getRawValueOffset() { + // the length of the array is size + offset, so that the caller + // and vec can have the same offset. + int[] rawValueOffset = new int[size + 1]; + offsetsBuf.getIntArray(0, rawValueOffset, 0, rawValueOffset.length * Integer.BYTES); + return rawValueOffset; + } + + /** + * get the value offset of the specified position. + * + * @param index the offset of element + * @param length the number of element + * @return the offsets + */ + public int[] getValueOffset(int index, int length) { + int startOffset = getValueOffset(index); + int[] offsets = new int[length + 1]; + for (int i = 1; i <= length; i++) { + offsets[i] = getValueOffset(index + i) - startOffset; + } + return offsets; + } + + /** + * get the offsets of variablewidthvec. + * + * @return offsets byte buffer + */ + public OmniBuffer getOffsetsBuf() { + return offsetsBuf; + } + + /** + * set offsets buffer. + * + * @param buf buf of offset + */ + public void setOffsetsBuf(byte[] buf) { + offsetsBuf.setBytes(0, buf, 0, buf.length); + } + + @Override + public void setNull(int index) { + super.setNull(index); + fillSlots(index); + setValueOffset(index + 1, getValueOffset(index)); + lastOffsetPosition = index; + } + + /** + * fill offset. + * + * @param index index of want to set + */ + protected void fillSlots(int index) { + for (int i = lastOffsetPosition + 1; i < index; i++) { + setValueOffset(i + 1, getValueOffset(i)); + } + lastOffsetPosition = index - 1; + } + + @Override + public int getRealValueBufCapacityInBytes() { + return getValueOffset(size) - getValueOffset(0); + } + + /** + * returns the number of bytes of the offsets. + * + * @return length in bytes + */ + @Override + public int getRealOffsetBufCapacityInBytes() { + return (size + 1) * Integer.BYTES; + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/vector/Vec.java b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/Vec.java new file mode 100644 index 0000000..ef9ca34 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/Vec.java @@ -0,0 +1,587 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import nova.hetu.omniruntime.OmniLibs; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.utils.NullsBufHelper; +import nova.hetu.omniruntime.utils.OmniErrorType; +import nova.hetu.omniruntime.utils.OmniRuntimeException; + +import java.io.Closeable; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * base class of vec. + * + * @since 2021-07-17 + */ +public abstract class Vec implements Closeable { + /** + * indicates a null value in a nulls buffer. + */ + public static final byte NULL = 1; + + /** + * indicates a not null value in a nulls buffer. + */ + public static final byte NOT_NULL = 0; + + static { + OmniLibs.load(); + } + + /** + * The value buffer. + */ + protected OmniBuffer valuesBuf; + + /** + * The nulls of vector, it is a bitmap. + */ + protected OmniBuffer nullsBuf; + + /** + * The native vector address. + */ + protected final long nativeVector; + + /** + * The capacity in bytes of this vector. + */ + protected int capacityInBytes; + + /** + * The actual number of value. + */ + protected int size; + + /** + * The {@link DataType} of this vector. + */ + private DataType dataType; + + /** + * When a vector has been sliced. The current vector and sliced vector are + * unwritable. + */ + private boolean isWritable = true; + + private boolean isCloseable = true; + + private AtomicBoolean isClosed = new AtomicBoolean(false); + + /** + * The routine will use the specialized vector allocator to allocate new vector. + * + * @param capacityInBytes the capacity in bytes of vector + * @param size the actual number of value of vector + * @param encoding the encoding type of vector + * @param datatype the data type of this vector + */ + public Vec(int capacityInBytes, int size, VecEncoding encoding, DataType datatype) { + this(newVectorNative(size, encoding.ordinal(), datatype.getId().toValue(), capacityInBytes), capacityInBytes, + size, datatype, true); + } + + public Vec(Vec dictionary, int[] ids, int capacityInBytes, DataType dataType) { + this(newDictionaryVectorNative(dictionary.nativeVector, ids, ids.length, dataType.getId().toValue()), + capacityInBytes, ids.length, dataType, true); + } + + /** + * The routine is just for slicing and copyRegion vector operator. + * + * @param vec the vector need to be sliced or copyRegion + * @param offset When a vector has been sliced or copyRegion, this value will + * point to where is the new slice {@link Vec} start + * @param length the number of value + * @param capacityInBytes the number of capacityInBytes + */ + protected Vec(Vec vec, int offset, int length, int capacityInBytes) { + this(sliceVectorNative(vec.nativeVector, offset, length), capacityInBytes, length, vec.dataType, false); + } + + /** + * The routine is just for copyPosition vector operator. + * + * @param vec the vector need to be copy + * @param positions the original vector positions + * @param offset offset of positions in the input parameter + * @param length number of elements copied + * @param capacityInBytes the number of capacityInBytes + */ + protected Vec(Vec vec, int[] positions, int offset, int length, int capacityInBytes) { + this(copyPositionsNative(vec.nativeVector, positions, offset, length), capacityInBytes, length, vec.dataType, + true); + } + + /** + * The routine will use native vector to initialize a new fixed width vector. + * + * @param nativeVector native vector address + * @param dataType the type of this vector + * @param typeLength the number of typeLength + */ + protected Vec(long nativeVector, DataType dataType, int typeLength) { + this(nativeVector, getSizeNative(nativeVector) * typeLength, getSizeNative(nativeVector), dataType, true); + } + + /** + * The routine will use native vector to initialize a new variable width vector. + * + * @param nativeVector native vector address + * @param dataType the type of this vector + */ + protected Vec(long nativeVector, DataType dataType) { + this(nativeVector, getCapacityInBytesNative(nativeVector), getSizeNative(nativeVector), dataType, true); + } + + /** + * The routine will use native vector to initialize a new vector. + * + * @param nativeVector native vector address + * @param nativeVectorValueBufAddress valueBuf address of native vector + * @param nativeVectorNullBufAddress nullBuf address of native vector + * @param capacityInBytes capacity in bytes of vector + * @param size the actual number of value of vector + * @param dataType the type of this vector + */ + protected Vec(long nativeVector, long nativeVectorValueBufAddress, long nativeVectorNullBufAddress, + int capacityInBytes, int size, DataType dataType) { + this(nativeVector, nativeVectorValueBufAddress, nativeVectorNullBufAddress, capacityInBytes, size, dataType, + true); + } + + private Vec(long nativeVector, long nativeVectorValueBufAddress, long nativeVectorNullBufAddress, + int capacityInBytes, int size, DataType dataType, boolean isWritable) { + this.capacityInBytes = capacityInBytes; + this.size = size; + this.dataType = dataType; + this.nativeVector = nativeVector; + this.valuesBuf = OmniBufferFactory.create(nativeVectorValueBufAddress, capacityInBytes); + this.nullsBuf = OmniBufferFactory.create(nativeVectorNullBufAddress, NullsBufHelper.nBytes(size)); + this.isWritable = isWritable; + } + + private Vec(long nativeVector, int capacityInBytes, int size, DataType dataType, boolean isWritable) { + this(nativeVector, getValuesNative(nativeVector), getValueNullsNative(nativeVector), capacityInBytes, size, + dataType, isWritable); + } + + private static native long newVectorNative(int size, int vecEncodingId, int dataTypeId, int capacityInBytes); + + private static native long newDictionaryVectorNative(long dictionaryNativeVector, int[] ids, int size, + int dataTypeId); + + private static native void freeVectorNative(long nativeVector); + + private static native long sliceVectorNative(long nativeVector, int offset, int length); + + private static native long copyPositionsNative(long nativeVector, int[] positions, int offset, int length); + + /** + * get capacity in Bytes from native vector + * + * @param nativeVector nativeVector address + * @return the CapacityInBytes of native vector + */ + protected static native int getCapacityInBytesNative(long nativeVector); + + /** + * get size of native vector. + * + * @param nativeVector native vector + * @return size + */ + protected static native int getSizeNative(long nativeVector); + + private static native int setSizeNative(long nativeVector, int valueCount); + + /** + * get value address from native vector. + * + * @param nativeVector native vector address + * @return value address of native vector + */ + protected static native long getValuesNative(long nativeVector); + + /** + * get encoding of native vector. + * + * @param nativeVector native vector + * @return encoding type id + */ + protected static native int getVecEncodingNative(long nativeVector); + + /** + * get null address of native vector. + * + * @param nativeVector native vector + * @return null address of native vector + */ + protected static native long getValueNullsNative(long nativeVector); + + /** + * merge two vectors. + * + * @param destNativeVector target native vector + * @param positionOffset position offset + * @param srcNativeVector source native vector + * @param length the number of element + */ + protected static native void appendVectorNative(long destNativeVector, int positionOffset, long srcNativeVector, + int length); + + private static native void setNullFlagNative(long nativeVector, boolean hasNull); + + private static native boolean hasNullNative(long nativeVector); + + /** + * get native vector. + * + * @return native vector address + */ + public long getNativeVector() { + return nativeVector; + } + + /** + * the size of vector. + * + * @return size + */ + public int getSize() { + return size; + } + + /** + * set size of vector. + * @param size size value + */ + public void setSize(int size) { + this.size = size; + setSizeNative(nativeVector, size); + } + + /** + * capacity in bytes of vector. + * + * @return capacity + */ + public int getCapacityInBytes() { + return capacityInBytes; + } + + /** + * vector data type. + * + * @return vector data type + */ + public DataType getType() { + return dataType; + } + + /** + * vector encoding. + * + * @return vector encoding + */ + public VecEncoding getEncoding() { + return VecEncoding.OMNI_VEC_ENCODING_FLAT; + } + + /** + * get values buffer. + * + * @return values buffer + */ + public OmniBuffer getValuesBuf() { + return valuesBuf; + } + + /** + * set values buffer. VarcharVec cannot use this interface. + * + * @param buf buf of data + */ + public void setValuesBuf(byte[] buf) { + valuesBuf.setBytes(0, buf, 0, buf.length); + } + + /** + * set values buffer and length. VarcharVec cannot use this interface. + * + * @param buf buf of data + * @param length the number of element + */ + public void setValuesBuf(byte[] buf, int length) { + valuesBuf.setBytes(0, buf, 0, length); + } + + /** + * get value nulls buffer. + * + * @return nulls value buffer + */ + public OmniBuffer getValueNullsBuf() { + return nullsBuf; + } + + /** + * specify whether the position element is null. + * + * @param index the element offset in vec + * @return if it is null, return true, otherwise return false + */ + public boolean isNull(int index) { + return NullsBufHelper.isSet(nullsBuf, index) == 1; + } + + /** + * set the element at the specified position to a null value. + * + * @param index the element offset in vec + */ + public void setNull(int index) { + NullsBufHelper.setBit(nullsBuf, index); + setNullFlagNative(nativeVector, true); + } + + /** + * set nulls in batch. + * + * @param index the offset of the element + * @param isNulls array of null values, true is null otherwise non-null. + * @param start array offset + * @param length number of elements + */ + public void setNulls(int index, boolean[] isNulls, int start, int length) { + byte[] values = transformBooleanToByte(isNulls, start, length); + NullsBufHelper.setBit(nullsBuf, index, values, 0, length); + setNullFlagNative(nativeVector, true); + } + + /** + * set nulls in batch. + * + * @param index the offset of the element + * @param isNulls array of null values, true is null otherwise non-null. + * @param start array offset + * @param length number of elements + */ + public void setNulls(int index, byte[] isNulls, int start, int length) { + NullsBufHelper.setBit(nullsBuf, index, isNulls, start, length); + setNullFlagNative(nativeVector, true); + } + + /** + * set nulls in batch. + * + * @param index the offset of the element + * @param isBitNulls array of null values, true is null otherwise non-null. (Bit) + * @param start array offset (Bit) + * @param length number of elements (Bit) + */ + public void setNullsByBits(int index, byte[] isBitNulls, int start, int length) { + NullsBufHelper.setBitByBits(nullsBuf, index, isBitNulls, start, length); + setNullFlagNative(nativeVector, true); + } + + /** + * set nulls buffer. + * + * @param buf buf of null + */ + public void setNullsBuf(byte[] buf) { + nullsBuf.setBytes(0, buf, 0, buf.length); + setNullFlagNative(nativeVector, true); + } + + /** + * set nulls buffer and length. + * + * @param buf buf of null + * @param length the number of element + */ + public void setNullsBuf(byte[] buf, int length) { + nullsBuf.setBytes(0, buf, 0, length); + setNullFlagNative(nativeVector, true); + } + + /** + * whether there is a null value. + * + * @return if yes, return true otherwise false + */ + public boolean hasNull() { + return hasNullNative(nativeVector); + } + + /** + * transform boolean array to byte array. + * + * @param values nulls array + * @param start array offset + * @param length number of elements + * @return byte array + */ + protected byte[] transformBooleanToByte(boolean[] values, int start, int length) { + byte[] transformedBytes = new byte[length]; + for (int i = 0; i < length; i++) { + if (values[i + start]) { + transformedBytes[i] = (byte) 1; + } else { + transformedBytes[i] = (byte) 0; + } + } + + return transformedBytes; + } + + /** + * transform byte array to boolean array. + * + * @param values byte array, 1 means null, 0 means non-null + * @param start array offset + * @param length number of elements + * @return boolean array + */ + protected boolean[] transformByteToBoolean(byte[] values, int start, int length) { + boolean[] transformedBoolean = new boolean[length]; + for (int i = 0; i < length; i++) { + transformedBoolean[i] = values[i + start] == 1; + } + return transformedBoolean; + } + + /** + * return null value array from 0 to size + offset length. + * + * @return raw value nulls + * @return raw nulls array + */ + public byte[] getRawValueNulls() { + // the length of the array is size + offset, so that the caller + // and vec can have the same offset. + byte[] rawValueNulls = new byte[NullsBufHelper.nBytes(size)]; + nullsBuf.getBytes(0, rawValueNulls, 0, rawValueNulls.length); + return rawValueNulls; + } + + /** + * get the specified nulls array at the specified absolute. + * + * @param index the offset of element in vec + * @param length the number of element + * @return boolean array + */ + public boolean[] getValuesNulls(int index, int length) { + byte[] nullsArray = new byte[length]; + NullsBufHelper.getBytes(nullsBuf, index, nullsArray, 0, length); + return transformByteToBoolean(nullsArray, 0, length); + } + + /** + * is vec writable. + * + * @return if it is writable, return true, otherwise return false + */ + public boolean isWritable() { + return isWritable; + } + + /** + * split a vec into two vec according to the specified index and length. + * + * @param start starting index + * @param length number of elements + * @return new vec + */ + public abstract Vec slice(int start, int length); + + /** + * copy a new vec based on the positions. + * + * @param positions all positions in vec + * @param offset position offset + * @param length the number of elements to be copied + * @return new vec + */ + public abstract Vec copyPositions(int[] positions, int offset, int length); + + /** + * This method takes input a source vector to append to the destination vector + * only If the destination vector has enough available positions. + * + * @param other Source Vector to be appended + * @param offset Number of Positions already occupied + * @param length Number of Positions in the Source Vector + */ + public void append(Vec other, int offset, int length) { + appendVectorNative(this.nativeVector, offset, other.nativeVector, length); + } + + @Override + public void close() { + if (!isCloseable) { + return; + } + if (isClosed.compareAndSet(false, true)) { + freeVectorNative(this.nativeVector); + } else { + throw new OmniRuntimeException(OmniErrorType.OMNI_DOUBLE_FREE, "vec has been closed:" + this + + ",threadName:" + Thread.currentThread().getName() + ",native:" + nativeVector); + } + } + + /** + * vec is closed. + * + * @return true is closed, otherwise it it not closed. + */ + public boolean isClosed() { + return isClosed.get(); + } + + /** + * set whether vec can be closed. + * + * @param isCloseable can vec be closed + */ + public void setClosable(boolean isCloseable) { + this.isCloseable = isCloseable; + } + + /** + * returns the number of bytes of the data written. + * + * @return length in bytes + */ + public abstract int getRealValueBufCapacityInBytes(); + + /** + * returns the number of bytes of the null buf. + * + * @return length in bytes + */ + public int getRealNullBufCapacityInBytes() { + return NullsBufHelper.nBytes(size); + } + + /** + * returns the number of bytes of the offsets, for VarcharVec returned according + * to size calculation, other types of vec return 0. + * + * @return length in bytes + */ + public int getRealOffsetBufCapacityInBytes() { + return 0; + } + + void setDataType(DataType dataType) { + this.dataType = dataType; + } +} \ No newline at end of file diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/vector/VecBatch.java b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/VecBatch.java new file mode 100644 index 0000000..06916c8 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/VecBatch.java @@ -0,0 +1,174 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import nova.hetu.omniruntime.OmniLibs; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.utils.OmniErrorType; +import nova.hetu.omniruntime.utils.OmniRuntimeException; + +import java.io.Closeable; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * vec batch. + * + * @since 2021-07-17 + */ +public class VecBatch implements Closeable { + // 加载jni所需动态库 + static { + OmniLibs.load(); + } + + private final Vec[] vectors; + + private final int rowCount; + + private final long nativeVectorBatch; + + private AtomicBoolean isClosed = new AtomicBoolean(false); + + /** + * The routine will use vectors and row count to initialize a new vector. + * + * @param vectors vectors in vecBatch,it may be empty + * @param rowCount the row count of vector batch + */ + public VecBatch(Vec[] vectors, int rowCount) { + this.vectors = vectors; + this.rowCount = rowCount; + long[] nativeVectors = new long[vectors.length]; + for (int i = 0; i < vectors.length; i++) { + nativeVectors[i] = vectors[i].getNativeVector(); + } + this.nativeVectorBatch = newVectorBatchNative(nativeVectors, rowCount); + } + + /** + * it is recommended to use the VecBatch(Vec[] vectors, rowCount) construct to + * prevent exceptions caused by empty vectors. + * + * @param vectors vectors in vecBatch,it may be empty + */ + public VecBatch(Vec[] vectors) { + this(vectors, vectors[0].getSize()); + } + + public VecBatch(List vectors, int rowCount) { + this(vectors.toArray(new Vec[vectors.size()]), rowCount); + } + + /** + * it is recommended to use the VecBatch(List, rowCount) construct to + * prevent exceptions caused by empty vectors. + * + * @param vectors vectors in vecBatch,it may be empty + */ + public VecBatch(List vectors) { + this(vectors.toArray(new Vec[vectors.size()])); + } + + /** + * This constructor is for native to call. + * + * @param nativeVecBatch native vector batch address + * @param nativeVectors native vector array + * @param nativeVectorValueBufAddresses valueBuf address of native vector + * @param nativeVectorNullBufAddresses nullBuf address of native vector + * @param nativeVectorOffsetBufAddresses offsetBuf address of native vector + * @param encodings the encoding type array of vector batch + * @param dataTypeIds the type array of this vector batch + * @param rowCount the row count of vector batch + */ + public VecBatch(long nativeVecBatch, long[] nativeVectors, long[] nativeVectorValueBufAddresses, + long[] nativeVectorNullBufAddresses, long[] nativeVectorOffsetBufAddresses, int[] encodings, + int[] dataTypeIds, int rowCount) { + int vecCount = nativeVectors.length; + Vec[] newVectors = new Vec[vecCount]; + for (int idx = 0; idx < vecCount; idx++) { + long nativeVector = nativeVectors[idx]; + DataType dataType = DataType.create(dataTypeIds[idx]); + newVectors[idx] = VecFactory.create(nativeVector, nativeVectorValueBufAddresses[idx], + nativeVectorNullBufAddresses[idx], nativeVectorOffsetBufAddresses[idx], rowCount, + VecEncoding.values()[encodings[idx]], dataType); + } + this.rowCount = rowCount; + this.nativeVectorBatch = nativeVecBatch; + this.vectors = newVectors; + } + + /** + * create vector batch based on the number of vectors. + * + * @param nativeVectors native vector array + * @param rowCount the row count of vector batch + * @return vector batch address + */ + public static native long newVectorBatchNative(long[] nativeVectors, int rowCount); + + /** + * release vector batch. + * + * @param nativeVectorBatch vector batch address + */ + public static native void freeVectorBatchNative(long nativeVectorBatch); + + /** + * row count in the vecBatch. + * + * @return row count + */ + public int getRowCount() { + return rowCount; + } + + /** + * vector count in the vecBatch. + * + * @return vector count + */ + public int getVectorCount() { + return vectors.length; + } + + public Vec[] getVectors() { + return vectors; + } + + /** + * get specified vector at the specified absolute. + * + * @param index the element offset in vec + * @return vector + */ + public Vec getVector(int index) { + return vectors[index]; + } + + public long getNativeVectorBatch() { + return nativeVectorBatch; + } + + /** + * release all vectors resource of vector batch. + */ + public void releaseAllVectors() { + for (Vec vector : vectors) { + vector.close(); + } + } + + @Override + public void close() { + if (isClosed.compareAndSet(false, true)) { + freeVectorBatchNative(nativeVectorBatch); + } else { + throw new OmniRuntimeException(OmniErrorType.OMNI_DOUBLE_FREE, "vec batch has been closed:" + this + + ",threadName:" + Thread.currentThread().getName() + ",native:" + nativeVectorBatch); + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/vector/VecEncoding.java b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/VecEncoding.java new file mode 100644 index 0000000..9e5a360 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/VecEncoding.java @@ -0,0 +1,18 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +/** + * vec encoding. + * + * @since 2022-02-17 + */ +public enum VecEncoding { + OMNI_VEC_ENCODING_FLAT, + OMNI_VEC_ENCODING_DICTIONARY, + OMNI_VEC_ENCODING_CONTAINER, + OMNI_VEC_ENCODING_LAZY, + OMNI_VEC_ENCODING_INVALID +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/vector/VecFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/VecFactory.java new file mode 100644 index 0000000..69325d8 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/VecFactory.java @@ -0,0 +1,166 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import nova.hetu.omniruntime.type.DataType; + +/** + * vec factory. + * + * @since 2021-08-05 + */ +public class VecFactory { + /** + * Create vector by native vector address and data type. + * + * @param nativeVector native vector address + * @param encoding vector encoding + * @param dataType vector data type + * @return a new {@link Vec} object instance + */ + public static Vec create(long nativeVector, VecEncoding encoding, DataType dataType) { + Vec vector; + switch (encoding) { + case OMNI_VEC_ENCODING_FLAT: + vector = createFlatVec(nativeVector, dataType); + break; + case OMNI_VEC_ENCODING_DICTIONARY: + vector = new DictionaryVec(nativeVector, dataType); + break; + case OMNI_VEC_ENCODING_CONTAINER: + vector = new ContainerVec(nativeVector); + break; + default: + throw new IllegalArgumentException("Not Support Vec Encoding " + encoding); + } + return vector; + } + + private static Vec createFlatVec(long nativeVector, DataType dataType) { + switch (dataType.getId()) { + case OMNI_INT: + case OMNI_DATE32: + return new IntVec(nativeVector); + case OMNI_LONG: + case OMNI_DATE64: + case OMNI_TIMESTAMP: + case OMNI_DECIMAL64: + return new LongVec(nativeVector); + case OMNI_DOUBLE: + return new DoubleVec(nativeVector); + case OMNI_SHORT: + return new ShortVec(nativeVector); + case OMNI_BYTE: + return new ByteVec(nativeVector); + case OMNI_BOOLEAN: + return new BooleanVec(nativeVector); + case OMNI_VARCHAR: + case OMNI_CHAR: + return new VarcharVec(nativeVector); + case OMNI_DECIMAL128: + return new Decimal128Vec(nativeVector); + default: + throw new IllegalArgumentException("Not Support Data Type " + dataType.getId()); + } + } + + /** + * Create vector by native vector address and vector type. + * + * @param nativeVector native vector address + * @param nativeVectorValueBufAddress native vector value buffer address + * @param nativeVectorNullBufAddress native vector nulls buffer address + * @param nativeVectorOffsetBufAddress native vector offset buffer address + * @param size size of vector + * @param encoding vector encoding type + * @param dataType vector data type + * @return Instance of {@link Vec} + */ + public static Vec create(long nativeVector, long nativeVectorValueBufAddress, long nativeVectorNullBufAddress, + long nativeVectorOffsetBufAddress, int size, VecEncoding encoding, DataType dataType) { + Vec vector; + switch (encoding) { + case OMNI_VEC_ENCODING_FLAT: + vector = createFlatVec(nativeVector, nativeVectorValueBufAddress, nativeVectorNullBufAddress, + nativeVectorOffsetBufAddress, size, dataType); + break; + case OMNI_VEC_ENCODING_DICTIONARY: + vector = new DictionaryVec(nativeVector, nativeVectorValueBufAddress, nativeVectorNullBufAddress, size, + dataType); + break; + case OMNI_VEC_ENCODING_CONTAINER: + vector = new ContainerVec(nativeVector, nativeVectorValueBufAddress, nativeVectorNullBufAddress, size); + break; + default: + throw new IllegalArgumentException("Not Support Vec Encoding " + encoding); + } + return vector; + } + + private static Vec createFlatVec(long nativeVector, long nativeVectorValueBufAddress, + long nativeVectorNullBufAddress, long nativeVectorOffsetBufAddress, int size, DataType dataType) { + switch (dataType.getId()) { + case OMNI_INT: + case OMNI_DATE32: + return new IntVec(nativeVector, nativeVectorValueBufAddress, nativeVectorNullBufAddress, size); + case OMNI_LONG: + case OMNI_TIMESTAMP: + case OMNI_DATE64: + case OMNI_DECIMAL64: + return new LongVec(nativeVector, nativeVectorValueBufAddress, nativeVectorNullBufAddress, size); + case OMNI_DOUBLE: + return new DoubleVec(nativeVector, nativeVectorValueBufAddress, nativeVectorNullBufAddress, size); + case OMNI_SHORT: + return new ShortVec(nativeVector, nativeVectorValueBufAddress, nativeVectorNullBufAddress, size); + case OMNI_BYTE: + return new ByteVec(nativeVector, nativeVectorValueBufAddress, nativeVectorNullBufAddress, size); + case OMNI_BOOLEAN: + return new BooleanVec(nativeVector, nativeVectorValueBufAddress, nativeVectorNullBufAddress, size); + case OMNI_VARCHAR: + case OMNI_CHAR: + return new VarcharVec(nativeVector, nativeVectorValueBufAddress, nativeVectorNullBufAddress, + nativeVectorOffsetBufAddress, size); + case OMNI_DECIMAL128: + return new Decimal128Vec(nativeVector, nativeVectorValueBufAddress, nativeVectorNullBufAddress, size); + default: + throw new IllegalArgumentException("Not Support Data Type " + dataType.getId()); + } + } + + /** + * Create empty vector by size and vector type, only use by expandDictionaryVec. + * + * @param size size of vector + * @param dataType vector data type + * @return Instance of {@link Vec} + */ + public static Vec createFlatVec(int size, DataType dataType) { + switch (dataType.getId()) { + case OMNI_INT: + case OMNI_DATE32: + return new IntVec(size); + case OMNI_TIMESTAMP: + case OMNI_LONG: + case OMNI_DATE64: + case OMNI_DECIMAL64: + return new LongVec(size); + case OMNI_DOUBLE: + return new DoubleVec(size); + case OMNI_SHORT: + return new ShortVec(size); + case OMNI_BYTE: + return new ByteVec(size); + case OMNI_BOOLEAN: + return new BooleanVec(size); + case OMNI_VARCHAR: + case OMNI_CHAR: + return new VarcharVec(size); + case OMNI_DECIMAL128: + return new Decimal128Vec(size); + default: + throw new IllegalArgumentException("Not Support Data Type " + dataType.getId()); + } + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/vector/serialize/OmniRowDeserializer.java b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/serialize/OmniRowDeserializer.java new file mode 100644 index 0000000..2d2a3a6 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/serialize/OmniRowDeserializer.java @@ -0,0 +1,58 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.vector.serialize; + +/** + * VecBatchSerializer implementation of protobuf. + * + * @since 2024-05-16 + */ +public class OmniRowDeserializer { + private long nativeParser = 0L; + + public OmniRowDeserializer(int[] types, long[] vecs) { + nativeParser = newOmniRowDeserializer(types, vecs); + } + + public long getNativeParser() { + return nativeParser; + } + + private static native long newOmniRowDeserializer(int[] types, long[] vecs); + + private static native long freeOmniRowDeserializer(long addr); + + private static native void parseOneRow(long nativeParserAddr, byte[] bytes, int rowIdx); + + private static native void parseOneRowByAddr(long nativeParserAddr, long rowAddress, int rowIdx); + + private static native void parseAllRow(long nativeParserAddr, long rowBatchAddr); + + /** + * used when shuffle read + * + * @param address one row 's bytes + * @param rowIdx vector 's index + */ + public void parse(long address, int rowIdx) { + parseOneRowByAddr(getNativeParser(), address, rowIdx); + } + + /** + * used to parse native row to vector batch. + * + * @param rowBatchAddr address of native row batch + */ + public void parseAll(long rowBatchAddr) { + parseAllRow(getNativeParser(), rowBatchAddr); + } + + /** + * deserializer is created in native side, we must free it after we use. + */ + public void close() { + freeOmniRowDeserializer(nativeParser); + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/vector/serialize/ProtoVecBatchSerializer.java b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/serialize/ProtoVecBatchSerializer.java new file mode 100644 index 0000000..5b04584 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/serialize/ProtoVecBatchSerializer.java @@ -0,0 +1,308 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.vector.serialize; + +import com.google.protobuf.ByteString; +import com.google.protobuf.InvalidProtocolBufferException; + +import nova.hetu.omniruntime.type.CharDataType; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.Date32DataType; +import nova.hetu.omniruntime.type.Date64DataType; +import nova.hetu.omniruntime.type.DecimalDataType; +import nova.hetu.omniruntime.type.VarcharDataType; +import nova.hetu.omniruntime.utils.OmniErrorType; +import nova.hetu.omniruntime.utils.OmniRuntimeException; +import nova.hetu.omniruntime.vector.BooleanVec; +import nova.hetu.omniruntime.vector.ContainerVec; +import nova.hetu.omniruntime.vector.Decimal128Vec; +import nova.hetu.omniruntime.vector.DictionaryVec; +import nova.hetu.omniruntime.vector.DoubleVec; +import nova.hetu.omniruntime.vector.IntVec; +import nova.hetu.omniruntime.vector.JvmUtils; +import nova.hetu.omniruntime.vector.LongVec; +import nova.hetu.omniruntime.vector.OmniBuffer; +import nova.hetu.omniruntime.vector.OmniBufferFactory; +import nova.hetu.omniruntime.vector.ShortVec; +import nova.hetu.omniruntime.vector.VarcharVec; +import nova.hetu.omniruntime.vector.Vec; +import nova.hetu.omniruntime.vector.VecBatch; +import nova.hetu.omniruntime.vector.VecEncoding; +import nova.hetu.omniruntime.vector.VecFactory; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +/** + * VecBatchSerializer implementation of protobuf. + * + * @since 2021-09-13 + */ +public class ProtoVecBatchSerializer implements VecBatchSerializer { + @Override + public byte[] serialize(VecBatch vecBatch) { + VecBatchSerde.VecBatch.Builder vecBatchBuilder = VecBatchSerde.VecBatch.newBuilder(); + + // set vectors + int index = 0; + for (Vec vector : vecBatch.getVectors()) { + VecBatchSerde.Vec vec = buildProtoVec(vector, null); + vecBatchBuilder.addVectors(index, vec); + index++; + } + + return vecBatchBuilder.setRowCount(vecBatch.getRowCount()).setVecCount(vecBatch.getVectorCount()).build() + .toByteArray(); + } + + private VecBatchSerde.Vec buildProtoVec(Vec vec, int[] ids) { + VecBatchSerde.Vec.Builder protoVecBuilder = VecBatchSerde.Vec.newBuilder(); + VecBatchSerde.DataTypeExt.Builder protoDataTypeExtBuild = VecBatchSerde.DataTypeExt.newBuilder(); + VecBatchSerde.VecEncoding.Builder protoVecEncodingBuild = VecBatchSerde.VecEncoding.newBuilder(); + DataType dataType = vec.getType(); + VecEncoding encoding = vec.getEncoding(); + protoDataTypeExtBuild.setId(VecBatchSerde.DataTypeExt.DataTypeId.valueOf(dataType.getId().name())); + protoVecEncodingBuild.setEncodingId(encoding.ordinal()); + switch (encoding) { + case OMNI_VEC_ENCODING_FLAT: + setProtoDataTypeExt(protoDataTypeExtBuild, dataType); + break; + case OMNI_VEC_ENCODING_DICTIONARY: { + DictionaryVec dictionaryVec = (DictionaryVec) vec; + Vec dictionary = dictionaryVec.expandDictionary(); + VecBatchSerde.Vec protoVec = buildProtoVec(dictionary, null); + dictionary.close(); + return protoVec; + } + case OMNI_VEC_ENCODING_CONTAINER: { + ContainerVec containerVec = (ContainerVec) vec; + int vecCount = containerVec.getDataTypes().length; + DataType[] subVecTypes = containerVec.getDataTypes(); + for (int i = 0; i < vecCount; i++) { + Vec subVec = VecFactory.create(containerVec.getVector(i), containerVec.getVecEncoding(i), + subVecTypes[i]); + VecBatchSerde.Vec subProtoVec = buildProtoVec(subVec, null); + protoVecBuilder.addSubVectors(subProtoVec); + } + break; + } + default: + throw new IllegalStateException("Unexpected encoding: " + encoding); + } + + Vec compactVec = compactVec(vec, ids); + + ByteBuffer valueBuf; + if (compactVec instanceof VarcharVec) { + VarcharVec varcharVec = (VarcharVec) compactVec; + valueBuf = serializeVarcharVector(protoVecBuilder, varcharVec); + } else { + // For fixed vector, only serialize value. + valueBuf = JvmUtils.directBuffer(compactVec.getValuesBuf()); + // only serialize the data actually written + valueBuf.limit(compactVec.getRealValueBufCapacityInBytes()); + } + + ByteBuffer valueNullsBuf = JvmUtils.directBuffer(compactVec.getValueNullsBuf()); + // only serialize the actual null size + valueNullsBuf.limit(compactVec.getRealNullBufCapacityInBytes()); + VecBatchSerde.Vec protoVec = protoVecBuilder.setTypeExt(protoDataTypeExtBuild.build()) + .setVecEncoding(protoVecEncodingBuild.build()).setSize(compactVec.getSize()) + .setValues(ByteString.copyFrom(valueBuf)).setNulls(ByteString.copyFrom(valueNullsBuf)).build(); + + if (compactVec != vec) { + compactVec.close(); + } + return protoVec; + } + + private void setProtoDataTypeExt(VecBatchSerde.DataTypeExt.Builder protoDataTypeExtBuild, DataType dataType) { + switch (dataType.getId()) { + case OMNI_INT: + case OMNI_LONG: + case OMNI_TIMESTAMP: + case OMNI_SHORT: + case OMNI_BOOLEAN: + case OMNI_DOUBLE: + break; + case OMNI_DATE32: + protoDataTypeExtBuild.setDateUnit( + VecBatchSerde.DataTypeExt.DateUnit.valueOf(((Date32DataType) dataType).getDateUnit().name())); + break; + case OMNI_DATE64: + protoDataTypeExtBuild.setDateUnit( + VecBatchSerde.DataTypeExt.DateUnit.valueOf(((Date64DataType) dataType).getDateUnit().name())); + break; + case OMNI_VARCHAR: + protoDataTypeExtBuild.setWidth(((VarcharDataType) dataType).getWidth()); + break; + case OMNI_CHAR: + protoDataTypeExtBuild.setWidth(((CharDataType) dataType).getWidth()); + break; + case OMNI_INTERVAL_DAY_TIME: + protoDataTypeExtBuild.setDateUnit(VecBatchSerde.DataTypeExt.DateUnit.DAY); + break; + case OMNI_INTERVAL_MONTHS: + protoDataTypeExtBuild.setDateUnit(VecBatchSerde.DataTypeExt.DateUnit.MILLI); + break; + case OMNI_DECIMAL64: + case OMNI_DECIMAL128: { + if (dataType instanceof DecimalDataType) { + protoDataTypeExtBuild.setScale(((DecimalDataType) dataType).getScale()); + protoDataTypeExtBuild.setPrecision(((DecimalDataType) dataType).getPrecision()); + } else { + throw new IllegalStateException("Unexpected data type: " + dataType.getId()); + } + break; + } + case OMNI_TIME32: + case OMNI_TIME64: + break; + default: + throw new IllegalStateException("Unexpected data type: " + dataType.getId()); + } + } + + private Vec compactVec(Vec vec, int[] ids) { + // original vec is dictionary vec + if (ids != null) { + return vec.copyPositions(ids, 0, ids.length); + } + // original vec + return vec; + } + + private ByteBuffer serializeVarcharVector(VecBatchSerde.Vec.Builder protoVecBuilder, VarcharVec varcharVec) { + ByteBuffer valueBuf; + ByteBuffer offsetBuf; + int size = varcharVec.getSize(); + int startOffset = varcharVec.getValueOffset(0); + int realOffsetBufCapacity = varcharVec.getRealOffsetBufCapacityInBytes(); + int realValueBufCapacity = varcharVec.getRealValueBufCapacityInBytes(); + if (startOffset > 0) { + // For sliced varchar vector, offset the value base address and offsets to + // ensure that the serialized value is correct. + offsetBuf = ByteBuffer.allocate(realOffsetBufCapacity).order(ByteOrder.LITTLE_ENDIAN); + for (int i = 0; i < size + 1; i++) { + offsetBuf.putInt(varcharVec.getValueOffset(i) - startOffset); + } + offsetBuf.flip(); + protoVecBuilder.setOffsets(ByteString.copyFrom(offsetBuf)); + + long valueBufAddress = varcharVec.getValuesBuf().getAddress() + startOffset; + OmniBuffer omniValueBuf = OmniBufferFactory.create(valueBufAddress, realValueBufCapacity); + valueBuf = JvmUtils.directBuffer(omniValueBuf); + // only serialize the data actually written + valueBuf.limit(realValueBufCapacity); + } else { + // For not sliced varchar vector, serialize value and offset. + offsetBuf = JvmUtils.directBuffer(varcharVec.getOffsetsBuf()); + // only serialize the actual offset size + offsetBuf.limit(realOffsetBufCapacity); + protoVecBuilder.setOffsets(ByteString.copyFrom(offsetBuf)); + + valueBuf = JvmUtils.directBuffer(varcharVec.getValuesBuf()); + // only serialize the data actually written + valueBuf.limit(realValueBufCapacity); + } + return valueBuf; + } + + @Override + public VecBatch deserialize(byte[] bytes) { + VecBatchSerde.VecBatch protoVecBatch; + try { + protoVecBatch = VecBatchSerde.VecBatch.parseFrom(bytes); + } catch (Exception e) { + throw new OmniRuntimeException(OmniErrorType.OMNI_INNER_ERROR, "deserialize failed." + e.getCause()); + } + int vecCount = protoVecBatch.getVecCount(); + int rowCount = protoVecBatch.getRowCount(); + Vec[] vecs = new Vec[vecCount]; + try { + for (int i = 0; i < vecCount; i++) { + vecs[i] = buildVec(protoVecBatch.getVectors(i)); + } + } catch (Exception e) { + for (Vec v : vecs) { + if (v == null) { + continue; + } + v.close(); + } + } + + return new VecBatch(vecs, rowCount); + } + + private Vec buildVec(VecBatchSerde.Vec protoVec) { + VecBatchSerde.DataTypeExt protoTypeExt = protoVec.getTypeExt(); + VecEncoding vecEncoding = VecEncoding.values()[protoVec.getVecEncoding().getEncodingId()]; + VecBatchSerde.DataTypeExt.DataTypeId dataTypeId = protoTypeExt.getId(); + int vecSize = protoVec.getSize(); + Vec vec; + switch (vecEncoding) { + case OMNI_VEC_ENCODING_FLAT: + vec = createFlatVec(vecSize, dataTypeId, protoVec); + vec.setValuesBuf(protoVec.getValues().toByteArray()); + vec.setNullsBuf(protoVec.getNulls().toByteArray()); + return vec; + case OMNI_VEC_ENCODING_CONTAINER: + int vecCount = protoVec.getSubVectorsCount(); + long[] subVecAddresses = new long[vecCount]; + DataType[] subDataTypes = new DataType[vecCount]; + for (int i = 0; i < vecCount; i++) { + Vec subVec = buildVec(protoVec.getSubVectors(i)); + subVecAddresses[i] = subVec.getNativeVector(); + subDataTypes[i] = subVec.getType(); + } + return new ContainerVec(vecCount, protoVec.getSize(), subVecAddresses, subDataTypes); + default: + throw new IllegalStateException("Unexpected encoding: " + vecEncoding); + } + } + + private Vec createFlatVec(int vecSize, VecBatchSerde.DataTypeExt.DataTypeId dataTypeId, + VecBatchSerde.Vec protoVec) { + Vec vec; + switch (dataTypeId) { + case OMNI_INT: + case OMNI_DATE32: + vec = new IntVec(vecSize); + break; + case OMNI_LONG: + case OMNI_TIMESTAMP: + case OMNI_DATE64: + case OMNI_DECIMAL64: + vec = new LongVec(vecSize); + break; + case OMNI_SHORT: + vec = new ShortVec(vecSize); + break; + case OMNI_BOOLEAN: + vec = new BooleanVec(vecSize); + break; + case OMNI_DOUBLE: + vec = new DoubleVec(vecSize); + break; + case OMNI_VARCHAR: + case OMNI_CHAR: + VarcharVec varcharVec = new VarcharVec(protoVec.getValues().size(), vecSize); + varcharVec.setOffsetsBuf(protoVec.getOffsets().toByteArray()); + vec = varcharVec; + break; + case OMNI_DECIMAL128: + vec = new Decimal128Vec(vecSize); + break; + case OMNI_TIME32: + case OMNI_TIME64: + case OMNI_INTERVAL_DAY_TIME: + case OMNI_INTERVAL_MONTHS: + default: + throw new IllegalStateException("Unexpected data type: " + protoVec.getTypeExt().getId()); + } + return vec; + } +} diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/vector/serialize/VecBatchSerializer.java b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/serialize/VecBatchSerializer.java new file mode 100644 index 0000000..4c032d1 --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/serialize/VecBatchSerializer.java @@ -0,0 +1,30 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.vector.serialize; + +import nova.hetu.omniruntime.vector.VecBatch; + +/** + * define the serialization interface of VecBatch. + * + * @since 2021-09-13 + */ +public interface VecBatchSerializer { + /** + * serialize vecBatch. + * + * @param vecBatch the vecbatch to be serialized + * @return byte array + */ + byte[] serialize(VecBatch vecBatch); + + /** + * deserialize vecbatch. + * + * @param bytes serialized vecbatch + * @return vec batch + */ + VecBatch deserialize(byte[] bytes); +} \ No newline at end of file diff --git a/bindings/java/src/main/java/nova/hetu/omniruntime/vector/serialize/VecBatchSerializerFactory.java b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/serialize/VecBatchSerializerFactory.java new file mode 100644 index 0000000..237939b --- /dev/null +++ b/bindings/java/src/main/java/nova/hetu/omniruntime/vector/serialize/VecBatchSerializerFactory.java @@ -0,0 +1,21 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.vector.serialize; + +/** + * create different types of serialization implementation. + * + * @since 2021-09-13 + */ +public class VecBatchSerializerFactory { + /** + * new a vec batch serializer object. + * + * @return protobuf vec batch serializer + */ + public static VecBatchSerializer create() { + return new ProtoVecBatchSerializer(); + } +} diff --git a/bindings/java/src/main/proto/vec_batch_serde.proto b/bindings/java/src/main/proto/vec_batch_serde.proto new file mode 100644 index 0000000..4c5cc1f --- /dev/null +++ b/bindings/java/src/main/proto/vec_batch_serde.proto @@ -0,0 +1,67 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ +syntax = "proto3"; + +package nova.hetu.omniruntime.vector.serialize; + +message VecBatch { + int32 rowCount = 1; + int32 vecCount = 2; + repeated Vec vectors = 3; +} + +message Vec { + DataTypeExt typeExt = 1; + VecEncoding vecEncoding = 2; + int32 size = 3; + int32 offset = 4; + bytes values = 5; + bytes nulls = 6; + bytes offsets = 7; + repeated Vec subVectors = 8; +} + +message VecEncoding{ + int32 encodingId = 1; +} + +message DataTypeExt { + enum DataTypeId { + OMNI_NONE = 0; + OMNI_INT = 1; + OMNI_LONG = 2; + OMNI_DOUBLE = 3; + OMNI_BOOLEAN = 4; + OMNI_SHORT = 5; + OMNI_DECIMAL64 = 6; + OMNI_DECIMAL128 = 7; + OMNI_DATE32 = 8; + OMNI_DATE64 = 9; + OMNI_TIME32 = 10; + OMNI_TIME64 = 11; + OMNI_TIMESTAMP = 12; + OMNI_INTERVAL_MONTHS = 13; + OMNI_INTERVAL_DAY_TIME = 14; + OMNI_VARCHAR = 15; + OMNI_CHAR = 16; + OMNI_CONTAINER = 17; + OMNI_INVALID = 18; + } + DataTypeId id = 1; + int32 width = 2; + int32 precision = 3; + int32 scale = 4; + enum DateUnit { + DAY = 0; + MILLI = 1; + } + DateUnit dateUnit = 5; + enum TimeUnit { + SEC = 0; + MILLISEC = 1; + MICROSEC = 2; + NANOSEC = 3; + } + TimeUnit timeUnit = 6; +} \ No newline at end of file diff --git a/bindings/java/src/main/scripts/build_cpp.sh b/bindings/java/src/main/scripts/build_cpp.sh new file mode 100644 index 0000000..4b4f490 --- /dev/null +++ b/bindings/java/src/main/scripts/build_cpp.sh @@ -0,0 +1,36 @@ +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved. + +set -e + +if [ -z "$OMNI_HOME" ] +then + echo "[ERROR] can not found environment variable OMNI_HOME." >&2 + exit 1 +fi + +test -d $OMNI_HOME/java-binding && JAVA_BINDING=$(find $OMNI_HOME/java-binding -name '*.so') + +build_params=($1) + +if [ ! -z "${build_params}" ] && [ "${build_params}" != 'NONE' ] +then + echo "[INFO] build omni library." + export LIBRARY_PATH=$OMNI_HOME/lib:$LIBRARY_PATH + export LD_LIBRARY_PATH=$OMNI_HOME/lib:$LD_LIBRARY_PATH + + build_type=$(echo "${build_params[0]}" | awk -F ':' '{print $1}') + unset build_params[0] + + . build_scripts/build.sh ${build_type}:java ${build_params[@]} +else + if [ -z "$JAVA_BINDING" ] + then + echo "[INFO] no valid lib found in OMNI_HOME." >&2 + else + echo "[INFO] using $(basename $JAVA_BINDING) in OMNI_HOME" + fi +fi + + + diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/constant/ConstantLoadTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/constant/ConstantLoadTest.java new file mode 100644 index 0000000..91d4d63 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/constant/ConstantLoadTest.java @@ -0,0 +1,61 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.constant; + +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_SUM; +import static nova.hetu.omniruntime.constants.Status.OMNI_STATUS_ERROR; +import static nova.hetu.omniruntime.constants.Status.OMNI_STATUS_FINISHED; +import static nova.hetu.omniruntime.constants.Status.OMNI_STATUS_NORMAL; +import static org.testng.Assert.assertEquals; + +import nova.hetu.omniruntime.OmniLibs; +import nova.hetu.omniruntime.constants.FunctionType; +import nova.hetu.omniruntime.constants.Status; + +import org.testng.annotations.Test; + +/** + * The type Constant load test. + * + * @since 2021-07-13 + */ +public class ConstantLoadTest { + /** + * Test load constant. + */ + @Test + public void testLoadConstant() { + Status status = OMNI_STATUS_NORMAL; + assertEquals(status.getValue(), 0); + status = OMNI_STATUS_ERROR; + assertEquals(status.getValue(), -1); + status = OMNI_STATUS_FINISHED; + assertEquals(status.getValue(), 1); + + FunctionType functionType = OMNI_AGGREGATION_TYPE_SUM; + assertEquals(functionType.getValue(), 0); + } + + /** + * Test equals. + */ + @Test + public void testEquals() { + Status status = OMNI_STATUS_NORMAL; + assertEquals(status.getValue(), 0); + } + + /** + * Test getVersion. + */ + @Test + public void testGetVersion() { + String version = OmniLibs.getVersion(); + String expected = "Product Name: Kunpeng BoostKit" + System.lineSeparator() + + "Product Version: 25.0.0" + System.lineSeparator() + "Component Name: BoostKit-omniop" + + System.lineSeparator() + "Component Version: 1.9.0"; + assertEquals(version, expected); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/memory/TestMemoryManager.java b/bindings/java/src/test/java/nova/hetu/omniruntime/memory/TestMemoryManager.java new file mode 100644 index 0000000..83fff07 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/memory/TestMemoryManager.java @@ -0,0 +1,64 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.memory; + +import static nova.hetu.omniruntime.memory.MemoryManager.clearMemory; +import static nova.hetu.omniruntime.memory.MemoryManager.setGlobalMemoryLimit; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +import nova.hetu.omniruntime.utils.OmniRuntimeException; +import nova.hetu.omniruntime.vector.IntVec; +import nova.hetu.omniruntime.vector.LongVec; + +import org.testng.annotations.Test; + +/** + * test memory manager + * + * @since 2023-01-17 + */ +public class TestMemoryManager { + @Test + public void testAllocatorBasic() { + long limit = -1; + clearMemory(); + setGlobalMemoryLimit(limit); + + MemoryManager memoryManager = new MemoryManager(); + int size = 1024 * 1024; + IntVec intVec = new IntVec(size); + // 4325384 = values(size * 4) + nulls(size) + other(CreateFlatVector_ptr(64)) + assertTrue(memoryManager.getAllocatedMemory() >= 4325384); + LongVec longVec = new LongVec(size); + // 12845072 = 5242944 + values(size * 8) + nulls(size) + other(CreateFlatVector_ptr(64)) + assertTrue(memoryManager.getAllocatedMemory() >= 12845072); + intVec.close(); + // 8519688 = 12845072 - 4325384 + assertTrue(memoryManager.getAllocatedMemory() >= 8519688); + longVec.close(); + assertEquals(memoryManager.getAllocatedMemory(), 0); + } + + @Test(expectedExceptions = OmniRuntimeException.class, expectedExceptionsMessageRegExp = "memory cap exceeded") + public void testMemoryManagerBeyondLimit() { + long limit = 1024L * 1024L; + int size = 1024 * 1024 + 1; + setGlobalMemoryLimit(limit); + + LongVec longVec = null; + try { + longVec = new LongVec(size); + } catch (OmniRuntimeException e) { + throw new OmniRuntimeException("memory cap exceeded"); + } finally { + if (longVec != null) { + longVec.close(); + } + clearMemory(); + } + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniAggregationOperatorTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniAggregationOperatorTest.java new file mode 100644 index 0000000..dc05298 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniAggregationOperatorTest.java @@ -0,0 +1,486 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.operator; + +import static java.lang.String.format; +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_AVG; +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_COUNT_ALL; +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_COUNT_COLUMN; +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_MAX; +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_MIN; +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_SUM; +import static nova.hetu.omniruntime.util.TestUtils.assertVecBatchEquals; +import static nova.hetu.omniruntime.util.TestUtils.createVec; +import static nova.hetu.omniruntime.util.TestUtils.freeVecBatch; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.fail; + +import com.google.common.collect.ImmutableList; + +import nova.hetu.omniruntime.constants.FunctionType; +import nova.hetu.omniruntime.operator.aggregator.OmniAggregationOperatorFactory; +import nova.hetu.omniruntime.operator.aggregator.OmniAggregationOperatorFactory.FactoryContext; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.type.CharDataType; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.Decimal128DataType; +import nova.hetu.omniruntime.type.DoubleDataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.type.TimestampDataType; +import nova.hetu.omniruntime.type.VarcharDataType; +import nova.hetu.omniruntime.vector.Decimal128Vec; +import nova.hetu.omniruntime.vector.DoubleVec; +import nova.hetu.omniruntime.vector.IntVec; +import nova.hetu.omniruntime.vector.LongVec; +import nova.hetu.omniruntime.vector.VarcharVec; +import nova.hetu.omniruntime.vector.Vec; +import nova.hetu.omniruntime.vector.VecBatch; + +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +/** + * The type Omni aggregation operator test. + * + * @since 2021-06-09 + */ +public class OmniAggregationOperatorTest { + /** + * test aggregation performance whether with jit or not. + */ + @Test + public void testAggregationComparePerf() { + DataType[] sourceTypes = {LongDataType.LONG, LongDataType.LONG}; + FunctionType[] aggFunctionTypes = {OMNI_AGGREGATION_TYPE_COUNT_COLUMN, OMNI_AGGREGATION_TYPE_COUNT_ALL, + OMNI_AGGREGATION_TYPE_COUNT_COLUMN}; + int[] aggInputChannels = {0, 1}; + int[] maskChannels = {-1, -1, -1}; + DataType[] aggOutputTypes = {LongDataType.LONG, LongDataType.LONG, LongDataType.LONG}; + OmniAggregationOperatorFactory factoryWithJit = new OmniAggregationOperatorFactory(sourceTypes, + aggFunctionTypes, aggInputChannels, maskChannels, aggOutputTypes, true, false, new OperatorConfig()); + OmniOperator omniOperatorWithJit = factoryWithJit.createOperator(); + + ImmutableList.Builder vecBatchList1 = ImmutableList.builder(); + int rowNum = 100000; + int pageCount = 10; + for (int i = 0; i < pageCount; i++) { + vecBatchList1.add(new VecBatch(buildDataForCount(rowNum))); + } + + long start1 = System.currentTimeMillis(); + for (VecBatch vecBatch : vecBatchList1.build()) { + omniOperatorWithJit.addInput(vecBatch); + } + + Iterator outputWithJit = omniOperatorWithJit.getOutput(); + long end1 = System.currentTimeMillis(); + System.out.println("Aggregation with jit use " + (end1 - start1) + " ms."); + + OmniAggregationOperatorFactory factoryWithoutJit = new OmniAggregationOperatorFactory(sourceTypes, + aggFunctionTypes, aggInputChannels, maskChannels, aggOutputTypes, true, false, new OperatorConfig()); + OmniOperator omniOperatorWithoutJit = factoryWithoutJit.createOperator(); + + ImmutableList.Builder vecBatchList2 = ImmutableList.builder(); + for (int i = 0; i < pageCount; i++) { + vecBatchList2.add(new VecBatch(buildDataForCount(rowNum))); + } + + long start2 = System.currentTimeMillis(); + for (VecBatch vecBatch : vecBatchList2.build()) { + omniOperatorWithoutJit.addInput(vecBatch); + } + + Iterator outputWithoutJit = omniOperatorWithoutJit.getOutput(); + long end2 = System.currentTimeMillis(); + System.out.println("Aggregation without jit use " + (end2 - start2) + " ms."); + + while (outputWithJit.hasNext()) { + VecBatch resultWithJit = outputWithJit.next(); + VecBatch resultWithoutJit = outputWithoutJit.next(); + assertVecBatchEquals(resultWithJit, resultWithoutJit); + freeVecBatch(resultWithJit); + freeVecBatch(resultWithoutJit); + } + + omniOperatorWithJit.close(); + omniOperatorWithoutJit.close(); + factoryWithJit.close(); + factoryWithoutJit.close(); + } + + /** + * Test execute agg multiple page. + */ + @Test + public void testExecuteCountMultiplePage() { + DataType[] sourceTypes = {LongDataType.LONG, LongDataType.LONG}; + FunctionType[] aggFunctionTypes = {OMNI_AGGREGATION_TYPE_COUNT_COLUMN, OMNI_AGGREGATION_TYPE_COUNT_ALL, + OMNI_AGGREGATION_TYPE_COUNT_COLUMN}; + int[] aggInputChannels = {0, 1}; + int[] maskChannels = {-1, -1, -1}; + DataType[] aggOutputTypes = {LongDataType.LONG, LongDataType.LONG, LongDataType.LONG}; + OmniAggregationOperatorFactory factory = new OmniAggregationOperatorFactory(sourceTypes, aggFunctionTypes, + aggInputChannels, maskChannels, aggOutputTypes, true, false); + + ImmutableList.Builder vecBatchList = ImmutableList.builder(); + int rowNum = 100; + int pageCount = 10; + for (int i = 0; i < pageCount; i++) { + vecBatchList.add(new VecBatch(buildDataForCount(rowNum))); + } + + OmniOperator omniOperator = factory.createOperator(); + for (VecBatch vecBatch : vecBatchList.build()) { + omniOperator.addInput(vecBatch); + } + + Iterator output = omniOperator.getOutput(); + while (output.hasNext()) { + VecBatch vecBatch = output.next(); + if (vecBatch.getVectors().length != aggOutputTypes.length) { + throw new IllegalArgumentException( + format("output vec size error: result size: %s, outputTypes size: %s,rows: %s", + vecBatch.getVectors().length, aggOutputTypes.length, vecBatch.getRowCount())); + } + + assertNotNull(vecBatch); + assertEquals(vecBatch.getVectors().length, 3); + Vec[] vectors = vecBatch.getVectors(); + assertEquals(((LongVec) vectors[0]).get(0), 1000L); + assertEquals(((LongVec) vectors[1]).get(0), 1000L); + assertEquals(((LongVec) vectors[2]).get(0), 500L); + + freeVecBatch(vecBatch); + } + + omniOperator.close(); + factory.close(); + } + + @Test + public void testFactoryContextEquals() { + DataType[] sourceTypes = {LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, LongDataType.LONG}; + FunctionType[] aggFunctionTypes = {OMNI_AGGREGATION_TYPE_SUM, OMNI_AGGREGATION_TYPE_SUM, + OMNI_AGGREGATION_TYPE_SUM, OMNI_AGGREGATION_TYPE_SUM}; + int[] aggInputChannels = {0, 1, 2, 3}; + int[] maskChannels = {-1, -1, -1, -1}; + DataType[] aggOutputTypes = {LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, LongDataType.LONG}; + FactoryContext factory1 = new FactoryContext(sourceTypes, aggFunctionTypes, aggInputChannels, maskChannels, + aggOutputTypes, true, false, new OperatorConfig()); + FactoryContext factory2 = new FactoryContext(sourceTypes, aggFunctionTypes, aggInputChannels, maskChannels, + aggOutputTypes, true, false, new OperatorConfig()); + FactoryContext factory3 = null; + + assertEquals(factory2, factory1); + assertEquals(factory1, factory1); + assertNotEquals(factory3, factory1); + } + + @Test + public void testExecuteMinMax() { + DataType[] sourceTypes = {IntDataType.INTEGER, LongDataType.LONG, DoubleDataType.DOUBLE, new CharDataType(20), + new VarcharDataType(20), new Decimal128DataType(20, 5), TimestampDataType.TIMESTAMP}; + FunctionType minFn = OMNI_AGGREGATION_TYPE_MIN; + FunctionType maxFn = OMNI_AGGREGATION_TYPE_MAX; + FunctionType[] aggFunctionTypes = {minFn, minFn, minFn, minFn, minFn, minFn, minFn, maxFn, maxFn, maxFn, maxFn, + maxFn, maxFn, maxFn}; + int[] aggInputChannels = {0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6}; + int[] maskChannels = {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}; + DataType[] aggOutputTypes = {IntDataType.INTEGER, LongDataType.LONG, DoubleDataType.DOUBLE, + new CharDataType(20), new VarcharDataType(20), new Decimal128DataType(20, 5), + TimestampDataType.TIMESTAMP, IntDataType.INTEGER, LongDataType.LONG, DoubleDataType.DOUBLE, + new CharDataType(20), new VarcharDataType(20), new Decimal128DataType(20, 5), + TimestampDataType.TIMESTAMP}; + OmniAggregationOperatorFactory factory = new OmniAggregationOperatorFactory(sourceTypes, aggFunctionTypes, + aggInputChannels, maskChannels, aggOutputTypes, true, false); + + Object[][] sampleDatas = {{2, 1, 5, 3, 1}, {3L, 10L, 2L, 7L, 3L}, {12.3, 7.2, 20.5, 6.1, 12.3}, + {"hello", "world", "c++", "shell", "golang"}, {"operator", "vectorBatch", "udf", "expression", "omni"}}; + Object[][] decimalDatas = {{4000L, 0L}, {2000L, 0L}, {1000L, 0L}, {2000L, 0L}, {5000L, 0L}}; + Object[] timestampDatas = {3000L, 1000L, 5000L, 2000L, 4000L}; + + Vec[] buildVecs = new Vec[sourceTypes.length]; + for (int i = 0; i < 5; i++) { + buildVecs[i] = createVec(sourceTypes[i], sampleDatas[i]); + } + buildVecs[5] = createVec(sourceTypes[5], decimalDatas); + buildVecs[6] = createVec(sourceTypes[6], timestampDatas); + + VecBatch inputData = new VecBatch(buildVecs); + OmniOperator omniOperator = factory.createOperator(); + omniOperator.addInput(inputData); + Iterator output = omniOperator.getOutput(); + + while (output.hasNext()) { + VecBatch vecBatch = output.next(); + if (vecBatch.getVectors().length != aggOutputTypes.length) { + throw new IllegalArgumentException( + format("output vec size error: result size: %s, outputTypes size: %s,rows: %s", + vecBatch.getVectors().length, aggOutputTypes.length, vecBatch.getRowCount())); + } + + assertNotNull(vecBatch); + assertEquals(vecBatch.getVectors().length, 14); + Vec[] vectors = vecBatch.getVectors(); + assertEquals(((IntVec) vectors[0]).get(0), 1); + assertEquals(((LongVec) vectors[1]).get(0), 2L); + assertEquals(((DoubleVec) vectors[2]).get(0), 6.1); + assertEquals(new String(((VarcharVec) vectors[3]).get(0)), "c++"); + assertEquals(new String(((VarcharVec) vectors[4]).get(0)), "expression"); + assertEquals(((Decimal128Vec) vectors[5]).get(0), new Object[]{1000L, 0L}); + assertEquals(((LongVec) vectors[6]).get(0), 1000L); + assertEquals(((IntVec) vectors[7]).get(0), 5); + assertEquals(((LongVec) vectors[8]).get(0), 10L); + assertEquals(((DoubleVec) vectors[9]).get(0), 20.5); + assertEquals(new String(((VarcharVec) vectors[10]).get(0)), "world"); + assertEquals(new String(((VarcharVec) vectors[11]).get(0)), "vectorBatch"); + assertEquals(((Decimal128Vec) vectors[12]).get(0), new Object[]{5000L, 0L}); + assertEquals(((LongVec) vectors[13]).get(0), 5000L); + + freeVecBatch(vecBatch); + } + omniOperator.close(); + factory.close(); + } + + @Test + public void testExecuteSumAvgMultipleStage() { + DataType[] sourceTypes = {LongDataType.LONG, new Decimal128DataType(20, 5)}; + FunctionType sumFn = OMNI_AGGREGATION_TYPE_SUM; + FunctionType avgFn = OMNI_AGGREGATION_TYPE_AVG; + FunctionType[] aggFunctionTypes = {sumFn, avgFn}; + int[] aggInputChannels = {0, 1}; + int[] maskChannels = {-1, -1}; + + Object[][] sampleDatas = {{3L, 10L, 2L, 7L, 3L}}; + Object[][] decimalDatas = {{4000L, 0L}, {2000L, 0L}, {1000L, 0L}, {3000L, 0L}, {5000L, 0L}}; + + Vec[] buildVecs = new Vec[sourceTypes.length]; + for (int i = 0; i < 1; i++) { + buildVecs[i] = createVec(sourceTypes[i], sampleDatas[i]); + } + buildVecs[1] = createVec(sourceTypes[1], decimalDatas); + VecBatch inputData = new VecBatch(buildVecs); + + DataType[] finalAggOutputTypes = {LongDataType.LONG, new Decimal128DataType(20, 5)}; + OmniAggregationOperatorFactory factory = new OmniAggregationOperatorFactory(sourceTypes, + aggFunctionTypes, aggInputChannels, maskChannels, finalAggOutputTypes, true, false); + + OmniOperator omniOperator = factory.createOperator(); + omniOperator.addInput(inputData); + + Iterator finalOutput = omniOperator.getOutput(); + while (finalOutput.hasNext()) { + VecBatch finalVecBatch = finalOutput.next(); + if (finalVecBatch.getVectors().length != finalAggOutputTypes.length) { + throw new IllegalArgumentException(format( + "output vec size error: result size: %s, outputTypes size: %s,rows: %s", + finalVecBatch.getVectors().length, finalAggOutputTypes.length, finalVecBatch.getRowCount())); + } + + assertNotNull(finalVecBatch); + assertEquals(finalVecBatch.getVectors().length, 2); + Vec[] vectors = finalVecBatch.getVectors(); + assertEquals(((LongVec) vectors[0]).get(0), 25L); + assertEquals(((Decimal128Vec) vectors[1]).get(0), new Object[]{3000L, 0L}); + freeVecBatch(finalVecBatch); + } + omniOperator.close(); + factory.close(); + } + + @Test + public void testExecuteAggMultiplePage() { + DataType[] sourceTypes = {LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, LongDataType.LONG}; + FunctionType[] aggFunctionTypes = {OMNI_AGGREGATION_TYPE_SUM, OMNI_AGGREGATION_TYPE_SUM, + OMNI_AGGREGATION_TYPE_SUM, OMNI_AGGREGATION_TYPE_SUM}; + int[] aggInputChannels = {0, 1, 2, 3}; + int[] maskChannels = {-1, -1, -1, -1}; + DataType[] aggOutputTypes = {LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, LongDataType.LONG}; + OmniAggregationOperatorFactory factory = new OmniAggregationOperatorFactory(sourceTypes, aggFunctionTypes, + aggInputChannels, maskChannels, aggOutputTypes, true, false); + + List inputData = new ArrayList<>(); + ImmutableList.Builder vecBatchList = ImmutableList.builder(); + int rowNum = 40000; + int pageCount = 10; + for (int i = 0; i < pageCount; i++) { + inputData.addAll(build4Columns(rowNum)); + vecBatchList.add(new VecBatch(build4Columns(rowNum))); + } + + OmniOperator omniOperator = factory.createOperator(); + for (VecBatch vecBatch : vecBatchList.build()) { + omniOperator.addInput(vecBatch); + } + // release input data memory + releaseVecMemory(inputData.toArray(new Vec[0])); + + Iterator output = omniOperator.getOutput(); + while (output.hasNext()) { + VecBatch vecBatch = output.next(); + if (vecBatch.getVectors().length != aggOutputTypes.length) { + throw new IllegalArgumentException( + format("output vec size error: result size: %s, outputTypes size: %s,rows: %s", + vecBatch.getVectors().length, aggOutputTypes.length, vecBatch.getRowCount())); + } + + assertNotNull(vecBatch); + assertEquals(vecBatch.getVectors().length, 4); + Vec[] vectors = vecBatch.getVectors(); + assertEquals(((LongVec) vectors[0]).get(0), 0); + assertEquals(((LongVec) vectors[1]).get(0), 0); + assertEquals(((LongVec) vectors[2]).get(0), rowNum * pageCount); + assertEquals(((LongVec) vectors[3]).get(0), rowNum * pageCount); + + freeVecBatch(vecBatch); + } + omniOperator.close(); + factory.close(); + } + + /** + * Test execute agg multiple thread. + */ + @Test + public void testExecuteAggMultipleThread() { + int pageCount = 10; + int threadCount = 10; + int rowNum = 100; + multiThreadExecution(threadCount, rowNum, pageCount); + } + + private void multiThreadExecution(int threadCount, int rowNum, int pageCount) { + DataType[] sourceTypes = {LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, LongDataType.LONG}; + FunctionType[] aggFunctionTypes = {OMNI_AGGREGATION_TYPE_SUM, OMNI_AGGREGATION_TYPE_SUM, + OMNI_AGGREGATION_TYPE_SUM, OMNI_AGGREGATION_TYPE_SUM}; + int[] aggInputChannels = {0, 1, 2, 3}; + int[] maskChannels = {-1, -1, -1, -1}; + DataType[] aggOutputTypes = {LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, LongDataType.LONG}; + OmniAggregationOperatorFactory factory = new OmniAggregationOperatorFactory(sourceTypes, aggFunctionTypes, + aggInputChannels, maskChannels, aggOutputTypes, true, false); + + CountDownLatch downLatch = new CountDownLatch(threadCount); + final int corePoolSize = 10; + final int maximumPoolSize = 50; + ThreadPoolExecutor threadPool = new ThreadPoolExecutor(corePoolSize, maximumPoolSize, 60L, TimeUnit.SECONDS, + new LinkedBlockingQueue<>(threadCount)); + + for (int tIdx = 0; tIdx < threadCount; tIdx++) { + CompletableFuture.runAsync(() -> { + try { + ImmutableList.Builder vecBatchList = ImmutableList.builder(); + for (int i = 0; i < pageCount; i++) { + vecBatchList.add(new VecBatch(build4Columns(rowNum))); + } + + OmniOperator omniOperator = factory.createOperator(); + for (VecBatch vecBatch : vecBatchList.build()) { + omniOperator.addInput(vecBatch); + } + + assertResult(rowNum, pageCount, aggOutputTypes, omniOperator); + omniOperator.close(); + } finally { + downLatch.countDown(); + } + }, threadPool); + } + + try { + downLatch.await(); + } catch (InterruptedException ex) { + fail(); + } + + threadPool.shutdown(); + factory.close(); + } + + private void assertResult(int rowNum, int pageCount, DataType[] aggOutputTypes, OmniOperator omniOperator) { + Iterator output = omniOperator.getOutput(); + while (output.hasNext()) { + VecBatch vecBatch = output.next(); + if (vecBatch.getVectors().length != aggOutputTypes.length) { + throw new IllegalArgumentException( + format("output vec size error: result size: %s, outputTypes size: %s,rows: %s", + vecBatch.getVectors().length, aggOutputTypes.length, vecBatch.getRowCount())); + } + + assertNotNull(vecBatch); + assertEquals(vecBatch.getVectors().length, 4); + Vec[] vectors = vecBatch.getVectors(); + assertEquals(((LongVec) vectors[0]).get(0), 0); + assertEquals(((LongVec) vectors[1]).get(0), 0); + assertEquals(((LongVec) vectors[2]).get(0), rowNum * pageCount); + assertEquals(((LongVec) vectors[3]).get(0), rowNum * pageCount); + + freeVecBatch(vecBatch); + } + } + + private List buildDataForCount(int rowNum) { + LongVec c1 = new LongVec(rowNum); + for (int i = 0; i < rowNum; i++) { + c1.set(i, 0); + } + + LongVec c2 = new LongVec(rowNum); + for (int i = 0; i < rowNum; i++) { + if (i % 2 == 0) { + c2.set(i, 1); + } else { + c2.setNull(i); + } + } + + List columns = new ArrayList<>(); + columns.add(c1); + columns.add(c2); + + return columns; + } + + private List build4Columns(int rowNum) { + LongVec c1 = new LongVec(rowNum); + LongVec c2 = new LongVec(rowNum); + for (int i = 0; i < rowNum; i++) { + c1.set(i, 0); + c2.set(i, 0); + } + + LongVec c3 = new LongVec(rowNum); + LongVec c4 = new LongVec(rowNum); + for (int i = 0; i < rowNum; i++) { + c3.set(i, 1); + c4.set(i, 1); + } + + List columns = new ArrayList<>(); + columns.add(c1); + columns.add(c2); + columns.add(c3); + columns.add(c4); + + return columns; + } + + private void releaseVecMemory(Vec[] vecs) { + for (Vec vec : vecs) { + vec.close(); + } + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniBloomFilterOperatorTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniBloomFilterOperatorTest.java new file mode 100644 index 0000000..63057b6 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniBloomFilterOperatorTest.java @@ -0,0 +1,43 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.operator; + +import static nova.hetu.omniruntime.util.TestUtils.createVecBatch; +import static nova.hetu.omniruntime.util.TestUtils.freeVecBatch; +import static org.testng.Assert.assertEquals; + +import nova.hetu.omniruntime.operator.filter.OmniBloomFilterOperatorFactory; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.vector.VecBatch; + +import org.testng.annotations.Test; + +import java.util.Iterator; + +/** + * The type Omni bloom filter operator test. + * + * @since 2023-03-03 + */ +public class OmniBloomFilterOperatorTest { + @Test + public void testCreateBloomFilter() { + DataType[] types = {IntDataType.INTEGER}; + Object[][] datas = {{1, 6, 4, 0, 0, 0, 0, 0, 0, 0, 0}}; + VecBatch inputVecBatch = createVecBatch(types, datas); + + OmniBloomFilterOperatorFactory factory = new OmniBloomFilterOperatorFactory(1); + OmniOperator op = factory.createOperator(); + op.addInput(inputVecBatch); + Iterator results = op.getOutput(); + VecBatch resultVecBatch = results.next(); + assertEquals(resultVecBatch.getVectorCount(), 1); + assertEquals(resultVecBatch.getRowCount(), 1); + freeVecBatch(resultVecBatch); + op.close(); + factory.close(); + } +} \ No newline at end of file diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniDistinctLimitOperatorTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniDistinctLimitOperatorTest.java new file mode 100644 index 0000000..cb91011 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniDistinctLimitOperatorTest.java @@ -0,0 +1,151 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.operator; + +import static nova.hetu.omniruntime.util.TestUtils.assertVecBatchEquals; +import static nova.hetu.omniruntime.util.TestUtils.createVecBatch; +import static nova.hetu.omniruntime.util.TestUtils.freeVecBatch; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; + +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.operator.limit.OmniDistinctLimitOperatorFactory; +import nova.hetu.omniruntime.operator.limit.OmniDistinctLimitOperatorFactory.FactoryContext; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DoubleDataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.type.VarcharDataType; +import nova.hetu.omniruntime.vector.VecBatch; + +import org.testng.annotations.Test; + +import java.util.Iterator; + +/** + * The type Omni distinct limit operator test. + * + * @since 2021-11-27 + */ +public class OmniDistinctLimitOperatorTest { + @Test + public void testDistinctLimitBasic() { + DataType[] sourceTypes = {IntDataType.INTEGER, DoubleDataType.DOUBLE, VarcharDataType.VARCHAR}; + Object[][] sourceDatas1 = {{0, 1, 2, 0, 1, 2}, {6.6, 5.5, 4.4, 6.6, 5.5, 1.1}, + {"abc", "hello", "world", "abc", "helle", "test"}}; + VecBatch vecBatch1 = createVecBatch(sourceTypes, sourceDatas1); + + int[] distinctCols = {0, 1, 2}; + OmniDistinctLimitOperatorFactory distinctLimitOperatorFactory = new OmniDistinctLimitOperatorFactory( + sourceTypes, distinctCols, -1, sourceDatas1[0].length - 1); + OmniOperator distinctLimitOperator = distinctLimitOperatorFactory.createOperator(); + distinctLimitOperator.addInput(vecBatch1); + Iterator results = distinctLimitOperator.getOutput(); + + Object[][] expectedDatas1 = {{0, 1, 2, 1, 2}, {6.6, 5.5, 4.4, 5.5, 1.1}, + {"abc", "hello", "world", "helle", "test"}}; + + VecBatch resultVecBatch1 = results.next(); + assertVecBatchEquals(resultVecBatch1, expectedDatas1); + + freeVecBatch(resultVecBatch1); + distinctLimitOperator.close(); + distinctLimitOperatorFactory.close(); + } + + @Test + public void testDistinctLimitColTypesCover() { + DataType[] sourceTypes = {LongDataType.LONG, IntDataType.INTEGER, VarcharDataType.VARCHAR, IntDataType.INTEGER, + DoubleDataType.DOUBLE, VarcharDataType.VARCHAR}; + Object[][] sourceDatas1 = {{10000L, 20000L, 10000L}, {3, 4, 5}, {"aaa", "bbb", "ccc"}, {0, 1, 0}, + {6.6, 5.5, 6.6}, {"hello", "world", "hello"}}; + VecBatch vecBatch1 = createVecBatch(sourceTypes, sourceDatas1); + + int[] distinctCols = {3, 4, 5}; + OmniDistinctLimitOperatorFactory distinctLimitOperatorFactory = new OmniDistinctLimitOperatorFactory( + sourceTypes, distinctCols, 0, sourceDatas1[0].length); + OmniOperator distinctLimitOperator = distinctLimitOperatorFactory.createOperator(); + distinctLimitOperator.addInput(vecBatch1); + Iterator results = distinctLimitOperator.getOutput(); + + // out put order: distinct cols => normal cols => hash col + Object[][] expectedDatas1 = {{0, 1}, {6.6, 5.5}, {"hello", "world"}, {10000L, 20000L}}; + + VecBatch resultVecBatch1 = results.next(); + assertVecBatchEquals(resultVecBatch1, expectedDatas1); + + freeVecBatch(resultVecBatch1); + distinctLimitOperator.close(); + distinctLimitOperatorFactory.close(); + } + + @Test + public void testDistinctLimitWithNull() { + DataType[] sourceTypes = {IntDataType.INTEGER, DoubleDataType.DOUBLE, VarcharDataType.VARCHAR}; + Object[][] sourceDatas1 = {{0, 1, 2, 0, null, 2, null, null, 2, null}, + {6.6, 5.5, 4.4, 6.6, 5.5, null, null, 5.5, null, null}, + {"abc", "hello", "world", null, "hello", "world", null, "hello", "world", null}}; + VecBatch vecBatch1 = createVecBatch(sourceTypes, sourceDatas1); + + int[] distinctCols = {0, 1, 2}; + OmniDistinctLimitOperatorFactory distinctLimitOperatorFactory = new OmniDistinctLimitOperatorFactory( + sourceTypes, distinctCols, -1, sourceDatas1[0].length); + OmniOperator distinctLimitOperator = distinctLimitOperatorFactory.createOperator(); + distinctLimitOperator.addInput(vecBatch1); + Iterator results = distinctLimitOperator.getOutput(); + + Object[][] expectedDatas1 = {{0, 1, 2, 0, null, 2, null}, {6.6, 5.5, 4.4, 6.6, 5.5, null, null}, + {"abc", "hello", "world", null, "hello", "world", null}}; + + VecBatch resultVecBatch1 = results.next(); + assertVecBatchEquals(resultVecBatch1, expectedDatas1); + + freeVecBatch(resultVecBatch1); + distinctLimitOperator.close(); + distinctLimitOperatorFactory.close(); + } + + @Test + public void testDistinctLimitWithHashCol() { + DataType[] sourceTypes = {IntDataType.INTEGER, DoubleDataType.DOUBLE, LongDataType.LONG}; + Object[][] sourceDatas1 = {{0, 1, 2, 0, 1}, {6.6, 5.5, 4.4, 6.6, 2.2}, + {100000L, 110000L, 120000L, 100000L, 110000L}}; + VecBatch vecBatch1 = createVecBatch(sourceTypes, sourceDatas1); + + int[] distinctCols = {0, 1}; + OmniDistinctLimitOperatorFactory distinctLimitOperatorFactory = new OmniDistinctLimitOperatorFactory( + sourceTypes, distinctCols, 2, sourceDatas1[0].length); + OmniOperator distinctLimitOperator = distinctLimitOperatorFactory.createOperator(); + distinctLimitOperator.addInput(vecBatch1); + Iterator results = distinctLimitOperator.getOutput(); + + Object[][] expectedDatas1 = {{0, 1, 2, 1}, {6.6, 5.5, 4.4, 2.2}, {100000L, 110000L, 120000L, 110000L}}; + + VecBatch resultVecBatch1 = results.next(); + assertVecBatchEquals(resultVecBatch1, expectedDatas1); + + freeVecBatch(resultVecBatch1); + distinctLimitOperator.close(); + distinctLimitOperatorFactory.close(); + } + + @Test + public void testFactoryContextEquals() { + DataType[] sourceTypes = {IntDataType.INTEGER, DoubleDataType.DOUBLE, LongDataType.LONG}; + Object[][] sourceDatas1 = {{0, 1, 2, 0, 1}, {6.6, 5.5, 4.4, 6.6, 2.2}, + {100000L, 110000L, 120000L, 100000L, 110000L}}; + + int[] distinctCols = {0, 1}; + FactoryContext factory1 = new FactoryContext(sourceTypes, distinctCols, 2, sourceDatas1[0].length, + new OperatorConfig()); + FactoryContext factory2 = new FactoryContext(sourceTypes, distinctCols, 2, sourceDatas1[0].length, + new OperatorConfig()); + FactoryContext factory3 = null; + + assertEquals(factory2, factory1); + assertEquals(factory1, factory1); + assertNotEquals(factory3, factory1); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniExprVerifyTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniExprVerifyTest.java new file mode 100644 index 0000000..77f289b --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniExprVerifyTest.java @@ -0,0 +1,89 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator; + +import static org.testng.Assert.assertEquals; + +import com.google.common.collect.ImmutableList; + +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DataTypeSerializer; +import nova.hetu.omniruntime.type.Decimal128DataType; + +import org.testng.annotations.Test; + +import java.util.List; + +/** + * OmniExprVerify test. + * + * @since 2022-05-16 + */ +public class OmniExprVerifyTest { + /** + * Test for Spark check error. + */ + @Test + public void exprVerifierForSpark() { + DataType[] inputTypes = {new Decimal128DataType(21, 5)}; + List projectionsJSON = ImmutableList + .of("{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":7,\"precision\":21,\"scale\":5,\"colVal\":0}"); + String filterJSON = "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"LESS_THAN_THAN\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":7,\"colVal\":0," + + "\"precision\":21,\"scale\":5},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":6," + + "\"precision\":10,\"scale\":5,\"isNull\":false,\"value\":2000}}"; + + long isSupported = new OmniExprVerify().exprVerifyNative(DataTypeSerializer.serialize(inputTypes), 0, + filterJSON, projectionsJSON.toArray(new Object[0]), projectionsJSON.size(), 1); + + assertEquals(isSupported, 0); + } + + /** + * Test for Spark check success. + */ + @Test + public void exprVerifierForSpark2() { + DataType[] inputTypes = {new Decimal128DataType(21, 5)}; + List projectionsJSON = ImmutableList + .of("{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":7,\"precision\":21,\"scale\":5,\"colVal\":0}"); + String filterJSON = "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"AND\"," + + "\"left\":{\"exprType\":\"UNARY\",\"returnType\":4,\"operator\":\"not\"," + + "\"expr\":{\"exprType\":\"IS_NULL\",\"returnType\":4," + + "\"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":6,\"colVal\":0,\"precision\":17," + + "\"scale\":2}]}},\"right\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"LESS_THAN\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":6,\"colVal\":0,\"precision\":17," + + "\"scale\":2},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":6,\"isNull\":false,\"value\":2000," + + "\"precision\":17,\"scale\":2}}}"; + + long isSupported = new OmniExprVerify().exprVerifyNative(DataTypeSerializer.serialize(inputTypes), 0, + filterJSON, projectionsJSON.toArray(new Object[0]), projectionsJSON.size(), 1); + + assertEquals(isSupported, 1); + } + + /** + * Test for Spark check success. + */ + @Test + public void exprVerifierForSpark3() { + DataType[] inputTypes = {new Decimal128DataType(21, 5)}; + List projectionsJSON = ImmutableList + .of("{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":7,\"precision\":21,\"scale\":5,\"colVal\":0}"); + String filterJSON = "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"AND\"," + + "\"left\":{\"exprType\":\"UNARY\",\"returnType\":4,\"operator\":\"not\"," + + "\"expr\":{\"exprType\":\"IS_NULL\",\"returnType\":4," + + "\"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":7,\"colVal\":0,\"precision\":22," + + "\"scale\":6}]}},\"right\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"LESS_THAN\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":7,\"colVal\":0,\"precision\":22," + + "\"scale\":6},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":7,\"isNull\":false," + + "\"value\":\"20000000\",\"precision\":22,\"scale\":6}}}"; + + long isSupported = new OmniExprVerify().exprVerifyNative(DataTypeSerializer.serialize(inputTypes), 0, + filterJSON, projectionsJSON.toArray(new Object[0]), projectionsJSON.size(), 1); + + assertEquals(isSupported, 1); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniFilterAndProjectOperatorTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniFilterAndProjectOperatorTest.java new file mode 100644 index 0000000..a6699d7 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniFilterAndProjectOperatorTest.java @@ -0,0 +1,1492 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.operator; + +import static nova.hetu.omniruntime.util.TestUtils.assertVecBatchEquals; +import static nova.hetu.omniruntime.util.TestUtils.createIntVec; +import static nova.hetu.omniruntime.util.TestUtils.createVecBatch; +import static nova.hetu.omniruntime.util.TestUtils.freeVecBatch; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotEquals; +import static org.testng.Assert.assertTrue; + +import com.google.common.collect.ImmutableList; + +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.operator.filter.OmniFilterAndProjectOperatorFactory; +import nova.hetu.omniruntime.operator.filter.OmniFilterAndProjectOperatorFactory.FactoryContext; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.Decimal128DataType; +import nova.hetu.omniruntime.type.Decimal64DataType; +import nova.hetu.omniruntime.type.DoubleDataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.type.VarcharDataType; +import nova.hetu.omniruntime.util.TestUtils; +import nova.hetu.omniruntime.vector.DictionaryVec; +import nova.hetu.omniruntime.vector.DoubleVec; +import nova.hetu.omniruntime.vector.IntVec; +import nova.hetu.omniruntime.vector.LongVec; +import nova.hetu.omniruntime.vector.Vec; +import nova.hetu.omniruntime.vector.VecBatch; + +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +/** + * The type Omni filter and project operator test. + * + * @since 2021-06-01 + */ +public class OmniFilterAndProjectOperatorTest { + private ImmutableList makeInput(int nRows, Vec... cols) { + return ImmutableList.copyOf(new VecBatch[]{new VecBatch(cols)}); + } + + /** + * Between int. + */ + @Test + public void betweenInt() { + DataType[] types = {IntDataType.INTEGER, IntDataType.INTEGER, IntDataType.INTEGER}; + Object[][] datas = {{0, 1, 2, 3, 4, 0, 1, 2, 3, 4}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, + {0, 1, 2, 3, 4, 5, 6, 6, 6, 6}}; + + List projections = ImmutableList.of("#0", "#1", "#2"); + OmniFilterAndProjectOperatorFactory factory = new OmniFilterAndProjectOperatorFactory( + "$operator$BETWEEN:4(#1, #0, #2)", types, projections); + + VecBatch vecBatch = createVecBatch(types, datas); + OmniOperator op = factory.createOperator(); + op.addInput(vecBatch); + + Iterator results = op.getOutput(); + VecBatch resultVecBatch = results.next(); + + Object[][] expectDatas = {{0, 1, 2, 3, 4, 0, 1}, {0, 1, 2, 3, 4, 5, 6}, {0, 1, 2, 3, 4, 5, 6}}; + assertVecBatchEquals(resultVecBatch, expectDatas); + + freeVecBatch(resultVecBatch); + op.close(); + factory.close(); + } + + /** + * Between int dictionary. + */ + @Test + public void betweenIntDictionary() { + DataType[] types = {IntDataType.INTEGER, IntDataType.INTEGER, IntDataType.INTEGER}; + Object[][] datas = {{0, 1, 2, 3, 4, 0, 1, 2, 3, 4}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, + {-3, -2, -1, 0, 1, 2, 3, 4, 5, 6}}; + Vec[] vecs = new Vec[3]; + vecs[0] = TestUtils.createIntVec(datas[0]); + vecs[1] = TestUtils.createIntVec(datas[1]); + + int[] ids = {3, 4, 5, 6, 7, 8, 9, 9, 9, 9}; + DictionaryVec dicVec = TestUtils.createDictionaryVec(types[2], datas[2], ids); + vecs[2] = dicVec; + + List projections = ImmutableList.of("#0", "#1", "#2"); + OmniFilterAndProjectOperatorFactory factory = new OmniFilterAndProjectOperatorFactory( + "$operator$BETWEEN:4(#1, #0, #2)", types, projections); + + OmniOperator op = factory.createOperator(); + VecBatch vecBatch = new VecBatch(vecs); + op.addInput(vecBatch); + + Iterator results = op.getOutput(); + VecBatch resultVecBatch = results.next(); + + Object[][] expectDatas = {{0, 1, 2, 3, 4, 0, 1}, {0, 1, 2, 3, 4, 5, 6}, {0, 1, 2, 3, 4, 5, 6}}; + assertVecBatchEquals(resultVecBatch, expectDatas); + + freeVecBatch(resultVecBatch); + op.close(); + factory.close(); + } + + /** + * Doubles. + */ + @Test + public void doubles() { + final int numRows = 5000; + DoubleVec col1 = new DoubleVec(numRows); + for (int i = 0; i < numRows; i++) { + col1.set(i, i % 2 == 0 ? 0.5 : 1.5); + } + DoubleVec col2 = new DoubleVec(numRows); + for (int i = 0; i < numRows; i++) { + col2.set(i, i % 2 == 0 ? 0.5 : 1.5); + } + + DataType[] types = {DoubleDataType.DOUBLE}; + List projections = ImmutableList.of("#0"); + List projectionsJSON = ImmutableList + .of("{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":3,\"colVal\":0}"); + OmniFilterAndProjectOperatorFactory factory = new OmniFilterAndProjectOperatorFactory( + "$operator$LESS_THAN:4(#0, 1.0:3)", types, projections); + OmniFilterAndProjectOperatorFactory factoryJSON = new OmniFilterAndProjectOperatorFactory( + "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"LESS_THAN\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":3,\"colVal\":0},\"right\"" + + ":{\"exprType\":\"LITERAL\",\"dataType\":3,\"isNull\":false,\"value\":1.0}}", + types, projectionsJSON, 1); + + OmniOperator op = factory.createOperator(); + OmniOperator opJSON = factoryJSON.createOperator(); + ImmutableList vecBatches1 = makeInput(numRows, col1); + ImmutableList vecBatches2 = makeInput(numRows, col2); + for (VecBatch vecBatch : vecBatches1) { + op.addInput(vecBatch); + } + for (VecBatch vecBatch : vecBatches2) { + opJSON.addInput(vecBatch); + } + + assertTrue(op.getOutput().hasNext()); + assertTrue(opJSON.getOutput().hasNext()); + VecBatch res = op.getOutput().next(); + VecBatch resJSON = opJSON.getOutput().next(); + assertEquals(res.getRowCount(), 2500); + assertEquals(resJSON.getRowCount(), 2500); + for (int i = 0; i < res.getRowCount(); i++) { + assertTrue(((DoubleVec) res.getVector(0)).get(i) < 1); + assertTrue(((DoubleVec) resJSON.getVector(0)).get(i) < 1); + } + + freeVecBatch(res); + freeVecBatch(resJSON); + op.close(); + opJSON.close(); + factory.close(); + factoryJSON.close(); + } + + /** + * Less than. + */ + @Test + public void lessThan() { + final int numRows = 5000; + IntVec col1 = new IntVec(numRows); + for (int i = 0; i < numRows; i++) { + col1.set(i, i); + } + IntVec col2 = new IntVec(numRows); + for (int i = 0; i < numRows; i++) { + col2.set(i, i); + } + + DataType[] types = {IntDataType.INTEGER}; + List projections = ImmutableList.of("#0"); + List projectionsJSON = ImmutableList + .of("{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}"); + OmniFilterAndProjectOperatorFactory factory = new OmniFilterAndProjectOperatorFactory( + "$operator$LESS_THAN:4(#0, 2000:1)", types, projections); + OmniFilterAndProjectOperatorFactory factoryJSON = new OmniFilterAndProjectOperatorFactory( + "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"LESS_THAN\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":2000}}", + types, projectionsJSON, 1); + + OmniOperator op = factory.createOperator(); + OmniOperator opJSON = factoryJSON.createOperator(); + ImmutableList vecBatches1 = makeInput(numRows, col1); + ImmutableList vecBatches2 = makeInput(numRows, col2); + for (VecBatch vecBatch : vecBatches1) { + op.addInput(vecBatch); + } + for (VecBatch vecBatch : vecBatches2) { + opJSON.addInput(vecBatch); + } + + assertTrue(op.getOutput().hasNext()); + assertTrue(opJSON.getOutput().hasNext()); + VecBatch res = op.getOutput().next(); + VecBatch resJSON = opJSON.getOutput().next(); + assertEquals(res.getRowCount(), 2000); + assertEquals(resJSON.getRowCount(), 2000); + for (int i = 0; i < res.getRowCount(); i++) { + assertTrue(((IntVec) res.getVector(0)).get(i) < 2000); + assertTrue(((IntVec) resJSON.getVector(0)).get(i) < 2000); + } + + freeVecBatch(res); + freeVecBatch(resJSON); + op.close(); + opJSON.close(); + factory.close(); + factoryJSON.close(); + } + + /** + * Less than dictionary varchar. + */ + @Test + public void lessThanDictionaryVarchar() { + DataType[] types = {IntDataType.INTEGER, new VarcharDataType(50)}; + Object[][] datas = {{0, 3, 9}, {"hello", "world", "friends"}}; + Vec[] vecs = new Vec[2]; + vecs[0] = createIntVec(datas[0]); + int[] ids = {0, 1, 2}; + DictionaryVec dicVec = TestUtils.createDictionaryVec(types[1], datas[1], ids); + vecs[1] = dicVec; + VecBatch vecBatch = new VecBatch(vecs); + + List projections = ImmutableList.of("#0", "#1"); + OmniFilterAndProjectOperatorFactory factory = new OmniFilterAndProjectOperatorFactory( + "$operator$LESS_THAN:4(#0, 6:1)", types, projections); + + OmniOperator op = factory.createOperator(); + op.addInput(vecBatch); + + Iterator results = op.getOutput(); + VecBatch resultVecBatch = results.next(); + + Object[][] expectDatas = {{0, 3}, {"hello", "world"}}; + assertVecBatchEquals(resultVecBatch, expectDatas); + + freeVecBatch(resultVecBatch); + op.close(); + factory.close(); + } + + /** + * Greater than. + */ + @Test + public void greaterThan() { + final int numRows = 5000; + IntVec col1 = new IntVec(numRows); + LongVec col2 = new LongVec(numRows); + for (int i = 0; i < numRows; i++) { + col1.set(i, i % 25); + col2.set(i, 3000000000L); + } + IntVec col3 = new IntVec(numRows); + LongVec col4 = new LongVec(numRows); + for (int i = 0; i < numRows; i++) { + col3.set(i, i % 25); + col4.set(i, 3000000000L); + } + + DataType[] types = {IntDataType.INTEGER, LongDataType.LONG}; + List projections = ImmutableList.of("#0", "#1"); + List projectionsJSON = ImmutableList.of( + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":1}"); + OmniFilterAndProjectOperatorFactory factoryJSON = new OmniFilterAndProjectOperatorFactory( + "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"GREATER_THAN\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":20}}", + types, projectionsJSON, 1); + OmniFilterAndProjectOperatorFactory factory = new OmniFilterAndProjectOperatorFactory( + "$operator$GREATER_THAN:4(#0, 20:1)", types, projections); + + OmniOperator op = factory.createOperator(); + OmniOperator opJSON = factoryJSON.createOperator(); + ImmutableList vecBatches1 = makeInput(numRows, col1, col2); + ImmutableList vecBatches2 = makeInput(numRows, col3, col4); + for (VecBatch vecBatch : vecBatches1) { + op.addInput(vecBatch); + } + for (VecBatch vecBatch : vecBatches2) { + opJSON.addInput(vecBatch); + } + + assertTrue(op.getOutput().hasNext()); + assertTrue(opJSON.getOutput().hasNext()); + VecBatch res = op.getOutput().next(); + VecBatch resJSON = opJSON.getOutput().next(); + assertEquals(res.getRowCount(), 800); + assertEquals(resJSON.getRowCount(), 800); + for (int i = 0; i < res.getRowCount(); i++) { + assertTrue(((IntVec) res.getVector(0)).get(i) > 20); + assertEquals(((LongVec) res.getVector(1)).get(i), 3000000000L); + assertTrue(((IntVec) resJSON.getVector(0)).get(i) > 20); + assertEquals(((LongVec) resJSON.getVector(1)).get(i), 3000000000L); + } + + freeVecBatch(res); + freeVecBatch(resJSON); + op.close(); + opJSON.close(); + factory.close(); + factoryJSON.close(); + } + + /** + * Equal to. + */ + @Test + public void equalTo() { + final int numRows = 5000; + IntVec col1 = new IntVec(numRows); + LongVec col2 = new LongVec(numRows); + DoubleVec col3 = new DoubleVec(numRows); + for (int i = 0; i < numRows; i++) { + col2.set(i, i % 100); + col3.set(i, i % 100); + } + IntVec col4 = new IntVec(numRows); + LongVec col5 = new LongVec(numRows); + DoubleVec col6 = new DoubleVec(numRows); + for (int i = 0; i < numRows; i++) { + col5.set(i, i % 100); + col6.set(i, i % 100); + } + + DataType[] types = {IntDataType.INTEGER, LongDataType.LONG, DoubleDataType.DOUBLE}; + List projections = ImmutableList.of("#1", "#2"); + List projectionsJSON = ImmutableList.of( + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":1}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":3,\"colVal\":2}"); + OmniFilterAndProjectOperatorFactory factory = new OmniFilterAndProjectOperatorFactory( + "$operator$EQUAL:4(#1, 50:2)", types, projections); + OmniFilterAndProjectOperatorFactory factoryJSON = new OmniFilterAndProjectOperatorFactory( + "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"EQUAL\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":1}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":2,\"isNull\":false,\"value\":50}}", + types, projectionsJSON, 1); + + OmniOperator op = factory.createOperator(); + OmniOperator opJSON = factoryJSON.createOperator(); + ImmutableList vecBatches1 = makeInput(numRows, col1, col2, col3); + ImmutableList vecBatches2 = makeInput(numRows, col4, col5, col6); + for (VecBatch vecBatch : vecBatches1) { + op.addInput(vecBatch); + } + for (VecBatch vecBatch : vecBatches2) { + opJSON.addInput(vecBatch); + } + + assertTrue(op.getOutput().hasNext()); + assertTrue(opJSON.getOutput().hasNext()); + VecBatch res = op.getOutput().next(); + VecBatch resJSON = opJSON.getOutput().next(); + assertEquals(res.getRowCount(), 50); + assertEquals(resJSON.getRowCount(), 50); + for (int i = 0; i < res.getRowCount(); i++) { + assertEquals(((LongVec) res.getVector(0)).get(i), 50); + assertEquals(((DoubleVec) res.getVector(1)).get(i), 50.0); + assertEquals(((LongVec) resJSON.getVector(0)).get(i), 50); + assertEquals(((DoubleVec) resJSON.getVector(1)).get(i), 50.0); + } + + freeVecBatch(res); + freeVecBatch(resJSON); + op.close(); + opJSON.close(); + factory.close(); + factoryJSON.close(); + } + + /** + * Greater than or equal to. + */ + @Test + public void greaterThanOrEqualTo() { + final int numRows = 5000; + IntVec col1 = new IntVec(numRows); + IntVec col2 = new IntVec(numRows); + for (int i = 0; i < numRows; i++) { + col1.set(i, i); + int value = (i * (i + 2)) % 40; + if (i % 45 == 0) { + value = 30; + } + col2.set(i, value); + } + IntVec col3 = new IntVec(numRows); + IntVec col4 = new IntVec(numRows); + for (int i = 0; i < numRows; i++) { + col3.set(i, i); + int value = (i * (i + 2)) % 40; + if (i % 45 == 0) { + value = 30; + } + col4.set(i, value); + } + + DataType[] types = {IntDataType.INTEGER, IntDataType.INTEGER}; + List projections = ImmutableList.of("#1"); + List projectionsJSON = ImmutableList + .of("{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":1}"); + OmniFilterAndProjectOperatorFactory factory = new OmniFilterAndProjectOperatorFactory( + "$operator$GREATER_THAN_OR_EQUAL:4(#1, 30:1)", types, projections); + OmniFilterAndProjectOperatorFactory factoryJSON = new OmniFilterAndProjectOperatorFactory( + "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"GREATER_THAN_OR_EQUAL\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":1}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":30}}", + types, projectionsJSON, 1); + + OmniOperator op = factory.createOperator(); + OmniOperator opJSON = factoryJSON.createOperator(); + ImmutableList vecBatches1 = makeInput(numRows, col1, col2); + ImmutableList vecBatches2 = makeInput(numRows, col3, col4); + for (VecBatch vecBatch : vecBatches1) { + op.addInput(vecBatch); + } + for (VecBatch vecBatch : vecBatches2) { + opJSON.addInput(vecBatch); + } + + assertTrue(op.getOutput().hasNext()); + assertTrue(opJSON.getOutput().hasNext()); + VecBatch res = op.getOutput().next(); + VecBatch resJSON = opJSON.getOutput().next(); + assertEquals(res.getRowCount(), 834); + assertEquals(resJSON.getRowCount(), 834); + for (int i = 0; i < res.getRowCount(); i++) { + assertTrue(((IntVec) res.getVector(0)).get(i) >= 30); + assertTrue(((IntVec) resJSON.getVector(0)).get(i) >= 30); + } + + freeVecBatch(res); + freeVecBatch(resJSON); + op.close(); + opJSON.close(); + factory.close(); + factoryJSON.close(); + } + + /** + * Not equal to. + */ + @Test + public void notEqualTo() { + final int numRows = 5000; + DoubleVec col1 = new DoubleVec(numRows); + for (int i = 0; i < numRows; i++) { + col1.set(i, i); + } + DoubleVec col2 = new DoubleVec(numRows); + for (int i = 0; i < numRows; i++) { + col2.set(i, i); + } + + DataType[] types = {DoubleDataType.DOUBLE}; + List projections = ImmutableList.of("#0"); + List projectionsJSON = ImmutableList + .of("{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":3,\"colVal\":0}"); + OmniFilterAndProjectOperatorFactory factory = new OmniFilterAndProjectOperatorFactory( + "$operator$NOT_EQUAL:4(#0, 0:3)", types, projections); + OmniFilterAndProjectOperatorFactory factoryJSON = new OmniFilterAndProjectOperatorFactory( + "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":3,\"colVal\":0}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":3,\"isNull\":false,\"value\":0}}", + types, projectionsJSON, 1); + + OmniOperator op = factory.createOperator(); + OmniOperator opJSON = factoryJSON.createOperator(); + ImmutableList vecBatches1 = makeInput(numRows, col1); + ImmutableList vecBatches2 = makeInput(numRows, col2); + for (VecBatch vecBatch : vecBatches1) { + op.addInput(vecBatch); + } + for (VecBatch vecBatch : vecBatches2) { + opJSON.addInput(vecBatch); + } + + assertTrue(op.getOutput().hasNext()); + assertTrue(opJSON.getOutput().hasNext()); + VecBatch res = op.getOutput().next(); + VecBatch resJSON = opJSON.getOutput().next(); + assertEquals(res.getRowCount(), 4999); + assertEquals(resJSON.getRowCount(), 4999); + double cnt = 1d; + for (int i = 0; i < res.getRowCount(); i++) { + assertEquals(((DoubleVec) res.getVector(0)).get(i), cnt); + assertEquals(((DoubleVec) resJSON.getVector(0)).get(i), cnt); + cnt++; + } + + freeVecBatch(res); + freeVecBatch(resJSON); + op.close(); + opJSON.close(); + factory.close(); + factoryJSON.close(); + } + + /** + * All pass. + */ + @Test + public void allPass() { + final int numRows = 20000; + IntVec col1 = new IntVec(numRows); + for (int i = 0; i < numRows; i++) { + col1.set(i, 9348); + } + IntVec col2 = new IntVec(numRows); + for (int i = 0; i < numRows; i++) { + col2.set(i, 9348); + } + + DataType[] types = {IntDataType.INTEGER}; + List projections = ImmutableList.of("#0"); + List projectionsJSON = ImmutableList + .of("{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}"); + OmniFilterAndProjectOperatorFactory factory = new OmniFilterAndProjectOperatorFactory( + "$operator$EQUAL:4(#0, 9348:1)", types, projections); + OmniFilterAndProjectOperatorFactory factoryJSON = new OmniFilterAndProjectOperatorFactory( + "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"EQUAL\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":9348}}", + types, projectionsJSON, 1); + + OmniOperator op = factory.createOperator(); + OmniOperator opJSON = factoryJSON.createOperator(); + ImmutableList vecBatches1 = makeInput(numRows, col1); + ImmutableList vecBatches2 = makeInput(numRows, col2); + for (VecBatch vecBatch : vecBatches1) { + op.addInput(vecBatch); + } + for (VecBatch vecBatch : vecBatches2) { + opJSON.addInput(vecBatch); + } + + assertTrue(op.getOutput().hasNext()); + assertTrue(opJSON.getOutput().hasNext()); + VecBatch res = op.getOutput().next(); + VecBatch resJSON = opJSON.getOutput().next(); + assertEquals(res.getRowCount(), 20000); + assertEquals(resJSON.getRowCount(), 20000); + for (int i = 0; i < res.getRowCount(); i++) { + assertEquals(((IntVec) res.getVector(0)).get(i), 9348); + assertEquals(((IntVec) resJSON.getVector(0)).get(i), 9348); + } + + freeVecBatch(res); + freeVecBatch(resJSON); + op.close(); + opJSON.close(); + factory.close(); + factoryJSON.close(); + } + + /** + * Multiple inputs. + */ + @Test + public void multipleInputs() { + final int numRows = 1000; + IntVec col1 = new IntVec(numRows); + IntVec col2 = new IntVec(numRows); + for (int i = 0; i < numRows; i++) { + col1.set(i, i % 10); + col2.set(i, i % 6 + 1); + } + IntVec col3 = new IntVec(numRows); + IntVec col4 = new IntVec(numRows); + for (int i = 0; i < numRows; i++) { + col3.set(i, i % 10); + col4.set(i, i % 6 + 1); + } + + DataType[] types = {IntDataType.INTEGER}; + List projections = ImmutableList.of("#0"); + List projectionsJSON = ImmutableList + .of("{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}"); + OmniFilterAndProjectOperatorFactory factory = new OmniFilterAndProjectOperatorFactory( + "$operator$LESS_THAN_OR_EQUAL:4(#0, 4:1)", types, projections); + OmniFilterAndProjectOperatorFactory factoryJSON = new OmniFilterAndProjectOperatorFactory( + "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"LESS_THAN_OR_EQUAL\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":4}}", + types, projectionsJSON, 1); + + OmniOperator op = factory.createOperator(); + OmniOperator opJSON = factoryJSON.createOperator(); + ImmutableList vecBatches01 = makeInput(numRows, col1); + ImmutableList vecBatches02 = makeInput(numRows, col3); + for (VecBatch vecBatch : vecBatches01) { + op.addInput(vecBatch); + } + for (VecBatch vecBatch : vecBatches02) { + opJSON.addInput(vecBatch); + } + + assertTrue(op.getOutput().hasNext()); + assertTrue(opJSON.getOutput().hasNext()); + VecBatch res = op.getOutput().next(); + VecBatch resJSON = opJSON.getOutput().next(); + assertEquals(res.getRowCount(), 500); + assertEquals(resJSON.getRowCount(), 500); + for (int i = 0; i < res.getRowCount(); i++) { + assertTrue(((IntVec) res.getVector(0)).get(i) <= 4); + assertTrue(((IntVec) resJSON.getVector(0)).get(i) <= 4); + } + + freeVecBatch(res); + freeVecBatch(resJSON); + + // Test multiple inputs + ImmutableList vecBatches11 = makeInput(numRows, col2); + ImmutableList vecBatches12 = makeInput(numRows, col4); + for (VecBatch vecBatch : vecBatches11) { + op.addInput(vecBatch); + } + for (VecBatch vecBatch : vecBatches12) { + opJSON.addInput(vecBatch); + } + assertTrue(op.getOutput().hasNext()); + assertTrue(opJSON.getOutput().hasNext()); + res = op.getOutput().next(); + resJSON = opJSON.getOutput().next(); + assertEquals(res.getRowCount(), 668); + assertEquals(resJSON.getRowCount(), 668); + for (int i = 0; i < res.getRowCount(); i++) { + assertTrue(((IntVec) res.getVector(0)).get(i) <= 4); + assertTrue(((IntVec) resJSON.getVector(0)).get(i) <= 4); + } + + freeVecBatch(res); + freeVecBatch(resJSON); + op.close(); + opJSON.close(); + factory.close(); + factoryJSON.close(); + } + + /** + * Negative values. + */ + @Test + public void negativeValues() { + final int numRows = 10000; + IntVec col1 = new IntVec(numRows); + LongVec col2 = new LongVec(numRows); + for (int i = 0; i < numRows; i++) { + int val1 = i * i + 1; + if (i % 5 == 0) { + val1 = -val1; + } + col1.set(i, val1); + long val2 = i % 100 + (long) 3e9; + if (i % 7 == 0) { + val2 = -val2; + } + col2.set(i, val2); + } + + IntVec col3 = new IntVec(numRows); + LongVec col4 = new LongVec(numRows); + for (int i = 0; i < numRows; i++) { + int val1 = i * i + 1; + if (i % 5 == 0) { + val1 = -val1; + } + col3.set(i, val1); + long val2 = i % 100 + (long) 3e9; + if (i % 7 == 0) { + val2 = -val2; + } + col4.set(i, val2); + } + + DataType[] types = {IntDataType.INTEGER, LongDataType.LONG}; + List projections = ImmutableList.of("#0", "#1"); + List projectionsJSON = ImmutableList.of( + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":1}"); + OmniFilterAndProjectOperatorFactory factoryJSON = new OmniFilterAndProjectOperatorFactory( + "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"AND\"," + + "\"left\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"LESS_THAN_OR_EQUAL\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":-1}}," + + "\"right\":{\"exprType\":\"BINARY\",\"returnType\":4," + + "\"operator\":\"LESS_THAN_OR_EQUAL\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":1}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":2,\"isNull\":false,\"value\":-1}}}", + types, projectionsJSON, 1); + OmniFilterAndProjectOperatorFactory factory = new OmniFilterAndProjectOperatorFactory( + "AND:4($operator$LESS_THAN_OR_EQUAL:4(#0, -1:1), $operator$LESS_THAN_OR_EQUAL:4(#1, -1:2))", types, + projections); + + OmniOperator op = factory.createOperator(); + OmniOperator opJSON = factoryJSON.createOperator(); + ImmutableList vecBatches1 = makeInput(numRows, col1, col2); + ImmutableList vecBatches2 = makeInput(numRows, col3, col4); + for (VecBatch vecBatch : vecBatches1) { + op.addInput(vecBatch); + } + for (VecBatch vecBatch : vecBatches2) { + opJSON.addInput(vecBatch); + } + + assertTrue(op.getOutput().hasNext()); + assertTrue(opJSON.getOutput().hasNext()); + VecBatch res = op.getOutput().next(); + VecBatch resJSON = opJSON.getOutput().next(); + assertEquals(res.getRowCount(), 286); + assertEquals(resJSON.getRowCount(), 286); + for (int i = 0; i < res.getRowCount(); i++) { + assertTrue(((IntVec) res.getVector(0)).get(i) < 0); + assertTrue(((LongVec) res.getVector(1)).get(i) < 0); + assertTrue(((IntVec) resJSON.getVector(0)).get(i) < 0); + assertTrue(((LongVec) resJSON.getVector(1)).get(i) < 0); + } + + freeVecBatch(res); + freeVecBatch(resJSON); + op.close(); + opJSON.close(); + factory.close(); + factoryJSON.close(); + } + + /** + * All types. + */ + @Test + public void allTypes() { + final int numRows = 10000; + IntVec col1 = new IntVec(numRows); + LongVec col2 = new LongVec(numRows); + DoubleVec col3 = new DoubleVec(numRows); + for (int i = 0; i < numRows; i++) { + col1.set(i, i % 3); + col2.set(i, i % 2 == 0 ? (long) 3e9 : 0); + col3.set(i, i % 10 / 10D); + } + IntVec col4 = new IntVec(numRows); + LongVec col5 = new LongVec(numRows); + DoubleVec col6 = new DoubleVec(numRows); + for (int i = 0; i < numRows; i++) { + col4.set(i, i % 3); + col5.set(i, i % 2 == 0 ? (long) 3e9 : 0); + col6.set(i, i % 10 / 10D); + } + + DataType[] types = {IntDataType.INTEGER, LongDataType.LONG, DoubleDataType.DOUBLE}; + List projections = ImmutableList.of("#0", "#1", "#2"); + List projectionsJSON = ImmutableList.of( + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":1}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":3,\"colVal\":2}"); + OmniFilterAndProjectOperatorFactory factory = new OmniFilterAndProjectOperatorFactory( + "AND:4($operator$EQUAL:4(#0, 0:1), AND:4($operator$EQUAL:4(#1, 3000000000:2), " + + "$operator$GREATER_THAN_OR_EQUAL:4(#2, 0.4:3)))", types, projections); + OmniFilterAndProjectOperatorFactory factoryJSON = new OmniFilterAndProjectOperatorFactory( + "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"AND\"," + + "\"left\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"EQUAL\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":0}}," + + "\"right\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"AND\"," + + "\"left\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"EQUAL\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":1}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":2,\"isNull\":false," + + "\"value\":3000000000}},\"right\":{\"exprType\":\"BINARY\",\"returnType\":4," + + "\"operator\":\"GREATER_THAN_OR_EQUAL\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":3,\"colVal\":2}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":3,\"isNull\":false,\"value\":0.4}}}}", + types, projectionsJSON, 1); + + OmniOperator op = factory.createOperator(); + OmniOperator opJSON = factoryJSON.createOperator(); + ImmutableList vecBatches1 = makeInput(numRows, col1, col2, col3); + ImmutableList vecBatches2 = makeInput(numRows, col4, col5, col6); + for (VecBatch vecBatch : vecBatches1) { + op.addInput(vecBatch); + } + for (VecBatch vecBatch : vecBatches2) { + opJSON.addInput(vecBatch); + } + + assertTrue(op.getOutput().hasNext()); + assertTrue(opJSON.getOutput().hasNext()); + VecBatch res = op.getOutput().next(); + VecBatch resJSON = opJSON.getOutput().next(); + assertEquals(res.getRowCount(), 1000); + assertEquals(resJSON.getRowCount(), 1000); + for (int i = 0; i < res.getRowCount(); i++) { + assertEquals(((IntVec) res.getVector(0)).get(i), 0); + assertEquals(((LongVec) res.getVector(1)).get(i), (long) 3e9); + assertTrue(((DoubleVec) res.getVector(2)).get(i) >= 0.4); + assertEquals(((IntVec) resJSON.getVector(0)).get(i), 0); + assertEquals(((LongVec) resJSON.getVector(1)).get(i), (long) 3e9); + assertTrue(((DoubleVec) resJSON.getVector(2)).get(i) >= 0.4); + } + + freeVecBatch(res); + freeVecBatch(resJSON); + op.close(); + opJSON.close(); + factory.close(); + factoryJSON.close(); + } + + /** + * Logical operators 1. + */ + @Test + public void logicalOperators1() { + final int numRows = 10000; + IntVec col01 = new IntVec(numRows); + IntVec col02 = new IntVec(numRows); + IntVec col03 = new IntVec(numRows); + LongVec col04 = new LongVec(numRows); + DoubleVec col05 = new DoubleVec(numRows); + LongVec col06 = new LongVec(numRows); + for (int i = 0; i < numRows; i++) { + col01.set(i, i % 3 == 0 ? 0 : 1); + col02.set(i, i); + col03.set(i, i); + col04.set(i, i % 2 == 0 ? 3000000000L : 2999999999L); + col05.set(i, 50 + i / 10D); + col06.set(i, i % 55); + } + IntVec col11 = new IntVec(numRows); + IntVec col12 = new IntVec(numRows); + IntVec col13 = new IntVec(numRows); + LongVec col14 = new LongVec(numRows); + DoubleVec col15 = new DoubleVec(numRows); + LongVec col16 = new LongVec(numRows); + for (int i = 0; i < numRows; i++) { + col11.set(i, i % 3 == 0 ? 0 : 1); + col12.set(i, i); + col13.set(i, i); + col14.set(i, i % 2 == 0 ? 3000000000L : 2999999999L); + col15.set(i, 50 + i / 10D); + col16.set(i, i % 55); + } + + DataType[] types = {IntDataType.INTEGER, IntDataType.INTEGER, IntDataType.INTEGER, LongDataType.LONG, + DoubleDataType.DOUBLE, LongDataType.LONG}; + List projections = ImmutableList.of("#0", "#2", "#4", "#5"); + String str = "OR:4($operator$GREATER_THAN_OR_EQUAL:4(#5, 52:2), AND:4($operator$LESS_THAN:4(#4, 50.8:3), " + + "AND:4(AND:4($operator$GREATER_THAN:4(#2, 4800:1), $operator$LESS_THAN_OR_EQUAL:4(#1, 9990:1)), " + + "AND:4($operator$NOT_EQUAL:4(#0, 1:1), $operator$EQUAL:4(#3, 3000000000:2)))))"; + List projectionsJSON = ImmutableList.of( + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\": 1,\"colVal\":0}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\": 1,\"colVal\":2}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\": 3,\"colVal\":4}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\": 2,\"colVal\":5}"); + String strJSON = "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"OR\"," + + "\"left\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":" + + "\"GREATER_THAN_OR_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2," + + "\"colVal\":5},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":2,\"isNull\":false," + + "\"value\":52}},\"right\":{\"exprType\":\"BINARY\",\"returnType\":4," + + "\"operator\":\"AND\",\"left\":{\"exprType\":\"BINARY\",\"returnType\":4," + + "\"operator\":\"LESS_THAN\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":3," + + "\"colVal\":4},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":3,\"isNull\":false," + + "\"value\":50.8}},\"right\":{\"exprType\":\"BINARY\",\"returnType\":4," + + "\"operator\":\"AND\",\"left\":{\"exprType\":\"BINARY\",\"returnType\":4," + + "\"operator\":\"AND\",\"left\":{\"exprType\":\"BINARY\",\"returnType\":4," + + "\"operator\":\"GREATER_THAN\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1," + + "\"colVal\":2},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false," + + "\"value\":4800}},\"right\":{\"exprType\":\"BINARY\",\"returnType\":4," + + "\"operator\":\"LESS_THAN_OR_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":" + + "1,\"colVal\":1},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false," + + "\"value\":9990}}},\"right\":{\"exprType\":\"BINARY\",\"returnType\":" + + "4,\"operator\":\"AND\",\"left\":{\"exprType\":\"BINARY\",\"returnType\":" + + "4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":" + + "1,\"colVal\":0},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false," + + "\"value\":1}},\"right\":{\"exprType\":\"BINARY\",\"returnType\":" + + "4,\"operator\":\"EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":" + + "2,\"colVal\":3},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":2,\"isNull\":false," + + "\"value\":3000000000}}}}}}"; + + OmniFilterAndProjectOperatorFactory factoryJSON = new OmniFilterAndProjectOperatorFactory(strJSON, types, + projectionsJSON, 1); + OmniFilterAndProjectOperatorFactory factory = new OmniFilterAndProjectOperatorFactory(str, types, projections); + OmniOperator op = factory.createOperator(); + OmniOperator opJSON = factoryJSON.createOperator(); + ImmutableList vecBatches1 = makeInput(numRows, col01, col02, col03, col04, col05, col06); + ImmutableList vecBatches2 = makeInput(numRows, col11, col12, col13, col14, col15, col16); + for (VecBatch vecBatch : vecBatches1) { + op.addInput(vecBatch); + } + for (VecBatch vecBatch : vecBatches2) { + opJSON.addInput(vecBatch); + } + + assertTrue(op.getOutput().hasNext()); + assertTrue(opJSON.getOutput().hasNext()); + VecBatch res = op.getOutput().next(); + VecBatch resJSON = opJSON.getOutput().next(); + assertEquals(res.getRowCount(), 543); + assertEquals(resJSON.getRowCount(), 543); + IntVec res0 = ((IntVec) res.getVector(0)); + IntVec res1 = ((IntVec) res.getVector(1)); + IntVec resJSON0 = ((IntVec) resJSON.getVector(0)); + IntVec resJSON1 = ((IntVec) resJSON.getVector(1)); + DoubleVec res2 = ((DoubleVec) res.getVector(2)); + LongVec res3 = ((LongVec) res.getVector(3)); + DoubleVec resJSON2 = ((DoubleVec) resJSON.getVector(2)); + LongVec resJSON3 = ((LongVec) resJSON.getVector(3)); + for (int i = 0; i < res.getRowCount(); i++) { + assertTrue((res0.get(i) != 1 && res1.get(i) > 4800 && res2.get(i) < 50.8) || res3.get(i) >= 52); + assertTrue((resJSON0.get(i) != 1 && resJSON1.get(i) > 4800 && resJSON2.get(i) < 50.8) + || resJSON3.get(i) >= 52); + } + + freeVecBatch(res); + freeVecBatch(resJSON); + op.close(); + opJSON.close(); + factory.close(); + factoryJSON.close(); + } + + /** + * Logical operators 2. + */ + @Test + public void logicalOperators2() { + final int numRows = 10000; + IntVec col01 = new IntVec(numRows); + IntVec col02 = new IntVec(numRows); + LongVec col03 = new LongVec(numRows); + LongVec col04 = new LongVec(numRows); + for (int i = 0; i < numRows; i++) { + col01.set(i, i % 100); + col02.set(i, i % 7 == 0 ? -12 : i); + col03.set(i, i % 8 == 0 ? -i - 3000000000L : i + 3000000000L); + col04.set(i, i % 9 - 4); + } + IntVec col11 = new IntVec(numRows); + IntVec col12 = new IntVec(numRows); + LongVec col13 = new LongVec(numRows); + LongVec col14 = new LongVec(numRows); + for (int i = 0; i < numRows; i++) { + col11.set(i, i % 100); + col12.set(i, i % 7 == 0 ? -12 : i); + col13.set(i, i % 8 == 0 ? -i - 3000000000L : i + 3000000000L); + col14.set(i, i % 9 - 4); + } + + DataType[] types = {IntDataType.INTEGER, IntDataType.INTEGER, LongDataType.LONG, LongDataType.LONG}; + List projections = ImmutableList.of("#3", "#2", "#1", "#0"); + String str = "AND:4(OR:4($operator$LESS_THAN:4(#0, 50:1), $operator$EQUAL:4(#1, -12:1)), " + + "OR:4($operator$LESS_THAN_OR_EQUAL:4(#2, -3000000000:2), " + + "$operator$GREATER_THAN_OR_EQUAL:4(#3, 0:2)))"; + List projectionsJSON = ImmutableList.of( + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":3}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":2}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":1}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}"); + String strJSON = "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"AND\",\"left\":" + + "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"OR\",\"left\":" + + "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"LESS_THAN\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":50}}," + + "\"right\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":" + + "\"EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":1}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":-12}}}," + + "\"right\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":" + + "\"OR\",\"left\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":" + + "\"LESS_THAN_OR_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2," + + "\"colVal\":2},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":2,\"isNull\":false," + + "\"value\":-3000000000}},\"right\":{\"exprType\":\"BINARY\",\"returnType\":" + + "4,\"operator\":\"GREATER_THAN_OR_EQUAL\",\"left\":{\"exprType\":" + + "\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":3},\"right\":{\"exprType\":\"LITERAL\"," + + "\"dataType\":2,\"isNull\":false,\"value\":0}}}}"; + + OmniFilterAndProjectOperatorFactory factoryJSON = new OmniFilterAndProjectOperatorFactory(strJSON, types, + projectionsJSON, 1); + OmniFilterAndProjectOperatorFactory factory = new OmniFilterAndProjectOperatorFactory(str, types, projections); + OmniOperator op = factory.createOperator(); + OmniOperator opJSON = factoryJSON.createOperator(); + ImmutableList vecBatches1 = makeInput(numRows, col01, col02, col03, col04); + ImmutableList vecBatches2 = makeInput(numRows, col11, col12, col13, col14); + for (VecBatch vecBatch : vecBatches1) { + op.addInput(vecBatch); + } + for (VecBatch vecBatch : vecBatches2) { + opJSON.addInput(vecBatch); + } + + assertTrue(op.getOutput().hasNext()); + assertTrue(opJSON.getOutput().hasNext()); + VecBatch res = op.getOutput().next(); + VecBatch resJSON = opJSON.getOutput().next(); + assertEquals(res.getRowCount(), 3498); + assertEquals(resJSON.getRowCount(), 3498); + LongVec res0 = ((LongVec) res.getVector(0)); + LongVec res1 = ((LongVec) res.getVector(1)); + IntVec res2 = ((IntVec) res.getVector(2)); + IntVec res3 = ((IntVec) res.getVector(3)); + LongVec resJSON0 = ((LongVec) resJSON.getVector(0)); + LongVec resJSON1 = ((LongVec) resJSON.getVector(1)); + IntVec resJSON2 = ((IntVec) resJSON.getVector(2)); + IntVec resJSON3 = ((IntVec) resJSON.getVector(3)); + for (int i = 0; i < res.getRowCount(); i++) { + assertTrue((res0.get(i) >= 0 || res1.get(i) <= -3000000000L) && (res2.get(i) == -12 || res3.get(i) < 50)); + assertTrue((resJSON0.get(i) >= 0 || resJSON1.get(i) <= -3000000000L) + && (resJSON2.get(i) == -12 || resJSON3.get(i) < 50)); + } + + freeVecBatch(res); + freeVecBatch(resJSON); + op.close(); + opJSON.close(); + factory.close(); + factoryJSON.close(); + } + + /** + * Logical operators 3. + */ + @Test + public void logicalOperators3() { + final int numRows = 1024; + IntVec col01 = new IntVec(numRows); + DoubleVec col02 = new DoubleVec(numRows); + for (int i = 0; i < numRows; i++) { + col01.set(i, 0); + col02.set(i, 1.5); + } + col01.set(0, 0); + col01.set(1, 1); + col01.set(2, 1); + col01.set(3, 2); + col01.set(4, 3); + col01.set(5, 5); + col01.set(6, 8); + col01.set(7, 13); + col02.set(2, 0); + IntVec col11 = new IntVec(numRows); + DoubleVec col12 = new DoubleVec(numRows); + for (int i = 0; i < numRows; i++) { + col11.set(i, 0); + col12.set(i, 1.5); + } + col11.set(0, 0); + col11.set(1, 1); + col11.set(2, 1); + col11.set(3, 2); + col11.set(4, 3); + col11.set(5, 5); + col11.set(6, 8); + col11.set(7, 13); + col12.set(2, 0); + + DataType[] types = {IntDataType.INTEGER, DoubleDataType.DOUBLE}; + List projections = ImmutableList.of("#1", "#0"); + String expr = "AND:4($operator$NOT_EQUAL:4(#1, 0:3), OR:4(OR:4(OR:4($operator$EQUAL:4(#0, 1:1), " + + "$operator$EQUAL:4(#0, 2:1)), $operator$EQUAL:4(#0, 3:1)), " + + "OR:4(OR:4(OR:4($operator$EQUAL:4(55:1, #0), $operator$EQUAL:4(5:1, #0)), " + + "$operator$EQUAL:4(#0, 8:1)), $operator$EQUAL:4(#0, 13:1))))"; + List projectionsJSON = ImmutableList.of( + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":3,\"colVal\":1}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}"); + String exprJSON = "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"AND\"," + + "\"left\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":3,\"colVal\":1}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":3,\"isNull\":false,\"value\":0}}," + + "\"right\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"OR\"," + + "\"left\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"OR\"," + + "\"left\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"OR\"," + + "\"left\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"EQUAL\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":1}}," + + "\"right\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"EQUAL\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":2}}}," + + "\"right\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"EQUAL\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":3}}}," + + "\"right\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"OR\"," + + "\"left\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"OR\"," + + "\"left\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"OR\"," + + "\"left\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"EQUAL\"," + + "\"left\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":55}," + + "\"right\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}}," + + "\"right\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"EQUAL\"," + + "\"left\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":5}," + + "\"right\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}}}," + + "\"right\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"EQUAL\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":8}}}," + + "\"right\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"EQUAL\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":13}}}}}"; + OmniFilterAndProjectOperatorFactory factoryJSON = new OmniFilterAndProjectOperatorFactory(exprJSON, types, + projectionsJSON, 1); + + OmniFilterAndProjectOperatorFactory factory = new OmniFilterAndProjectOperatorFactory(expr, types, projections); + OmniOperator op = factory.createOperator(); + OmniOperator opJSON = factoryJSON.createOperator(); + ImmutableList vecBatches1 = makeInput(numRows, col01, col02); + ImmutableList vecBatches2 = makeInput(numRows, col11, col12); + for (VecBatch vecBatch : vecBatches1) { + op.addInput(vecBatch); + } + for (VecBatch vecBatch : vecBatches2) { + opJSON.addInput(vecBatch); + } + + assertTrue(op.getOutput().hasNext()); + assertTrue(opJSON.getOutput().hasNext()); + VecBatch res = op.getOutput().next(); + VecBatch resJSON = opJSON.getOutput().next(); + assertEquals(res.getRowCount(), 6); + int[] vals = {1, 2, 3, 5, 8, 13}; + for (int i = 0; i < res.getRowCount(); i++) { + assertEquals(((IntVec) res.getVector(1)).get(i), vals[i]); + assertEquals(((IntVec) resJSON.getVector(1)).get(i), vals[i]); + } + + freeVecBatch(res); + freeVecBatch(resJSON); + op.close(); + opJSON.close(); + factory.close(); + factoryJSON.close(); + } + + /** + * Arithmetic add. + */ + @Test + public void arithmeticAdd() { + final int numRows = 10000; + IntVec col01 = new IntVec(numRows); + for (int i = 0; i < numRows; i++) { + col01.set(i, i % 5); + } + IntVec col11 = new IntVec(numRows); + for (int i = 0; i < numRows; i++) { + col11.set(i, i % 5); + } + + DataType[] types = {IntDataType.INTEGER}; + List projections = ImmutableList.of("#0"); + List projectionsJSON = ImmutableList + .of("{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}"); + OmniFilterAndProjectOperatorFactory factory = new OmniFilterAndProjectOperatorFactory( + "$operator$GREATER_THAN:4(ADD:1(#0, 1:1), 4:1)", types, projections); + OmniFilterAndProjectOperatorFactory factoryJSON = new OmniFilterAndProjectOperatorFactory( + "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"GREATER_THAN\"," + + "\"left\":{\"exprType\":\"BINARY\",\"returnType\":1,\"operator\":\"ADD\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":1}}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":4}}", + types, projectionsJSON, 1); + + OmniOperator op = factory.createOperator(); + OmniOperator opJSON = factoryJSON.createOperator(); + ImmutableList vecBatches1 = makeInput(numRows, col01); + ImmutableList vecBatches2 = makeInput(numRows, col11); + for (VecBatch vecBatch : vecBatches1) { + op.addInput(vecBatch); + } + for (VecBatch vecBatch : vecBatches2) { + opJSON.addInput(vecBatch); + } + + assertTrue(op.getOutput().hasNext()); + assertTrue(opJSON.getOutput().hasNext()); + VecBatch res = op.getOutput().next(); + VecBatch resJSON = opJSON.getOutput().next(); + assertEquals(res.getRowCount(), 2000); + assertEquals(resJSON.getRowCount(), 2000); + for (int i = 0; i < res.getRowCount(); i++) { + assertTrue(((IntVec) res.getVector(0)).get(i) + 1 > 4); + assertTrue(((IntVec) resJSON.getVector(0)).get(i) + 1 > 4); + } + + freeVecBatch(res); + freeVecBatch(resJSON); + op.close(); + opJSON.close(); + factory.close(); + factoryJSON.close(); + } + + private List createTable(final int numRows) { + IntVec col1 = new IntVec(numRows); + IntVec col2 = new IntVec(numRows); + DoubleVec col3 = new DoubleVec(numRows); + DoubleVec col4 = new DoubleVec(numRows); + for (int i = 0; i < numRows; i++) { + col1.set(i, i); + col2.set(i, i); + col3.set(i, i); + col4.set(i, i); + } + List table = new ArrayList<>(); + table.add(col1); + table.add(col2); + table.add(col3); + table.add(col4); + return table; + } + + /** + * Multithread test. + * + * @throws InterruptedException thread interrupt exception + */ + @Test + public void multithreadTest() throws InterruptedException { + DataType[] types = {IntDataType.INTEGER, IntDataType.INTEGER, DoubleDataType.DOUBLE, DoubleDataType.DOUBLE}; + List projections = ImmutableList.of("#0", "#1", "#2", "#3"); + List projectionsJSON = ImmutableList.of( + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":1}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":3,\"colVal\":2}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":3,\"colVal\":3}"); + String str = "$operator$LESS_THAN_OR_EQUAL:4(#0, 500:1)"; + OmniFilterAndProjectOperatorFactory factory = new OmniFilterAndProjectOperatorFactory(str, types, projections); + + final int threadCount = 1000; + final int corePoolSize = 10; + final int maximumPoolSize = 50; + CountDownLatch countDownLatch = new CountDownLatch(threadCount); + ThreadPoolExecutor threadPool = new ThreadPoolExecutor(corePoolSize, maximumPoolSize, 60L, TimeUnit.SECONDS, + new LinkedBlockingQueue<>(threadCount)); + final int numRows = 1000; + + for (int i = 0; i < threadCount; i++) { + CompletableFuture.runAsync(() -> { + try { + OmniOperator op = factory.createOperator(); + VecBatch vecBatch = new VecBatch(createTable(numRows)); + op.addInput(vecBatch); + assertTrue(op.getOutput().hasNext()); + VecBatch res = op.getOutput().next(); + assertEquals(res.getRowCount(), 501); + + freeVecBatch(res); + op.close(); + } finally { + countDownLatch.countDown(); + } + }, threadPool); + } + + // This will wait until all future ready. + try { + countDownLatch.await(); + } catch (InterruptedException e) { + assertTrue(false); + } + + CountDownLatch countDownLatchJSON = new CountDownLatch(threadCount); + OmniFilterAndProjectOperatorFactory factoryJSON = new OmniFilterAndProjectOperatorFactory( + "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"LESS_THAN_OR_EQUAL\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":500}}", + types, projectionsJSON, 1); + // test in JSON format + for (int i = 0; i < threadCount; i++) { + CompletableFuture.runAsync(() -> { + try { + OmniOperator opJSON = factoryJSON.createOperator(); + VecBatch vecBatch = new VecBatch(createTable(numRows)); + opJSON.addInput(vecBatch); + assertTrue(opJSON.getOutput().hasNext()); + VecBatch resJSON = opJSON.getOutput().next(); + assertEquals(resJSON.getRowCount(), 501); + + freeVecBatch(resJSON); + opJSON.close(); + } finally { + countDownLatchJSON.countDown(); + } + }, threadPool); + } + + // This will wait until all future ready. + try { + countDownLatchJSON.await(); + } catch (InterruptedException e) { + assertTrue(false); + } + + threadPool.shutdown(); + factory.close(); + factoryJSON.close(); + } + + /** + * Conditional. + */ + @Test + public void conditional() { + final int numRows = 10000; + IntVec col01 = new IntVec(numRows); + IntVec col02 = new IntVec(numRows); + IntVec col03 = new IntVec(numRows); + for (int i = 0; i < numRows; i++) { + col01.set(i, i % 2); + col02.set(i, i % 5); + col03.set(i, i % 10); + } + IntVec col11 = new IntVec(numRows); + IntVec col12 = new IntVec(numRows); + IntVec col13 = new IntVec(numRows); + for (int i = 0; i < numRows; i++) { + col11.set(i, i % 2); + col12.set(i, i % 5); + col13.set(i, i % 10); + } + + DataType[] types = {IntDataType.INTEGER, IntDataType.INTEGER, IntDataType.INTEGER}; + List projections = ImmutableList.of("#0", "#1", "#2"); + List projectionsJSON = ImmutableList.of( + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":1}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":2}"); + OmniFilterAndProjectOperatorFactory factory = new OmniFilterAndProjectOperatorFactory( + "AND:4(IF:4($operator$EQUAL:4(#0, 0:1), $operator$LESS_THAN:4(#1, 3:1), $operator$EQUAL:4(#1, 4:1)), " + + "$operator$GREATER_THAN:4(#2, 3:1))", + types, projections); + OmniFilterAndProjectOperatorFactory factoryJSON = new OmniFilterAndProjectOperatorFactory( + "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"AND\"," + + "\"left\":{\"exprType\":\"IF\",\"returnType\":4," + + "\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"EQUAL\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":0}}," + + "\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"LESS_THAN\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":1}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":3}}," + + "\"if_false\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"EQUAL\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":1}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":4}}}," + + "\"right\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"GREATER_THAN\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":2}," + + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":3}}}", + types, projectionsJSON, 1); + + OmniOperator op = factory.createOperator(); + OmniOperator opJSON = factoryJSON.createOperator(); + ImmutableList vecBatches1 = makeInput(numRows, col01, col02, col03); + ImmutableList vecBatches2 = makeInput(numRows, col11, col12, col13); + for (VecBatch vecBatch : vecBatches1) { + op.addInput(vecBatch); + } + for (VecBatch vecBatch : vecBatches2) { + opJSON.addInput(vecBatch); + } + + assertTrue(op.getOutput().hasNext()); + assertTrue(opJSON.getOutput().hasNext()); + VecBatch res = op.getOutput().next(); + VecBatch resJSON = opJSON.getOutput().next(); + assertEquals(res.getRowCount(), 2000); + assertEquals(resJSON.getRowCount(), 2000); + freeVecBatch(res); + freeVecBatch(resJSON); + op.close(); + opJSON.close(); + factory.close(); + factoryJSON.close(); + } + + /** + * decimal InExpr WithNull + */ + @Test + public void decimalInExprWithNull() { + DataType[] sourceTypes = {new Decimal64DataType(7, 2), new Decimal64DataType(7, 5), + new Decimal64DataType(18, 9)}; + Object[][] sourceDatas = {{4570289L, -9999999L, null}, {9999999L, null, -234527L}, + {null, -999999999999999999L, -234527000012L}}; + VecBatch vecBatch = createVecBatch(sourceTypes, sourceDatas); + + List projectionsJSON = ImmutableList.of( + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":6,\"colVal\":0, \"precision\":7,\"scale\":2}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":6,\"colVal\":1, \"precision\":7,\"scale\":5}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":6,\"colVal\":2, \"precision\":18,\"scale\":9}"); + // deci7_5 in (deci7_2, deci7_5, deci18_9) + String filterJSON = "{\"exprType\":\"IN\",\"returnType\":4,\"arguments\":[{\"exprType\":\"FUNCTION\"," + + "\"returnType\":6,\"function_name\":\"CAST\",\"arguments\":[{\"exprType\":\"FIELD_REFERENCE\"," + + "\"dataType\":6,\"colVal\":1,\"precision\":7,\"scale\":5}],\"precision\":18,\"scale\":9}," + + "{\"exprType\":\"FUNCTION\",\"returnType\":6,\"function_name\":\"CAST\",\"arguments\":[{\"exprType" + + "\":\"FIELD_REFERENCE\",\"dataType\":6,\"colVal\":0,\"precision\":7,\"scale\":2}],\"precision\":18," + + "\"scale\":9},{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":6,\"colVal\":2,\"precision\":18," + + "\"scale\":9},{\"exprType\":\"FUNCTION\",\"returnType\":6,\"function_name\":\"CAST\",\"arguments" + + "\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":6,\"colVal\":1,\"precision\":7,\"scale\":5}]," + + "\"precision\":18,\"scale\":9}]}"; + OmniFilterAndProjectOperatorFactory factoryJSON = new OmniFilterAndProjectOperatorFactory(filterJSON, + sourceTypes, projectionsJSON, 1); + + OmniOperator opJSON = factoryJSON.createOperator(); + opJSON.addInput(vecBatch); + + Iterator results = opJSON.getOutput(); + VecBatch resultVecBatch = results.next(); + Object[][] expectedDatas = {{4570289L, null}, {9999999L, -234527L}, {null, -234527000012L}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + + freeVecBatch(resultVecBatch); + opJSON.close(); + factoryJSON.close(); + } + + /** + * Unsupported expression. + */ + @Test + public void unsupportedExpr() { + DataType[] types = {DoubleDataType.DOUBLE}; + List projectionsJSON = ImmutableList + .of("{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":3,\"colVal\":0}"); + + OmniFilterAndProjectOperatorFactory factory = new OmniFilterAndProjectOperatorFactory( + "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"CAST\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":3,\"colVal\":0},\"right\"" + + ":{\"exprType\":\"LITERAL\",\"dataType\":3,\"isNull\":false,\"value\":1.0}}", + types, projectionsJSON, 1); + + assertFalse(factory.isSupported()); + factory.close(); + } + + @Test + public void exprVerifier() { + DataType[] types = {new Decimal128DataType(21, 5)}; + List projectionsJSON = ImmutableList + .of("{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":7,\"precision\":21,\"scale\":5,\"colVal\":0}"); + + OmniFilterAndProjectOperatorFactory factory = new OmniFilterAndProjectOperatorFactory( + "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"LESS_THAN_THAN\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":7,\"colVal\":0," + + "\"precision\":21,\"scale\":5},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":6," + + "\"precision\":9,\"scale\":5,\"isNull\":false,\"value\":2000}}", + types, projectionsJSON, 1); + + assertFalse(factory.isSupported()); + factory.close(); + } + + @Test + public void testFactoryContextEquals() { + DataType[] types = {DoubleDataType.DOUBLE}; + List projectionsJSON = ImmutableList + .of("{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":3,\"colVal\":0}"); + + FactoryContext factory1 = new FactoryContext( + "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"CAST\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":3,\"colVal\":0},\"right\"" + + ":{\"exprType\":\"LITERAL\",\"dataType\":3,\"isNull\":false,\"value\":1.0}}", + types, projectionsJSON, 1, new OperatorConfig()); + FactoryContext factory2 = new FactoryContext( + "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"CAST\"," + + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":3,\"colVal\":0},\"right\"" + + ":{\"exprType\":\"LITERAL\",\"dataType\":3,\"isNull\":false,\"value\":1.0}}", + types, projectionsJSON, 1, new OperatorConfig()); + FactoryContext factory3 = null; + + assertEquals(factory2, factory1); + assertEquals(factory1, factory1); + assertNotEquals(factory3, factory1); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniHashAggregationOperatorTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniHashAggregationOperatorTest.java new file mode 100644 index 0000000..8b91560 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniHashAggregationOperatorTest.java @@ -0,0 +1,402 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.operator; + +import static java.lang.String.format; +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_COUNT_ALL; +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_COUNT_COLUMN; +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_SUM; +import static nova.hetu.omniruntime.util.TestUtils.assertVecBatchEquals; +import static nova.hetu.omniruntime.util.TestUtils.createVecBatch; +import static nova.hetu.omniruntime.util.TestUtils.freeVecBatch; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +import com.google.common.collect.ImmutableList; + +import nova.hetu.omniruntime.constants.FunctionType; +import nova.hetu.omniruntime.operator.aggregator.OmniHashAggregationOperatorFactory; +import nova.hetu.omniruntime.operator.aggregator.OmniHashAggregationOperatorFactory.FactoryContext; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.type.VarcharDataType; +import nova.hetu.omniruntime.vector.LongVec; +import nova.hetu.omniruntime.vector.Vec; +import nova.hetu.omniruntime.vector.VecBatch; + +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +/** + * The type Omni hash aggregation operator test. + * + * @since 2021-07-21 + */ +public class OmniHashAggregationOperatorTest { + /** + * test hashAggregation performance whether with jit or not. + */ + @Test + public void testHashAggregationComparePref() { + String[] groupByChannel = {"#0", "#1"}; + DataType[] groupByTypes = {LongDataType.LONG, LongDataType.LONG}; + String[] aggChannels = {"#3"}; + DataType[] aggTypes = {LongDataType.LONG}; + FunctionType[] aggFunctionTypes = {OMNI_AGGREGATION_TYPE_COUNT_COLUMN, OMNI_AGGREGATION_TYPE_COUNT_ALL}; + DataType[] aggOutputTypes = {LongDataType.LONG, LongDataType.LONG}; + + OmniHashAggregationOperatorFactory factoryWithJit = new OmniHashAggregationOperatorFactory(groupByChannel, + groupByTypes, aggChannels, aggTypes, aggFunctionTypes, aggOutputTypes, true, false, + new OperatorConfig()); + OmniOperator omniOperatorWithJit = factoryWithJit.createOperator(); + + ImmutableList.Builder vecBatchList1 = ImmutableList.builder(); + int rowNum = 100000; + int pageCount = 10; + for (int i = 0; i < pageCount; i++) { + vecBatchList1.add(new VecBatch(buildDataForCount(rowNum))); + } + + long start1 = System.currentTimeMillis(); + for (VecBatch vecBatch : vecBatchList1.build()) { + omniOperatorWithJit.addInput(vecBatch); + } + + Iterator outputWithJit = omniOperatorWithJit.getOutput(); + long end1 = System.currentTimeMillis(); + System.out.println("HashAggregation with jit use " + (end1 - start1) + " ms."); + + OmniHashAggregationOperatorFactory factoryWithoutJit = new OmniHashAggregationOperatorFactory(groupByChannel, + groupByTypes, aggChannels, aggTypes, aggFunctionTypes, aggOutputTypes, true, false, + new OperatorConfig()); + OmniOperator omniOperatorWithoutJit = factoryWithoutJit.createOperator(); + + ImmutableList.Builder vecBatchList2 = ImmutableList.builder(); + for (int i = 0; i < pageCount; i++) { + vecBatchList2.add(new VecBatch(buildDataForCount(rowNum))); + } + + long start2 = System.currentTimeMillis(); + for (VecBatch vecBatch : vecBatchList2.build()) { + omniOperatorWithoutJit.addInput(vecBatch); + } + + Iterator outputWithoutJit = omniOperatorWithoutJit.getOutput(); + long end2 = System.currentTimeMillis(); + System.out.println("HashAggregation without jit use " + (end2 - start2) + " ms."); + + while (outputWithJit.hasNext() && outputWithoutJit.hasNext()) { + VecBatch resultWithJit = outputWithJit.next(); + VecBatch resultWithoutJit = outputWithoutJit.next(); + assertVecBatchEquals(resultWithJit, resultWithoutJit); + freeVecBatch(resultWithJit); + freeVecBatch(resultWithoutJit); + } + + omniOperatorWithJit.close(); + omniOperatorWithoutJit.close(); + factoryWithJit.close(); + factoryWithoutJit.close(); + } + + /** + * Test execute agg multiple page. + */ + @Test + public void testExecuteCountMultiplePage() { + String[] groupByChannel = {"#0", "#1"}; + DataType[] groupByTypes = {LongDataType.LONG, LongDataType.LONG}; + String[] aggChannels = {"#3"}; + DataType[] aggTypes = {LongDataType.LONG}; + FunctionType[] aggFunctionTypes = {OMNI_AGGREGATION_TYPE_COUNT_COLUMN, OMNI_AGGREGATION_TYPE_COUNT_ALL}; + DataType[] aggOutputTypes = {LongDataType.LONG, LongDataType.LONG}; + + OmniHashAggregationOperatorFactory factory = new OmniHashAggregationOperatorFactory(groupByChannel, + groupByTypes, aggChannels, aggTypes, aggFunctionTypes, aggOutputTypes, true, false); + int rowNum = 100; + int pageCount = 10; + + OmniOperator omniOperator = factory.createOperator(); + + for (int i = 0; i < pageCount; i++) { + VecBatch vecBatch = new VecBatch(buildDataForCount(rowNum)); + omniOperator.addInput(vecBatch); + } + + Iterator output = omniOperator.getOutput(); + VecBatch vecBatch = null; + while (output.hasNext()) { + vecBatch = output.next(); + if (vecBatch.getVectors().length != aggOutputTypes.length + groupByTypes.length) { + throw new IllegalArgumentException( + format("output vec size error: result size: %s, outputTypes size: %s,rows: %s", + vecBatch.getVectors().length, aggOutputTypes.length, vecBatch.getRowCount())); + } + assertNotNull(vecBatch); + assertEquals(vecBatch.getVectors().length, 4); + Vec[] vectors = vecBatch.getVectors(); + assertEquals(((LongVec) vectors[0]).get(0), 1); + assertEquals(((LongVec) vectors[1]).get(0), 1); + assertEquals(((LongVec) vectors[2]).get(0), 500L); + assertEquals(((LongVec) vectors[3]).get(0), 1000L); + + freeVecBatch(vecBatch); + } + + omniOperator.close(); + factory.close(); + } + + @Test + public void testExecuteAggMultiplePage() { + String[] groupByChanel = {"#0", "#1"}; + DataType[] groupByTypes = {LongDataType.LONG, LongDataType.LONG}; + String[] aggChannels = {"#2", "#3"}; + DataType[] aggTypes = {LongDataType.LONG, LongDataType.LONG}; + FunctionType[] aggFunctionTypes = {OMNI_AGGREGATION_TYPE_SUM, OMNI_AGGREGATION_TYPE_SUM}; + DataType[] aggOutputTypes = {LongDataType.LONG, LongDataType.LONG}; + + DataType[] inputTypes = {LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, LongDataType.LONG}; + OmniHashAggregationOperatorFactory factory = new OmniHashAggregationOperatorFactory(groupByChanel, groupByTypes, + aggChannels, aggTypes, aggFunctionTypes, aggOutputTypes, true, false); + int rowNum = 40000; + int pageCount = 10; + + OmniOperator omniOperator = factory.createOperator(); + + for (int i = 0; i < pageCount; i++) { + VecBatch vecBatch = new VecBatch(build4Columns(rowNum)); + omniOperator.addInput(vecBatch); + } + + Iterator output = omniOperator.getOutput(); + VecBatch vecBatch = null; + while (output.hasNext()) { + vecBatch = output.next(); + if (vecBatch.getVectors().length != aggOutputTypes.length + groupByTypes.length) { + throw new IllegalArgumentException( + format("output vec size error: result size: %s, outputTypes size: %s,rows: %s", + vecBatch.getVectors().length, aggOutputTypes.length, vecBatch.getRowCount())); + } + assertNotNull(vecBatch); + assertEquals(vecBatch.getVectors().length, 4); + Vec[] vectors = vecBatch.getVectors(); + assertEquals(((LongVec) vectors[0]).get(0), 1); + assertEquals(((LongVec) vectors[1]).get(0), 1); + assertEquals(((LongVec) vectors[2]).get(0), rowNum * pageCount); + assertEquals(((LongVec) vectors[3]).get(0), rowNum * pageCount); + freeVecBatch(vecBatch); + } + omniOperator.close(); + factory.close(); + } + + /** + * Test execute agg multiple thread. + */ + @Test + public void testExecuteAggMultipleThread() { + int pageCount = 10; + int threadCount = 1; + int rowNum = 100; + multiThreadExecution(threadCount, rowNum, pageCount); + } + + @Test + public void testFactoryContextEquals() { + String[] groupByChannel = {"#0", "#1"}; + DataType[] groupByTypes = {LongDataType.LONG, LongDataType.LONG}; + String[] aggChannels = {"#3"}; + DataType[] aggTypes = {LongDataType.LONG}; + FunctionType[] aggFunctionTypes = {OMNI_AGGREGATION_TYPE_COUNT_COLUMN, OMNI_AGGREGATION_TYPE_COUNT_ALL}; + DataType[] aggOutputTypes = {LongDataType.LONG, LongDataType.LONG}; + + FactoryContext factory1 = new FactoryContext(groupByChannel, groupByTypes, aggChannels, aggTypes, + aggFunctionTypes, aggOutputTypes, true, false, new OperatorConfig()); + FactoryContext factory2 = new FactoryContext(groupByChannel, groupByTypes, aggChannels, aggTypes, + aggFunctionTypes, aggOutputTypes, true, false, new OperatorConfig()); + FactoryContext factory3 = null; + + assertEquals(factory2, factory1); + assertNotEquals(factory3, factory1); + assertEquals(factory1, factory1); + } + + @Test + public void testExecuteHashAggEmptyString() { + String[] groupByChanel = {"#0"}; + int varcharWidth = 10; + DataType[] groupByTypes = {new VarcharDataType(varcharWidth)}; + String[] aggChannels = {"#1"}; + DataType[] aggTypes = {LongDataType.LONG}; + FunctionType[] aggFunctionTypes = {OMNI_AGGREGATION_TYPE_SUM}; + DataType[] aggOutputTypes = {LongDataType.LONG}; + + OmniHashAggregationOperatorFactory factory = new OmniHashAggregationOperatorFactory(groupByChanel, groupByTypes, + aggChannels, aggTypes, aggFunctionTypes, aggOutputTypes, true, false); + + Object[][] datas = {{"", null, "", null}, {1L, 2L, 3L, 4L}}; + DataType[] sourceTypes = {new VarcharDataType(varcharWidth), LongDataType.LONG}; + VecBatch vecBatch = createVecBatch(sourceTypes, datas); + + OmniOperator omniOperator = factory.createOperator(); + omniOperator.addInput(vecBatch); + + Iterator output = omniOperator.getOutput(); + VecBatch result = output.next(); + // adjust the output sequence in the vector. + Object[][] expectedDatas = {{null, ""}, {6L, 4L}}; + assertVecBatchEquals(result, expectedDatas); + + freeVecBatch(result); + omniOperator.close(); + factory.close(); + } + + private void multiThreadExecution(int threadCount, int rowNum, int pageCount) { + String[] groupByChanel = {"#0", "#1"}; + DataType[] groupByTypes = {LongDataType.LONG, LongDataType.LONG}; + String[] aggChannels = {"#2", "#3"}; + DataType[] aggTypes = {LongDataType.LONG, LongDataType.LONG}; + FunctionType[] aggFunctionTypes = {OMNI_AGGREGATION_TYPE_SUM, OMNI_AGGREGATION_TYPE_SUM}; + DataType[] aggOutputTypes = {LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, LongDataType.LONG}; + OmniHashAggregationOperatorFactory factory = new OmniHashAggregationOperatorFactory(groupByChanel, groupByTypes, + aggChannels, aggTypes, aggFunctionTypes, aggOutputTypes, true, false); + + CountDownLatch downLatch = new CountDownLatch(threadCount); + final int corePoolSize = 10; + final int maximumPoolSize = 50; + ThreadPoolExecutor threadPool = new ThreadPoolExecutor(corePoolSize, maximumPoolSize, 60L, TimeUnit.SECONDS, + new LinkedBlockingQueue<>(threadCount)); + + for (int tIdx = 0; tIdx < threadCount; tIdx++) { + CompletableFuture.runAsync(() -> { + try { + OmniOperator omniOperator = factory.createOperator(); + for (int i = 0; i < pageCount; i++) { + omniOperator.addInput(new VecBatch(build4Columns(rowNum))); + } + + assertResult(rowNum, pageCount, aggOutputTypes, omniOperator); + omniOperator.close(); + } finally { + downLatch.countDown(); + } + }, threadPool); + } + + try { + downLatch.await(); + } catch (InterruptedException ex) { + assertTrue(false); + } + + threadPool.shutdown(); + factory.close(); + } + + private void assertResult(int rowNum, int pageCount, DataType[] aggOutputTypes, OmniOperator omniOperator) { + Iterator output = omniOperator.getOutput(); + while (output.hasNext()) { + VecBatch vecBatch = output.next(); + if (vecBatch.getVectors().length != aggOutputTypes.length) { + throw new IllegalArgumentException( + format("output vec size error: result size: %s, outputTypes size: %s,rows: %s", + vecBatch.getVectors().length, aggOutputTypes.length, vecBatch.getRowCount())); + } + + assertNotNull(vecBatch); + assertEquals(vecBatch.getVectors().length, 4); + Vec[] vectors = vecBatch.getVectors(); + assertEquals(((LongVec) vectors[0]).get(0), 1); + assertEquals(((LongVec) vectors[1]).get(0), 1); + assertEquals(((LongVec) vectors[2]).get(0), rowNum * pageCount); + assertEquals(((LongVec) vectors[3]).get(0), rowNum * pageCount); + freeVecBatch(vecBatch); + } + } + + private List build4Columns(int rowNum) { + LongVec c1 = new LongVec(rowNum); + LongVec c2 = new LongVec(rowNum); + for (int i = 0; i < rowNum; i++) { + c1.set(i, 1); + c2.set(i, 1); + } + + LongVec c3 = new LongVec(rowNum); + LongVec c4 = new LongVec(rowNum); + for (int i = 0; i < rowNum; i++) { + c3.set(i, 1); + c4.set(i, 1); + } + + List columns = new ArrayList<>(); + columns.add(c1); + columns.add(c2); + columns.add(c3); + columns.add(c4); + + return columns; + } + + private List build2Columns(int rowNum) { + LongVec c1 = new LongVec(rowNum); + for (int i = 0; i < rowNum; i++) { + c1.set(i, 0); + } + + LongVec c2 = new LongVec(rowNum); + for (int i = 0; i < rowNum; i++) { + c2.set(i, 1); + } + + List columns = new ArrayList<>(); + columns.add(c1); + columns.add(c2); + + return columns; + } + + private List buildDataForCount(int rowNum) { + LongVec c1 = new LongVec(rowNum); + LongVec c2 = new LongVec(rowNum); + for (int i = 0; i < rowNum; i++) { + c1.set(i, 1); + c2.set(i, 1); + } + + LongVec c3 = new LongVec(rowNum); + LongVec c4 = new LongVec(rowNum); + for (int i = 0; i < rowNum; i++) { + if (i % 2 == 0) { + c3.set(i, 1); + c4.set(i, 1); + } else { + c3.setNull(i); + c4.setNull(i); + } + } + + List columns = new ArrayList<>(); + columns.add(c1); + columns.add(c2); + columns.add(c3); + columns.add(c4); + + return columns; + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniHashAggregationWithExprOperatorTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniHashAggregationWithExprOperatorTest.java new file mode 100644 index 0000000..1f937fb --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniHashAggregationWithExprOperatorTest.java @@ -0,0 +1,410 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator; + +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_AVG; +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_COUNT_ALL; +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_COUNT_COLUMN; +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_SUM; +import static nova.hetu.omniruntime.util.TestUtils.assertVecBatchEquals; +import static nova.hetu.omniruntime.util.TestUtils.createVecBatch; +import static nova.hetu.omniruntime.util.TestUtils.freeVecBatch; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonFieldReference; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonLiteral; +import static nova.hetu.omniruntime.util.TestUtils.omniFunctionExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonFourArithmeticExpr; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; + +import com.google.common.collect.ImmutableList; + +import nova.hetu.omniruntime.constants.FunctionType; +import nova.hetu.omniruntime.operator.aggregator.OmniHashAggregationOperatorFactory; +import nova.hetu.omniruntime.operator.aggregator.OmniHashAggregationWithExprOperatorFactory; +import nova.hetu.omniruntime.operator.aggregator.OmniHashAggregationWithExprOperatorFactory.FactoryContext; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DoubleDataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.utils.OmniRuntimeException; +import nova.hetu.omniruntime.vector.LongVec; +import nova.hetu.omniruntime.vector.Vec; +import nova.hetu.omniruntime.vector.VecBatch; + +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +/** + * The type Omni hash aggregation with expression operator test. + * + * @since 2021-11-11 + */ +public class OmniHashAggregationWithExprOperatorTest { + /** + * test hashAggregationWithExpr performance whether with jit or not. + */ + @Test + public void testHashAggregationOutputlMultiVectorBatch() { + String[] groupByChannel = {"#0", "#1"}; + DataType[] groupByTypes = {LongDataType.LONG, LongDataType.LONG}; + String[] aggChannels = {"#3"}; + DataType[] aggTypes = {LongDataType.LONG}; + FunctionType[] aggFunctionTypes = {OMNI_AGGREGATION_TYPE_COUNT_COLUMN, OMNI_AGGREGATION_TYPE_COUNT_ALL}; + DataType[] aggOutputTypes = {LongDataType.LONG, LongDataType.LONG}; + + OmniHashAggregationOperatorFactory operatorFactory = new OmniHashAggregationOperatorFactory(groupByChannel, + groupByTypes, aggChannels, aggTypes, aggFunctionTypes, aggOutputTypes, true, false, + new OperatorConfig()); + OmniOperator omniOperator = operatorFactory.createOperator(); + + ImmutableList.Builder vecBatchList1 = ImmutableList.builder(); + int rowNum = 100000; + int pageCount = 10; + for (int i = 0; i < pageCount; i++) { + vecBatchList1.add(new VecBatch(buildDataForOutputMultiVectorBatch(rowNum))); + } + + for (VecBatch vecBatch : vecBatchList1.build()) { + omniOperator.addInput(vecBatch); + } + + Iterator outputVecBatch = omniOperator.getOutput(); + + int vecBatchCount = 0; + int totalRowcount = 0; + long col1Sum = 0L; + long col2Sum = 0L; + long col3Sum = 0L; + long col4Sum = 0L; + while (outputVecBatch.hasNext()) { + VecBatch result = outputVecBatch.next(); + Vec[] vectors = result.getVectors(); + int vecBatchRowCurrent = result.getRowCount(); + for (int i = 0; i < vecBatchRowCurrent; ++i) { + col1Sum += ((LongVec) vectors[0]).get(i); + col2Sum += ((LongVec) vectors[1]).get(i); + col3Sum += ((LongVec) vectors[2]).get(i); + col4Sum += ((LongVec) vectors[3]).get(i); + } + totalRowcount += vecBatchRowCurrent; + freeVecBatch(result); + vecBatchCount++; + } + omniOperator.close(); + operatorFactory.close(); + assertEquals(totalRowcount, rowNum); + // each row contains four columns, each of which contains 8 bytes. + int rowSize = 4 * 8; + // single vecBatch is 1MB, calculate the maximum number of rows in single + // vecBatch. + int rowsPerBatch = (1024 * 1024 + rowSize - 1) / rowSize; + int expectedBatchCount = (rowNum + rowsPerBatch - 1) / rowsPerBatch; + assertEquals(vecBatchCount, expectedBatchCount); + // sum of an arithmetic sequence with a step of 1 + assertEquals(col1Sum, (((long) rowNum - 1) * rowNum) / 2); + assertEquals(col2Sum, (long) rowNum); + assertEquals(col3Sum, (long) (rowNum / 2 * pageCount)); + assertEquals(col4Sum, (long) (rowNum * pageCount)); + } + + @Test + public void testHashAggregationWithExprComparePref() { + String[] groupByChannel = {getOmniJsonFieldReference(2, 0), getOmniJsonFieldReference(2, 1)}; + String[][] aggChannels = {{getOmniJsonFieldReference(2, 3)}}; + DataType[] sourceTypes = {LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, LongDataType.LONG}; + FunctionType[] aggFunctionTypes = {OMNI_AGGREGATION_TYPE_COUNT_ALL, OMNI_AGGREGATION_TYPE_COUNT_COLUMN}; + DataType[][] aggOutputTypes = {{LongDataType.LONG}, {LongDataType.LONG}}; + String[] aggChannelsfilter = {null, null}; + OmniHashAggregationWithExprOperatorFactory factoryWithJit = new OmniHashAggregationWithExprOperatorFactory( + groupByChannel, aggChannels, aggChannelsfilter, sourceTypes, aggFunctionTypes, aggOutputTypes, + new boolean[]{true, true}, new boolean[]{false, false}, new OperatorConfig()); + OmniOperator omniOperatorWithJit = factoryWithJit.createOperator(); + + ImmutableList.Builder vecBatchList1 = ImmutableList.builder(); + int rowNum = 100000; + int pageCount = 10; + for (int i = 0; i < pageCount; i++) { + vecBatchList1.add(new VecBatch(buildDataForCount(rowNum))); + } + + long start1 = System.currentTimeMillis(); + for (VecBatch vecBatch : vecBatchList1.build()) { + omniOperatorWithJit.addInput(vecBatch); + } + + Iterator outputWithJit = omniOperatorWithJit.getOutput(); + long end1 = System.currentTimeMillis(); + System.out.println("HashAggregationWithExpr with jit use " + (end1 - start1) + " ms."); + + OmniHashAggregationWithExprOperatorFactory factoryWithoutJit = new OmniHashAggregationWithExprOperatorFactory( + groupByChannel, aggChannels, aggChannelsfilter, sourceTypes, aggFunctionTypes, aggOutputTypes, + new boolean[]{true, true}, new boolean[]{false, false}, new OperatorConfig()); + OmniOperator omniOperatorWithoutJit = factoryWithoutJit.createOperator(); + + ImmutableList.Builder vecBatchList2 = ImmutableList.builder(); + for (int i = 0; i < pageCount; i++) { + vecBatchList2.add(new VecBatch(buildDataForCount(rowNum))); + } + + long start2 = System.currentTimeMillis(); + for (VecBatch vecBatch : vecBatchList2.build()) { + omniOperatorWithoutJit.addInput(vecBatch); + } + + Iterator outputWithoutJit = omniOperatorWithoutJit.getOutput(); + long end2 = System.currentTimeMillis(); + System.out.println("HashAggregationWithExpr without jit use " + (end2 - start2) + " ms."); + + while (outputWithJit.hasNext()) { + VecBatch resultWithJit = outputWithJit.next(); + VecBatch resultWithoutJit = outputWithoutJit.next(); + assertVecBatchEquals(resultWithJit, resultWithoutJit); + freeVecBatch(resultWithJit); + freeVecBatch(resultWithoutJit); + } + + omniOperatorWithJit.close(); + omniOperatorWithoutJit.close(); + factoryWithJit.close(); + factoryWithoutJit.close(); + } + + @Test + public void testHashAggWithPartialExpr() { + String[] groupByChanel = {omniJsonFourArithmeticExpr("MODULUS", 2, getOmniJsonFieldReference(2, 0), + getOmniJsonLiteral(2, false, 3)), getOmniJsonFieldReference(1, 2)}; + String[][] aggChannels = {{omniJsonFourArithmeticExpr("MULTIPLY", 2, getOmniJsonFieldReference(2, 1), + getOmniJsonLiteral(2, false, 5))}, {getOmniJsonFieldReference(1, 3)}}; + + FunctionType[] aggFunctionTypes = {OMNI_AGGREGATION_TYPE_SUM, OMNI_AGGREGATION_TYPE_AVG}; + DataType[][] aggOutputTypes = {{LongDataType.LONG}, {DoubleDataType.DOUBLE}}; + + DataType[] sourceTypes = {LongDataType.LONG, LongDataType.LONG, IntDataType.INTEGER, IntDataType.INTEGER}; + String[] aggChannelsfilter = {null, null}; + OmniHashAggregationWithExprOperatorFactory factory = new OmniHashAggregationWithExprOperatorFactory( + groupByChanel, aggChannels, aggChannelsfilter, sourceTypes, aggFunctionTypes, aggOutputTypes, + new boolean[]{true, true}, new boolean[]{false, false}); + + OmniOperator omniOperator = factory.createOperator(); + + Object[][] sourceDatas = {{2L, 5L, 8L, 11L, 14L, 17L, 20L, 23L}, {5L, 3L, 2L, 6L, 1L, 4L, 7L, 8L}, + {5, 5, 5, 5, 5, 5, 5, 5}, {5, 3, 2, 6, 1, 4, 7, 8}}; + VecBatch vecBatch = createVecBatch(sourceTypes, sourceDatas); + omniOperator.addInput(vecBatch); + + Iterator results = omniOperator.getOutput(); + + assertEquals(results.hasNext(), true); + VecBatch resultVecBatch = results.next(); + assertEquals(results.hasNext(), false); + + // should return false when multiple invoke hasNext() + assertEquals(results.hasNext(), false); + assertEquals(resultVecBatch.getRowCount(), 1); + assertEquals(resultVecBatch.getVectorCount(), 4); + + Object[][] expectedDatas = {{2L}, {5}, {180L}, {4.5}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + + freeVecBatch(resultVecBatch); + omniOperator.close(); + factory.close(); + } + + @Test + public void testHashAggWithAllExpr() { + String[] groupByChanel = { + omniJsonFourArithmeticExpr("MODULUS", 2, getOmniJsonFieldReference(2, 0), + getOmniJsonLiteral(2, false, 3)), + omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonFieldReference(1, 2), getOmniJsonLiteral(1, false, 5))}; + String[][] aggChannels = { + {omniJsonFourArithmeticExpr("MULTIPLY", 2, getOmniJsonFieldReference(2, 1), + getOmniJsonLiteral(2, false, 5))}, + {omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonFieldReference(1, 3), + getOmniJsonLiteral(1, false, 5))}}; + + FunctionType[] aggFunctionTypes = {OMNI_AGGREGATION_TYPE_SUM, OMNI_AGGREGATION_TYPE_AVG}; + DataType[][] aggOutputTypes = {{LongDataType.LONG}, {DoubleDataType.DOUBLE}}; + + DataType[] sourceTypes = {LongDataType.LONG, LongDataType.LONG, IntDataType.INTEGER, IntDataType.INTEGER}; + String[] aggChannelsfilter = {null, null}; + OmniHashAggregationWithExprOperatorFactory factory = new OmniHashAggregationWithExprOperatorFactory( + groupByChanel, aggChannels, aggChannelsfilter, sourceTypes, aggFunctionTypes, aggOutputTypes, + new boolean[]{true, true}, new boolean[]{false, false}); + + OmniOperator omniOperator = factory.createOperator(); + + Object[][] sourceDatas = {{2L, 5L, 8L, 11L, 14L, 17L, 20L, 23L}, {5L, 3L, 2L, 6L, 1L, 4L, 7L, 8L}, + {5, 5, 5, 5, 5, 5, 5, 5}, {5, 3, 2, 6, 1, 4, 7, 8}}; + VecBatch vecBatch = createVecBatch(sourceTypes, sourceDatas); + omniOperator.addInput(vecBatch); + + Iterator results = omniOperator.getOutput(); + + assertEquals(results.hasNext(), true); + VecBatch resultVecBatch = results.next(); + assertEquals(results.hasNext(), false); + assertEquals(resultVecBatch.getRowCount(), 1); + assertEquals(resultVecBatch.getVectorCount(), 4); + + Object[][] expectedDatas = {{2L}, {10}, {180L}, {9.5}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + + freeVecBatch(resultVecBatch); + omniOperator.close(); + factory.close(); + } + + @Test + public void testHashAggWithNoExpr() { + String[] groupByChanel = {getOmniJsonFieldReference(2, 0), getOmniJsonFieldReference(1, 2)}; + String[][] aggChannels = {{getOmniJsonFieldReference(2, 1)}, {getOmniJsonFieldReference(1, 3)}}; + + FunctionType[] aggFunctionTypes = {OMNI_AGGREGATION_TYPE_SUM, OMNI_AGGREGATION_TYPE_AVG, + OMNI_AGGREGATION_TYPE_COUNT_ALL}; + DataType[][] aggOutputTypes = {{LongDataType.LONG}, {DoubleDataType.DOUBLE}, {LongDataType.LONG}}; + + DataType[] sourceTypes = {LongDataType.LONG, LongDataType.LONG, IntDataType.INTEGER, IntDataType.INTEGER}; + String[] aggChannelsfilter = {null, null, null}; + OmniHashAggregationWithExprOperatorFactory factory = new OmniHashAggregationWithExprOperatorFactory( + groupByChanel, aggChannels, aggChannelsfilter, sourceTypes, aggFunctionTypes, aggOutputTypes, + new boolean[]{true, true, true}, new boolean[]{false, false, false}); + + OmniOperator omniOperator = factory.createOperator(); + + Object[][] sourceDatas = {{2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L}, {5L, 3L, 2L, 6L, 1L, 4L, 7L, 8L}, + {5, 5, 5, 5, 5, 5, 5, 5}, {5, 3, 2, 6, 1, 4, 7, 8}}; + VecBatch vecBatch = createVecBatch(sourceTypes, sourceDatas); + omniOperator.addInput(vecBatch); + + Iterator results = omniOperator.getOutput(); + + assertEquals(results.hasNext(), true); + VecBatch resultVecBatch = results.next(); + assertEquals(results.hasNext(), false); + assertEquals(resultVecBatch.getRowCount(), 1); + assertEquals(resultVecBatch.getVectorCount(), 5); + + Object[][] expectedDatas = {{2L}, {5}, {36L}, {4.5}, {8L}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + + freeVecBatch(resultVecBatch); + omniOperator.close(); + factory.close(); + } + + @Test(expectedExceptions = OmniRuntimeException.class, expectedExceptionsMessageRegExp = ".*EXPRESSION_NOT_SUPPORT.*") + public void testHashAggWithInvalidGroupByKeys() { + String[] groupByChanel = {omniFunctionExpr("abc", 2, getOmniJsonFieldReference(2, 0)), + getOmniJsonFieldReference(1, 2)}; + String[][] aggChannels = {{getOmniJsonFieldReference(2, 1)}, {getOmniJsonFieldReference(1, 3)}}; + FunctionType[] aggFunctionTypes = {OMNI_AGGREGATION_TYPE_SUM, OMNI_AGGREGATION_TYPE_AVG}; + DataType[][] aggOutputTypes = {{LongDataType.LONG}, {DoubleDataType.DOUBLE}}; + DataType[] sourceTypes = {LongDataType.LONG, LongDataType.LONG, IntDataType.INTEGER, IntDataType.INTEGER}; + String[] aggChannelsfilter = {null, null}; + OmniHashAggregationWithExprOperatorFactory factory = new OmniHashAggregationWithExprOperatorFactory( + groupByChanel, aggChannels, aggChannelsfilter, sourceTypes, aggFunctionTypes, aggOutputTypes, + new boolean[]{true, true}, new boolean[]{false, false}); + } + + @Test(expectedExceptions = OmniRuntimeException.class, expectedExceptionsMessageRegExp = ".*EXPRESSION_NOT_SUPPORT.*") + public void testHashAggWithInvalidAggKeys() { + String[] groupByChanel = {getOmniJsonFieldReference(2, 0), getOmniJsonFieldReference(1, 2)}; + String[][] aggChannels = {{omniFunctionExpr("abc", 2, getOmniJsonFieldReference(2, 1))}, + {getOmniJsonFieldReference(1, 3)}}; + FunctionType[] aggFunctionTypes = {OMNI_AGGREGATION_TYPE_SUM, OMNI_AGGREGATION_TYPE_AVG}; + DataType[][] aggOutputTypes = {{LongDataType.LONG}, {DoubleDataType.DOUBLE}}; + DataType[] sourceTypes = {LongDataType.LONG, LongDataType.LONG, IntDataType.INTEGER, IntDataType.INTEGER}; + String[] aggChannelsfilter = {null, null}; + OmniHashAggregationWithExprOperatorFactory factory = new OmniHashAggregationWithExprOperatorFactory( + groupByChanel, aggChannels, aggChannelsfilter, sourceTypes, aggFunctionTypes, aggOutputTypes, + new boolean[]{true, true}, new boolean[]{false, false}); + } + + @Test + public void testFactoryContextEquals() { + String[] groupByChanel = {getOmniJsonFieldReference(2, 0), getOmniJsonFieldReference(1, 2)}; + String[][] aggChannels = {{getOmniJsonFieldReference(2, 1)}, {getOmniJsonFieldReference(1, 3)}}; + + FunctionType[] aggFunctionTypes = {OMNI_AGGREGATION_TYPE_SUM, OMNI_AGGREGATION_TYPE_AVG}; + DataType[][] aggOutputTypes = {{LongDataType.LONG}, {DoubleDataType.DOUBLE}}; + + DataType[] sourceTypes = {LongDataType.LONG, LongDataType.LONG, IntDataType.INTEGER, IntDataType.INTEGER}; + String[] aggChannelsfilter = {null, null}; + FactoryContext factory1 = new FactoryContext(groupByChanel, aggChannels, aggChannelsfilter, sourceTypes, + aggFunctionTypes, aggOutputTypes, new boolean[]{true, true}, new boolean[]{false, false}, + new OperatorConfig()); + FactoryContext factory2 = new FactoryContext(groupByChanel, aggChannels, aggChannelsfilter, sourceTypes, + aggFunctionTypes, aggOutputTypes, new boolean[]{true, true}, new boolean[]{false, false}, + new OperatorConfig()); + FactoryContext factory3 = null; + + assertEquals(factory2, factory1); + assertEquals(factory1, factory1); + assertNotEquals(factory3, factory1); + } + + private List buildDataForCount(int rowNum) { + LongVec c1 = new LongVec(rowNum); + LongVec c2 = new LongVec(rowNum); + for (int i = 0; i < rowNum; i++) { + c1.set(i, 1); + c2.set(i, 1); + } + + LongVec c3 = new LongVec(rowNum); + LongVec c4 = new LongVec(rowNum); + for (int i = 0; i < rowNum; i++) { + if (i % 2 == 0) { + c3.set(i, 1); + c4.set(i, 1); + } else { + c3.setNull(i); + c4.setNull(i); + } + } + + List columns = new ArrayList<>(); + columns.add(c1); + columns.add(c2); + columns.add(c3); + columns.add(c4); + + return columns; + } + + private List buildDataForOutputMultiVectorBatch(int rowNum) { + LongVec c1 = new LongVec(rowNum); + LongVec c2 = new LongVec(rowNum); + for (int i = 0; i < rowNum; i++) { + c1.set(i, i); + c2.set(i, 1); + } + + LongVec c3 = new LongVec(rowNum); + LongVec c4 = new LongVec(rowNum); + for (int i = 0; i < rowNum; i++) { + if (i % 2 == 0) { + c3.set(i, 1); + c4.set(i, 1); + } else { + c3.setNull(i); + c4.setNull(i); + } + } + + List columns = new ArrayList<>(); + columns.add(c1); + columns.add(c2); + columns.add(c3); + columns.add(c4); + + return columns; + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniHashJoinOperatorsTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniHashJoinOperatorsTest.java new file mode 100644 index 0000000..9557e04 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniHashJoinOperatorsTest.java @@ -0,0 +1,1196 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator; + +import static nova.hetu.omniruntime.constants.JoinType.OMNI_JOIN_TYPE_FULL; +import static nova.hetu.omniruntime.constants.JoinType.OMNI_JOIN_TYPE_INNER; +import static nova.hetu.omniruntime.constants.JoinType.OMNI_JOIN_TYPE_LEFT; +import static nova.hetu.omniruntime.util.TestUtils.assertDecimal128VecEqualsIgnoreOrder; +import static nova.hetu.omniruntime.util.TestUtils.assertVecBatchEqualsIgnoreOrder; +import static nova.hetu.omniruntime.util.TestUtils.assertVecEqualsIgnoreOrder; +import static nova.hetu.omniruntime.util.TestUtils.createDictionaryVec; +import static nova.hetu.omniruntime.util.TestUtils.createLongVec; +import static nova.hetu.omniruntime.util.TestUtils.createVec; +import static nova.hetu.omniruntime.util.TestUtils.createVecBatch; +import static nova.hetu.omniruntime.util.TestUtils.freeVecBatch; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonFieldReference; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonLiteral; +import static nova.hetu.omniruntime.util.TestUtils.omniFunctionExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonGreaterThanExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonNotEqualExpr; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; +import static org.testng.Assert.assertTrue; + +import com.google.common.collect.ImmutableList; + +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.operator.join.OmniHashBuilderOperatorFactory; +import nova.hetu.omniruntime.operator.join.OmniLookupJoinOperatorFactory; +import nova.hetu.omniruntime.operator.join.OmniLookupOuterJoinOperatorFactory; +import nova.hetu.omniruntime.type.CharDataType; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.Date32DataType; +import nova.hetu.omniruntime.type.Decimal128DataType; +import nova.hetu.omniruntime.type.Decimal64DataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.type.VarcharDataType; +import nova.hetu.omniruntime.utils.OmniRuntimeException; +import nova.hetu.omniruntime.vector.LongVec; +import nova.hetu.omniruntime.vector.Vec; +import nova.hetu.omniruntime.vector.VecBatch; + +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; + +/** + * The type Omni hash join operators test. + * + * @since 2021-6-2 + */ +public class OmniHashJoinOperatorsTest { + /** + * The Page distinct count. + */ + int pageDistinctCount = 4; + + /** + * The Page distinct value repeat count. + */ + int pageDistinctValueRepeatCount = 100; + + /** + * test hash join performance whether with jit or not. + */ + @Test + public void testHashJoinComparePref() { + DataType[] buildTypes = {LongDataType.LONG, LongDataType.LONG}; + int[] buildHashCols = {0}; + int operatorCount = 1; + int buildPageCount = 10; + + OmniHashBuilderOperatorFactory hashBuilderOperatorFactoryWithoutJit = new OmniHashBuilderOperatorFactory( + OMNI_JOIN_TYPE_INNER, buildTypes, buildHashCols, operatorCount, new OperatorConfig()); + OmniOperator hashBuilderOperatorWithoutJit = hashBuilderOperatorFactoryWithoutJit.createOperator(); + ImmutableList buildVecsWithoutJit = buildVecs(buildPageCount); + + long start = System.currentTimeMillis(); + for (VecBatch vec : buildVecsWithoutJit) { + hashBuilderOperatorWithoutJit.addInput(vec); + } + Iterator hashBuilderOutputWithoutJit = hashBuilderOperatorWithoutJit.getOutput(); + long end = System.currentTimeMillis(); + System.out.println("HashBuilder without jit use " + (end - start) + " ms."); + + DataType[] probeTypes = {LongDataType.LONG, LongDataType.LONG}; + int[] probeOutputCols = {1}; + int[] probeHashCols = {0}; + int[] buildOutputCols = {1}; + DataType[] buildOutputTypes = {LongDataType.LONG}; + int probePageCount = 1; + + OmniLookupJoinOperatorFactory lookupJoinOperatorFactoryWithoutJit = new OmniLookupJoinOperatorFactory( + probeTypes, probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactoryWithoutJit, Optional.empty(), false, new OperatorConfig()); + OmniOperator lookupJoinOperatorWithoutJit = lookupJoinOperatorFactoryWithoutJit.createOperator(); + ImmutableList probeVecsWithoutJit = buildVecs(probePageCount); + + start = System.currentTimeMillis(); + lookupJoinOperatorWithoutJit.addInput(probeVecsWithoutJit.get(0)); + Iterator lookupJoinOutputWithoutJit = lookupJoinOperatorWithoutJit.getOutput(); + end = System.currentTimeMillis(); + System.out.println("LookupJoin without jit use " + (end - start) + " ms."); + + OmniHashBuilderOperatorFactory hashBuilderOperatorFactoryWithJit = new OmniHashBuilderOperatorFactory( + OMNI_JOIN_TYPE_INNER, buildTypes, buildHashCols, operatorCount, new OperatorConfig()); + OmniOperator hashBuilderOperatorWithJit = hashBuilderOperatorFactoryWithJit.createOperator(); + ImmutableList buildVecsWithJit = buildVecs(buildPageCount); + + start = System.currentTimeMillis(); + for (VecBatch vec : buildVecsWithJit) { + hashBuilderOperatorWithJit.addInput(vec); + } + Iterator hashBuilderOutputWithJit = hashBuilderOperatorWithJit.getOutput(); + end = System.currentTimeMillis(); + System.out.println("HashBuilder with jit use " + (end - start) + " ms."); + + OmniLookupJoinOperatorFactory lookupJoinOperatorFactoryWithJit = new OmniLookupJoinOperatorFactory(probeTypes, + probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactoryWithJit, Optional.empty(), false, new OperatorConfig()); + OmniOperator lookupJoinOperatorWithJit = lookupJoinOperatorFactoryWithJit.createOperator(); + ImmutableList probeVecsWithJit = buildVecs(probePageCount); + + start = System.currentTimeMillis(); + lookupJoinOperatorWithJit.addInput(probeVecsWithJit.get(0)); + Iterator lookupJoinOutputWithJit = lookupJoinOperatorWithJit.getOutput(); + end = System.currentTimeMillis(); + System.out.println("LookupJoin with jit use " + (end - start) + " ms."); + + while (hashBuilderOutputWithoutJit.hasNext()) { + VecBatch resultWithoutJit = hashBuilderOutputWithoutJit.next(); + VecBatch resultWithJit = hashBuilderOutputWithJit.next(); + assertVecBatchEqualsIgnoreOrder(resultWithoutJit, resultWithJit); + freeVecBatch(resultWithoutJit); + freeVecBatch(resultWithJit); + } + + while (lookupJoinOutputWithoutJit.hasNext() && lookupJoinOutputWithJit.hasNext()) { + VecBatch resultWithoutJit = lookupJoinOutputWithoutJit.next(); + VecBatch resultWithJit = lookupJoinOutputWithJit.next(); + assertVecBatchEqualsIgnoreOrder(resultWithoutJit, resultWithJit); + freeVecBatch(resultWithoutJit); + freeVecBatch(resultWithJit); + } + + lookupJoinOperatorWithoutJit.close(); + hashBuilderOperatorWithoutJit.close(); + lookupJoinOperatorWithJit.close(); + hashBuilderOperatorWithJit.close(); + lookupJoinOperatorFactoryWithoutJit.close(); + hashBuilderOperatorFactoryWithoutJit.close(); + lookupJoinOperatorFactoryWithJit.close(); + hashBuilderOperatorFactoryWithJit.close(); + } + + /** + * Test inner hash join one column 1. + */ + @Test + public void testInnerEqualityJoinOneColumn1() { + DataType[] buildTypes = {LongDataType.LONG, LongDataType.LONG}; + Object[][] buildDatas = {{1L, 2L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 1L}, + {79L, 79L, 70L, 70L, 70L, 70L, 70L, 70L, 70L, 70L}}; + VecBatch buildVecBatch = createVecBatch(buildTypes, buildDatas); + + int[] buildHashCols = {0}; + int operatorCount = 1; + OmniHashBuilderOperatorFactory hashBuilderOperatorFactory = new OmniHashBuilderOperatorFactory( + OMNI_JOIN_TYPE_INNER, buildTypes, buildHashCols, operatorCount); + OmniOperator hashBuilderOperator = hashBuilderOperatorFactory.createOperator(); + hashBuilderOperator.addInput(buildVecBatch); + hashBuilderOperator.getOutput(); + + DataType[] probeTypes = {LongDataType.LONG, LongDataType.LONG}; + Object[][] probeDatas = {{1L, 2L, 3L, 4L, 5L, 6L, 1L, 1L, 2L, 3L}, + {78L, 78L, 78L, 78L, 78L, 78L, 78L, 82L, 82L, 65L}}; + VecBatch probeVecBatch = createVecBatch(probeTypes, probeDatas); + + int[] probeOutputCols = {1}; + int[] probeHashCols = {0}; + int[] buildOutputCols = {1}; + DataType[] buildOutputTypes = {LongDataType.LONG}; + OmniLookupJoinOperatorFactory lookupJoinOperatorFactory = new OmniLookupJoinOperatorFactory(probeTypes, + probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, false); + OmniOperator lookupJoinOperator = lookupJoinOperatorFactory.createOperator(); + lookupJoinOperator.addInput(probeVecBatch); + Iterator results = lookupJoinOperator.getOutput(); + VecBatch resultVecBatch = results.next(); + + int len = resultVecBatch.getRowCount(); + assertEquals(len, 18); + Object[][] expectedDatas = { + {78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 82L, 82L, 82L, 82L, 82L, 65L}, + {79L, 70L, 70L, 79L, 70L, 70L, 70L, 70L, 70L, 79L, 70L, 70L, 79L, 70L, 70L, 79L, 70L, 70L}}; + assertVecBatchEqualsIgnoreOrder(resultVecBatch, expectedDatas); + freeVecBatch(resultVecBatch); + lookupJoinOperator.close(); + hashBuilderOperator.close(); + lookupJoinOperatorFactory.close(); + hashBuilderOperatorFactory.close(); + } + + /** + * Test inner hash join one column 2. + */ + @Test(enabled = false) + public void testInnerEqualityJoinOneColumn2() { + DataType[] buildTypes = {LongDataType.LONG, LongDataType.LONG}; + Object[][] buildDatas1 = {{1L, 1L, 3L, 6L, 7L, 1L}, {79L, 70L, 70L, 70L, 70L, 70L}}; + VecBatch buildVecBatch1 = createVecBatch(buildTypes, buildDatas1); + Object[][] buildDatas2 = {{2L, 2L, 4L, 5L}, {79L, 70L, 70L, 70L}}; + VecBatch buildVecBatch2 = createVecBatch(buildTypes, buildDatas2); + + int[] buildHashCols = {0}; + int operatorCount = 2; + OmniHashBuilderOperatorFactory hashBuilderOperatorFactory = new OmniHashBuilderOperatorFactory( + OMNI_JOIN_TYPE_INNER, buildTypes, buildHashCols, operatorCount); + OmniOperator hashBuilderOperator1 = hashBuilderOperatorFactory.createOperator(); + hashBuilderOperator1.addInput(buildVecBatch1); + hashBuilderOperator1.getOutput(); + OmniOperator hashBuilderOperator2 = hashBuilderOperatorFactory.createOperator(); + hashBuilderOperator2.addInput(buildVecBatch2); + hashBuilderOperator2.getOutput(); + + DataType[] probeTypes = {LongDataType.LONG, LongDataType.LONG}; + Object[][] probeDatas = {{1L, 2L, 3L, 4L, 5L, 6L, 1L, 1L, 2L, 3L}, + {78L, 78L, 78L, 78L, 78L, 78L, 78L, 82L, 82L, 65L}}; + VecBatch probeVecBatch = createVecBatch(probeTypes, probeDatas); + + int[] probeOutputCols = {1}; + int[] probeHashCols = {0}; + int[] buildOutputCols = {1}; + DataType[] buildOutputTypes = {LongDataType.LONG}; + OmniLookupJoinOperatorFactory lookupJoinOperatorFactory = new OmniLookupJoinOperatorFactory(probeTypes, + probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, false); + OmniOperator lookupJoinOperator = lookupJoinOperatorFactory.createOperator(); + lookupJoinOperator.addInput(probeVecBatch); + Iterator results = lookupJoinOperator.getOutput(); + VecBatch resultVecBatch = results.next(); + + int len = resultVecBatch.getRowCount(); + assertEquals(len, 18); + Object[][] expectedDatas = { + {78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 82L, 82L, 82L, 82L, 82L, 65L}, + {79L, 70L, 70L, 79L, 70L, 70L, 70L, 70L, 70L, 79L, 70L, 70L, 79L, 70L, 70L, 79L, 70L, 70L}}; + assertVecBatchEqualsIgnoreOrder(resultVecBatch, expectedDatas); + freeVecBatch(resultVecBatch); + lookupJoinOperator.close(); + hashBuilderOperator1.close(); + hashBuilderOperator2.close(); + lookupJoinOperatorFactory.close(); + hashBuilderOperatorFactory.close(); + } + + /** + * Test left join + */ + @Test + public void testLeftEqualityJoin() { + DataType[] buildTypes = {LongDataType.LONG, LongDataType.LONG}; + Object[][] buildDatas = {{1L, 2L, 3L, 4L}, {111L, 11L, 333L, 33L}}; + VecBatch buildVecBatch = createVecBatch(buildTypes, buildDatas); + + int[] buildHashCols = {1}; + int operatorCount = 1; + OmniHashBuilderOperatorFactory hashBuilderOperatorFactory = new OmniHashBuilderOperatorFactory( + OMNI_JOIN_TYPE_LEFT, buildTypes, buildHashCols, operatorCount); + OmniOperator hashBuilderOperator = hashBuilderOperatorFactory.createOperator(); + hashBuilderOperator.addInput(buildVecBatch); + hashBuilderOperator.getOutput(); + + DataType[] probeTypes = {LongDataType.LONG, LongDataType.LONG}; + Object[][] probeDatas = {{1L, 2L, 3L, 4L}, {11L, 22L, 33L, 44L}}; + VecBatch probeVecBatch = createVecBatch(probeTypes, probeDatas); + + int[] probeOutputCols = {0, 1}; + int[] probeHashCols = {1}; + int[] buildOutputCols = {0, 1}; + DataType[] buildOutputTypes = {LongDataType.LONG, LongDataType.LONG}; + OmniLookupJoinOperatorFactory lookupJoinOperatorFactory = new OmniLookupJoinOperatorFactory(probeTypes, + probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, false); + OmniOperator lookupJoinOperator = lookupJoinOperatorFactory.createOperator(); + lookupJoinOperator.addInput(probeVecBatch); + Iterator results = lookupJoinOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + int len = resultVecBatch.getRowCount(); + assertEquals(len, 4); + assertEquals(resultVecBatch.getVectorCount(), 4); + Object[][] expectedDatas = {{1L, 2L, 3L, 4L}, {11L, 22L, 33L, 44L}, {2L, null, 4L, null}, + {11L, null, 33L, null}}; + assertVecBatchEqualsIgnoreOrder(resultVecBatch, expectedDatas); + freeVecBatch(resultVecBatch); + lookupJoinOperator.close(); + hashBuilderOperator.close(); + lookupJoinOperatorFactory.close(); + hashBuilderOperatorFactory.close(); + } + + /** + * Test left join with varchar join key + */ + @Test + public void testLeftEqualityJoinVarchar() { + DataType[] buildTypes = {LongDataType.LONG, new VarcharDataType(3)}; + Object[][] buildDatas = {{1L, 2L, 3L, 4L}, {"aaa", "11", "ccc", "33"}}; + VecBatch buildVecBatch = createVecBatch(buildTypes, buildDatas); + + int[] buildHashCols = {1}; + int operatorCount = 1; + OmniHashBuilderOperatorFactory hashBuilderOperatorFactory = new OmniHashBuilderOperatorFactory( + OMNI_JOIN_TYPE_LEFT, buildTypes, buildHashCols, operatorCount); + OmniOperator hashBuilderOperator = hashBuilderOperatorFactory.createOperator(); + hashBuilderOperator.addInput(buildVecBatch); + hashBuilderOperator.getOutput(); + + DataType[] probeTypes = {LongDataType.LONG, new VarcharDataType(2)}; + Object[][] probeDatas = {{1L, 2L, 3L, 4L}, {"11", "22", "33", "44"}}; + VecBatch probeVecBatch = createVecBatch(probeTypes, probeDatas); + + int[] probeOutputCols = {0, 1}; + int[] probeHashCols = {1}; + int[] buildOutputCols = {0, 1}; + DataType[] buildOutputTypes = {LongDataType.LONG, new VarcharDataType(2)}; + OmniLookupJoinOperatorFactory lookupJoinOperatorFactory = new OmniLookupJoinOperatorFactory(probeTypes, + probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, false); + OmniOperator lookupJoinOperator = lookupJoinOperatorFactory.createOperator(); + lookupJoinOperator.addInput(probeVecBatch); + Iterator results = lookupJoinOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + int len = resultVecBatch.getRowCount(); + assertEquals(len, 4); + assertEquals(resultVecBatch.getVectorCount(), 4); + Object[][] expectedDatas = {{1L, 2L, 3L, 4L}, {"11", "22", "33", "44"}, {2L, null, 4L, null}, + {"11", null, "33", null}}; + assertVecBatchEqualsIgnoreOrder(resultVecBatch, expectedDatas); + freeVecBatch(resultVecBatch); + lookupJoinOperator.close(); + hashBuilderOperator.close(); + lookupJoinOperatorFactory.close(); + hashBuilderOperatorFactory.close(); + } + + /** + * Test left join with char join key + */ + @Test + public void testLeftEqualityJoinChar() { + DataType[] buildTypes = {LongDataType.LONG, new CharDataType(3)}; + Object[][] buildDatas = {{1L, 2L, 3L, 4L}, {"aaa", "11", "ccc", "33"}}; + VecBatch buildVecBatch = createVecBatch(buildTypes, buildDatas); + + int[] buildJoinCols = {1}; + int operatorCount = 1; + OmniHashBuilderOperatorFactory hashBuilderOperatorFactory = new OmniHashBuilderOperatorFactory( + OMNI_JOIN_TYPE_LEFT, buildTypes, buildJoinCols, operatorCount); + OmniOperator hashBuilderOperator = hashBuilderOperatorFactory.createOperator(); + hashBuilderOperator.addInput(buildVecBatch); + hashBuilderOperator.getOutput(); + + DataType[] probeTypes = {LongDataType.LONG, new CharDataType(2)}; + Object[][] probeDatas = {{1L, 2L, 3L, 4L}, {"11", "22", "33", "44"}}; + VecBatch probeVecBatch = createVecBatch(probeTypes, probeDatas); + + int[] probeOutputCols = {0, 1}; + int[] probeHashCols = {1}; + int[] buildOutputCols = {0, 1}; + DataType[] buildOutputTypes = {LongDataType.LONG, new CharDataType(2)}; + OmniLookupJoinOperatorFactory lookupJoinOperatorFactory = new OmniLookupJoinOperatorFactory(probeTypes, + probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, false); + OmniOperator lookupJoinOperator = lookupJoinOperatorFactory.createOperator(); + lookupJoinOperator.addInput(probeVecBatch); + Iterator results = lookupJoinOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + int len = resultVecBatch.getRowCount(); + assertEquals(len, 4); + assertEquals(resultVecBatch.getVectorCount(), 4); + Object[][] expectedDatas = {{1L, 2L, 3L, 4L}, {"11", "22", "33", "44"}, {2L, null, 4L, null}, + {"11", null, "33", null}}; + assertVecBatchEqualsIgnoreOrder(resultVecBatch, expectedDatas); + freeVecBatch(resultVecBatch); + lookupJoinOperator.close(); + hashBuilderOperator.close(); + lookupJoinOperatorFactory.close(); + hashBuilderOperatorFactory.close(); + } + + /** + * Test left join with date32 join key + */ + @Test + public void testLeftEqualityJoinDate32() { + DataType[] buildTypes = {LongDataType.LONG, new Date32DataType(DataType.DateUnit.DAY)}; + Object[][] buildDatas = {{1L, 2L, 3L, 4L}, {123, 11, 321, 33}}; + VecBatch buildVecBatch = createVecBatch(buildTypes, buildDatas); + + int[] buildHashCols = {1}; + int operatorCount = 1; + OmniHashBuilderOperatorFactory hashBuilderOperatorFactory = new OmniHashBuilderOperatorFactory( + OMNI_JOIN_TYPE_LEFT, buildTypes, buildHashCols, operatorCount); + OmniOperator hashBuilderOperator = hashBuilderOperatorFactory.createOperator(); + hashBuilderOperator.addInput(buildVecBatch); + hashBuilderOperator.getOutput(); + + DataType[] probeTypes = {LongDataType.LONG, new Date32DataType(DataType.DateUnit.DAY)}; + Object[][] probeDatas = {{1L, 2L, 3L, 4L}, {11, 22, 33, 44}}; + VecBatch probeVecBatch = createVecBatch(probeTypes, probeDatas); + + int[] probeOutputCols = {0, 1}; + int[] probeHashCols = {1}; + int[] buildOutputCols = {0, 1}; + DataType[] buildOutputTypes = {LongDataType.LONG, new Date32DataType(DataType.DateUnit.DAY)}; + OmniLookupJoinOperatorFactory lookupJoinOperatorFactory = new OmniLookupJoinOperatorFactory(probeTypes, + probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, false); + OmniOperator lookupJoinOperator = lookupJoinOperatorFactory.createOperator(); + lookupJoinOperator.addInput(probeVecBatch); + Iterator results = lookupJoinOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + int len = resultVecBatch.getRowCount(); + assertEquals(len, 4); + assertEquals(resultVecBatch.getVectorCount(), 4); + Object[][] expectedDatas = {{1L, 2L, 3L, 4L}, {11, 22, 33, 44}, {2L, null, 4L, null}, {11, null, 33, null}}; + assertVecBatchEqualsIgnoreOrder(resultVecBatch, expectedDatas); + freeVecBatch(resultVecBatch); + lookupJoinOperator.close(); + hashBuilderOperator.close(); + lookupJoinOperatorFactory.close(); + hashBuilderOperatorFactory.close(); + } + + /** + * Test left join with decimal64 join key + */ + @Test + public void testLeftEqualityJoinDecimal64() { + DataType[] buildTypes = {LongDataType.LONG, new Decimal64DataType(3, 0)}; + Object[][] buildDatas = {{1L, 2L, 3L, 4L}, {123L, 11L, 321L, 33L}}; + VecBatch buildVecBatch = createVecBatch(buildTypes, buildDatas); + + int[] buildHashCols = {1}; + int operatorCount = 1; + OmniHashBuilderOperatorFactory hashBuilderOperatorFactory = new OmniHashBuilderOperatorFactory( + OMNI_JOIN_TYPE_LEFT, buildTypes, buildHashCols, operatorCount); + OmniOperator hashBuilderOperator = hashBuilderOperatorFactory.createOperator(); + hashBuilderOperator.addInput(buildVecBatch); + hashBuilderOperator.getOutput(); + + DataType[] probeTypes = {LongDataType.LONG, new Decimal64DataType(2, 0)}; + Object[][] probeDatas = {{1L, 2L, 3L, 4L}, {11L, 22L, 33L, 44L}}; + VecBatch probeVecBatch = createVecBatch(probeTypes, probeDatas); + + int[] probeOutputCols = {0, 1}; + int[] probeHashCols = {1}; + int[] buildOutputCols = {0, 1}; + DataType[] buildOutputTypes = {LongDataType.LONG, new Decimal64DataType(3, 0)}; + OmniLookupJoinOperatorFactory lookupJoinOperatorFactory = new OmniLookupJoinOperatorFactory(probeTypes, + probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, false); + OmniOperator lookupJoinOperator = lookupJoinOperatorFactory.createOperator(); + lookupJoinOperator.addInput(probeVecBatch); + Iterator results = lookupJoinOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + int len = resultVecBatch.getRowCount(); + assertEquals(len, 4); + assertEquals(resultVecBatch.getVectorCount(), 4); + Object[][] expectedDatas = {{1L, 2L, 3L, 4L}, {11L, 22L, 33L, 44L}, {2L, null, 4L, null}, + {11L, null, 33L, null}}; + assertVecBatchEqualsIgnoreOrder(resultVecBatch, expectedDatas); + freeVecBatch(resultVecBatch); + lookupJoinOperator.close(); + hashBuilderOperator.close(); + lookupJoinOperatorFactory.close(); + hashBuilderOperatorFactory.close(); + } + + /** + * Test left join with decimal128 join key + */ + @Test + public void testLeftEqualityJoinDecimal128() { + DataType[] buildTypes = {LongDataType.LONG, new Decimal128DataType(3, 0)}; + Vec[] buildVecs = new Vec[buildTypes.length]; + buildVecs[0] = createVec(buildTypes[0], new Object[]{1L, 2L, 3L, 4L}); + buildVecs[1] = createVec(buildTypes[1], new Object[][]{{123L, 0L}, {11L, 0L}, {321L, 0L}, {33L, 0L}}); + VecBatch buildVecBatch = new VecBatch(buildVecs); + + int[] buildHashCols = {1}; + int operatorCount = 1; + OmniHashBuilderOperatorFactory hashBuilderOperatorFactory = new OmniHashBuilderOperatorFactory( + OMNI_JOIN_TYPE_LEFT, buildTypes, buildHashCols, operatorCount); + OmniOperator hashBuilderOperator = hashBuilderOperatorFactory.createOperator(); + hashBuilderOperator.addInput(buildVecBatch); + hashBuilderOperator.getOutput(); + + DataType[] probeTypes = {LongDataType.LONG, new Decimal128DataType(2, 0)}; + Vec[] probeVecs = new Vec[probeTypes.length]; + probeVecs[0] = createVec(probeTypes[0], new Object[]{1L, 2L, 3L, 4L}); + probeVecs[1] = createVec(probeTypes[1], new Object[][]{{11L, 0L}, {22L, 0L}, {33L, 0L}, {44L, 0L}}); + VecBatch probeVecBatch = new VecBatch(probeVecs); + + int[] probeOutputCols = {0, 1}; + int[] probeHashCols = {1}; + int[] buildOutputCols = {0, 1}; + DataType[] buildOutputTypes = {LongDataType.LONG, new Decimal128DataType(3, 0)}; + OmniLookupJoinOperatorFactory lookupJoinOperatorFactory = new OmniLookupJoinOperatorFactory(probeTypes, + probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, false); + OmniOperator lookupJoinOperator = lookupJoinOperatorFactory.createOperator(); + lookupJoinOperator.addInput(probeVecBatch); + Iterator results = lookupJoinOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + int len = resultVecBatch.getRowCount(); + assertEquals(len, 4); + assertEquals(resultVecBatch.getVectorCount(), 4); + assertVecEqualsIgnoreOrder(resultVecBatch.getVectors()[0], new Object[]{1L, 2L, 3L, 4L}); + assertDecimal128VecEqualsIgnoreOrder(resultVecBatch.getVectors()[1], + new Long[][]{{11L, 0L}, {22L, 0L}, {33L, 0L}, {44L, 0L}}); + assertVecEqualsIgnoreOrder(resultVecBatch.getVectors()[2], new Object[]{2L, null, 4L, null}); + assertDecimal128VecEqualsIgnoreOrder(resultVecBatch.getVectors()[3], new Long[][]{{11L, 0L}, null, {33L, 0L}, + null}); + freeVecBatch(resultVecBatch); + lookupJoinOperator.close(); + hashBuilderOperator.close(); + lookupJoinOperatorFactory.close(); + hashBuilderOperatorFactory.close(); + } + + /** + * Test inner join with dictionary join key + */ + @Test + public void testInnerEqualityJoinDictionary() { + DataType[] buildTypes = {LongDataType.LONG, LongDataType.LONG}; + Object[][] buildDatas = {{1L, null, 3L, null}, {111L, 11L, 333L, 33L}}; + Vec[] vecs = new Vec[2]; + int[] ids = {0, 1, 2, 3}; + vecs[0] = createLongVec(buildDatas[0]); + vecs[1] = createDictionaryVec(buildTypes[1], buildDatas[1], ids); + VecBatch buildVecBatch = new VecBatch(vecs); + + int[] buildHashCols = {1}; + int operatorCount = 1; + OmniHashBuilderOperatorFactory hashBuilderOperatorFactory = new OmniHashBuilderOperatorFactory( + OMNI_JOIN_TYPE_INNER, buildTypes, buildHashCols, operatorCount); + OmniOperator hashBuilderOperator = hashBuilderOperatorFactory.createOperator(); + hashBuilderOperator.addInput(buildVecBatch); + hashBuilderOperator.getOutput(); + + DataType[] probeTypes = {LongDataType.LONG, LongDataType.LONG}; + Object[][] probeDatas = {{null, 2L, null, 4L}, {11L, 22L, 33L, 44L}}; + Vec[] probeVecs = new Vec[2]; + probeVecs[0] = createLongVec(probeDatas[0]); + probeVecs[1] = createDictionaryVec(probeTypes[1], probeDatas[1], ids); + VecBatch probeVecBatch = new VecBatch(probeVecs); + + int[] probeOutputCols = {0, 1}; + int[] probeHashCols = {1}; + int[] buildOutputCols = {0, 1}; + DataType[] buildOutputTypes = {LongDataType.LONG, LongDataType.LONG}; + OmniLookupJoinOperatorFactory lookupJoinOperatorFactory = new OmniLookupJoinOperatorFactory(probeTypes, + probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, false); + OmniOperator lookupJoinOperator = lookupJoinOperatorFactory.createOperator(); + lookupJoinOperator.addInput(probeVecBatch); + Iterator results = lookupJoinOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + int len = resultVecBatch.getRowCount(); + assertEquals(len, 2); + assertEquals(resultVecBatch.getVectorCount(), 4); + Object[][] expectedDatas = {{null, null}, {11L, 33L}, {null, null}, {11L, 33L}}; + assertVecBatchEqualsIgnoreOrder(resultVecBatch, expectedDatas); + freeVecBatch(resultVecBatch); + lookupJoinOperator.close(); + hashBuilderOperator.close(); + lookupJoinOperatorFactory.close(); + hashBuilderOperatorFactory.close(); + } + + /** + * Test inner join with join filter on int column + */ + @Test + public void testInnerEqualityJoinWithIntFilter() { + DataType[] buildTypes = {IntDataType.INTEGER, IntDataType.INTEGER}; + Object[][] buildDatas = {{19, 14, 7, 19, 1, 20, 10, 13, 20, 16}, + {35709, 31904, 35709, 31904, 35709, null, 35709, 31904, null, 31904}}; + VecBatch buildVecBatch = createVecBatch(buildTypes, buildDatas); + + int[] buildHashCols = {0}; + int operatorCount = 1; + String filterExpression = omniJsonGreaterThanExpr(getOmniJsonFieldReference(1, 1), + getOmniJsonFieldReference(1, 3)); + OmniHashBuilderOperatorFactory hashBuilderOperatorFactory = new OmniHashBuilderOperatorFactory( + OMNI_JOIN_TYPE_INNER, buildTypes, buildHashCols, operatorCount); + OmniOperator hashBuilderOperator = hashBuilderOperatorFactory.createOperator(); + hashBuilderOperator.addInput(buildVecBatch); + hashBuilderOperator.getOutput(); + + DataType[] probeTypes = {IntDataType.INTEGER, IntDataType.INTEGER}; + Object[][] probeDatas = {{20, 16, 13, 4, 20, 4, 22, 19, 8, 7}, + {35709, 35709, 31904, 12477, null, 38721, 90419, 35709, 88371, null}}; + VecBatch probeVecBatch = createVecBatch(probeTypes, probeDatas); + + int[] probeOutputCols = {0, 1}; + int[] probeHashCols = {0}; + int[] buildOutputCols = {0, 1}; + DataType[] buildOutputTypes = {IntDataType.INTEGER, IntDataType.INTEGER}; + OmniLookupJoinOperatorFactory lookupJoinOperatorFactory = new OmniLookupJoinOperatorFactory(probeTypes, + probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, Optional.of(filterExpression), false); + OmniOperator lookupJoinOperator = lookupJoinOperatorFactory.createOperator(); + lookupJoinOperator.addInput(probeVecBatch); + Iterator results = lookupJoinOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + int len = resultVecBatch.getRowCount(); + assertEquals(len, 2); + assertEquals(resultVecBatch.getVectorCount(), 4); + Object[][] expectedDatas = {{16, 19}, {35709, 35709}, {16, 19}, {31904, 31904}}; + assertVecBatchEqualsIgnoreOrder(resultVecBatch, expectedDatas); + freeVecBatch(resultVecBatch); + lookupJoinOperator.close(); + hashBuilderOperator.close(); + lookupJoinOperatorFactory.close(); + hashBuilderOperatorFactory.close(); + } + + /** + * Test inner join with join filter on varchar column + */ + @Test + public void testInnerEqualityJoinWithCharFilter() { + DataType[] buildTypes = {IntDataType.INTEGER, new VarcharDataType(5)}; + Object[][] buildDatas = {{19, 14, 7, 19, 1, 20, 10, 13, 20, 16}, + {"35709", "31904", "35709", "31904", "35709", "31904", "35709", "31904", "35709", "31904"}}; + VecBatch buildVecBatch = createVecBatch(buildTypes, buildDatas); + + int[] buildHashCols = {0}; + int operatorCount = 1; + String filterExpression = omniJsonNotEqualExpr( + omniFunctionExpr("substr", 15, + getOmniJsonFieldReference(15, 1) + "," + getOmniJsonLiteral(1, false, 1) + "," + + getOmniJsonLiteral(1, false, 5)), + omniFunctionExpr("substr", 15, getOmniJsonFieldReference(15, 3) + "," + getOmniJsonLiteral(1, false, 1) + + "," + getOmniJsonLiteral(1, false, 5))); + OmniHashBuilderOperatorFactory hashBuilderOperatorFactory = new OmniHashBuilderOperatorFactory( + OMNI_JOIN_TYPE_INNER, buildTypes, buildHashCols, operatorCount); + OmniOperator hashBuilderOperator = hashBuilderOperatorFactory.createOperator(); + hashBuilderOperator.addInput(buildVecBatch); + hashBuilderOperator.getOutput(); + + DataType[] probeTypes = {IntDataType.INTEGER, new VarcharDataType(5)}; + Object[][] probeDatas = {{20, 16, 13, 4, 20, 4, 22, 19, 8, 7}, + {"35709", "35709", "31904", "12477", null, "38721", "90419", "35709", "88371", null}}; + VecBatch probeVecBatch = createVecBatch(probeTypes, probeDatas); + + int[] probeOutputCols = {0, 1}; + int[] probeHashCols = {0}; + int[] buildOutputCols = {0, 1}; + DataType[] buildOutputTypes = {IntDataType.INTEGER, new VarcharDataType(5)}; + OmniLookupJoinOperatorFactory lookupJoinOperatorFactory = new OmniLookupJoinOperatorFactory(probeTypes, + probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, Optional.of(filterExpression), false); + OmniOperator lookupJoinOperator = lookupJoinOperatorFactory.createOperator(); + lookupJoinOperator.addInput(probeVecBatch); + Iterator results = lookupJoinOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + int len = resultVecBatch.getRowCount(); + assertEquals(len, 3); + assertEquals(resultVecBatch.getVectorCount(), 4); + Object[][] expectedDatas = {{20, 16, 19}, {"35709", "35709", "35709"}, {20, 16, 19}, + {"31904", "31904", "31904"}}; + assertVecBatchEqualsIgnoreOrder(resultVecBatch, expectedDatas); + freeVecBatch(resultVecBatch); + lookupJoinOperator.close(); + hashBuilderOperator.close(); + lookupJoinOperatorFactory.close(); + hashBuilderOperatorFactory.close(); + } + + /** + * Test inner join without output + */ + @Test + public void testInnerEqualityJoinWithNoOutput() { + DataType[] buildTypes = {LongDataType.LONG, LongDataType.LONG}; + Object[][] buildDatas = {{1L, 2L, 3L, 4L}, {111L, 11L, 333L, 33L}}; + VecBatch buildVecBatch = createVecBatch(buildTypes, buildDatas); + + int[] buildHashCols = {0}; + int operatorCount = 1; + OmniHashBuilderOperatorFactory hashBuilderOperatorFactory = new OmniHashBuilderOperatorFactory( + OMNI_JOIN_TYPE_INNER, buildTypes, buildHashCols, operatorCount); + OmniOperator hashBuilderOperator = hashBuilderOperatorFactory.createOperator(); + hashBuilderOperator.addInput(buildVecBatch); + hashBuilderOperator.getOutput(); + + DataType[] probeTypes = {LongDataType.LONG, LongDataType.LONG}; + Object[][] probeDatas = {{0L, 5L, 6L, 7L}, {11L, 22L, 33L, 44L}}; + VecBatch probeVecBatch = createVecBatch(probeTypes, probeDatas); + + int[] probeOutputCols = {0, 1}; + int[] probeHashCols = {0}; + int[] buildOutputCols = {0, 1}; + DataType[] buildOutputTypes = {LongDataType.LONG, LongDataType.LONG}; + OmniLookupJoinOperatorFactory lookupJoinOperatorFactory = new OmniLookupJoinOperatorFactory(probeTypes, + probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, false); + OmniOperator lookupJoinOperator = lookupJoinOperatorFactory.createOperator(); + lookupJoinOperator.addInput(probeVecBatch); + Iterator results = lookupJoinOperator.getOutput(); + assertTrue(results != null); + assertEquals(results.hasNext(), false); + lookupJoinOperator.close(); + hashBuilderOperator.close(); + lookupJoinOperatorFactory.close(); + hashBuilderOperatorFactory.close(); + } + + @Test(expectedExceptions = OmniRuntimeException.class, expectedExceptionsMessageRegExp = + ".*EXPRESSION_NOT_SUPPORT.*") + public void testInnerEqualityJoinWithInvalidFilter() { + DataType[] buildTypes = {LongDataType.LONG, LongDataType.LONG}; + int[] buildHashCols = {0}; + int operatorCount = 1; + String filterExpression = omniJsonNotEqualExpr( + omniFunctionExpr("substring", 15, + getOmniJsonFieldReference(2, 1) + "," + getOmniJsonLiteral(1, false, 1) + "," + + getOmniJsonLiteral(1, false, 5)), + omniFunctionExpr("substring", 15, getOmniJsonFieldReference(2, 0) + "," + + getOmniJsonLiteral(1, false, 1) + "," + getOmniJsonLiteral(1, false, 5))); + OmniHashBuilderOperatorFactory hashBuilderOperatorFactory = new OmniHashBuilderOperatorFactory( + OMNI_JOIN_TYPE_INNER, buildTypes, buildHashCols, operatorCount); + + DataType[] probeTypes = {LongDataType.LONG, LongDataType.LONG}; + int[] probeOutputCols = {0, 1}; + int[] probeHashCols = {0}; + int[] buildOutputCols = {0, 1}; + DataType[] buildOutputTypes = {LongDataType.LONG, LongDataType.LONG}; + OmniLookupJoinOperatorFactory lookupJoinOperatorFactory = new OmniLookupJoinOperatorFactory(probeTypes, + probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, Optional.of(filterExpression), false); + + hashBuilderOperatorFactory.close(); + lookupJoinOperatorFactory.close(); + } + + @Test + public void testFactoryContextEquals() { + DataType[] buildTypes = {LongDataType.LONG, LongDataType.LONG}; + int[] buildHashCols = {0}; + int operatorCount = 1; + OmniHashBuilderOperatorFactory.FactoryContext hashBuilderOperatorFactory1 = + new OmniHashBuilderOperatorFactory.FactoryContext( + OMNI_JOIN_TYPE_INNER, buildTypes, buildHashCols, operatorCount, new OperatorConfig()); + OmniHashBuilderOperatorFactory.FactoryContext hashBuilderOperatorFactory2 = + new OmniHashBuilderOperatorFactory.FactoryContext( + OMNI_JOIN_TYPE_INNER, buildTypes, buildHashCols, operatorCount, new OperatorConfig()); + OmniHashBuilderOperatorFactory.FactoryContext hashBuilderOperatorFactory3 = null; + assertEquals(hashBuilderOperatorFactory2, hashBuilderOperatorFactory1); + assertEquals(hashBuilderOperatorFactory1, hashBuilderOperatorFactory1); + assertNotEquals(hashBuilderOperatorFactory3, hashBuilderOperatorFactory1); + + DataType[] probeTypes = {LongDataType.LONG, LongDataType.LONG}; + int[] probeOutputCols = {0, 1}; + int[] probeHashCols = {0}; + int[] buildOutputCols = {0, 1}; + DataType[] buildOutputTypes = {LongDataType.LONG, LongDataType.LONG}; + + OmniHashBuilderOperatorFactory omniHashBuilderOperatorFactory = new OmniHashBuilderOperatorFactory( + OMNI_JOIN_TYPE_INNER, buildTypes, buildHashCols, operatorCount, new OperatorConfig()); + OmniLookupJoinOperatorFactory.FactoryContext lookupJoinOperatorFactory1 = + new OmniLookupJoinOperatorFactory.FactoryContext( + probeTypes, probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, + omniHashBuilderOperatorFactory, Optional.empty(), false, new OperatorConfig()); + OmniLookupJoinOperatorFactory.FactoryContext lookupJoinOperatorFactory2 = + new OmniLookupJoinOperatorFactory.FactoryContext( + probeTypes, probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, + omniHashBuilderOperatorFactory, Optional.empty(), false, new OperatorConfig()); + OmniLookupJoinOperatorFactory.FactoryContext lookupJoinOperatorFactory3 = null; + + assertEquals(lookupJoinOperatorFactory2, lookupJoinOperatorFactory1); + assertEquals(lookupJoinOperatorFactory1, lookupJoinOperatorFactory1); + assertNotEquals(lookupJoinOperatorFactory3, lookupJoinOperatorFactory1); + } + + private ImmutableList buildVecs(int pageCount) { + ImmutableList.Builder vecBatchList = ImmutableList.builder(); + int positionCount = pageDistinctCount * pageDistinctValueRepeatCount; + List vecs = new ArrayList<>(); + for (int i = 0; i < pageCount; i++) { + LongVec longVec1 = new LongVec(positionCount); + LongVec longVec2 = new LongVec(positionCount); + int idx = 0; + for (int j = 0; j < pageDistinctCount; j++) { + for (int k = 0; k < pageDistinctValueRepeatCount; k++) { + longVec1.set(idx, j); + longVec2.set(idx, j); + idx++; + } + } + vecs.add(longVec1); + vecs.add(longVec2); + VecBatch vecBatch = new VecBatch(new Vec[]{longVec1, longVec2}); + vecBatchList.add(vecBatch); + } + return vecBatchList.build(); + } + + /** + * Test full join + */ + @Test + public void testFullEqualityJoin() { + DataType[] buildTypes = {LongDataType.LONG, LongDataType.LONG}; + Object[][] buildDatas = {{1L, 2L, 3L, 4L}, {111L, 11L, 333L, 33L}}; + VecBatch buildVecBatch = createVecBatch(buildTypes, buildDatas); + + int[] buildHashCols = {1}; + int operatorCount = 1; + OmniHashBuilderOperatorFactory hashBuilderOperatorFactory = new OmniHashBuilderOperatorFactory( + OMNI_JOIN_TYPE_FULL, buildTypes, buildHashCols, operatorCount); + OmniOperator hashBuilderOperator = hashBuilderOperatorFactory.createOperator(); + hashBuilderOperator.addInput(buildVecBatch); + hashBuilderOperator.getOutput(); + + DataType[] probeTypes = {LongDataType.LONG, LongDataType.LONG}; + Object[][] probeDatas = {{1L, 2L, 3L, 4L}, {11L, 22L, 33L, 44L}}; + int[] probeOutputCols = {0, 1}; + int[] probeHashCols = {1}; + int[] buildOutputCols = {0, 1}; + VecBatch probeVecBatch = createVecBatch(probeTypes, probeDatas); + DataType[] buildOutputTypes = {LongDataType.LONG, LongDataType.LONG}; + OmniLookupJoinOperatorFactory lookupJoinOperatorFactory = new OmniLookupJoinOperatorFactory(probeTypes, + probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, false); + OmniOperator lookupJoinOperator = lookupJoinOperatorFactory.createOperator(); + OmniLookupOuterJoinOperatorFactory lookupOuterJoinOperatorFactory = new OmniLookupOuterJoinOperatorFactory( + probeTypes, probeOutputCols, buildOutputCols, buildOutputTypes, hashBuilderOperatorFactory, + new OperatorConfig()); + OmniOperator lookupOuterJoinOperator = lookupOuterJoinOperatorFactory.createOperator(); + lookupJoinOperator.addInput(probeVecBatch); + Iterator results = lookupJoinOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + int len = resultVecBatch.getRowCount(); + assertEquals(len, 4); + assertEquals(resultVecBatch.getVectorCount(), 4); + Object[][] expectedDatas = {{1L, 2L, 3L, 4L}, {11L, 22L, 33L, 44L}, {2L, null, 4L, null}, + {11L, null, 33L, null}}; + assertVecBatchEqualsIgnoreOrder(resultVecBatch, expectedDatas); + + Iterator appendResults = lookupOuterJoinOperator.getOutput(); + VecBatch appendBatch = appendResults.next(); + len = appendBatch.getRowCount(); + assertEquals(len, 2); + Object[][] expectedData = {{null, null}, {null, null}, {1L, 3L}, {111L, 333L}}; + assertVecBatchEqualsIgnoreOrder(appendBatch, expectedData); + + freeVecBatch(resultVecBatch); + freeVecBatch(appendBatch); + lookupOuterJoinOperator.close(); + lookupOuterJoinOperatorFactory.close(); + lookupJoinOperator.close(); + hashBuilderOperator.close(); + lookupJoinOperatorFactory.close(); + hashBuilderOperatorFactory.close(); + } + + /** + * Test full join with varchar join key + */ + @Test + public void testFullEqualityJoinVarchar() { + DataType[] buildTypes = {LongDataType.LONG, new VarcharDataType(3)}; + Object[][] buildDatas = {{1L, 2L, 3L, 4L}, {"aaa", "11", "ccc", "33"}}; + VecBatch buildVecBatch = createVecBatch(buildTypes, buildDatas); + + int[] buildHashCols = {1}; + int operatorCount = 1; + OmniHashBuilderOperatorFactory hashBuilderOperatorFactory = new OmniHashBuilderOperatorFactory( + OMNI_JOIN_TYPE_FULL, buildTypes, buildHashCols, operatorCount); + OmniOperator hashBuilderOperator = hashBuilderOperatorFactory.createOperator(); + hashBuilderOperator.addInput(buildVecBatch); + hashBuilderOperator.getOutput(); + + DataType[] probeTypes = {LongDataType.LONG, new VarcharDataType(2)}; + Object[][] probeDatas = {{1L, 2L, 3L, 4L}, {"11", "22", "33", "44"}}; + int[] probeOutputCols = {0, 1}; + int[] probeHashCols = {1}; + int[] buildOutputCols = {0, 1}; + VecBatch probeVecBatch = createVecBatch(probeTypes, probeDatas); + DataType[] buildOutputTypes = {LongDataType.LONG, new VarcharDataType(2)}; + OmniLookupJoinOperatorFactory lookupJoinOperatorFactory = new OmniLookupJoinOperatorFactory(probeTypes, + probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, false); + OmniOperator lookupJoinOperator = lookupJoinOperatorFactory.createOperator(); + OmniLookupOuterJoinOperatorFactory lookupOuterJoinOperatorFactory = new OmniLookupOuterJoinOperatorFactory( + probeTypes, probeOutputCols, buildOutputCols, buildOutputTypes, hashBuilderOperatorFactory, + new OperatorConfig()); + OmniOperator lookupOuterJoinOperator = lookupOuterJoinOperatorFactory.createOperator(); + lookupJoinOperator.addInput(probeVecBatch); + Iterator results = lookupJoinOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + int len = resultVecBatch.getRowCount(); + assertEquals(len, 4); + assertEquals(resultVecBatch.getVectorCount(), 4); + Object[][] expectedDatas = {{1L, 2L, 3L, 4L}, {"11", "22", "33", "44"}, {2L, null, 4L, null}, + {"11", null, "33", null}}; + assertVecBatchEqualsIgnoreOrder(resultVecBatch, expectedDatas); + Iterator appendResults = lookupOuterJoinOperator.getOutput(); + VecBatch appendBatch = appendResults.next(); + len = appendBatch.getRowCount(); + assertEquals(len, 2); + Object[][] expectedData = {{null, null}, {null, null}, {1L, 3L}, {"aaa", "ccc"}}; + assertVecBatchEqualsIgnoreOrder(appendBatch, expectedData); + + freeVecBatch(resultVecBatch); + freeVecBatch(appendBatch); + lookupOuterJoinOperator.close(); + lookupOuterJoinOperatorFactory.close(); + lookupJoinOperator.close(); + hashBuilderOperator.close(); + lookupJoinOperatorFactory.close(); + hashBuilderOperatorFactory.close(); + } + + /** + * Test full join with char join key + */ + @Test + public void testFullEqualityJoinChar() { + DataType[] buildTypes = {LongDataType.LONG, new CharDataType(3)}; + Object[][] buildDatas = {{1L, 2L, 3L, 4L}, {"aaa", "11", "ccc", "33"}}; + VecBatch buildVecBatch = createVecBatch(buildTypes, buildDatas); + + int[] buildJoinCols = {1}; + int operatorCount = 1; + OmniHashBuilderOperatorFactory hashBuilderOperatorFactory = new OmniHashBuilderOperatorFactory( + OMNI_JOIN_TYPE_FULL, buildTypes, buildJoinCols, operatorCount); + OmniOperator hashBuilderOperator = hashBuilderOperatorFactory.createOperator(); + hashBuilderOperator.addInput(buildVecBatch); + hashBuilderOperator.getOutput(); + + DataType[] probeTypes = {LongDataType.LONG, new CharDataType(2)}; + Object[][] probeDatas = {{1L, 2L, 3L, 4L}, {"11", "22", "33", "44"}}; + int[] probeOutputCols = {0, 1}; + int[] probeHashCols = {1}; + int[] buildOutputCols = {0, 1}; + VecBatch probeVecBatch = createVecBatch(probeTypes, probeDatas); + DataType[] buildOutputTypes = {LongDataType.LONG, new CharDataType(2)}; + OmniLookupJoinOperatorFactory lookupJoinOperatorFactory = new OmniLookupJoinOperatorFactory(probeTypes, + probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, false); + OmniOperator lookupJoinOperator = lookupJoinOperatorFactory.createOperator(); + OmniLookupOuterJoinOperatorFactory lookupOuterJoinOperatorFactory = new OmniLookupOuterJoinOperatorFactory( + probeTypes, probeOutputCols, buildOutputCols, buildOutputTypes, hashBuilderOperatorFactory, + new OperatorConfig()); + OmniOperator lookupOuterJoinOperator = lookupOuterJoinOperatorFactory.createOperator(); + lookupJoinOperator.addInput(probeVecBatch); + Iterator results = lookupJoinOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + int len = resultVecBatch.getRowCount(); + assertEquals(len, 4); + assertEquals(resultVecBatch.getVectorCount(), 4); + Object[][] expectedDatas = {{1L, 2L, 3L, 4L}, {"11", "22", "33", "44"}, {2L, null, 4L, null}, + {"11", null, "33", null}}; + assertVecBatchEqualsIgnoreOrder(resultVecBatch, expectedDatas); + + Iterator appendResults = lookupOuterJoinOperator.getOutput(); + VecBatch appendBatch = appendResults.next(); + len = appendBatch.getRowCount(); + assertEquals(len, 2); + Object[][] expectedData = {{null, null}, {null, null}, {1L, 3L}, {"aaa", "ccc"}}; + assertVecBatchEqualsIgnoreOrder(appendBatch, expectedData); + + freeVecBatch(appendBatch); + freeVecBatch(resultVecBatch); + lookupJoinOperator.close(); + hashBuilderOperator.close(); + lookupOuterJoinOperator.close(); + lookupOuterJoinOperatorFactory.close(); + lookupJoinOperatorFactory.close(); + hashBuilderOperatorFactory.close(); + } + + /** + * Test full join with date32 join key + */ + @Test + public void testFullEqualityJoinDate32() { + DataType[] buildTypes = {LongDataType.LONG, new Date32DataType(DataType.DateUnit.DAY)}; + Object[][] buildDatas = {{1L, 2L, 3L, 4L}, {123, 11, 321, 33}}; + VecBatch buildVecBatch = createVecBatch(buildTypes, buildDatas); + + int[] buildHashCols = {1}; + int operatorCount = 1; + OmniHashBuilderOperatorFactory hashBuilderOperatorFactory = new OmniHashBuilderOperatorFactory( + OMNI_JOIN_TYPE_FULL, buildTypes, buildHashCols, operatorCount); + OmniOperator hashBuilderOperator = hashBuilderOperatorFactory.createOperator(); + hashBuilderOperator.addInput(buildVecBatch); + hashBuilderOperator.getOutput(); + + DataType[] probeTypes = {LongDataType.LONG, new Date32DataType(DataType.DateUnit.DAY)}; + Object[][] probeDatas = {{1L, 2L, 3L, 4L}, {11, 22, 33, 44}}; + int[] probeOutputCols = {0, 1}; + int[] probeHashCols = {1}; + int[] buildOutputCols = {0, 1}; + VecBatch probeVecBatch = createVecBatch(probeTypes, probeDatas); + DataType[] buildOutputTypes = {LongDataType.LONG, new Date32DataType(DataType.DateUnit.DAY)}; + OmniLookupJoinOperatorFactory lookupJoinOperatorFactory = new OmniLookupJoinOperatorFactory(probeTypes, + probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, false); + OmniOperator lookupJoinOperator = lookupJoinOperatorFactory.createOperator(); + OmniLookupOuterJoinOperatorFactory lookupOuterJoinOperatorFactory = new OmniLookupOuterJoinOperatorFactory( + probeTypes, probeOutputCols, buildOutputCols, buildOutputTypes, hashBuilderOperatorFactory, + new OperatorConfig()); + OmniOperator lookupOuterJoinOperator = lookupOuterJoinOperatorFactory.createOperator(); + lookupJoinOperator.addInput(probeVecBatch); + Iterator results = lookupJoinOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + int len = resultVecBatch.getRowCount(); + assertEquals(len, 4); + assertEquals(resultVecBatch.getVectorCount(), 4); + Object[][] expectedDatas = {{1L, 2L, 3L, 4L}, {11, 22, 33, 44}, {2L, null, 4L, null}, {11, null, 33, null}}; + assertVecBatchEqualsIgnoreOrder(resultVecBatch, expectedDatas); + + Iterator appendResults = lookupOuterJoinOperator.getOutput(); + VecBatch appendBatch = appendResults.next(); + len = appendBatch.getRowCount(); + assertEquals(len, 2); + Object[][] expectedData = {{null, null}, {null, null}, {1L, 3L}, {123, 321}}; + assertVecBatchEqualsIgnoreOrder(appendBatch, expectedData); + + freeVecBatch(resultVecBatch); + freeVecBatch(appendBatch); + lookupOuterJoinOperator.close(); + lookupJoinOperator.close(); + hashBuilderOperator.close(); + lookupOuterJoinOperatorFactory.close(); + lookupJoinOperatorFactory.close(); + hashBuilderOperatorFactory.close(); + } + + /** + * Test full join with decimal64 join key + */ + @Test + public void testFullEqualityJoinDecimal64() { + DataType[] buildTypes = {LongDataType.LONG, new Decimal64DataType(3, 0)}; + Object[][] buildDatas = {{1L, 2L, 3L, 4L}, {123L, 11L, 321L, 33L}}; + VecBatch buildVecBatch = createVecBatch(buildTypes, buildDatas); + + int[] buildHashCols = {1}; + int operatorCount = 1; + OmniHashBuilderOperatorFactory hashBuilderOperatorFactory = new OmniHashBuilderOperatorFactory( + OMNI_JOIN_TYPE_FULL, buildTypes, buildHashCols, operatorCount); + OmniOperator hashBuilderOperator = hashBuilderOperatorFactory.createOperator(); + hashBuilderOperator.addInput(buildVecBatch); + hashBuilderOperator.getOutput(); + + DataType[] probeTypes = {LongDataType.LONG, new Decimal64DataType(2, 0)}; + Object[][] probeDatas = {{1L, 2L, 3L, 4L}, {11L, 22L, 33L, 44L}}; + int[] probeOutputCols = {0, 1}; + int[] probeHashCols = {1}; + int[] buildOutputCols = {0, 1}; + VecBatch probeVecBatch = createVecBatch(probeTypes, probeDatas); + DataType[] buildOutputTypes = {LongDataType.LONG, new Decimal64DataType(3, 0)}; + OmniLookupJoinOperatorFactory lookupJoinOperatorFactory = new OmniLookupJoinOperatorFactory(probeTypes, + probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, false); + OmniOperator lookupJoinOperator = lookupJoinOperatorFactory.createOperator(); + OmniLookupOuterJoinOperatorFactory lookupOuterJoinOperatorFactory = new OmniLookupOuterJoinOperatorFactory( + probeTypes, probeOutputCols, buildOutputCols, buildOutputTypes, hashBuilderOperatorFactory, + new OperatorConfig()); + OmniOperator lookupOuterJoinOperator = lookupOuterJoinOperatorFactory.createOperator(); + lookupJoinOperator.addInput(probeVecBatch); + Iterator results = lookupJoinOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + int len = resultVecBatch.getRowCount(); + assertEquals(len, 4); + assertEquals(resultVecBatch.getVectorCount(), 4); + Object[][] expectedDatas = {{1L, 2L, 3L, 4L}, {11L, 22L, 33L, 44L}, {2L, null, 4L, null}, + {11L, null, 33L, null}}; + assertVecBatchEqualsIgnoreOrder(resultVecBatch, expectedDatas); + + Iterator appendResults = lookupOuterJoinOperator.getOutput(); + VecBatch appendBatch = appendResults.next(); + len = appendBatch.getRowCount(); + assertEquals(len, 2); + Object[][] expectedData = {{null, null}, {null, null}, {1L, 3L}, {123L, 321L}}; + assertVecBatchEqualsIgnoreOrder(appendBatch, expectedData); + + freeVecBatch(resultVecBatch); + freeVecBatch(appendBatch); + lookupOuterJoinOperator.close(); + lookupJoinOperator.close(); + hashBuilderOperator.close(); + lookupJoinOperatorFactory.close(); + hashBuilderOperatorFactory.close(); + lookupOuterJoinOperatorFactory.close(); + } + + /** + * Test full join with decimal128 join key + */ + @Test + public void testFullEqualityJoinDecimal128() { + DataType[] buildTypes = {LongDataType.LONG, new Decimal128DataType(3, 0)}; + Vec[] buildVecs = new Vec[buildTypes.length]; + buildVecs[0] = createVec(buildTypes[0], new Object[]{1L, 2L, 3L, 4L}); + buildVecs[1] = createVec(buildTypes[1], new Object[][]{{123L, 0L}, {11L, 0L}, {321L, 0L}, {33L, 0L}}); + VecBatch buildVecBatch = new VecBatch(buildVecs); + + int[] buildHashCols = {1}; + int operatorCount = 1; + OmniHashBuilderOperatorFactory hashBuilderOperatorFactory = new OmniHashBuilderOperatorFactory( + OMNI_JOIN_TYPE_FULL, buildTypes, buildHashCols, operatorCount); + OmniOperator hashBuilderOperator = hashBuilderOperatorFactory.createOperator(); + hashBuilderOperator.addInput(buildVecBatch); + hashBuilderOperator.getOutput(); + + DataType[] probeTypes = {LongDataType.LONG, new Decimal128DataType(2, 0)}; + Vec[] probeVecs = new Vec[probeTypes.length]; + probeVecs[0] = createVec(probeTypes[0], new Object[]{1L, 2L, 3L, 4L}); + probeVecs[1] = createVec(probeTypes[1], new Object[][]{{11L, 0L}, {22L, 0L}, {33L, 0L}, {44L, 0L}}); + + int[] probeOutputCols = {0, 1}; + int[] probeHashCols = {1}; + int[] buildOutputCols = {0, 1}; + VecBatch probeVecBatch = new VecBatch(probeVecs); + DataType[] buildOutputTypes = {LongDataType.LONG, new Decimal128DataType(3, 0)}; + OmniLookupJoinOperatorFactory lookupJoinOperatorFactory = new OmniLookupJoinOperatorFactory(probeTypes, + probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, false); + OmniOperator lookupJoinOperator = lookupJoinOperatorFactory.createOperator(); + OmniLookupOuterJoinOperatorFactory lookupOuterJoinOperatorFactory = new OmniLookupOuterJoinOperatorFactory( + probeTypes, probeOutputCols, buildOutputCols, buildOutputTypes, hashBuilderOperatorFactory, + new OperatorConfig()); + OmniOperator lookupOuterJoinOperator = lookupOuterJoinOperatorFactory.createOperator(); + lookupJoinOperator.addInput(probeVecBatch); + Iterator results = lookupJoinOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + int len = resultVecBatch.getRowCount(); + assertEquals(len, 4); + assertEquals(resultVecBatch.getVectorCount(), 4); + assertVecEqualsIgnoreOrder(resultVecBatch.getVectors()[0], new Object[]{1L, 2L, 3L, 4L}); + assertDecimal128VecEqualsIgnoreOrder(resultVecBatch.getVectors()[1], + new Long[][]{{11L, 0L}, {22L, 0L}, {33L, 0L}, {44L, 0L}}); + assertVecEqualsIgnoreOrder(resultVecBatch.getVectors()[2], new Object[]{2L, null, 4L, null}); + assertDecimal128VecEqualsIgnoreOrder(resultVecBatch.getVectors()[3], new Long[][]{{11L, 0L}, null, {33L, 0L}, + null}); + + Iterator appendResults = lookupOuterJoinOperator.getOutput(); + VecBatch appendBatch = appendResults.next(); + len = appendBatch.getRowCount(); + assertEquals(len, 2); + assertVecEqualsIgnoreOrder(appendBatch.getVector(0), new Object[]{null, null}); + assertDecimal128VecEqualsIgnoreOrder(appendBatch.getVector(1), new Long[][]{null, null}); + assertVecEqualsIgnoreOrder(appendBatch.getVector(2), new Object[]{1L, 3L}); + assertDecimal128VecEqualsIgnoreOrder(appendBatch.getVector(3), new Long[][]{{123L, 0L}, {321L, 0L}}); + + freeVecBatch(resultVecBatch); + freeVecBatch(appendBatch); + lookupOuterJoinOperator.close(); + lookupJoinOperator.close(); + hashBuilderOperator.close(); + lookupJoinOperatorFactory.close(); + hashBuilderOperatorFactory.close(); + lookupOuterJoinOperatorFactory.close(); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniHashJoinWithExprOperatorsTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniHashJoinWithExprOperatorsTest.java new file mode 100644 index 0000000..b11ec18 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniHashJoinWithExprOperatorsTest.java @@ -0,0 +1,940 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator; + +import static nova.hetu.omniruntime.constants.JoinType.OMNI_JOIN_TYPE_FULL; +import static nova.hetu.omniruntime.constants.JoinType.OMNI_JOIN_TYPE_LEFT; +import static nova.hetu.omniruntime.constants.JoinType.OMNI_JOIN_TYPE_INNER; +import static nova.hetu.omniruntime.util.TestUtils.assertVecBatchEqualsIgnoreOrder; +import static nova.hetu.omniruntime.constants.BuildSide.BUILD_LEFT; +import static nova.hetu.omniruntime.util.TestUtils.assertVecBatchEquals; +import static nova.hetu.omniruntime.util.TestUtils.createVecBatch; +import static nova.hetu.omniruntime.util.TestUtils.freeVecBatch; +import static nova.hetu.omniruntime.util.TestUtils.freeVecBatches; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonFieldReference; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonLiteral; +import static nova.hetu.omniruntime.util.TestUtils.omniFunctionExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonFourArithmeticExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonNotEqualExpr; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; + +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.operator.join.OmniHashBuilderWithExprOperatorFactory; +import nova.hetu.omniruntime.operator.join.OmniLookupJoinWithExprOperatorFactory; +import nova.hetu.omniruntime.operator.join.OmniLookupOuterJoinWithExprOperatorFactory; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.type.VarcharDataType; +import nova.hetu.omniruntime.util.TestUtils; +import nova.hetu.omniruntime.utils.OmniRuntimeException; +import nova.hetu.omniruntime.vector.Vec; +import nova.hetu.omniruntime.vector.VecBatch; + +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; + +/** + * The type Omni hash join with expression operator test. + * + * @since 2021-10-16 + */ +public class OmniHashJoinWithExprOperatorsTest { + /** + * Test inner hash join one column . + */ + @Test + public void testInnerEqualityJoinOneColumn() { + DataType[] buildTypes = {LongDataType.LONG, LongDataType.LONG}; + Object[][] buildDatas = {{1L, 2L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 1L}, + {79L, 79L, 70L, 70L, 70L, 70L, 70L, 70L, 70L, 70L}}; + VecBatch buildVecBatch = createVecBatch(buildTypes, buildDatas); + + String[] buildHashKeys = {omniJsonFourArithmeticExpr("ADD", 2, getOmniJsonFieldReference(2, 0), + getOmniJsonLiteral(2, false, 50))}; + int operatorCount = 1; + OmniHashBuilderWithExprOperatorFactory hashBuilderOperatorFactory = new OmniHashBuilderWithExprOperatorFactory( + OMNI_JOIN_TYPE_INNER, buildTypes, buildHashKeys, operatorCount); + OmniOperator hashBuilderOperator = hashBuilderOperatorFactory.createOperator(); + hashBuilderOperator.addInput(buildVecBatch); + hashBuilderOperator.getOutput(); + + DataType[] probeTypes = {LongDataType.LONG, LongDataType.LONG}; + Object[][] probeDatas = {{1L, 2L, 3L, 4L, 5L, 6L, 1L, 1L, 2L, 3L}, + {78L, 78L, 78L, 78L, 78L, 78L, 78L, 82L, 82L, 65L}}; + VecBatch probeVecBatch = createVecBatch(probeTypes, probeDatas); + + int[] probeOutputCols = {1}; + String[] probeHashKeys = {omniJsonFourArithmeticExpr("ADD", 2, getOmniJsonFieldReference(2, 0), + getOmniJsonLiteral(2, false, 50))}; + int[] buildOutputCols = {1}; + DataType[] buildOutputTypes = {LongDataType.LONG}; + OmniLookupJoinWithExprOperatorFactory lookupJoinOperatorFactory = new OmniLookupJoinWithExprOperatorFactory( + probeTypes, probeOutputCols, probeHashKeys, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, false); + OmniOperator lookupJoinOperator = lookupJoinOperatorFactory.createOperator(); + lookupJoinOperator.addInput(probeVecBatch); + Iterator results = lookupJoinOperator.getOutput(); + VecBatch resultVecBatch = results.next(); + + int len = resultVecBatch.getRowCount(); + assertEquals(len, 18); + Object[][] expectedDatas = { + {78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 82L, 82L, 82L, 82L, 82L, 65L}, + {79L, 70L, 70L, 79L, 70L, 70L, 70L, 70L, 70L, 79L, 70L, 70L, 79L, 70L, 70L, 79L, 70L, 70L}}; + assertVecBatchEqualsIgnoreOrder(resultVecBatch, expectedDatas); + freeVecBatch(resultVecBatch); + lookupJoinOperator.close(); + hashBuilderOperator.close(); + lookupJoinOperatorFactory.close(); + hashBuilderOperatorFactory.close(); + } + + @Test + public void testInnerEqualityJoinOneColumnWithMultiVectorBatch() { + DataType[] buildTypes = {IntDataType.INTEGER, IntDataType.INTEGER}; + Object[][] buildDatas = {{1, 2, 1, 2, 3, 4, 5, 6, 7, 1}, {79, 79, 70, 70, 70, 70, 70, 70, 70, 70}}; + VecBatch buildVecBatch = createVecBatch(buildTypes, buildDatas); + + String[] buildHashKeys = {omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonFieldReference(1, 0), + getOmniJsonLiteral(1, false, 50))}; + int operatorCount = 1; + OmniHashBuilderWithExprOperatorFactory hashBuilderOperatorFactory = new OmniHashBuilderWithExprOperatorFactory( + OMNI_JOIN_TYPE_INNER, buildTypes, buildHashKeys, operatorCount); + OmniOperator hashBuilderOperator = hashBuilderOperatorFactory.createOperator(); + hashBuilderOperator.addInput(buildVecBatch); + hashBuilderOperator.getOutput(); + + DataType[] probeTypes = {IntDataType.INTEGER, IntDataType.INTEGER}; + + int[] probeOutputCols = {1}; + String[] probeHashKeys = {omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonFieldReference(1, 0), + getOmniJsonLiteral(1, false, 50))}; + int[] buildOutputCols = {1}; + DataType[] buildOutputTypes = {IntDataType.INTEGER}; + OmniLookupJoinWithExprOperatorFactory lookupJoinOperatorFactory = new OmniLookupJoinWithExprOperatorFactory( + probeTypes, probeOutputCols, probeHashKeys, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, false); + OmniOperator lookupJoinOperator = lookupJoinOperatorFactory.createOperator(); + + Object[][] baseDatas = {{1, 3, 4, 5, 6, 1, 1, 2, 3}, {78, 78, 78, 78, 78, 78, 82, 82, 65}}; + int baseDataLen = 9; + + Object[][] probeDatas = new Object[2][]; + int baseRowCnt = 16; + int maxRowCntPerBatch = 131072; // 1M / (4+4) + // each batch of baseData will generate 16 row of records, and each vectorBatch + // output will have a maximum of 131072 rows.The final result here will output 3 + // vectorBatch + probeDatas[0] = new Object[3 * (maxRowCntPerBatch / baseRowCnt) * baseDataLen]; + probeDatas[1] = new Object[3 * (maxRowCntPerBatch / baseRowCnt) * baseDataLen]; + for (int i = 0; i < probeDatas[0].length; i++) { + probeDatas[0][i] = baseDatas[0][i % baseDataLen]; + probeDatas[1][i] = baseDatas[1][i % baseDataLen]; + } + + VecBatch probeVecBatch = createVecBatch(probeTypes, probeDatas); + lookupJoinOperator.addInput(probeVecBatch); + Iterator results = lookupJoinOperator.getOutput(); + int actualRowCnt = 0; + Object[][] baseExpectedDatas = {{78, 78, 78, 78, 78, 78, 78, 78, 78, 78, 82, 82, 82, 82, 82, 65}, + {79, 70, 70, 70, 70, 70, 70, 79, 70, 70, 79, 70, 70, 79, 70, 70}}; + while (results.hasNext()) { + VecBatch resultVecBatch = results.next(); + int rowCnt = resultVecBatch.getRowCount(); + actualRowCnt += rowCnt; + Object[][] expectedDatas = new Object[2][]; + expectedDatas[0] = new Object[rowCnt]; + expectedDatas[1] = new Object[rowCnt]; + for (int i = 0; i < rowCnt; i++) { + expectedDatas[0][i] = baseExpectedDatas[0][i % baseRowCnt]; + expectedDatas[1][i] = baseExpectedDatas[1][i % baseRowCnt]; + } + assertVecBatchEqualsIgnoreOrder(resultVecBatch, expectedDatas); + freeVecBatch(resultVecBatch); + } + assertEquals(actualRowCnt, 393216); + + // the next batch of probe data + VecBatch probeVecBatch1 = createVecBatch(probeTypes, probeDatas); + lookupJoinOperator.addInput(probeVecBatch1); + results = lookupJoinOperator.getOutput(); + actualRowCnt = 0; + while (results.hasNext()) { + VecBatch resultVecBatch = results.next(); + int rowCnt = resultVecBatch.getRowCount(); + actualRowCnt += rowCnt; + Object[][] expectedDatas = new Object[2][]; + expectedDatas[0] = new Object[rowCnt]; + expectedDatas[1] = new Object[rowCnt]; + for (int i = 0; i < rowCnt; i++) { + expectedDatas[0][i] = baseExpectedDatas[0][i % baseRowCnt]; + expectedDatas[1][i] = baseExpectedDatas[1][i % baseRowCnt]; + } + assertVecBatchEqualsIgnoreOrder(resultVecBatch, expectedDatas); + freeVecBatch(resultVecBatch); + } + assertEquals(actualRowCnt, 393216); + + lookupJoinOperator.close(); + hashBuilderOperator.close(); + lookupJoinOperatorFactory.close(); + hashBuilderOperatorFactory.close(); + } + + /** + * Test inner hash join one dictionary column . + */ + @Test + public void testInnerEqualityJoinOneDictionaryColumn() { + DataType[] buildTypes = {LongDataType.LONG, LongDataType.LONG}; + Object[][] buildDatas = {{1L, 2L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 1L}, + {79L, 79L, 70L, 70L, 70L, 70L, 70L, 70L, 70L, 70L}}; + Vec[] buildVecs = new Vec[2]; + int[] ids = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + buildVecs[0] = TestUtils.createDictionaryVec(buildTypes[0], buildDatas[0], ids); + buildVecs[1] = TestUtils.createDictionaryVec(buildTypes[1], buildDatas[1], ids); + VecBatch buildVecBatch = new VecBatch(buildVecs); + + String[] buildHashKeys = {omniJsonFourArithmeticExpr("ADD", 2, getOmniJsonFieldReference(2, 0), + getOmniJsonLiteral(2, false, 50))}; + int operatorCount = 1; + OmniHashBuilderWithExprOperatorFactory hashBuilderOperatorFactory = new OmniHashBuilderWithExprOperatorFactory( + OMNI_JOIN_TYPE_INNER, buildTypes, buildHashKeys, operatorCount); + OmniOperator hashBuilderOperator = hashBuilderOperatorFactory.createOperator(); + hashBuilderOperator.addInput(buildVecBatch); + hashBuilderOperator.getOutput(); + + DataType[] probeTypes = {LongDataType.LONG, LongDataType.LONG}; + Object[][] probeDatas = {{1L, 2L, 3L, 4L, 5L, 6L, 1L, 1L, 2L, 3L}, + {78L, 78L, 78L, 78L, 78L, 78L, 78L, 82L, 82L, 65L}}; + Vec[] probeVecs = new Vec[2]; + probeVecs[0] = TestUtils.createDictionaryVec(probeTypes[0], probeDatas[0], ids); + probeVecs[1] = TestUtils.createDictionaryVec(probeTypes[1], probeDatas[1], ids); + VecBatch probeVecBatch = new VecBatch(probeVecs); + + int[] probeOutputCols = {1}; + String[] probeHashKeys = {omniJsonFourArithmeticExpr("ADD", 2, getOmniJsonFieldReference(2, 0), + getOmniJsonLiteral(2, false, 50))}; + int[] buildOutputCols = {1}; + DataType[] buildOutputTypes = {LongDataType.LONG}; + OmniLookupJoinWithExprOperatorFactory lookupJoinOperatorFactory = new OmniLookupJoinWithExprOperatorFactory( + probeTypes, probeOutputCols, probeHashKeys, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, false); + OmniOperator lookupJoinOperator = lookupJoinOperatorFactory.createOperator(); + lookupJoinOperator.addInput(probeVecBatch); + Iterator results = lookupJoinOperator.getOutput(); + VecBatch resultVecBatch = results.next(); + + int len = resultVecBatch.getRowCount(); + assertEquals(len, 18); + Object[][] expectedDatas = { + {78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 82L, 82L, 82L, 82L, 82L, 65L}, + {79L, 70L, 70L, 79L, 70L, 70L, 70L, 70L, 70L, 79L, 70L, 70L, 79L, 70L, 70L, 79L, 70L, 70L}}; + assertVecBatchEqualsIgnoreOrder(resultVecBatch, expectedDatas); + freeVecBatch(resultVecBatch); + lookupJoinOperator.close(); + hashBuilderOperator.close(); + lookupJoinOperatorFactory.close(); + hashBuilderOperatorFactory.close(); + } + + /** + * Test inner hash join with join filter expression . + */ + @Test + public void testInnerEqualityJoinWithCharFilter() { + DataType[] buildTypes = {IntDataType.INTEGER, new VarcharDataType(5)}; + Object[][] buildDatas = {{19, 14, 7, 19, 1, 20, 10, 13, 20, 16}, + {"35709", "31904", "35709", "31904", "35709", "31904", "35709", "31904", "35709", "31904"}}; + VecBatch buildVecBatch = createVecBatch(buildTypes, buildDatas); + + String[] buildHashCols = {getOmniJsonFieldReference(1, 0)}; + int operatorCount = 1; + String filterExpression = omniJsonNotEqualExpr( + omniFunctionExpr("substr", 15, + getOmniJsonFieldReference(15, 1) + "," + getOmniJsonLiteral(1, false, 1) + "," + + getOmniJsonLiteral(1, false, 5)), + omniFunctionExpr("substr", 15, getOmniJsonFieldReference(15, 3) + "," + getOmniJsonLiteral(1, false, 1) + + "," + getOmniJsonLiteral(1, false, 5))); + OmniHashBuilderWithExprOperatorFactory hashBuilderOperatorFactory = new OmniHashBuilderWithExprOperatorFactory( + OMNI_JOIN_TYPE_INNER, buildTypes, buildHashCols, operatorCount); + OmniOperator hashBuilderOperator = hashBuilderOperatorFactory.createOperator(); + hashBuilderOperator.addInput(buildVecBatch); + hashBuilderOperator.getOutput(); + + DataType[] probeTypes = {IntDataType.INTEGER, new VarcharDataType(5)}; + Object[][] probeDatas = {{20, 16, 13, 4, 20, 4, 22, 19, 8, 7}, + {"35709", "35709", "31904", "12477", null, "38721", "90419", "35709", "88371", null}}; + VecBatch probeVecBatch = createVecBatch(probeTypes, probeDatas); + + int[] probeOutputCols = {0, 1}; + String[] probeHashCols = {getOmniJsonFieldReference(1, 0)}; + int[] buildOutputCols = {0, 1}; + DataType[] buildOutputTypes = {IntDataType.INTEGER, new VarcharDataType(5)}; + OmniLookupJoinWithExprOperatorFactory lookupJoinOperatorFactory = new OmniLookupJoinWithExprOperatorFactory( + probeTypes, probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, Optional.of(filterExpression), false); + OmniOperator lookupJoinOperator = lookupJoinOperatorFactory.createOperator(); + lookupJoinOperator.addInput(probeVecBatch); + Iterator results = lookupJoinOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + int len = resultVecBatch.getRowCount(); + assertEquals(len, 3); + assertEquals(resultVecBatch.getVectorCount(), 4); + Object[][] expectedDatas = {{20, 16, 19}, {"35709", "35709", "35709"}, {20, 16, 19}, + {"31904", "31904", "31904"}}; + assertVecBatchEqualsIgnoreOrder(resultVecBatch, expectedDatas); + freeVecBatch(resultVecBatch); + lookupJoinOperator.close(); + hashBuilderOperator.close(); + lookupJoinOperatorFactory.close(); + hashBuilderOperatorFactory.close(); + } + + /** + * Test opposite side hash join one column . + * It means (left outer join && BuildRight) or (right outer join && BuildLeft) . + * This test example is (left outer join && BuildRight). + */ + @Test + public void testOppositeSideOuterEqualityJoinOneColumn() { + DataType[] buildTypes = {LongDataType.LONG, LongDataType.LONG}; + Object[][] buildDatas = {{1L, 2L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 1L}, + {79L, 79L, 70L, 70L, 70L, 70L, 70L, 70L, 70L, 70L}}; + VecBatch buildVecBatch = createVecBatch(buildTypes, buildDatas); + + String[] buildHashKeys = {omniJsonFourArithmeticExpr("ADD", 2, getOmniJsonFieldReference(2, 0), + getOmniJsonLiteral(2, false, 50))}; + int operatorCount = 1; + OmniHashBuilderWithExprOperatorFactory hashBuilderOperatorFactory = new OmniHashBuilderWithExprOperatorFactory( + OMNI_JOIN_TYPE_LEFT, buildTypes, buildHashKeys, operatorCount); + OmniOperator hashBuilderOperator = hashBuilderOperatorFactory.createOperator(); + hashBuilderOperator.addInput(buildVecBatch); + hashBuilderOperator.getOutput(); + + DataType[] probeTypes = {LongDataType.LONG, LongDataType.LONG}; + Object[][] probeDatas = {{1L, 2L, 3L, 4L, 5L, 6L, 1L, 1L, 2L, 3L, 9L}, + {78L, 78L, 78L, 78L, 78L, 78L, 78L, 82L, 82L, 65L, 99L}}; + VecBatch probeVecBatch = createVecBatch(probeTypes, probeDatas); + + int[] probeOutputCols = {0, 1}; + String[] probeHashKeys = {omniJsonFourArithmeticExpr("ADD", 2, getOmniJsonFieldReference(2, 0), + getOmniJsonLiteral(2, false, 50))}; + int[] buildOutputCols = {0, 1}; + DataType[] buildOutputTypes = {LongDataType.LONG, LongDataType.LONG}; + OmniLookupJoinWithExprOperatorFactory lookupJoinOperatorFactory = new OmniLookupJoinWithExprOperatorFactory( + probeTypes, probeOutputCols, probeHashKeys, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, false); + OmniOperator lookupJoinOperator = lookupJoinOperatorFactory.createOperator(); + lookupJoinOperator.addInput(probeVecBatch); + Iterator results = lookupJoinOperator.getOutput(); + VecBatch resultVecBatch = results.next(); + assertEquals(resultVecBatch.getRowCount(), 19); + Object[][] expectedDatas = {{1L, 1L, 1L, 2L, 2L, 3L, 4L, 5L, 6L, 1L, 1L, 1L, 1L, 1L, 1L, 2L, 2L, 3L, 9L}, + {78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 82L, 82L, 82L, 82L, 82L, 65L, 99L}, + {1L, 1L, 1L, 2L, 2L, 3L, 4L, 5L, 6L, 1L, 1L, 1L, 1L, 1L, 1L, 2L, 2L, 3L, null}, + {79L, 70L, 70L, 79L, 70L, 70L, 70L, 70L, 70L, 79L, 70L, 70L, 79L, 70L, 70L, 79L, 70L, 70L, null}}; + assertVecBatchEqualsIgnoreOrder(resultVecBatch, expectedDatas); + freeVecBatch(resultVecBatch); + lookupJoinOperator.close(); + hashBuilderOperator.close(); + lookupJoinOperatorFactory.close(); + hashBuilderOperatorFactory.close(); + } + + /** + * Test same side hash join one column . + * It means (left outer join && BuildLeft) or (right outer join && BuildRight) . + * This test example is (left outer join && BuildLeft). + */ + @Test + public void testSameSideOuterEqualityJoinOneColumn() { + DataType[] buildTypes = {LongDataType.LONG, LongDataType.LONG}; + Object[][] buildDatas = {{1L, 2L, 3L, 4L, 5L, 6L, 1L, 1L, 2L, 3L, 9L}, + {78L, 78L, 78L, 78L, 78L, 78L, 78L, 82L, 82L, 65L, 99L}}; + VecBatch buildVecBatch = createVecBatch(buildTypes, buildDatas); + + String[] buildHashKeys = {omniJsonFourArithmeticExpr("ADD", 2, getOmniJsonFieldReference(2, 0), + getOmniJsonLiteral(2, false, 50))}; + int operatorCount = 1; + OmniHashBuilderWithExprOperatorFactory hashBuilderOperatorFactory = new OmniHashBuilderWithExprOperatorFactory( + OMNI_JOIN_TYPE_LEFT, BUILD_LEFT, buildTypes, buildHashKeys, operatorCount, new OperatorConfig()); + OmniOperator hashBuilderOperator = hashBuilderOperatorFactory.createOperator(); + hashBuilderOperator.addInput(buildVecBatch); + hashBuilderOperator.getOutput(); + + DataType[] probeTypes = {LongDataType.LONG, LongDataType.LONG}; + Object[][] probeDatas = {{1L, 2L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 1L}, + {79L, 79L, 70L, 70L, 70L, 70L, 70L, 70L, 70L, 70L}}; + VecBatch probeVecBatch = createVecBatch(probeTypes, probeDatas); + + int[] probeOutputCols = {0, 1}; + String[] probeHashKeys = {omniJsonFourArithmeticExpr("ADD", 2, getOmniJsonFieldReference(2, 0), + getOmniJsonLiteral(2, false, 50))}; + int[] buildOutputCols = {0, 1}; + DataType[] buildOutputTypes = {LongDataType.LONG, LongDataType.LONG}; + OmniLookupJoinWithExprOperatorFactory lookupJoinOperatorFactory = new OmniLookupJoinWithExprOperatorFactory( + probeTypes, probeOutputCols, probeHashKeys, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, false); + OmniOperator lookupJoinOperator = lookupJoinOperatorFactory.createOperator(); + + OmniLookupOuterJoinWithExprOperatorFactory lookupOuterJoinOperatorFactory = + new OmniLookupOuterJoinWithExprOperatorFactory( + probeTypes, probeOutputCols, probeHashKeys, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, new OperatorConfig()); + OmniOperator lookupOuterJoinOperator = lookupOuterJoinOperatorFactory.createOperator(); + lookupJoinOperator.addInput(probeVecBatch); + Iterator results = lookupJoinOperator.getOutput(); + VecBatch resultVecBatch = results.next(); + assertEquals(resultVecBatch.getRowCount(), 18); + Object[][] expectedDatas = {{1L, 1L, 1L, 2L, 2L, 1L, 1L, 1L, 2L, 2L, 3L, 3L, 4L, 5L, 6L, 1L, 1L, 1L}, + {78L, 78L, 82L, 78L, 82L, 78L, 78L, 82L, 78L, 82L, 78L, 65L, 78L, 78L, 78L, 78L, 78L, 82L}, + {1L, 1L, 1L, 2L, 2L, 1L, 1L, 1L, 2L, 2L, 3L, 3L, 4L, 5L, 6L, 1L, 1L, 1L}, + {79L, 79L, 79L, 79L, 79L, 70L, 70L, 70L, 70L, 70L, 70L, 70L, 70L, 70L, 70L, 70L, 70L, 70L}}; + assertVecBatchEqualsIgnoreOrder(resultVecBatch, expectedDatas); + Iterator appendResults = lookupOuterJoinOperator.getOutput(); + VecBatch appendBatch = appendResults.next(); + assertEquals(appendBatch.getRowCount(), 1); + Object[][] expectedData = {{null}, {null}, {9L}, {99L}}; + assertVecBatchEquals(appendBatch, expectedData); + freeVecBatch(resultVecBatch); + freeVecBatch(appendBatch); + lookupJoinOperator.close(); + hashBuilderOperator.close(); + lookupOuterJoinOperator.close(); + lookupJoinOperatorFactory.close(); + hashBuilderOperatorFactory.close(); + lookupOuterJoinOperatorFactory.close(); + } + + /** + * Test full hash join one column . + */ + @Test + public void testFullOuterEqualityJoinOneColumn() { + DataType[] buildTypes = {LongDataType.LONG, LongDataType.LONG}; + Object[][] buildDatas = {{1L, 2L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 1L}, + {79L, 79L, 70L, 70L, 70L, 70L, 70L, 70L, 70L, 70L}}; + VecBatch buildVecBatch = createVecBatch(buildTypes, buildDatas); + + String[] buildHashKeys = {omniJsonFourArithmeticExpr("ADD", 2, getOmniJsonFieldReference(2, 0), + getOmniJsonLiteral(2, false, 50))}; + int operatorCount = 1; + OmniHashBuilderWithExprOperatorFactory hashBuilderOperatorFactory = new OmniHashBuilderWithExprOperatorFactory( + OMNI_JOIN_TYPE_FULL, buildTypes, buildHashKeys, operatorCount); + OmniOperator hashBuilderOperator = hashBuilderOperatorFactory.createOperator(); + hashBuilderOperator.addInput(buildVecBatch); + hashBuilderOperator.getOutput(); + + DataType[] probeTypes = {LongDataType.LONG, LongDataType.LONG}; + Object[][] probeDatas = {{1L, 2L, 3L, 4L, 5L, 6L, 1L, 1L, 2L, 3L}, + {78L, 78L, 78L, 78L, 78L, 78L, 78L, 82L, 82L, 65L}}; + VecBatch probeVecBatch = createVecBatch(probeTypes, probeDatas); + + int[] probeOutputCols = {0, 1}; + String[] probeHashKeys = {omniJsonFourArithmeticExpr("ADD", 2, getOmniJsonFieldReference(2, 0), + getOmniJsonLiteral(2, false, 50))}; + int[] buildOutputCols = {0, 1}; + DataType[] buildOutputTypes = {LongDataType.LONG, LongDataType.LONG}; + OmniLookupJoinWithExprOperatorFactory lookupJoinOperatorFactory = new OmniLookupJoinWithExprOperatorFactory( + probeTypes, probeOutputCols, probeHashKeys, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, false); + OmniOperator lookupJoinOperator = lookupJoinOperatorFactory.createOperator(); + + OmniLookupOuterJoinWithExprOperatorFactory lookupOuterJoinOperatorFactory = + new OmniLookupOuterJoinWithExprOperatorFactory( + probeTypes, probeOutputCols, probeHashKeys, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, new OperatorConfig()); + OmniOperator lookupOuterJoinOperator = lookupOuterJoinOperatorFactory.createOperator(); + lookupJoinOperator.addInput(probeVecBatch); + Iterator results = lookupJoinOperator.getOutput(); + VecBatch resultVecBatch = results.next(); + assertEquals(resultVecBatch.getRowCount(), 18); + Object[][] expectedDatas = {{1L, 1L, 1L, 2L, 2L, 3L, 4L, 5L, 6L, 1L, 1L, 1L, 1L, 1L, 1L, 2L, 2L, 3L}, + {78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 82L, 82L, 82L, 82L, 82L, 65L}, + {1L, 1L, 1L, 2L, 2L, 3L, 4L, 5L, 6L, 1L, 1L, 1L, 1L, 1L, 1L, 2L, 2L, 3L}, + {79L, 70L, 70L, 79L, 70L, 70L, 70L, 70L, 70L, 79L, 70L, 70L, 79L, 70L, 70L, 79L, 70L, 70L}}; + assertVecBatchEqualsIgnoreOrder(resultVecBatch, expectedDatas); + Iterator appendResults = lookupOuterJoinOperator.getOutput(); + VecBatch appendBatch = appendResults.next(); + assertEquals(appendBatch.getRowCount(), 1); + Object[][] expectedData = {{null}, {null}, {7L}, {70L}}; + assertVecBatchEqualsIgnoreOrder(appendBatch, expectedData); + freeVecBatch(resultVecBatch); + freeVecBatch(appendBatch); + lookupJoinOperator.close(); + hashBuilderOperator.close(); + lookupOuterJoinOperator.close(); + lookupJoinOperatorFactory.close(); + hashBuilderOperatorFactory.close(); + lookupOuterJoinOperatorFactory.close(); + } + + /** + * Test full hash join one dictionary column . + */ + @Test + public void testFullOuterEqualityJoinOneDictionaryColumn() { + DataType[] buildTypes = {LongDataType.LONG, LongDataType.LONG}; + Object[][] buildDatas = {{1L, 2L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 1L}, + {79L, 79L, 70L, 70L, 70L, 70L, 70L, 70L, 70L, 70L}}; + Vec[] buildVecs = new Vec[2]; + int[] ids = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + buildVecs[0] = TestUtils.createDictionaryVec(buildTypes[0], buildDatas[0], ids); + buildVecs[1] = TestUtils.createDictionaryVec(buildTypes[1], buildDatas[1], ids); + VecBatch buildVecBatch = new VecBatch(buildVecs); + + String[] buildHashKeys = {omniJsonFourArithmeticExpr("ADD", 2, getOmniJsonFieldReference(2, 0), + getOmniJsonLiteral(2, false, 50))}; + int operatorCount = 1; + OmniHashBuilderWithExprOperatorFactory hashBuilderOperatorFactory = new OmniHashBuilderWithExprOperatorFactory( + OMNI_JOIN_TYPE_FULL, buildTypes, buildHashKeys, operatorCount); + OmniOperator hashBuilderOperator = hashBuilderOperatorFactory.createOperator(); + hashBuilderOperator.addInput(buildVecBatch); + hashBuilderOperator.getOutput(); + + DataType[] probeTypes = {LongDataType.LONG, LongDataType.LONG}; + Object[][] probeDatas = {{1L, 2L, 3L, 4L, 5L, 6L, 1L, 1L, 2L, 3L}, + {78L, 78L, 78L, 78L, 78L, 78L, 78L, 82L, 82L, 65L}}; + Vec[] probeVecs = new Vec[2]; + probeVecs[0] = TestUtils.createDictionaryVec(probeTypes[0], probeDatas[0], ids); + probeVecs[1] = TestUtils.createDictionaryVec(probeTypes[1], probeDatas[1], ids); + VecBatch probeVecBatch = new VecBatch(probeVecs); + + int[] probeOutputCols = {1}; + String[] probeHashKeys = {omniJsonFourArithmeticExpr("ADD", 2, getOmniJsonFieldReference(2, 0), + getOmniJsonLiteral(2, false, 50))}; + int[] buildOutputCols = {1}; + DataType[] buildOutputTypes = {LongDataType.LONG}; + OmniLookupJoinWithExprOperatorFactory lookupJoinOperatorFactory = new OmniLookupJoinWithExprOperatorFactory( + probeTypes, probeOutputCols, probeHashKeys, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, false); + OmniOperator lookupJoinOperator = lookupJoinOperatorFactory.createOperator(); + OmniLookupOuterJoinWithExprOperatorFactory lookupOuterJoinOperatorFactory = + new OmniLookupOuterJoinWithExprOperatorFactory( + probeTypes, probeOutputCols, probeHashKeys, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, new OperatorConfig()); + OmniOperator lookupOuterJoinOperator = lookupOuterJoinOperatorFactory.createOperator(); + lookupJoinOperator.addInput(probeVecBatch); + Iterator results = lookupJoinOperator.getOutput(); + VecBatch resultVecBatch = results.next(); + int len = resultVecBatch.getRowCount(); + assertEquals(len, 18); + Object[][] expectedDatas = { + {78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 78L, 82L, 82L, 82L, 82L, 82L, 65L}, + {79L, 70L, 70L, 79L, 70L, 70L, 70L, 70L, 70L, 79L, 70L, 70L, 79L, 70L, 70L, 79L, 70L, 70L}}; + assertVecBatchEqualsIgnoreOrder(resultVecBatch, expectedDatas); + + Iterator appendResults = lookupOuterJoinOperator.getOutput(); + VecBatch appendBatch = appendResults.next(); + len = appendBatch.getRowCount(); + assertEquals(len, 1); + Object[][] expectedData = {{null}, {70L}}; + assertVecBatchEqualsIgnoreOrder(appendBatch, expectedData); + freeVecBatch(resultVecBatch); + freeVecBatch(appendBatch); + lookupJoinOperator.close(); + hashBuilderOperator.close(); + lookupOuterJoinOperator.close(); + lookupJoinOperatorFactory.close(); + hashBuilderOperatorFactory.close(); + lookupOuterJoinOperatorFactory.close(); + } + + /** + * Test full hash join with join filter expression . + */ + @Test + public void testFullEqualityJoinWithCharFilter() { + DataType[] buildTypes = {IntDataType.INTEGER, new VarcharDataType(5)}; + Object[][] buildDatas = {{19, 14, 7, 19, 1, 20, 10, 13, 20, 16}, + {"35709", "31904", "35709", "31904", "35709", "31904", "35709", "31904", "35709", "31904"}}; + VecBatch buildVecBatch = createVecBatch(buildTypes, buildDatas); + + String[] buildHashCols = {getOmniJsonFieldReference(1, 0)}; + int operatorCount = 1; + String filterExpression = omniJsonNotEqualExpr( + omniFunctionExpr("substr", 15, + getOmniJsonFieldReference(15, 1) + "," + getOmniJsonLiteral(1, false, 1) + "," + + getOmniJsonLiteral(1, false, 5)), + omniFunctionExpr("substr", 15, getOmniJsonFieldReference(15, 3) + "," + getOmniJsonLiteral(1, false, 1) + + "," + getOmniJsonLiteral(1, false, 5))); + OmniHashBuilderWithExprOperatorFactory hashBuilderOperatorFactory = new OmniHashBuilderWithExprOperatorFactory( + OMNI_JOIN_TYPE_FULL, buildTypes, buildHashCols, operatorCount); + OmniOperator hashBuilderOperator = hashBuilderOperatorFactory.createOperator(); + hashBuilderOperator.addInput(buildVecBatch); + hashBuilderOperator.getOutput(); + + DataType[] probeTypes = {IntDataType.INTEGER, new VarcharDataType(5)}; + Object[][] probeDatas = {{20, 16, 13, 4, 20, 4, 22, 19, 8, 7}, + {"35709", "35709", "31904", "12477", null, "38721", "90419", "35709", "88371", null}}; + VecBatch probeVecBatch = createVecBatch(probeTypes, probeDatas); + + int[] probeOutputCols = {0, 1}; + String[] probeHashCols = {getOmniJsonFieldReference(1, 0)}; + int[] buildOutputCols = {0, 1}; + DataType[] buildOutputTypes = {IntDataType.INTEGER, new VarcharDataType(5)}; + OmniLookupJoinWithExprOperatorFactory lookupJoinOperatorFactory = new OmniLookupJoinWithExprOperatorFactory( + probeTypes, probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, Optional.of(filterExpression), false); + OmniOperator lookupJoinOperator = lookupJoinOperatorFactory.createOperator(); + OmniLookupOuterJoinWithExprOperatorFactory lookupOuterJoinOperatorFactory = + new OmniLookupOuterJoinWithExprOperatorFactory( + probeTypes, probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, new OperatorConfig()); + OmniOperator lookupOuterJoinOperator = lookupOuterJoinOperatorFactory.createOperator(); + lookupJoinOperator.addInput(probeVecBatch); + Iterator results = lookupJoinOperator.getOutput(); + VecBatch resultVecBatch = results.next(); + assertEquals(resultVecBatch.getRowCount(), 10); + Object[][] expectedDatas = {{20, 16, 13, 4, 20, 4, 22, 19, 8, 7}, + {"35709", "35709", "31904", "12477", null, "38721", "90419", "35709", "88371", null}, + {20, 16, null, null, null, null, null, 19, null, null}, + {"31904", "31904", null, null, null, null, null, "31904", null, null}}; + assertVecBatchEqualsIgnoreOrder(resultVecBatch, expectedDatas); + Iterator appendResults = lookupOuterJoinOperator.getOutput(); + VecBatch appendBatch = appendResults.next(); + assertEquals(appendBatch.getRowCount(), 7); + Object[][] expectedData = {{null, null, null, null, null, null, null}, + {null, null, null, null, null, null, null}, {1, 7, 10, 13, 14, 19, 20}, + {"35709", "35709", "35709", "31904", "31904", "35709", "35709"}}; + assertVecBatchEqualsIgnoreOrder(appendBatch, expectedData); + freeVecBatch(resultVecBatch); + freeVecBatch(appendBatch); + lookupJoinOperator.close(); + hashBuilderOperator.close(); + lookupJoinOperatorFactory.close(); + hashBuilderOperatorFactory.close(); + } + + /** + * Test full hash join multiple build batch. + */ + @Test + public void testFullOuterJoinMultipleBatch() { + DataType[] buildTypes = {LongDataType.LONG, new VarcharDataType(5)}; + String[] buildHashKeys = {getOmniJsonFieldReference(2, 0)}; + OmniHashBuilderWithExprOperatorFactory hashBuilderOperatorFactory = new OmniHashBuilderWithExprOperatorFactory( + OMNI_JOIN_TYPE_FULL, buildTypes, buildHashKeys, 1); + OmniOperator hashBuilderOperator = hashBuilderOperatorFactory.createOperator(); + hashBuilderOperator.addInput(createVecBatch(buildTypes, new Object[][]{{1L}, {"35709"}})); + hashBuilderOperator.addInput(createVecBatch(buildTypes, new Object[][]{{13L}, {"31904"}})); + hashBuilderOperator.addInput(createVecBatch(buildTypes, new Object[][]{{16L}, {"31904"}})); + hashBuilderOperator.addInput(createVecBatch(buildTypes, new Object[][]{{20L, 20L}, {"31904", "35709"}})); + hashBuilderOperator.addInput(createVecBatch(buildTypes, new Object[][]{{19L, 19L}, {"35709", "31904"}})); + hashBuilderOperator.addInput(createVecBatch(buildTypes, new Object[][]{{7L}, {"35709"}})); + hashBuilderOperator.addInput(createVecBatch(buildTypes, new Object[][]{{10L}, {"35709"}})); + hashBuilderOperator.addInput(createVecBatch(buildTypes, new Object[][]{{14L}, {"31904"}})); + hashBuilderOperator.getOutput(); + + DataType[] probeTypes = {LongDataType.LONG, new VarcharDataType(5)}; + int[] probeOutputCols = {0, 1}; + String[] probeHashKeys = {getOmniJsonFieldReference(2, 0)}; + int[] buildOutputCols = {0, 1}; + DataType[] buildOutputTypes = {LongDataType.LONG, new VarcharDataType(5)}; + OmniLookupJoinWithExprOperatorFactory lookupJoinOperatorFactory = new OmniLookupJoinWithExprOperatorFactory( + probeTypes, probeOutputCols, probeHashKeys, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, false); + OmniOperator lookupJoinOperator = lookupJoinOperatorFactory.createOperator(); + OmniLookupOuterJoinWithExprOperatorFactory lookupOuterJoinOperatorFactory = + new OmniLookupOuterJoinWithExprOperatorFactory( + probeTypes, probeOutputCols, probeHashKeys, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, new OperatorConfig()); + OmniOperator lookupOuterJoinOperator = lookupOuterJoinOperatorFactory.createOperator(); + Object[][] probeDatas = {{22L, 13L, 16L, 20L, 20L, 19L, 4L, 4L, 8L, 7L}, + {"90419", "31904", "35709", "35709", null, "35709", "12477", "38721", "88371", null}}; + VecBatch probeVecBatch = createVecBatch(probeTypes, probeDatas); + lookupJoinOperator.addInput(probeVecBatch); + VecBatch resultVecBatch = lookupJoinOperator.getOutput().next(); + VecBatch appendBatch = lookupOuterJoinOperator.getOutput().next(); + assertVecBatchEqualsIgnoreOrder(resultVecBatch, + new Object[][]{{22L, 13L, 16L, 20L, 20L, 20L, 20L, 19L, 19L, 4L, 4L, 8L, 7L}, + {"90419", "31904", "35709", "35709", "35709", null, null, "35709", "35709", "12477", "38721", + "88371", null}, + {null, 13L, 16L, 20L, 20L, 20L, 20L, 19L, 19L, null, null, null, 7L}, {null, "31904", "31904", + "31904", "35709", "31904", "35709", "35709", "31904", null, null, null, "35709"}}); + assertVecBatchEqualsIgnoreOrder(appendBatch, + new Object[][]{{null, null, null}, {null, null, null}, {1L, 10L, 14L}, {"35709", "35709", "31904"}}); + + lookupJoinOperator.close(); + hashBuilderOperator.close(); + lookupOuterJoinOperator.close(); + lookupJoinOperatorFactory.close(); + hashBuilderOperatorFactory.close(); + lookupOuterJoinOperatorFactory.close(); + freeVecBatch(resultVecBatch); + freeVecBatch(appendBatch); + } + + @Test(expectedExceptions = OmniRuntimeException.class, expectedExceptionsMessageRegExp = + ".*EXPRESSION_NOT_SUPPORT.*") + public void testHashBuilderWithInvalidKeys() { + DataType[] buildTypes = {LongDataType.LONG, LongDataType.LONG}; + int operatorCount = 1; + + // invalid build hash key + String[] invalidBuildHashKeys = {omniFunctionExpr("abc", 2, getOmniJsonFieldReference(2, 1))}; + OmniHashBuilderWithExprOperatorFactory operatorFactory = new OmniHashBuilderWithExprOperatorFactory( + OMNI_JOIN_TYPE_FULL, buildTypes, invalidBuildHashKeys, operatorCount); + } + + @Test(expectedExceptions = OmniRuntimeException.class, expectedExceptionsMessageRegExp = + ".*EXPRESSION_NOT_SUPPORT.*") + public void testLookupJoinWithInvalidKeys() { + DataType[] buildTypes = {LongDataType.LONG, LongDataType.LONG}; + int operatorCount = 1; + String[] buildHashCols = {getOmniJsonFieldReference(2, 0)}; + OmniHashBuilderWithExprOperatorFactory hashBuilderOperatorFactory = new OmniHashBuilderWithExprOperatorFactory( + OMNI_JOIN_TYPE_INNER, buildTypes, buildHashCols, operatorCount); + + DataType[] probeTypes = {LongDataType.LONG, LongDataType.LONG}; + int[] probeOutputCols = {0, 1}; + String[] invalidProbeHashKeys = {omniFunctionExpr("abc", 2, getOmniJsonFieldReference(2, 1))}; + int[] buildOutputCols = {0, 1}; + DataType[] buildOutputTypes = {LongDataType.LONG, LongDataType.LONG}; + OmniLookupJoinWithExprOperatorFactory lookupJoinOperatorFactory = new OmniLookupJoinWithExprOperatorFactory( + probeTypes, probeOutputCols, invalidProbeHashKeys, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, false); + + hashBuilderOperatorFactory.close(); + } + + @Test(expectedExceptions = OmniRuntimeException.class, expectedExceptionsMessageRegExp = + ".*EXPRESSION_NOT_SUPPORT.*") + public void testInnerEqualityJoinWithInvalidExprs() { + DataType[] buildTypes = {LongDataType.LONG, LongDataType.LONG}; + int operatorCount = 1; + String[] buildHashCols = {getOmniJsonFieldReference(2, 0)}; + String filterExpression = omniJsonNotEqualExpr( + omniFunctionExpr("substring", 15, + getOmniJsonFieldReference(2, 1) + "," + getOmniJsonLiteral(1, false, 1) + "," + + getOmniJsonLiteral(1, false, 5)), + omniFunctionExpr("substring", 15, getOmniJsonFieldReference(2, 0) + "," + + getOmniJsonLiteral(1, false, 1) + "," + getOmniJsonLiteral(1, false, 5))); + OmniHashBuilderWithExprOperatorFactory hashBuilderOperatorFactory = new OmniHashBuilderWithExprOperatorFactory( + OMNI_JOIN_TYPE_INNER, buildTypes, buildHashCols, operatorCount); + + DataType[] probeTypes = {LongDataType.LONG, LongDataType.LONG}; + int[] probeOutputCols = {0, 1}; + String[] probeHashCols = {getOmniJsonFieldReference(2, 1)}; + int[] buildOutputCols = {0, 1}; + DataType[] buildOutputTypes = {LongDataType.LONG, LongDataType.LONG}; + OmniLookupJoinWithExprOperatorFactory lookupJoinOperatorFactory = new OmniLookupJoinWithExprOperatorFactory( + probeTypes, probeOutputCols, probeHashCols, buildOutputCols, buildOutputTypes, + hashBuilderOperatorFactory, Optional.of(filterExpression), false); + + hashBuilderOperatorFactory.close(); + } + + @Test + public void testFactoryContextEquals() { + DataType[] buildTypes = {LongDataType.LONG, LongDataType.LONG}; + String[] buildHashKeys = {omniJsonFourArithmeticExpr("ADD", 2, getOmniJsonFieldReference(2, 0), + getOmniJsonLiteral(2, false, 50))}; + int operatorCount = 1; + OmniHashBuilderWithExprOperatorFactory.FactoryContext hashBuilderOperatorFactory1 = + new OmniHashBuilderWithExprOperatorFactory.FactoryContext( + OMNI_JOIN_TYPE_INNER, buildTypes, buildHashKeys, operatorCount, new OperatorConfig()); + OmniHashBuilderWithExprOperatorFactory.FactoryContext hashBuilderOperatorFactory2 = + new OmniHashBuilderWithExprOperatorFactory.FactoryContext( + OMNI_JOIN_TYPE_INNER, buildTypes, buildHashKeys, operatorCount, new OperatorConfig()); + OmniHashBuilderWithExprOperatorFactory.FactoryContext hashBuilderOperatorFactory3 = null; + assertEquals(hashBuilderOperatorFactory2, hashBuilderOperatorFactory1); + assertEquals(hashBuilderOperatorFactory1, hashBuilderOperatorFactory1); + assertNotEquals(hashBuilderOperatorFactory3, hashBuilderOperatorFactory1); + + DataType[] probeTypes = {LongDataType.LONG, LongDataType.LONG}; + int[] probeOutputCols = {1}; + String[] probeHashKeys = {omniJsonFourArithmeticExpr("ADD", 2, getOmniJsonFieldReference(2, 0), + getOmniJsonLiteral(2, false, 50))}; + int[] buildOutputCols = {1}; + DataType[] buildOutputTypes = {LongDataType.LONG}; + + OmniHashBuilderWithExprOperatorFactory omniHashBuilderWithExprOperatorFactory = + new OmniHashBuilderWithExprOperatorFactory( + OMNI_JOIN_TYPE_INNER, buildTypes, buildHashKeys, operatorCount, new OperatorConfig()); + OmniLookupJoinWithExprOperatorFactory.FactoryContext lookupJoinOperatorFactory1 = + new OmniLookupJoinWithExprOperatorFactory.FactoryContext( + probeTypes, probeOutputCols, probeHashKeys, buildOutputCols, buildOutputTypes, + omniHashBuilderWithExprOperatorFactory, Optional.empty(), false, new OperatorConfig()); + OmniLookupJoinWithExprOperatorFactory.FactoryContext lookupJoinOperatorFactory2 = + new OmniLookupJoinWithExprOperatorFactory.FactoryContext( + probeTypes, probeOutputCols, probeHashKeys, buildOutputCols, buildOutputTypes, + omniHashBuilderWithExprOperatorFactory, Optional.empty(), false, new OperatorConfig()); + OmniLookupJoinWithExprOperatorFactory.FactoryContext lookupJoinOperatorFactory3 = null; + + assertEquals(lookupJoinOperatorFactory2, lookupJoinOperatorFactory1); + assertEquals(lookupJoinOperatorFactory1, lookupJoinOperatorFactory1); + assertNotEquals(lookupJoinOperatorFactory3, lookupJoinOperatorFactory1); + } + + private void buildLookupJoinExpectedData(Object[][] lookupJoinData1, Object[][] lookupJoinData2, + Object[][] lookupJoinData3, int tableSize, int maxRowCount) { + for (int i = 0; i < maxRowCount; i++) { + lookupJoinData1[0][i] = tableSize / 2 + i + 1L; + lookupJoinData1[1][i] = i + 200L; + lookupJoinData1[2][i] = i + 201L; + lookupJoinData1[3][i] = i + 202L; + + lookupJoinData1[4][i] = tableSize / 2 + i + 1L; + lookupJoinData1[5][i] = tableSize / 2 + i + 100L; + lookupJoinData1[6][i] = tableSize / 2 + i + 101L; + lookupJoinData1[7][i] = tableSize / 2 + i + 102L; + + lookupJoinData2[0][i] = tableSize / 2 + maxRowCount + i + 1L; + lookupJoinData2[1][i] = maxRowCount + i + 200L; + lookupJoinData2[2][i] = maxRowCount + i + 201L; + lookupJoinData2[3][i] = maxRowCount + i + 202L; + + if (i < 6) { + lookupJoinData2[4][i] = tableSize / 2 + maxRowCount + i + 1L; + lookupJoinData2[5][i] = tableSize / 2 + maxRowCount + i + 100L; + lookupJoinData2[6][i] = tableSize / 2 + maxRowCount + i + 101L; + lookupJoinData2[7][i] = tableSize / 2 + maxRowCount + i + 102L; + } else { + lookupJoinData2[4][i] = null; + lookupJoinData2[5][i] = null; + lookupJoinData2[6][i] = null; + lookupJoinData2[7][i] = null; + } + } + + for (int i = 0; i < tableSize - 2 * maxRowCount; i++) { + lookupJoinData3[0][i] = tableSize / 2 + 2 * maxRowCount + i + 1L; + lookupJoinData3[1][i] = 2 * maxRowCount + i + 200L; + lookupJoinData3[2][i] = 2 * maxRowCount + i + 201L; + lookupJoinData3[3][i] = 2 * maxRowCount + i + 202L; + lookupJoinData3[4][i] = null; + lookupJoinData3[5][i] = null; + lookupJoinData3[6][i] = null; + lookupJoinData3[7][i] = null; + } + } + + private void buildFullOuterExpectedData(Object[][] fullOuterData1, Object[][] fullOuterData2, int tableSize, + int maxRowCount) { + for (int i = 0; i < maxRowCount; i++) { + fullOuterData1[0][i] = null; + fullOuterData1[1][i] = null; + fullOuterData1[2][i] = null; + fullOuterData1[3][i] = null; + fullOuterData1[4][i] = i + 1L; + fullOuterData1[5][i] = i + 100L; + fullOuterData1[6][i] = i + 101L; + fullOuterData1[7][i] = i + 102L; + } + + for (int i = 0; i < tableSize / 2 - maxRowCount; i++) { + fullOuterData2[0][i] = null; + fullOuterData2[1][i] = null; + fullOuterData2[2][i] = null; + fullOuterData2[3][i] = null; + fullOuterData2[4][i] = maxRowCount + i + 1L; + fullOuterData2[5][i] = maxRowCount + i + 100L; + fullOuterData2[6][i] = maxRowCount + i + 101L; + fullOuterData2[7][i] = maxRowCount + i + 102L; + } + } + + private void buildFullOuterAddInputData(Object[][] buildData, Object[][] probeData) { + int tableSize = buildData[0].length; + for (int i = 0; i < tableSize; i++) { + buildData[0][i] = i + 1L; + buildData[1][i] = i + 100L; + buildData[2][i] = i + 101L; + buildData[3][i] = i + 102L; + probeData[0][i] = tableSize / 2 + i + 1L; + probeData[1][i] = i + 200L; + probeData[2][i] = i + 201L; + probeData[3][i] = i + 202L; + } + } + + /** + * Test full hash join multiple build batch. + */ + @Test + public void testFullOuterJoinIterativeGetOutput() { + DataType[] buildTypes = {LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, LongDataType.LONG}; + String[] buildHashKeys = {getOmniJsonFieldReference(2, 0)}; + OmniHashBuilderWithExprOperatorFactory hashBuilderOperatorFactory = new OmniHashBuilderWithExprOperatorFactory( + OMNI_JOIN_TYPE_FULL, buildTypes, buildHashKeys, 1); + OmniOperator hashBuilderOperator = hashBuilderOperatorFactory.createOperator(); + int tableSize = 32780; + Object[][] buildData = new Object[4][tableSize]; + Object[][] probeData = new Object[4][tableSize]; + buildFullOuterAddInputData(buildData, probeData); + hashBuilderOperator.addInput(createVecBatch(buildTypes, buildData)); + hashBuilderOperator.getOutput(); + + DataType[] probeTypes = {LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, LongDataType.LONG}; + int[] probeOutputCols = {0, 1, 2, 3}; + String[] probeHashKeys = {getOmniJsonFieldReference(2, 0)}; + int[] buildOutputCols = {0, 1, 2, 3}; + OmniLookupJoinWithExprOperatorFactory lookupJoinOperatorFactory = new OmniLookupJoinWithExprOperatorFactory( + probeTypes, probeOutputCols, probeHashKeys, + buildOutputCols, buildTypes, hashBuilderOperatorFactory, false); + OmniOperator lookupJoinOperator = lookupJoinOperatorFactory.createOperator(); + + OmniLookupOuterJoinWithExprOperatorFactory lookupOuterJoinOperatorFactory = + new OmniLookupOuterJoinWithExprOperatorFactory( + probeTypes, probeOutputCols, probeHashKeys, buildOutputCols, buildTypes, hashBuilderOperatorFactory, + new OperatorConfig()); + OmniOperator lookupOuterJoinOperator = lookupOuterJoinOperatorFactory.createOperator(); + VecBatch probeVecBatch = createVecBatch(probeTypes, probeData); + lookupJoinOperator.addInput(probeVecBatch); + + VecBatch vecBatch = null; + List lookupJoinList = new ArrayList<>(); + Iterator resultIterator = lookupJoinOperator.getOutput(); + while (resultIterator.hasNext()) { + vecBatch = resultIterator.next(); + lookupJoinList.add(vecBatch); + } + + List fullOuterList = new ArrayList<>(); + Iterator fullOuterIterator = lookupOuterJoinOperator.getOutput(); + while (fullOuterIterator.hasNext()) { + vecBatch = fullOuterIterator.next(); + fullOuterList.add(vecBatch); + } + + int maxRowCount = 16384; // 1M / (8 * 8) + Object[][] lookupJoinData1 = new Object[8][maxRowCount]; + Object[][] lookupJoinData2 = new Object[8][maxRowCount]; + Object[][] lookupJoinData3 = new Object[8][tableSize - 2 * maxRowCount]; + Object[][] fullOuterData1 = new Object[8][maxRowCount]; + Object[][] fullOuterData2 = new Object[8][tableSize / 2 - maxRowCount]; + + buildLookupJoinExpectedData(lookupJoinData1, lookupJoinData2, lookupJoinData3, tableSize, maxRowCount); + buildFullOuterExpectedData(fullOuterData1, fullOuterData2, tableSize, maxRowCount); + + assertEquals(lookupJoinList.size(), 3); + assertEquals(fullOuterList.size(), 2); + + assertVecBatchEqualsIgnoreOrder(lookupJoinList.get(0), lookupJoinData1); + assertVecBatchEqualsIgnoreOrder(lookupJoinList.get(1), lookupJoinData2); + assertVecBatchEqualsIgnoreOrder(lookupJoinList.get(2), lookupJoinData3); + assertEquals(fullOuterList.get(0).getVector(4).getSize(), maxRowCount); + assertEquals(fullOuterList.get(1).getVector(4).getSize(), tableSize / 2 - maxRowCount); + + lookupJoinOperator.close(); + hashBuilderOperator.close(); + lookupOuterJoinOperator.close(); + lookupJoinOperatorFactory.close(); + hashBuilderOperatorFactory.close(); + lookupOuterJoinOperatorFactory.close(); + + freeVecBatches(lookupJoinList); + freeVecBatches(fullOuterList); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniLimitOperatorTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniLimitOperatorTest.java new file mode 100644 index 0000000..fdb44b0 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniLimitOperatorTest.java @@ -0,0 +1,279 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.operator; + +import static nova.hetu.omniruntime.util.TestUtils.assertVecBatchEquals; +import static nova.hetu.omniruntime.util.TestUtils.createVecBatch; +import static nova.hetu.omniruntime.util.TestUtils.freeVecBatches; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; + +import nova.hetu.omniruntime.operator.limit.OmniLimitOperatorFactory; +import nova.hetu.omniruntime.operator.limit.OmniLimitOperatorFactory.FactoryContext; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DoubleDataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.type.VarcharDataType; +import nova.hetu.omniruntime.vector.Vec; +import nova.hetu.omniruntime.vector.VecBatch; + +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Arrays; + +/** + * The type Omni limit operator test. + * + * @since 2021-11-27 + */ +public class OmniLimitOperatorTest { + @Test + public void testLimitByTwoColum() { + DataType[] sourceTypes = {IntDataType.INTEGER, DoubleDataType.DOUBLE}; + Object[][] sourceDatas1 = {{0, 1, 2, 0, 1, 2}, {6.6, 5.5, 4.4, 3.3, 2.2, 1.1}}; + VecBatch vecBatch1 = createVecBatch(sourceTypes, sourceDatas1); + + OmniLimitOperatorFactory limitOperatorFactory = new OmniLimitOperatorFactory(4); + OmniOperator limitOperator = limitOperatorFactory.createOperator(); + limitOperator.addInput(vecBatch1); + Iterator results = limitOperator.getOutput(); + + Object[][] expectedDatas1 = {{0, 1, 2, 0}, {6.6, 5.5, 4.4, 3.3}}; + + VecBatch resultVecBatch1 = results.next(); + assertVecBatchEquals(resultVecBatch1, expectedDatas1); + + resultVecBatch1.releaseAllVectors(); + resultVecBatch1.close(); + limitOperator.close(); + limitOperatorFactory.close(); + } + + @Test + public void testLimitWithNull() { + DataType[] sourceTypes = {IntDataType.INTEGER, DoubleDataType.DOUBLE}; + Object[][] sourceDatas1 = {{0, 1, 2, 3, 4, 5}, {6.6, 5.5, 4.4, 3.3, 2.2, 1.1}}; + VecBatch vecBatch1 = createVecBatch(sourceTypes, sourceDatas1); + Vec[] inVectors = vecBatch1.getVectors(); + inVectors[0].setNull(2); + inVectors[0].setNull(3); + inVectors[1].setNull(3); + inVectors[1].setNull(4); + + OmniLimitOperatorFactory limitOperatorFactory = new OmniLimitOperatorFactory(6); + OmniOperator limitOperator = limitOperatorFactory.createOperator(); + limitOperator.addInput(vecBatch1); + Iterator results = limitOperator.getOutput(); + + Object[][] expectedDatas1 = {{0, 1, null, null, 4, 5}, {6.6, 5.5, 4.4, null, null, 1.1}}; + + VecBatch resultVecBatch1 = results.next(); + assertVecBatchEquals(resultVecBatch1, expectedDatas1); + + resultVecBatch1.releaseAllVectors(); + resultVecBatch1.close(); + limitOperator.close(); + limitOperatorFactory.close(); + } + + @Test + public void testFactoryContextEquals() { + FactoryContext factory1 = new FactoryContext(6, 0); + FactoryContext factory2 = new FactoryContext(6, 0); + FactoryContext factory3 = null; + + assertEquals(factory2, factory1); + assertEquals(factory1, factory1); + assertNotEquals(factory3, factory1); + } + + private static void buildLimitExpectData(Object[][] expectedData1, Object[][] expectedData2, + int maxRowCount) { + for (int i = 0; i < expectedData1[0].length; i++) { + expectedData1[0][i] = i; + expectedData1[1][i] = i + 1; + expectedData1[2][i] = i + 2L; + expectedData1[3][i] = "abc" + i; + } + + for (int i = 0; i < expectedData2[0].length; i++) { + expectedData2[0][i] = i + maxRowCount; + expectedData2[1][i] = i + maxRowCount + 1; + expectedData2[2][i] = i + maxRowCount + 2L; + expectedData2[3][i] = "abc" + (i + maxRowCount); + } + } + + private static void buildLimitExpectData2(List> expectedData1, List> expectedData2, + int dataSize1, int dataSize2, int offset) { + for (int i = 0; i < dataSize1; i++) { + expectedData1.get(0).add(offset + i); + expectedData1.get(1).add(offset + i + 1); + expectedData1.get(2).add(offset + i + 2L); + expectedData1.get(3).add("abc" + (offset + i)); + } + for (int i = 0; i < dataSize2; i++) { + expectedData2.get(0).add(i + 32768); + expectedData2.get(1).add(i + 32768 + 1); + expectedData2.get(2).add(i + 32768 + 2L); + expectedData2.get(3).add("abc" + (i + 32768)); + } + } + + @Test + public void testLimitMultiBatchGetOutput() { + int dataSize = 32800; + int maxRowCount = 32768; // 1M / (4 + 4 + 8 + 8) + DataType[] sourceTypes = {IntDataType.INTEGER, IntDataType.INTEGER, LongDataType.LONG, new VarcharDataType(8)}; + Object[][] sourceData1 = new Object[sourceTypes.length][maxRowCount]; + Object[][] sourceData2 = new Object[sourceTypes.length][dataSize - maxRowCount]; + + for (int i = 0; i < maxRowCount; i++) { + sourceData1[0][i] = i; + sourceData1[1][i] = i + 1; + sourceData1[2][i] = i + 2L; + sourceData1[3][i] = "abc" + i; + } + for (int i = 0; i < dataSize - maxRowCount; i++) { + sourceData2[0][i] = i + maxRowCount; + sourceData2[1][i] = i + maxRowCount + 1; + sourceData2[2][i] = i + maxRowCount + 2L; + sourceData2[3][i] = "abc" + (i + maxRowCount); + } + + VecBatch vecBatch1 = createVecBatch(sourceTypes, sourceData1); + VecBatch vecBatch2 = createVecBatch(sourceTypes, sourceData2); + + int limitSize = 32780; + OmniLimitOperatorFactory limitOperatorFactory = new OmniLimitOperatorFactory(limitSize); + OmniOperator limitOperator = limitOperatorFactory.createOperator(); + limitOperator.addInput(vecBatch1); + + List resultList = new ArrayList<>(); + Iterator limitIterator = limitOperator.getOutput(); + resultList.add(limitIterator.next()); + limitOperator.addInput(vecBatch2); + while (limitIterator.hasNext()) { + resultList.add(limitIterator.next()); + } + assertEquals(resultList.size(), 2); + + Object[][] expectedData1 = new Object[sourceTypes.length][maxRowCount]; + Object[][] expectedData2 = new Object[sourceTypes.length][limitSize - maxRowCount]; + + buildLimitExpectData(expectedData1, expectedData2, maxRowCount); + assertVecBatchEquals(resultList.get(0), expectedData1); + assertVecBatchEquals(resultList.get(1), expectedData2); + + freeVecBatches(resultList); + limitOperator.close(); + limitOperatorFactory.close(); + } + + @Test + public void testOnlyOffsetByTwoColumn1() { + DataType[] sourceTypes = {IntDataType.INTEGER, DoubleDataType.DOUBLE}; + List> sourceDatas1 = Arrays.asList(Arrays.asList(0, 1, 2, 0, 1, 2), + Arrays.asList(1.1, 1.2, 1.3, 2.1, 2.2, 2.3)); + VecBatch vecBatch1 = createVecBatch(sourceTypes, sourceDatas1); + + OmniLimitOperatorFactory limitOperatorFactory = new OmniLimitOperatorFactory(-1, 2); + OmniOperator limitOperator = limitOperatorFactory.createOperator(); + limitOperator.addInput(vecBatch1); + Iterator results = limitOperator.getOutput(); + + List> expectedDatas1 = Arrays.asList(Arrays.asList(2, 0, 1, 2), + Arrays.asList(1.3, 2.1, 2.2, 2.3)); + VecBatch expectedVecBatch1 = createVecBatch(sourceTypes, expectedDatas1); + + VecBatch resultVecBatch1 = results.next(); + assertVecBatchEquals(resultVecBatch1, expectedVecBatch1); + + resultVecBatch1.releaseAllVectors(); + resultVecBatch1.close(); + limitOperator.close(); + limitOperatorFactory.close(); + } + + @Test + public void testLimitOffsetByTwoColumn() { + DataType[] sourceTypes = {IntDataType.INTEGER, DoubleDataType.DOUBLE}; + List> sourceDatas1 = Arrays.asList(Arrays.asList(0, 1, 2, 0, 1, 2), + Arrays.asList(1.1, 1.2, 1.3, 2.1, 2.2, 2.3)); + VecBatch vecBatch1 = createVecBatch(sourceTypes, sourceDatas1); + + OmniLimitOperatorFactory limitOperatorFactory = new OmniLimitOperatorFactory(4, 2); + OmniOperator limitOperator = limitOperatorFactory.createOperator(); + limitOperator.addInput(vecBatch1); + Iterator results = limitOperator.getOutput(); + + List> expectedDatas1 = Arrays.asList(Arrays.asList(2, 0), + Arrays.asList(1.3, 2.1)); + VecBatch expectedVecBatch1 = createVecBatch(sourceTypes, expectedDatas1); + + VecBatch resultVecBatch1 = results.next(); + assertVecBatchEquals(resultVecBatch1, expectedVecBatch1); + + resultVecBatch1.releaseAllVectors(); + resultVecBatch1.close(); + limitOperator.close(); + limitOperatorFactory.close(); + } + + @Test + public void testLimitOffsetMultiBatchGetOutput() { + int offset = 20; + int dataSize = 32800; + int maxRowCount = 32768; // 1M / (4 + 4 + 8 + 8) + DataType[] sourceTypes = {IntDataType.INTEGER, IntDataType.INTEGER, LongDataType.LONG, new VarcharDataType(8)}; + List> sourceData1 = new ArrayList<>(); + List> sourceData2 = new ArrayList<>(); + + for (int i = 0; i < sourceTypes.length; i++) { + sourceData1.add(new ArrayList<>()); + sourceData2.add(new ArrayList<>()); + } + + buildLimitExpectData2(sourceData1, sourceData2, maxRowCount, dataSize - maxRowCount, 0); + VecBatch vecBatch1 = createVecBatch(sourceTypes, sourceData1); + VecBatch vecBatch2 = createVecBatch(sourceTypes, sourceData2); + + int limitSize = 32780; + OmniLimitOperatorFactory limitOperatorFactory = new OmniLimitOperatorFactory(limitSize, offset); + OmniOperator limitOperator = limitOperatorFactory.createOperator(); + limitOperator.addInput(vecBatch1); + + List resultList = new ArrayList<>(); + Iterator limitIterator = limitOperator.getOutput(); + resultList.add(limitIterator.next()); + limitOperator.addInput(vecBatch2); + while (limitIterator.hasNext()) { + resultList.add(limitIterator.next()); + } + assertEquals(resultList.size(), 2); + + List> expectedData1 = new ArrayList<>(); + List> expectedData2 = new ArrayList<>(); + + for (int i = 0; i < sourceTypes.length; i++) { + expectedData1.add(new ArrayList<>()); + expectedData2.add(new ArrayList<>()); + } + + buildLimitExpectData2(expectedData1, expectedData2, maxRowCount - offset, limitSize - maxRowCount, offset); + VecBatch expectedVecBatch1 = createVecBatch(sourceTypes, expectedData1); + VecBatch expectedVecBatch2 = createVecBatch(sourceTypes, expectedData2); + assertVecBatchEquals(resultList.get(0), expectedVecBatch1); + assertVecBatchEquals(resultList.get(1), expectedVecBatch2); + + freeVecBatches(resultList); + limitOperator.close(); + limitOperatorFactory.close(); + } +} \ No newline at end of file diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniNestedLoopJoinOperatorTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniNestedLoopJoinOperatorTest.java new file mode 100644 index 0000000..903dd5f --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniNestedLoopJoinOperatorTest.java @@ -0,0 +1,377 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.operator; + +import static nova.hetu.omniruntime.constants.JoinType.OMNI_JOIN_TYPE_INNER; +import static nova.hetu.omniruntime.constants.JoinType.OMNI_JOIN_TYPE_LEFT; +import static nova.hetu.omniruntime.constants.JoinType.OMNI_JOIN_TYPE_RIGHT; +import static nova.hetu.omniruntime.type.VarcharDataType.MAX_WIDTH; +import static nova.hetu.omniruntime.util.TestUtils.assertVecBatchEquals; +import static nova.hetu.omniruntime.util.TestUtils.createDictionaryVec; +import static nova.hetu.omniruntime.util.TestUtils.createIntVec; +import static nova.hetu.omniruntime.util.TestUtils.createVecBatch; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; +import static org.testng.Assert.assertTrue; + +import nova.hetu.omniruntime.constants.JoinType; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.operator.join.OmniNestedLoopJoinBuildOperatorFactory; +import nova.hetu.omniruntime.operator.join.OmniNestedLoopJoinLookupOperatorFactory; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DoubleDataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.VarcharDataType; +import nova.hetu.omniruntime.util.TestUtils; +import nova.hetu.omniruntime.vector.IntVec; +import nova.hetu.omniruntime.vector.Vec; +import nova.hetu.omniruntime.vector.VecBatch; + +import org.testng.annotations.Test; + +import java.util.Iterator; +import java.util.Optional; + +/** + * The type Omni Nested Loop Join operators test. + * + * @since 2024-12-2 + */ +public class OmniNestedLoopJoinOperatorTest { + @Test + public void testFactoryContextEquals() { + DataType[] sourceTypes = {IntDataType.INTEGER, DoubleDataType.DOUBLE}; + int[] outputCols = new int[]{0, 1}; + int[] outputCols1 = new int[]{1, 0}; + + OmniNestedLoopJoinBuildOperatorFactory.FactoryContext context1 = new OmniNestedLoopJoinBuildOperatorFactory + .FactoryContext(sourceTypes, outputCols); + OmniNestedLoopJoinBuildOperatorFactory.FactoryContext context2 = new OmniNestedLoopJoinBuildOperatorFactory + .FactoryContext(sourceTypes, outputCols); + OmniNestedLoopJoinBuildOperatorFactory.FactoryContext context3 = new OmniNestedLoopJoinBuildOperatorFactory + .FactoryContext(sourceTypes, outputCols1); + assertNotEquals(context1, null); + assertNotEquals(new Object(), context1); + assertNotEquals(context3, context1); + assertEquals(context1, context1); + assertEquals(context2, context1); + + OmniNestedLoopJoinBuildOperatorFactory omniNestedLoopJoinBuildOperatorFactory = + new OmniNestedLoopJoinBuildOperatorFactory(sourceTypes, outputCols); + OmniNestedLoopJoinLookupOperatorFactory.FactoryContext lookupContext1 = + new OmniNestedLoopJoinLookupOperatorFactory.FactoryContext(OMNI_JOIN_TYPE_INNER, sourceTypes, + outputCols, Optional.of("test"), omniNestedLoopJoinBuildOperatorFactory, new OperatorConfig()); + OmniNestedLoopJoinLookupOperatorFactory.FactoryContext lookupContext2 = + new OmniNestedLoopJoinLookupOperatorFactory.FactoryContext( + OMNI_JOIN_TYPE_INNER, sourceTypes, outputCols, Optional.of("test"), + omniNestedLoopJoinBuildOperatorFactory, new OperatorConfig()); + OmniNestedLoopJoinLookupOperatorFactory.FactoryContext lookupContext3 = + new OmniNestedLoopJoinLookupOperatorFactory.FactoryContext( + OMNI_JOIN_TYPE_INNER, sourceTypes, outputCols, Optional.of("test3"), + omniNestedLoopJoinBuildOperatorFactory, new OperatorConfig()); + OmniNestedLoopJoinLookupOperatorFactory.FactoryContext lookupContext4 = + new OmniNestedLoopJoinLookupOperatorFactory.FactoryContext( + OMNI_JOIN_TYPE_RIGHT, sourceTypes, outputCols, Optional.of("test"), + omniNestedLoopJoinBuildOperatorFactory, new OperatorConfig()); + assertNotEquals(lookupContext1, null); + assertNotEquals(new Object(), lookupContext1); + assertNotEquals(lookupContext3, lookupContext1); + assertNotEquals(lookupContext4, lookupContext1); + assertEquals(lookupContext1, lookupContext1); + assertEquals(lookupContext2, lookupContext1); + } + + @Test + public void testRightOutJoin() { + DataType[] sourceTypes = {new VarcharDataType(20), new VarcharDataType(20), IntDataType.INTEGER, + DoubleDataType.DOUBLE}; + int[] buildOutputCols = new int[]{0, 1, 2, 3}; + Object[][] sourceDatas1 = {{"abc", "yeah", "", "add"}, {"", "yeah", "Hello", "World"}, {4, 10, 1, 8}, + {2.0, 8.0, 1.0, 3.0}}; + VecBatch vecBatch = createVecBatch(sourceTypes, sourceDatas1); + + OmniNestedLoopJoinBuildOperatorFactory omniNestedLoopJoinBuildOperatorFactory = + new OmniNestedLoopJoinBuildOperatorFactory(sourceTypes, buildOutputCols); + OmniOperator omniOperator = omniNestedLoopJoinBuildOperatorFactory.createOperator(); + omniOperator.addInput(vecBatch); + omniOperator.getOutput(); + + DataType[] probeTypes = {new VarcharDataType(20), new VarcharDataType(20), IntDataType.INTEGER, + DoubleDataType.DOUBLE}; + int[] probeOutputCols = new int[]{0, 1, 2, 3}; + + String fieldExpr1 = TestUtils.getOmniJsonFieldReference(1, 2); + String fieldExpr2 = TestUtils.getOmniJsonFieldReference(1, 6); + Optional filter = Optional.of(TestUtils.omniJsonGreaterThanExpr(fieldExpr1, fieldExpr2)); + + OmniNestedLoopJoinLookupOperatorFactory omniNestedLoopJoinLookupOperatorFactory = + new OmniNestedLoopJoinLookupOperatorFactory(OMNI_JOIN_TYPE_RIGHT, probeTypes, probeOutputCols, filter, + omniNestedLoopJoinBuildOperatorFactory, new OperatorConfig()); + OmniOperator lookUpOperator = omniNestedLoopJoinLookupOperatorFactory.createOperator(); + Object[][] sourceDatas2 = {{"abc", "", "yeah", "add"}, {"", "Hello", "yeah", "World"}, {4, 2, 0, 1}, + {1.0, 2.0, 4.0, 3.0}}; + VecBatch vecBatch3 = createVecBatch(probeTypes, sourceDatas2); + lookUpOperator.addInput(vecBatch3); + Iterator results = lookUpOperator.getOutput(); + + Object[][] expectedDatas1 = {{"abc", "", "yeah", "add"}, {"", "Hello", "yeah", "World"}, {4, 2, 0, 1}, + {1.0, 2.0, 4.0, 3.0}, {"", "", null, null}, {"Hello", "Hello", null, null}, {1, 1, null, null}, + {1.0, 1.0, null, null}}; + VecBatch resultVecBatch1 = results.next(); + assertVecBatchEquals(resultVecBatch1, expectedDatas1); + resultVecBatch1.releaseAllVectors(); + resultVecBatch1.close(); + lookUpOperator.close(); + omniNestedLoopJoinLookupOperatorFactory.close(); + omniOperator.close(); + omniNestedLoopJoinBuildOperatorFactory.close(); + } + + @Test + public void testCrossJoin() { + DataType[] sourceTypes = {IntDataType.INTEGER, DoubleDataType.DOUBLE}; + int[] buildOutputCols = new int[]{0, 1}; + Object[][] sourceDatas1 = {{0, 1, 2}, {6.6, 5.5, 4.4}}; + VecBatch vecBatch1 = createVecBatch(sourceTypes, sourceDatas1); + VecBatch vecBatch2 = createVecBatch(sourceTypes, sourceDatas1); + + OmniNestedLoopJoinBuildOperatorFactory omniNestedLoopJoinBuildOperatorFactory = + new OmniNestedLoopJoinBuildOperatorFactory(sourceTypes, buildOutputCols); + OmniOperator omniOperator = omniNestedLoopJoinBuildOperatorFactory.createOperator(); + omniOperator.addInput(vecBatch1); + omniOperator.addInput(vecBatch2); + omniOperator.getOutput(); + + DataType[] probeTypes = {IntDataType.INTEGER, DoubleDataType.DOUBLE}; + int[] probeOutputCols = new int[]{0}; + Optional filter = Optional.empty(); + + OmniNestedLoopJoinLookupOperatorFactory omniNestedLoopJoinLookupOperatorFactory = + new OmniNestedLoopJoinLookupOperatorFactory(OMNI_JOIN_TYPE_INNER, probeTypes, probeOutputCols, filter, + omniNestedLoopJoinBuildOperatorFactory, new OperatorConfig()); + OmniOperator lookUpOperator = omniNestedLoopJoinLookupOperatorFactory.createOperator(); + VecBatch vecBatch3 = createVecBatch(sourceTypes, sourceDatas1); + lookUpOperator.addInput(vecBatch3); + Iterator results = lookUpOperator.getOutput(); + + Object[][] expectedDatas1 = {{0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2}, + {0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2}, + {6.6, 5.5, 4.4, 6.6, 5.5, 4.4, 6.6, 5.5, 4.4, 6.6, 5.5, 4.4, 6.6, 5.5, 4.4, 6.6, 5.5, 4.4}}; + VecBatch resultVecBatch1 = results.next(); + assertVecBatchEquals(resultVecBatch1, expectedDatas1); + resultVecBatch1.releaseAllVectors(); + resultVecBatch1.close(); + lookUpOperator.close(); + omniNestedLoopJoinLookupOperatorFactory.close(); + omniOperator.close(); + omniNestedLoopJoinBuildOperatorFactory.close(); + } + + @Test + public void testInnerJoin() { + DataType[] sourceTypes = {IntDataType.INTEGER, DoubleDataType.DOUBLE, new VarcharDataType(20)}; + int[] buildOutputCols = new int[]{0, 1, 2}; + Object[][] sourceDatas1 = {{0, 1, 2}, {6.6, 5.5, 4.4}, {"0123test", "012test", "01test"}}; + VecBatch vecBatch1 = createVecBatch(sourceTypes, sourceDatas1); + VecBatch vecBatch2 = createVecBatch(sourceTypes, sourceDatas1); + + OmniNestedLoopJoinBuildOperatorFactory omniNestedLoopJoinBuildOperatorFactory = + new OmniNestedLoopJoinBuildOperatorFactory(sourceTypes, buildOutputCols); + OmniOperator omniOperator = omniNestedLoopJoinBuildOperatorFactory.createOperator(); + omniOperator.addInput(vecBatch1); + omniOperator.addInput(vecBatch2); + omniOperator.getOutput(); + + DataType[] probeTypes = {IntDataType.INTEGER, DoubleDataType.DOUBLE}; + int[] probeOutputCols = new int[]{1}; + + String lengthExpr = TestUtils.omniFunctionExpr("length", 1, TestUtils.getOmniJsonFieldReference(15, 4)); + String fieldExpr = TestUtils.getOmniJsonFieldReference(1, 0); + Optional filter = Optional.of(TestUtils.omniJsonGreaterThanExpr(fieldExpr, lengthExpr)); + + OmniNestedLoopJoinLookupOperatorFactory omniNestedLoopJoinLookupOperatorFactory = + new OmniNestedLoopJoinLookupOperatorFactory(OMNI_JOIN_TYPE_INNER, probeTypes, probeOutputCols, filter, + omniNestedLoopJoinBuildOperatorFactory, new OperatorConfig()); + OmniOperator lookUpOperator = omniNestedLoopJoinLookupOperatorFactory.createOperator(); + Object[][] sourceDatas2 = {{9, 7, 6}, {6.6, 5.5, 4.4}}; + VecBatch vecBatch3 = createVecBatch(probeTypes, sourceDatas2); + lookUpOperator.addInput(vecBatch3); + Iterator results = lookUpOperator.getOutput(); + + Object[][] expectedDatas1 = {{6.6, 6.6, 6.6, 6.6, 6.6, 6.6, 5.5, 5.5}, {0, 1, 2, 0, 1, 2, 2, 2}, + {6.6, 5.5, 4.4, 6.6, 5.5, 4.4, 4.4, 4.4}, + {"0123test", "012test", "01test", "0123test", "012test", "01test", "01test", "01test"}}; + VecBatch resultVecBatch1 = results.next(); + assertVecBatchEquals(resultVecBatch1, expectedDatas1); + resultVecBatch1.releaseAllVectors(); + resultVecBatch1.close(); + lookUpOperator.close(); + omniNestedLoopJoinLookupOperatorFactory.close(); + omniOperator.close(); + omniNestedLoopJoinBuildOperatorFactory.close(); + } + + @Test + public void testInnerJoinWithLargeRowSize() { + DataType[] sourceTypes = {IntDataType.INTEGER, DoubleDataType.DOUBLE, new VarcharDataType(MAX_WIDTH)}; + int[] buildOutputCols = new int[]{0, 1, 2}; + Object[][] sourceDatas1 = {{0, 1, 2}, {6.6, 5.5, 4.4}, {"0123test", "012test", "01test"}}; + VecBatch vecBatch1 = createVecBatch(sourceTypes, sourceDatas1); + VecBatch vecBatch2 = createVecBatch(sourceTypes, sourceDatas1); + + OmniNestedLoopJoinBuildOperatorFactory omniNestedLoopJoinBuildOperatorFactory = + new OmniNestedLoopJoinBuildOperatorFactory(sourceTypes, buildOutputCols); + OmniOperator omniOperator = omniNestedLoopJoinBuildOperatorFactory.createOperator(); + omniOperator.addInput(vecBatch2); + omniOperator.addInput(vecBatch1); + omniOperator.getOutput(); + + DataType[] probeTypes = {IntDataType.INTEGER, DoubleDataType.DOUBLE}; + int[] probeOutputCols = new int[]{0, 1}; + + String fieldExpr = TestUtils.getOmniJsonFieldReference(1, 0); + String lengthExpr = TestUtils.omniFunctionExpr("length", 1, TestUtils.getOmniJsonFieldReference(15, 4)); + Optional filter = Optional.of(TestUtils.omniJsonGreaterThanExpr(fieldExpr, lengthExpr)); + + OmniNestedLoopJoinLookupOperatorFactory omniNestedLoopJoinLookupOperatorFactory = + new OmniNestedLoopJoinLookupOperatorFactory(OMNI_JOIN_TYPE_INNER, probeTypes, probeOutputCols, filter, + omniNestedLoopJoinBuildOperatorFactory, new OperatorConfig()); + Object[][] sourceDatas2 = {{9, 7, 6}, {6.6, 5.5, 4.4}}; + OmniOperator lookUpOperator = omniNestedLoopJoinLookupOperatorFactory.createOperator(); + VecBatch vecBatch3 = createVecBatch(probeTypes, sourceDatas2); + lookUpOperator.addInput(vecBatch3); + Iterator results = lookUpOperator.getOutput(); + + Object[][] expectedDatas1 = {{9, 9, 9, 9, 9, 9}, {6.6, 6.6, 6.6, 6.6, 6.6, 6.6}, {0, 1, 2, 0, 1, 2}, + {6.6, 5.5, 4.4, 6.6, 5.5, 4.4}, {"0123test", "012test", "01test", "0123test", "012test", "01test"}}; + VecBatch resultVecBatch1 = results.next(); + assertEquals(resultVecBatch1.getRowCount(), vecBatch1.getRowCount() + vecBatch1.getRowCount()); + assertVecBatchEquals(resultVecBatch1, expectedDatas1); + resultVecBatch1.releaseAllVectors(); + resultVecBatch1.close(); + assertTrue(results.hasNext()); + VecBatch resultVecBatch2 = results.next(); + Object[][] expectedDatas2 = {{7, 7}, {5.5, 5.5}, {2, 2}, {4.4, 4.4}, {"01test", "01test"}}; + assertVecBatchEquals(resultVecBatch2, expectedDatas2); + resultVecBatch2.releaseAllVectors(); + resultVecBatch2.close(); + lookUpOperator.close(); + omniNestedLoopJoinLookupOperatorFactory.close(); + omniOperator.close(); + omniNestedLoopJoinBuildOperatorFactory.close(); + } + + @Test + public void testOuterJoin() { + DataType[] sourceTypes = {IntDataType.INTEGER, DoubleDataType.DOUBLE, new VarcharDataType(20)}; + int[] buildOutputCols = new int[]{0, 1, 2}; + Object[][] sourceDatas1 = {{0, 1, 2}, {6.6, 5.5, 4.4}, {"0123test", "012test", "01test"}}; + VecBatch vecBatch1 = createVecBatch(sourceTypes, sourceDatas1); + VecBatch vecBatch2 = createVecBatch(sourceTypes, sourceDatas1); + + OmniNestedLoopJoinBuildOperatorFactory omniNestedLoopJoinBuildOperatorFactory = + new OmniNestedLoopJoinBuildOperatorFactory(sourceTypes, buildOutputCols); + OmniOperator omniOperator = omniNestedLoopJoinBuildOperatorFactory.createOperator(); + omniOperator.addInput(vecBatch2); + omniOperator.addInput(vecBatch1); + omniOperator.getOutput(); + + outerJoin(OMNI_JOIN_TYPE_LEFT, omniNestedLoopJoinBuildOperatorFactory); + + outerJoin(OMNI_JOIN_TYPE_RIGHT, omniNestedLoopJoinBuildOperatorFactory); + + omniOperator.close(); + omniNestedLoopJoinBuildOperatorFactory.close(); + } + + @Test + public void testOuterJoinWithDictionaryVec() { + DataType[] sourceTypes = {IntDataType.INTEGER, DoubleDataType.DOUBLE, new VarcharDataType(20)}; + int[] buildOutputCols = new int[]{0, 1, 2}; + Object[] filed1 = {0, 1, 2}; + Object[] filed2 = {6.6, 5.5, 4.4}; + Object[] filed3 = {"0123test", "012test", "01test"}; + IntVec intVec = createIntVec(filed1); + Vec doubleVec = createDictionaryVec(DoubleDataType.DOUBLE, filed2, new int[]{0, 1, 2}); + Vec varcharVec = createDictionaryVec(new VarcharDataType(20), filed3, new int[]{0, 1, 2}); + Vec[] vecs = {intVec, doubleVec, varcharVec}; + VecBatch vecBatch1 = new VecBatch(vecs); + IntVec intVec2 = createIntVec(filed1); + Vec doubleVec2 = createDictionaryVec(DoubleDataType.DOUBLE, filed2, new int[]{0, 1, 2}); + Vec varcharVec2 = createDictionaryVec(new VarcharDataType(20), filed3, new int[]{0, 1, 2}); + Vec[] vecs2 = {intVec2, doubleVec2, varcharVec2}; + VecBatch vecBatch2 = new VecBatch(vecs2); + + OmniNestedLoopJoinBuildOperatorFactory omniNestedLoopJoinBuildOperatorFactory = + new OmniNestedLoopJoinBuildOperatorFactory(sourceTypes, buildOutputCols); + OmniOperator omniOperator = omniNestedLoopJoinBuildOperatorFactory.createOperator(); + omniOperator.addInput(vecBatch1); + omniOperator.addInput(vecBatch2); + omniOperator.getOutput(); + + outerJoin(OMNI_JOIN_TYPE_LEFT, omniNestedLoopJoinBuildOperatorFactory); + + outerJoin(OMNI_JOIN_TYPE_RIGHT, omniNestedLoopJoinBuildOperatorFactory); + + omniOperator.close(); + omniNestedLoopJoinBuildOperatorFactory.close(); + } + + private static void outerJoin(JoinType joinType, + OmniNestedLoopJoinBuildOperatorFactory omniNestedLoopJoinBuildOperatorFactory) { + DataType[] probeTypes = {IntDataType.INTEGER, DoubleDataType.DOUBLE}; + int[] probeOutputCols = new int[]{1}; + String lengthExpr = TestUtils.omniFunctionExpr("length", 1, TestUtils.getOmniJsonFieldReference(15, 4)); + String fieldExpr = TestUtils.getOmniJsonFieldReference(1, 0); + Optional filter = Optional.of(TestUtils.omniJsonGreaterThanExpr(fieldExpr, lengthExpr)); + + OmniNestedLoopJoinLookupOperatorFactory omniNestedLoopJoinLookupOperatorFactory = + new OmniNestedLoopJoinLookupOperatorFactory(joinType, probeTypes, probeOutputCols, filter, + omniNestedLoopJoinBuildOperatorFactory, new OperatorConfig()); + OmniOperator lookUpOperator = omniNestedLoopJoinLookupOperatorFactory.createOperator(); + Object[][] sourceDatas2 = {{9, 7, 6}, {6.6, 5.5, 4.4}}; + VecBatch vecBatch3 = createVecBatch(probeTypes, sourceDatas2); + lookUpOperator.addInput(vecBatch3); + Iterator results = lookUpOperator.getOutput(); + + Object[][] expectedDatas1 = {{6.6, 6.6, 6.6, 6.6, 6.6, 6.6, 5.5, 5.5, 4.4}, {0, 1, 2, 0, 1, 2, 2, 2, null}, + {6.6, 5.5, 4.4, 6.6, 5.5, 4.4, 4.4, 4.4, null}, + {"0123test", "012test", "01test", "0123test", "012test", "01test", "01test", "01test", null}}; + VecBatch resultVecBatch1 = results.next(); + assertVecBatchEquals(resultVecBatch1, expectedDatas1); + resultVecBatch1.releaseAllVectors(); + resultVecBatch1.close(); + lookUpOperator.close(); + omniNestedLoopJoinLookupOperatorFactory.close(); + } + + @Test + public void testRightOutJoinWithShared() { + int builderNodeId = 10; + DataType[] sourceTypes = {new VarcharDataType(20), new VarcharDataType(20), IntDataType.INTEGER, + DoubleDataType.DOUBLE}; + int[] buildOutputCols = new int[]{0, 1, 2, 3}; + OmniNestedLoopJoinBuildOperatorFactory omniNestedLoopJoinBuildOperatorFactory = + new OmniNestedLoopJoinBuildOperatorFactory(sourceTypes, buildOutputCols); + OmniOperator omniOperator = omniNestedLoopJoinBuildOperatorFactory.createOperator(); + OmniNestedLoopJoinBuildOperatorFactory tempOmniNestedLoopJoinBuildOperatorFactory = + OmniNestedLoopJoinBuildOperatorFactory.getNestedLoopJoinBuilderOperatorFactory(builderNodeId); + assertEquals(tempOmniNestedLoopJoinBuildOperatorFactory == null, true); + OmniNestedLoopJoinBuildOperatorFactory.saveNestedLoopJoinBuilderOperatorAndFactory(builderNodeId, + omniNestedLoopJoinBuildOperatorFactory, omniOperator); + tempOmniNestedLoopJoinBuildOperatorFactory = OmniNestedLoopJoinBuildOperatorFactory + .getNestedLoopJoinBuilderOperatorFactory(builderNodeId); + assertEquals(omniNestedLoopJoinBuildOperatorFactory == tempOmniNestedLoopJoinBuildOperatorFactory, true); + OmniNestedLoopJoinBuildOperatorFactory.dereferenceNestedBuilderOperatorAndFactory(builderNodeId); + tempOmniNestedLoopJoinBuildOperatorFactory = OmniNestedLoopJoinBuildOperatorFactory + .getNestedLoopJoinBuilderOperatorFactory(builderNodeId); + assertEquals(omniNestedLoopJoinBuildOperatorFactory == tempOmniNestedLoopJoinBuildOperatorFactory, true); + OmniNestedLoopJoinBuildOperatorFactory.dereferenceNestedBuilderOperatorAndFactory(builderNodeId); + OmniNestedLoopJoinBuildOperatorFactory.dereferenceNestedBuilderOperatorAndFactory(builderNodeId); + tempOmniNestedLoopJoinBuildOperatorFactory = OmniNestedLoopJoinBuildOperatorFactory + .getNestedLoopJoinBuilderOperatorFactory(builderNodeId); + assertEquals(tempOmniNestedLoopJoinBuildOperatorFactory == null, true); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniOperatorConfigTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniOperatorConfigTest.java new file mode 100644 index 0000000..3bea3b8 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniOperatorConfigTest.java @@ -0,0 +1,43 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator; + +import static org.testng.Assert.assertEquals; + +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.operator.config.SparkSpillConfig; +import nova.hetu.omniruntime.operator.config.SpillConfig; + +import org.testng.annotations.Test; + +import java.nio.file.Paths; + +/** + * The omni operator config test. + * + * @since 2022-04-16 + */ +public class OmniOperatorConfigTest { + @Test + public void TestSerialization() { + final String spillPath = Paths.get("").toAbsolutePath().toString(); + + // disable jit + String noneConfigString = OperatorConfig.serialize(OperatorConfig.NONE); + assertEquals(OperatorConfig.NONE, OperatorConfig.deserialize(noneConfigString)); + + OperatorConfig invalidSpillConfig = new OperatorConfig(SpillConfig.INVALID); + String invalidConfigString = OperatorConfig.serialize(invalidSpillConfig); + assertEquals(invalidSpillConfig, OperatorConfig.deserialize(invalidConfigString)); + + OperatorConfig sparkOperatorConfig1 = new OperatorConfig(new SparkSpillConfig(spillPath, 5)); + String sparkConfigString1 = OperatorConfig.serialize(sparkOperatorConfig1); + assertEquals(sparkOperatorConfig1, OperatorConfig.deserialize(sparkConfigString1)); + + OperatorConfig sparkOperatorConfig2 = new OperatorConfig(new SparkSpillConfig(false, spillPath, 1024, 1)); + String sparkConfigString2 = OperatorConfig.serialize(sparkOperatorConfig2); + assertEquals(sparkOperatorConfig2, OperatorConfig.deserialize(sparkConfigString2)); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniOperatorFactoryTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniOperatorFactoryTest.java new file mode 100644 index 0000000..a6ab7a1 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniOperatorFactoryTest.java @@ -0,0 +1,141 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; +import static org.testng.Assert.assertTrue; + +import nova.hetu.omniruntime.operator.config.OperatorConfig; + +import org.testng.annotations.Test; + +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +/** + * The type Omni operator factory test. + * + * @since 2021-6-7 + */ +public class OmniOperatorFactoryTest { + /** + * The Check factory. + */ + MockOperatorFactory checkFactory = new MockOperatorFactory(1); + + /** + * Test operator factory cache. + */ + @Test + public void testOperatorFactoryCache() { + MockOperatorFactory factory1 = new MockOperatorFactory(1); + MockOperatorFactory factory2 = new MockOperatorFactory(1); + assertEquals(factory1.getNativeOperatorFactory(), factory2.getNativeOperatorFactory()); + MockOperatorFactory factory3 = new MockOperatorFactory(2); + assertNotEquals(factory1.getNativeOperatorFactory(), factory3.getNativeOperatorFactory()); + } + + /** + * Test operator factory cache multi thread. + * + * @throws InterruptedException the interrupted exception + */ + @Test + public void testOperatorFactoryCacheMultiThread() { + final int threadNum = 10000; + final int corePoolSize = 10; + final int maximumPoolSize = 50; + CountDownLatch countDownLatch = new CountDownLatch(threadNum); + ThreadPoolExecutor threadPool = new ThreadPoolExecutor(corePoolSize, maximumPoolSize, 60L, TimeUnit.SECONDS, + new LinkedBlockingQueue<>(threadNum)); + + for (int i = 0; i < threadNum; i++) { + CompletableFuture.runAsync(() -> { + try { + MockOperatorFactory factory = new MockOperatorFactory(1); + assertEquals(checkFactory.getNativeOperatorFactory(), factory.getNativeOperatorFactory()); + factory.close(); + } finally { + countDownLatch.countDown(); + } + }, threadPool); + } + // This will wait until all future ready. + try { + countDownLatch.await(); + } catch (InterruptedException ex) { + assertTrue(false); + } + + threadPool.shutdown(); + } + + /** + * The type Mock operator factory. + */ + public static class MockOperatorFactory extends OmniOperatorFactory { + /** + * Instantiates a new Mock operator factory. + * + * @param context the context + */ + public MockOperatorFactory(long context) { + super(new FactoryContext(context, new OperatorConfig())); + } + + @Override + protected long createNativeOperatorFactory(FactoryContext context) { + return System.nanoTime(); + } + + /** + * Factory Context + * + * @since 2021-7-13 + */ + public static class FactoryContext extends OmniOperatorFactoryContext { + private final OperatorConfig operatorConfig; + + private final long context; + + /** + * Instantiates a new Context. + * + * @param operatorConfig operatorConfig + * @param context context + */ + public FactoryContext(long context, OperatorConfig operatorConfig) { + this.context = context; + this.operatorConfig = operatorConfig; + } + + /** + * Calculate hash code + * + * @return hash value + */ + @Override + public int hashCode() { + return Objects.hash(context); + } + + /** + * Check equals + * + * @param that object + * @return whether equals + */ + @Override + public boolean equals(Object that) { + return context == ((FactoryContext) that).context; + } + } + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniPartionOutOperatorTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniPartionOutOperatorTest.java new file mode 100644 index 0000000..3aaf72c --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniPartionOutOperatorTest.java @@ -0,0 +1,159 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator; + +import static nova.hetu.omniruntime.util.TestUtils.assertVecBatchEquals; +import static nova.hetu.omniruntime.util.TestUtils.createVecBatch; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; + +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.operator.partitionedoutput.OmniPartitionedOutPutOperatorFactory; +import nova.hetu.omniruntime.operator.partitionedoutput.OmniPartitionedOutPutOperatorFactory.FactoryContext; +import nova.hetu.omniruntime.type.CharDataType; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.VarcharDataType; +import nova.hetu.omniruntime.util.TestUtils; +import nova.hetu.omniruntime.vector.VecBatch; + +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.OptionalInt; + +/** + * The type Omni partition out operators test. + * + * @since 2021-6-30 + */ +public class OmniPartionOutOperatorTest { + @Test(enabled = false) + public void testPartionOut() { + OptionalInt nullChannel = OptionalInt.empty(); + + int[] partitionChannels = {0}; + int partitionCount = 1; + int[] bucketToPartition = {0}; + DataType[] hashChannelTypes = {VarcharDataType.VARCHAR}; + int[] hashChannels = {0}; + + DataType[] buildTypes = {new VarcharDataType(3), new VarcharDataType(3)}; + Object[][] buildDatas = {{"abc", "de", "f"}, {"def", "bc", "a"}}; + VecBatch vecBatch = createVecBatch(buildTypes, buildDatas); + DataType[] sourceTypes = {VarcharDataType.VARCHAR}; + + OmniPartitionedOutPutOperatorFactory omniPartitionedOutPutOperatorFactory = new OmniPartitionedOutPutOperatorFactory( + sourceTypes, false, nullChannel, partitionChannels, partitionCount, bucketToPartition, false, + hashChannelTypes, hashChannels); + OmniOperator omniOperator = omniPartitionedOutPutOperatorFactory.createOperator(); + omniOperator.addInput(vecBatch); + + Iterator results = omniOperator.getOutput(); + List resultList = new ArrayList<>(); + while (results.hasNext()) { + resultList.add(results.next()); + } + + assertEquals(resultList.get(0).getRowCount(), 3); + + Object[][] expectedDatas = {{"abc", "de", "f"}}; + assertVecBatchEquals(resultList.get(0), expectedDatas); + TestUtils.freeVecBatch(resultList.get(0)); + omniOperator.close(); + omniPartitionedOutPutOperatorFactory.close(); + } + + @Test(enabled = false) + public void testPartionOutCache() { + OptionalInt nullChannel = OptionalInt.empty(); + int[] partitionChannels = {0}; + int partitionCount = 1; + int[] bucketToPartition = {0}; + DataType[] hashChannelTypes = {VarcharDataType.VARCHAR}; + int[] hashChannels = {0}; + + DataType[] buildTypes = {new VarcharDataType(3), new VarcharDataType(3)}; + Object[][] buildDatas = {{"abc", "de", null}, {"abc", "de", null}}; + VecBatch vecBatch = createVecBatch(buildTypes, buildDatas); + DataType[] sourceTypes = {VarcharDataType.VARCHAR}; + + OmniPartitionedOutPutOperatorFactory omniPartitionedOutPutOperatorFactory = new OmniPartitionedOutPutOperatorFactory( + sourceTypes, false, nullChannel, partitionChannels, partitionCount, bucketToPartition, false, + hashChannelTypes, hashChannels); + OmniOperator omniOperator = omniPartitionedOutPutOperatorFactory.createOperator(); + omniOperator.addInput(vecBatch); + + Iterator results = omniOperator.getOutput(); + List resultList = new ArrayList<>(); + while (results.hasNext()) { + resultList.add(results.next()); + } + + assertEquals(resultList.get(0).getRowCount(), 3); + + Object[][] expectedDatas = {{"abc", "de", null}}; + assertVecBatchEquals(resultList.get(0), expectedDatas); + TestUtils.freeVecBatch(resultList.get(0)); + omniOperator.close(); + omniPartitionedOutPutOperatorFactory.close(); + } + + @Test(enabled = false) + public void testPartionOutChar() { + OptionalInt nullChannel = OptionalInt.empty(); + int[] partitionChannels = {0}; + int partitionCount = 1; + int[] bucketToPartition = {0}; + DataType[] hashChannelTypes = {CharDataType.CHAR}; + int[] hashChannels = {0}; + + DataType[] buildTypes = {new CharDataType(3), new CharDataType(3)}; + Object[][] buildDatas = {{"abc", "de", "f"}, {"def", "bc", "a"}}; + VecBatch vecBatch = createVecBatch(buildTypes, buildDatas); + DataType[] sourceTypes = {CharDataType.CHAR}; + + OmniPartitionedOutPutOperatorFactory omniPartitionedOutPutOperatorFactory = new OmniPartitionedOutPutOperatorFactory( + sourceTypes, false, nullChannel, partitionChannels, partitionCount, bucketToPartition, false, + hashChannelTypes, hashChannels); + OmniOperator omniOperator = omniPartitionedOutPutOperatorFactory.createOperator(); + omniOperator.addInput(vecBatch); + + Iterator results = omniOperator.getOutput(); + List resultList = new ArrayList<>(); + while (results.hasNext()) { + resultList.add(results.next()); + } + + assertEquals(resultList.get(0).getRowCount(), 3); + + Object[][] expectedDatas = {{"abc", "de", "f"}}; + assertVecBatchEquals(resultList.get(0), expectedDatas); + TestUtils.freeVecBatch(resultList.get(0)); + omniOperator.close(); + omniPartitionedOutPutOperatorFactory.close(); + } + + @Test + public void testFactoryContextEquals() { + DataType[] sourceTypes = {CharDataType.CHAR}; + OptionalInt nullChannel = OptionalInt.empty(); + int[] partitionChannels = {0}; + int partitionCount = 1; + int[] bucketToPartition = {0}; + DataType[] hashChannelTypes = {CharDataType.CHAR}; + int[] hashChannels = {0}; + + FactoryContext factory1 = new FactoryContext(sourceTypes, false, nullChannel, partitionChannels, partitionCount, + bucketToPartition, false, hashChannelTypes, hashChannels, new OperatorConfig()); + FactoryContext factory2 = new FactoryContext(sourceTypes, false, nullChannel, partitionChannels, partitionCount, + bucketToPartition, false, hashChannelTypes, hashChannels, new OperatorConfig()); + FactoryContext factory3 = null; + assertEquals(factory2, factory1); + assertEquals(factory1, factory1); + assertNotEquals(factory3, factory1); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniProjectOperatorTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniProjectOperatorTest.java new file mode 100644 index 0000000..645a6a3 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniProjectOperatorTest.java @@ -0,0 +1,339 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator; + +import static nova.hetu.omniruntime.util.TestUtils.assertVecBatchEquals; +import static nova.hetu.omniruntime.util.TestUtils.createVec; +import static nova.hetu.omniruntime.util.TestUtils.createVecBatch; +import static nova.hetu.omniruntime.util.TestUtils.freeVecBatch; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotEquals; +import static org.testng.Assert.assertTrue; + +import com.google.common.collect.ImmutableList; + +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.operator.project.OmniProjectOperatorFactory; +import nova.hetu.omniruntime.operator.project.OmniProjectOperatorFactory.FactoryContext; +import nova.hetu.omniruntime.type.BooleanDataType; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.Decimal128DataType; +import nova.hetu.omniruntime.type.DoubleDataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.type.VarcharDataType; +import nova.hetu.omniruntime.vector.BooleanVec; +import nova.hetu.omniruntime.vector.Decimal128Vec; +import nova.hetu.omniruntime.vector.DoubleVec; +import nova.hetu.omniruntime.vector.IntVec; +import nova.hetu.omniruntime.vector.LongVec; +import nova.hetu.omniruntime.vector.VarcharVec; +import nova.hetu.omniruntime.vector.Vec; +import nova.hetu.omniruntime.vector.VecBatch; + +import org.testng.annotations.Test; + +import java.nio.charset.StandardCharsets; +import java.util.Iterator; + +/** + * The type Omni project operator test. + * + * @since 2021-7-6 + */ +public class OmniProjectOperatorTest { + private ImmutableList makeInput(int nRows, Vec... cols) { + return ImmutableList.copyOf(new VecBatch[]{new VecBatch(cols)}); + } + + /** + * Simple test. + */ + @Test + public void simpleTest() { + String[] exprs = {"$operator$ADD:1(#0, 5:1)"}; + DataType[] inputTypes = {IntDataType.INTEGER}; + OmniProjectOperatorFactory factory = new OmniProjectOperatorFactory(exprs, inputTypes); + final int numRows = 1000; + IntVec col1 = new IntVec(numRows); + for (int i = 0; i < numRows; i++) { + col1.set(i, i); + } + OmniOperator op = factory.createOperator(); + ImmutableList vecBatches = makeInput(numRows, col1); + for (VecBatch vecBatch : vecBatches) { + op.addInput(vecBatch); + } + + Iterator vecBatchIterator = op.getOutput(); + assertTrue(vecBatchIterator.hasNext()); + VecBatch res = op.getOutput().next(); + assertFalse(vecBatchIterator.hasNext()); + assertEquals(res.getRowCount(), numRows); + for (int i = 0; i < res.getRowCount(); i++) { + assertEquals(((IntVec) res.getVector(0)).get(i), i + 5); + } + + freeVecBatch(res); + op.close(); + factory.close(); + } + + /** + * Complex test. + */ + @Test + public void complexTest() { + String[] exprs = {"$operator$MULTIPLY:1(#0, #1)", "IF:2($operator$LESS_THAN:4(#0, 500:1), 4000000000:2, #2)"}; + DataType[] inputTypes = {IntDataType.INTEGER, IntDataType.INTEGER, LongDataType.LONG}; + OmniProjectOperatorFactory factory = new OmniProjectOperatorFactory(exprs, inputTypes); + final int numRows = 1000; + IntVec col1 = new IntVec(numRows); + IntVec col2 = new IntVec(numRows); + LongVec col3 = new LongVec(numRows); + for (int i = 0; i < numRows; i++) { + col1.set(i, i + 1); + col2.set(i, i - 100); + col3.set(i, i + 3000000000L); + } + OmniOperator op = factory.createOperator(); + ImmutableList vecBatches = makeInput(numRows, col1, col2, col3); + for (VecBatch vecBatch : vecBatches) { + op.addInput(vecBatch); + } + + assertTrue(op.getOutput().hasNext()); + VecBatch res = op.getOutput().next(); + assertEquals(res.getRowCount(), numRows); + for (int i = 0; i < res.getRowCount(); i++) { + assertEquals(((IntVec) res.getVector(0)).get(i), (i + 1) * (i - 100)); + assertEquals(((LongVec) res.getVector(1)).get(i), (i + 1) < 500 ? 4000000000L : i + 3000000000L); + } + + freeVecBatch(res); + op.close(); + factory.close(); + } + + /** + * Murmur3hash&Pmod test. + */ + @Test + public void mm3HashAndPmodTest() { + final int numRows = 3; + final byte[] byteVal1 = "Wednesday".getBytes(StandardCharsets.UTF_8); + final byte[] byteVal2 = "Hello World".getBytes(StandardCharsets.UTF_8); + IntVec col1 = new IntVec(numRows); + DoubleVec col2 = new DoubleVec(numRows); + VarcharVec col3 = new VarcharVec(byteVal1.length + byteVal2.length, numRows); + Decimal128Vec col4 = new Decimal128Vec(numRows); + BooleanVec col5 = new BooleanVec(numRows); + col1.set(0, Integer.MIN_VALUE); + col2.set(0, Double.MAX_VALUE); + col3.set(0, byteVal1); + col4.set(0, new long[]{Long.MIN_VALUE, Long.MAX_VALUE}); + col5.set(0, true); + col1.set(1, Integer.MAX_VALUE); + col2.set(1, Double.MIN_VALUE); + col3.set(1, byteVal2); + col4.set(1, new long[]{Long.MAX_VALUE, Long.MIN_VALUE}); + col5.set(1, false); + // null value + col1.set(2, Integer.MIN_VALUE); + col1.setNull(2); + col2.set(2, Double.MAX_VALUE); + col2.setNull(2); + col3.setNull(2); + col4.set(2, new long[]{Long.MAX_VALUE, Long.MAX_VALUE}); + col4.setNull(2); + col5.setNull(2); + + String[] exprs = {"pmod:1(mm3hash:1(#0, 42:1), 42:1)", "mm3hash:1(#1, 42:1)", "mm3hash:1(#2, 42:1)", + "mm3hash:1(#3, 42:1)", "mm3hash:1(#4, 42:1)"}; + DataType[] inputTypes = {IntDataType.INTEGER, DoubleDataType.DOUBLE, VarcharDataType.VARCHAR, + Decimal128DataType.DECIMAL128, BooleanDataType.BOOLEAN}; + OmniProjectOperatorFactory factory = new OmniProjectOperatorFactory(exprs, inputTypes); + OmniOperator op = factory.createOperator(); + ImmutableList vecBatches = makeInput(numRows, col1, col2, col3, col4, col5); + for (VecBatch vecBatch : vecBatches) { + op.addInput(vecBatch); + } + + assertTrue(op.getOutput().hasNext()); + VecBatch res = op.getOutput().next(); + assertEquals(res.getRowCount(), numRows); + assertEquals(res.getVectors().length, exprs.length); + assertEquals(((IntVec) res.getVector(0)).get(0), 20); + assertEquals(((IntVec) res.getVector(1)).get(0), -508695674); + assertEquals(((IntVec) res.getVector(2)).get(0), 613818021); + assertEquals(((IntVec) res.getVector(3)).get(0), 1090190174); + assertEquals(((IntVec) res.getVector(4)).get(0), -559580957); + assertEquals(((IntVec) res.getVector(0)).get(1), 25); + assertEquals(((IntVec) res.getVector(1)).get(1), -1712319331); + assertEquals(((IntVec) res.getVector(2)).get(1), 352365215); + assertEquals(((IntVec) res.getVector(3)).get(1), 1352383760); + assertEquals(((IntVec) res.getVector(4)).get(1), 933211791); + // null value check + assertEquals(((IntVec) res.getVector(0)).get(2), 0); + assertEquals(((IntVec) res.getVector(1)).get(2), 42); + assertEquals(((IntVec) res.getVector(2)).get(2), 42); + assertEquals(((IntVec) res.getVector(3)).get(2), 42); + assertEquals(((IntVec) res.getVector(4)).get(2), 42); + + freeVecBatch(res); + op.close(); + factory.close(); + } + + /** + * xxHash64 test. + */ + @Test + public void xxHash64StringTest() { + DataType[] inputTypes = {new VarcharDataType(50)}; + Object[][] datas = {{"hello world", "abcdefghijklmnopqrstuvwxyz ABCDEFGHIJKLMNOPQRSTUVWXYZ", "china"}}; + VecBatch vecBatch = createVecBatch(inputTypes, datas); + String[] expressions = {"{\"exprType\":\"FUNCTION\",\"returnType\":2,\"function_name\":\"xxhash64\"," + + "\"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":0,\"width\":50}," + + "{\"exprType\":\"LITERAL\",\"dataType\":2,\"isNull\":false,\"value\":42}]}"}; + + OmniProjectOperatorFactory factory = new OmniProjectOperatorFactory(expressions, inputTypes, 1, + new OperatorConfig()); + + OmniOperator op = factory.createOperator(); + op.addInput(vecBatch); + + Iterator results = op.getOutput(); + VecBatch resultVecBatch = results.next(); + + Object[][] expectDatas = {{7620854247404556961L, -8961370173016112133L, 1148854020565811068L}}; + assertVecBatchEquals(resultVecBatch, expectDatas); + + freeVecBatch(resultVecBatch); + op.close(); + factory.close(); + } + + @Test + public void xxHash64Decimal128Test() { + DataType[] inputTypes = {new Decimal128DataType(38, 16)}; + Object[][] datas = {{4000L, 0L}, {2000L, 0L}, {1000L, 0L}}; + Vec[] buildVecs = new Vec[inputTypes.length]; + buildVecs[0] = createVec(inputTypes[0], datas); + VecBatch vecBatch = new VecBatch(buildVecs); + String[] expressions = {"{\"exprType\":\"FUNCTION\",\"returnType\":2,\"function_name\":\"xxhash64\"," + + "\"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":7,\"precision\":38,\"scale\":16," + + "\"colVal\":0},{\"exprType\":\"LITERAL\",\"dataType\":2,\"isNull\":false,\"value\":42}]}"}; + + OmniProjectOperatorFactory factory = new OmniProjectOperatorFactory(expressions, inputTypes, 1, + new OperatorConfig()); + + OmniOperator op = factory.createOperator(); + op.addInput(vecBatch); + + Iterator results = op.getOutput(); + VecBatch resultVecBatch = results.next(); + + Object[][] expectDatas = {{4122469761574251967L, -1100376009453395183L, -1241606273492999864L}}; + assertVecBatchEquals(resultVecBatch, expectDatas); + + freeVecBatch(resultVecBatch); + op.close(); + factory.close(); + } + + /** + * Unsupported expression. + */ + @Test + public void unsupportedCast() { + DataType[] types = {}; + String[] projectionsJSON = {"{\"exprType\": \"FUNCTION\", \"returnType\": 2, \"function_name\": \"CAST\", " + + "\"arguments\": [{\"exprType\": \"IF\", \"returnType\": 1, \"condition\": {\"exprType\": " + + "\"FUNCTION\", \"returnType\": 4, \"function_name\": \"not\", \"arguments\": " + + "[{ \"exprType\": \"LITERAL\", \"dataType\": 1, \"isNull\": true}]}, \"if_true\": " + + "{ \"exprType\": \"LITERAL\", \"dataType\": 1, \"isNull\": false, \"value\": 1}, " + + "\"if_false\": { \"exprType\": \"LITERAL\", \"dataType\": 1, \"isNull\": false, \"value\": 0}}]}"}; + + OmniProjectOperatorFactory factory = new OmniProjectOperatorFactory(projectionsJSON, types, 1); + + assertFalse(factory.isSupported()); + factory.close(); + } + + @Test + public void testFactoryContextEquals() { + DataType[] types = {}; + String[] projectionsJSON = {"{\"exprType\": \"FUNCTION\", \"returnType\": 2, \"function_name\": \"CAST\", " + + "\"arguments\": [{\"exprType\": \"IF\", \"returnType\": 1, \"condition\": {\"exprType\": " + + "\"FUNCTION\", \"returnType\": 4, \"function_name\": \"not\", \"arguments\": " + + "[{ \"exprType\": \"LITERAL\", \"dataType\": 1, \"isNull\": true}]}, \"if_true\": " + + "{ \"exprType\": \"LITERAL\", \"dataType\": 1, \"isNull\": false, \"value\": 1}, " + + "\"if_false\": { \"exprType\": \"LITERAL\", \"dataType\": 1, \"isNull\": false, \"value\": 0}}]}"}; + FactoryContext factory1 = new FactoryContext(projectionsJSON, types, 1, new OperatorConfig()); + FactoryContext factory2 = new FactoryContext(projectionsJSON, types, 1, new OperatorConfig()); + FactoryContext factory3 = null; + assertEquals(factory2, factory1); + assertEquals(factory1, factory1); + assertNotEquals(factory3, factory1); + } + + @Test + public void testReplaceWithRep() { + DataType[] types = {new VarcharDataType(20), new VarcharDataType(10), new VarcharDataType(10)}; + Object[][] datas = {{"varchar100", "varchar200", "varchar300"}, {"char1", "char", "char3"}, + {"opera", "*#", "VARCHAR"}}; + VecBatch vecBatch = createVecBatch(types, datas); + String[] expressions = {"{\"exprType\":\"FUNCTION\",\"returnType\":15,\"function_name\":\"replace\"," + + "\"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":0,\"width\":20}," + + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":1,\"width\":10},{\"exprType\":" + + "\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":10}],\"width\":100}"}; + DataType[] inputTypes = {new VarcharDataType(20), new VarcharDataType(10), new VarcharDataType(10)}; + + OmniProjectOperatorFactory factory = new OmniProjectOperatorFactory(expressions, inputTypes, 1, + new OperatorConfig()); + + OmniOperator op = factory.createOperator(); + op.addInput(vecBatch); + + Iterator results = op.getOutput(); + VecBatch resultVecBatch = results.next(); + + Object[][] expectDatas = {{"varopera00", "var*#200", "varVARCHAR00"}}; + assertVecBatchEquals(resultVecBatch, expectDatas); + + freeVecBatch(resultVecBatch); + op.close(); + factory.close(); + } + + @Test + public void testReplaceWithoutRep() { + DataType[] types = {new VarcharDataType(20), new VarcharDataType(10)}; + Object[][] datas = {{"varchar100", "varchar200", "varchar300"}, {"char1", "char2", "char3"}}; + VecBatch vecBatch = createVecBatch(types, datas); + String[] expressions = {"{\"exprType\":\"FUNCTION\",\"returnType\":15,\"function_name\":\"replace\"," + + "\"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":0,\"width\":20}," + + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":1,\"width\":10}],\"width\":100}"}; + DataType[] inputTypes = {new VarcharDataType(20), new VarcharDataType(10)}; + + OmniProjectOperatorFactory factory = new OmniProjectOperatorFactory(expressions, inputTypes, 1, + new OperatorConfig()); + + OmniOperator op = factory.createOperator(); + op.addInput(vecBatch); + + Iterator results = op.getOutput(); + VecBatch resultVecBatch = results.next(); + + Object[][] expectDatas = {{"var00", "var00", "var00"}}; + assertVecBatchEquals(resultVecBatch, expectDatas); + + freeVecBatch(resultVecBatch); + op.close(); + factory.close(); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniSortMergeJoinWithExprOperatorsTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniSortMergeJoinWithExprOperatorsTest.java new file mode 100644 index 0000000..1a7dfa0 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniSortMergeJoinWithExprOperatorsTest.java @@ -0,0 +1,447 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator; + +import static nova.hetu.omniruntime.constants.JoinType.OMNI_JOIN_TYPE_INNER; +import static nova.hetu.omniruntime.constants.JoinType.OMNI_JOIN_TYPE_LEFT_ANTI; +import static nova.hetu.omniruntime.constants.JoinType.OMNI_JOIN_TYPE_LEFT_SEMI; +import static nova.hetu.omniruntime.util.TestUtils.assertVecBatchEquals; +import static nova.hetu.omniruntime.util.TestUtils.createBlankVecBatch; +import static nova.hetu.omniruntime.util.TestUtils.createVecBatch; +import static nova.hetu.omniruntime.util.TestUtils.decodeAddFlag; +import static nova.hetu.omniruntime.util.TestUtils.decodeFetchFlag; +import static nova.hetu.omniruntime.util.TestUtils.freeVecBatch; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonFieldReference; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonLiteral; +import static nova.hetu.omniruntime.util.TestUtils.omniFunctionExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonFourArithmeticExpr; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; + +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.operator.join.OmniSmjBufferedTableWithExprOperatorFactory; +import nova.hetu.omniruntime.operator.join.OmniSmjStreamedTableWithExprOperatorFactory; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DoubleDataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.utils.OmniRuntimeException; +import nova.hetu.omniruntime.vector.VecBatch; + +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; + +/** + * The type Omni sort merge join with expression operators test. + * + * @since 2022-1-10 + */ +public class OmniSortMergeJoinWithExprOperatorsTest { + /** + * Test inner hash join one column 1. + */ + @Test + public void testSmjOneTimeEqualCondition() { + DataType[] streamedTypes = {IntDataType.INTEGER, LongDataType.LONG}; + + String[] streamedKeyExps = { + omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonFieldReference(1, 0), getOmniJsonLiteral(1, false, 5))}; + int[] streamedOutputCols = {1}; + OmniSmjStreamedTableWithExprOperatorFactory streamedBuilderWithExprOperatorFactory = + new OmniSmjStreamedTableWithExprOperatorFactory( + streamedTypes, streamedKeyExps, streamedOutputCols, OMNI_JOIN_TYPE_INNER, Optional.empty()); + + DataType[] bufferedTypes = {LongDataType.LONG, IntDataType.INTEGER}; + + int[] bufferedOutputCols = {0}; + String[] bufferedKeyExps = { + omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonFieldReference(1, 1), getOmniJsonLiteral(1, false, 5))}; + OmniSmjBufferedTableWithExprOperatorFactory bufferedWithExprOperatorFactory = + new OmniSmjBufferedTableWithExprOperatorFactory( + bufferedTypes, bufferedKeyExps, bufferedOutputCols, streamedBuilderWithExprOperatorFactory); + OmniOperator bufferedTableOperator = bufferedWithExprOperatorFactory.createOperator(); + + // start to add input + Object[][] streamedDatas1 = {{0, 1, 2, 3, 4, 5}, {6600L, 5500L, 4400L, 3300L, 2200L, 1100L}}; + VecBatch streamedVecBatch1 = createVecBatch(streamedTypes, streamedDatas1); + OmniOperator streamedTableOperator = streamedBuilderWithExprOperatorFactory.createOperator(); + int intputResult = streamedTableOperator.addInput(streamedVecBatch1); + assertEquals(decodeAddFlag(intputResult), 3); + + Object[][] bufferedDatas1 = {{6006L, 5005L, 4004L, 3003L, 2002L, 1001L}, {0, 1, 2, 3, 4, 5}}; + VecBatch bufferedVecBatch1 = createVecBatch(bufferedTypes, bufferedDatas1); + intputResult = bufferedTableOperator.addInput(bufferedVecBatch1); + assertEquals(decodeAddFlag(intputResult), 3); + + VecBatch bufferedVecBatchEof = createBlankVecBatch(bufferedTypes); + intputResult = bufferedTableOperator.addInput(bufferedVecBatchEof); + assertEquals(decodeAddFlag(intputResult), 2); + + VecBatch streamedVecBatchEof = createBlankVecBatch(streamedTypes); + intputResult = streamedTableOperator.addInput(streamedVecBatchEof); + assertEquals(decodeFetchFlag(intputResult), 5); + + Iterator results = bufferedTableOperator.getOutput(); + VecBatch resultVecBatch = results.next(); + + int len = resultVecBatch.getRowCount(); + assertEquals(len, 6); + Object[][] expectedDatas = {{6600L, 5500L, 4400L, 3300L, 2200L, 1100L}, + {6006L, 5005L, 4004L, 3003L, 2002L, 1001L}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + + freeVecBatch(resultVecBatch); + bufferedTableOperator.close(); + bufferedWithExprOperatorFactory.close(); + streamedTableOperator.close(); + streamedBuilderWithExprOperatorFactory.close(); + } + + @Test(expectedExceptions = OmniRuntimeException.class, expectedExceptionsMessageRegExp = + ".*EXPRESSION_NOT_SUPPORT.*") + public void testInvalidStreamedKeys() { + DataType[] streamedTypes = {IntDataType.INTEGER, LongDataType.LONG}; + String[] streamedKeyExps = {omniFunctionExpr("abc", 1, getOmniJsonFieldReference(1, 0))}; + int[] streamedOutputCols = {1}; + OmniSmjStreamedTableWithExprOperatorFactory streamedBuilderWithExprOperatorFactory = new OmniSmjStreamedTableWithExprOperatorFactory( + streamedTypes, streamedKeyExps, streamedOutputCols, OMNI_JOIN_TYPE_INNER, Optional.empty()); + } + + @Test(expectedExceptions = OmniRuntimeException.class, expectedExceptionsMessageRegExp = + ".*EXPRESSION_NOT_SUPPORT.*") + public void testInvalidBufferedKeys() { + DataType[] streamedTypes = {IntDataType.INTEGER, LongDataType.LONG}; + String[] streamedKeyExps = { + omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonFieldReference(1, 0), getOmniJsonLiteral(1, false, 5))}; + int[] streamedOutputCols = {1}; + OmniSmjStreamedTableWithExprOperatorFactory streamedBuilderWithExprOperatorFactory = + new OmniSmjStreamedTableWithExprOperatorFactory( + streamedTypes, streamedKeyExps, streamedOutputCols, OMNI_JOIN_TYPE_INNER, Optional.empty()); + + DataType[] bufferedTypes = {LongDataType.LONG, IntDataType.INTEGER}; + int[] bufferedOutputCols = {0}; + String[] bufferedKeyExps = {omniFunctionExpr("abc", 2, getOmniJsonFieldReference(2, 1))}; + OmniSmjBufferedTableWithExprOperatorFactory bufferedWithExprOperatorFactory = + new OmniSmjBufferedTableWithExprOperatorFactory( + bufferedTypes, bufferedKeyExps, bufferedOutputCols, streamedBuilderWithExprOperatorFactory); + } + + @Test + public void testFactoryContextEquals() { + DataType[] streamedTypes = {IntDataType.INTEGER, LongDataType.LONG}; + + String[] streamedKeyExps = { + omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonFieldReference(1, 0), getOmniJsonLiteral(1, false, 5))}; + int[] streamedOutputCols = {1}; + OmniSmjStreamedTableWithExprOperatorFactory.FactoryContext streamedBuilderWithExprOperatorFactory1 = + new OmniSmjStreamedTableWithExprOperatorFactory.FactoryContext( + streamedTypes, streamedKeyExps, streamedOutputCols, OMNI_JOIN_TYPE_INNER, Optional.empty(), + new OperatorConfig()); + OmniSmjStreamedTableWithExprOperatorFactory.FactoryContext streamedBuilderWithExprOperatorFactory2 = + new OmniSmjStreamedTableWithExprOperatorFactory.FactoryContext( + streamedTypes, streamedKeyExps, streamedOutputCols, OMNI_JOIN_TYPE_INNER, Optional.empty(), + new OperatorConfig()); + OmniSmjStreamedTableWithExprOperatorFactory.FactoryContext streamedBuilderWithExprOperatorFactory3 = null; + assertEquals(streamedBuilderWithExprOperatorFactory2, streamedBuilderWithExprOperatorFactory1); + assertEquals(streamedBuilderWithExprOperatorFactory1, streamedBuilderWithExprOperatorFactory1); + assertNotEquals(streamedBuilderWithExprOperatorFactory3, streamedBuilderWithExprOperatorFactory1); + + DataType[] bufferedTypes = {LongDataType.LONG, IntDataType.INTEGER}; + + OmniSmjStreamedTableWithExprOperatorFactory streamedBuilderWithExprOperatorFactory = + new OmniSmjStreamedTableWithExprOperatorFactory( + streamedTypes, streamedKeyExps, streamedOutputCols, OMNI_JOIN_TYPE_INNER, Optional.empty()); + + int[] bufferedOutputCols = {0}; + String[] bufferedKeyExps = { + omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonFieldReference(1, 0), getOmniJsonLiteral(1, false, 5))}; + OmniSmjBufferedTableWithExprOperatorFactory.FactoryContext bufferedWithExprOperatorFactory1 = + new OmniSmjBufferedTableWithExprOperatorFactory.FactoryContext( + bufferedTypes, bufferedKeyExps, bufferedOutputCols, new OperatorConfig(), + streamedBuilderWithExprOperatorFactory); + OmniSmjBufferedTableWithExprOperatorFactory.FactoryContext bufferedWithExprOperatorFactory2 = + new OmniSmjBufferedTableWithExprOperatorFactory.FactoryContext( + bufferedTypes, bufferedKeyExps, bufferedOutputCols, new OperatorConfig(), + streamedBuilderWithExprOperatorFactory); + OmniSmjBufferedTableWithExprOperatorFactory.FactoryContext bufferedWithExprOperatorFactory3 = null; + assertEquals(bufferedWithExprOperatorFactory2, bufferedWithExprOperatorFactory1); + assertEquals(bufferedWithExprOperatorFactory1, bufferedWithExprOperatorFactory1); + assertNotEquals(bufferedWithExprOperatorFactory3, bufferedWithExprOperatorFactory1); + } + + /** + * Test left semi join one column 1. + */ + @Test + public void testSmjOneTimeEqualConditionForLeftSemiJoin() { + DataType[] streamedTypes = {IntDataType.INTEGER, LongDataType.LONG}; + String[] streamedKeyExps = { + omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonFieldReference(1, 0), getOmniJsonLiteral(1, false, 5))}; + int[] streamedOutputCols = {0, 1}; + OmniSmjStreamedTableWithExprOperatorFactory streamedBuilderWithExprOperatorFactory = new OmniSmjStreamedTableWithExprOperatorFactory( + streamedTypes, streamedKeyExps, streamedOutputCols, OMNI_JOIN_TYPE_LEFT_SEMI, Optional.empty()); + + DataType[] bufferedTypes = {LongDataType.LONG, IntDataType.INTEGER}; + int[] bufferedOutputCols = {}; + String[] bufferedKeyExps = { + omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonFieldReference(1, 1), getOmniJsonLiteral(1, false, 5))}; + OmniSmjBufferedTableWithExprOperatorFactory bufferedWithExprOperatorFactory = new OmniSmjBufferedTableWithExprOperatorFactory( + bufferedTypes, bufferedKeyExps, bufferedOutputCols, streamedBuilderWithExprOperatorFactory); + OmniOperator bufferedTableOperator = bufferedWithExprOperatorFactory.createOperator(); + + // start to add input + Object[][] streamedDatas1 = {{0, 1, 2, 2, 2, 3, 4, 5}, + {8800L, 7700L, 6600L, 5500L, 4400L, 3300L, 2200L, 1100L}}; + VecBatch streamedVecBatch1 = createVecBatch(streamedTypes, streamedDatas1); // 0, 1, 2, 2, 2, 3, 4, 5 + OmniOperator streamedTableOperator = streamedBuilderWithExprOperatorFactory.createOperator(); + int intputResult = streamedTableOperator.addInput(streamedVecBatch1); + assertEquals(decodeAddFlag(intputResult), 3); + + Object[][] bufferedDatas1 = {{8008L, 7007L, 6006L, 5005L, 4004L, 3003L, 2002L, 1001L}, + {0, 1, 2, 2, 3, 3, 4, 5}}; + VecBatch bufferedVecBatch1 = createVecBatch(bufferedTypes, bufferedDatas1); // 0, 1, 2, 2, 3, 3, 4, 5 + intputResult = bufferedTableOperator.addInput(bufferedVecBatch1); + assertEquals(decodeAddFlag(intputResult), 3); + + VecBatch bufferedVecBatchEof = createBlankVecBatch(bufferedTypes); + intputResult = bufferedTableOperator.addInput(bufferedVecBatchEof); + assertEquals(decodeAddFlag(intputResult), 2); + + VecBatch streamedVecBatchEof = createBlankVecBatch(streamedTypes); + intputResult = streamedTableOperator.addInput(streamedVecBatchEof); + assertEquals(decodeFetchFlag(intputResult), 5); + + Iterator results = bufferedTableOperator.getOutput(); + VecBatch resultVecBatch = results.next(); + + int len = resultVecBatch.getRowCount(); + assertEquals(len, 8); + Object[][] expectedDatas = {{0, 1, 2, 2, 2, 3, 4, 5}, {8800L, 7700L, 6600L, 5500L, 4400L, 3300L, 2200L, 1100L}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + + freeVecBatch(resultVecBatch); + bufferedTableOperator.close(); + bufferedWithExprOperatorFactory.close(); + streamedTableOperator.close(); + streamedBuilderWithExprOperatorFactory.close(); + } + + /** + * Test left anti join one column 1. + */ + @Test + public void testSmjOneTimeEqualConditionForLeftAntiJoin() { + DataType[] streamedTypes = {IntDataType.INTEGER, LongDataType.LONG}; + String[] streamedKeyExps = { + omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonFieldReference(1, 0), getOmniJsonLiteral(1, false, 5))}; + int[] streamedOutputCols = {0, 1}; + OmniSmjStreamedTableWithExprOperatorFactory streamedBuilderWithExprOperatorFactory = new OmniSmjStreamedTableWithExprOperatorFactory( + streamedTypes, streamedKeyExps, streamedOutputCols, OMNI_JOIN_TYPE_LEFT_ANTI, Optional.empty()); + + DataType[] bufferedTypes = {IntDataType.INTEGER, DoubleDataType.DOUBLE}; + int[] bufferedOutputCols = {}; + String[] bufferedKeyExps = { + omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonFieldReference(1, 0), getOmniJsonLiteral(1, false, 5))}; + OmniSmjBufferedTableWithExprOperatorFactory bufferedWithExprOperatorFactory = new OmniSmjBufferedTableWithExprOperatorFactory( + bufferedTypes, bufferedKeyExps, bufferedOutputCols, streamedBuilderWithExprOperatorFactory); + OmniOperator bufferedTableOperator = bufferedWithExprOperatorFactory.createOperator(); + + // start to add input + Object[][] streamedDatas1 = {{1, 1, 2, 3}, {40L, 25L, 35L, 30L}}; + VecBatch streamedVecBatch1 = createVecBatch(streamedTypes, streamedDatas1); + OmniOperator streamedTableOperator = streamedBuilderWithExprOperatorFactory.createOperator(); + int intputResult = streamedTableOperator.addInput(streamedVecBatch1); + assertEquals(decodeAddFlag(intputResult), 3); + + Object[][] bufferedDatas1 = {{3, 3, 4, 4}, {3.3, 3.5, 4.4, 4.5}}; + VecBatch bufferedVecBatch1 = createVecBatch(bufferedTypes, bufferedDatas1); + intputResult = bufferedTableOperator.addInput(bufferedVecBatch1); + assertEquals(decodeAddFlag(intputResult), 2); + + VecBatch streamedVecBatchEof = createBlankVecBatch(streamedTypes); + intputResult = streamedTableOperator.addInput(streamedVecBatchEof); + assertEquals(decodeFetchFlag(intputResult), 5); + + Iterator results = bufferedTableOperator.getOutput(); + VecBatch resultVecBatch = results.next(); + + int len = resultVecBatch.getRowCount(); + assertEquals(len, 3); + Object[][] expectedDatas = {{1, 1, 2}, {40L, 25L, 35L}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + + freeVecBatch(resultVecBatch); + bufferedTableOperator.close(); + bufferedWithExprOperatorFactory.close(); + streamedTableOperator.close(); + streamedBuilderWithExprOperatorFactory.close(); + } + + private void buildAddInputData(Object[][] streamedData1, Object[][] bufferedData1, Object[][] streamedData2, + Object[][] bufferedData2, int tableSize) { + for (int i = 0; i < tableSize; i++) { + streamedData1[0][i] = i; + streamedData2[0][i] = i + tableSize; + for (int j = 1; j < streamedData1.length; j++) { + streamedData1[j][i] = i + 1001L; + streamedData2[j][i] = i + 1001L + tableSize; + } + + bufferedData1[bufferedData1.length - 1][i] = i; + bufferedData2[bufferedData2.length - 1][i] = i + tableSize; + for (int k = 0; k < bufferedData1.length - 1; k++) { + bufferedData1[k][i] = i + 1003L; + bufferedData2[k][i] = i + 1003L + tableSize; + } + } + } + + private void buildExpectedData(Object[][] expectedData1, Object[][] expectedData2, Object[][] expectedData3, + Object[][] expectedData4, int maxRowCount, int remainCount) { + for (int i = 0; i < maxRowCount; i++) { + for (int j = 0; j < 4; j++) { + expectedData1[j][i] = i + 1001L; + expectedData1[j + 4][i] = i + 1003L; + expectedData3[j][i] = i + 1001L + maxRowCount + remainCount; + expectedData3[j + 4][i] = i + 1003L + maxRowCount + remainCount; + } + } + + for (int i = 0; i < remainCount; i++) { + for (int j = 0; j < 4; j++) { + expectedData2[j][i] = i + 1001L + maxRowCount; + expectedData4[j][i] = i + 1001L + 2 * maxRowCount + remainCount; + expectedData2[j + 4][i] = i + 1003L + maxRowCount; + expectedData4[j + 4][i] = i + 1003L + 2 * maxRowCount + remainCount; + } + } + for (int j = 0; j < 4; j++) { + expectedData4[j][remainCount] = remainCount + 1001L + 2 * maxRowCount + remainCount; + expectedData4[j + 4][remainCount] = remainCount + 1003L + 2 * maxRowCount + remainCount; + } + } + + /** + * Test smj iterative getOutput. + */ + @Test + public void testSmjIterativeGetOutput() { + DataType[] streamedTypes = {IntDataType.INTEGER, LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, + LongDataType.LONG}; + + String[] streamedKeyExps = { + omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonFieldReference(1, 0), getOmniJsonLiteral(1, false, 5))}; + int[] streamedOutputCols = {1, 2, 3, 4}; + OmniSmjStreamedTableWithExprOperatorFactory streamedBuilderWithExprOperatorFactory = + new OmniSmjStreamedTableWithExprOperatorFactory(streamedTypes, streamedKeyExps, streamedOutputCols, + OMNI_JOIN_TYPE_INNER, Optional.empty()); + + DataType[] bufferedTypes = {LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, + IntDataType.INTEGER}; + + int[] bufferedOutputCols = {0, 1, 2, 3}; + + String[] bufferedKeyExps = { + omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonFieldReference(1, 4), getOmniJsonLiteral(1, false, 5))}; + OmniSmjBufferedTableWithExprOperatorFactory bufferedWithExprOperatorFactory = + new OmniSmjBufferedTableWithExprOperatorFactory(bufferedTypes, bufferedKeyExps, bufferedOutputCols, + streamedBuilderWithExprOperatorFactory); + OmniOperator bufferedTableOperator = bufferedWithExprOperatorFactory.createOperator(); + + // construct addInput data + int tableSize = 20000; + Object[][] streamedData1 = new Object[5][tableSize]; + Object[][] bufferedData1 = new Object[5][tableSize]; + Object[][] streamedData2 = new Object[5][tableSize]; + Object[][] bufferedData2 = new Object[5][tableSize]; + buildAddInputData(streamedData1, bufferedData1, streamedData2, bufferedData2, tableSize); + + // start to add input + VecBatch streamedVecBatch1 = createVecBatch(streamedTypes, streamedData1); + OmniOperator streamedTableOperator = streamedBuilderWithExprOperatorFactory.createOperator(); + int intputResult = streamedTableOperator.addInput(streamedVecBatch1); + assertEquals(decodeAddFlag(intputResult), 3); + + VecBatch bufferedVecBatch1 = createVecBatch(bufferedTypes, bufferedData1); + intputResult = bufferedTableOperator.addInput(bufferedVecBatch1); + assertEquals(decodeAddFlag(intputResult), 3); + assertEquals(decodeFetchFlag(intputResult), 5); + + Iterator results = bufferedTableOperator.getOutput(); + VecBatch resultVecBatch = null; + + List result = new ArrayList<>(); + while (results.hasNext()) { + resultVecBatch = results.next(); + result.add(resultVecBatch); + } + + VecBatch bufferedVecBatch2 = createVecBatch(bufferedTypes, bufferedData2); + intputResult = bufferedTableOperator.addInput(bufferedVecBatch2); + assertEquals(decodeAddFlag(intputResult), 2); + + VecBatch streamedVecBatch2 = createVecBatch(streamedTypes, streamedData2); + intputResult = streamedTableOperator.addInput(streamedVecBatch2); + assertEquals(decodeAddFlag(intputResult), 3); + assertEquals(decodeFetchFlag(intputResult), 5); + + results = streamedTableOperator.getOutput(); + while (results.hasNext()) { + resultVecBatch = results.next(); + result.add(resultVecBatch); + } + + VecBatch bufferedVecBatchEof = createBlankVecBatch(bufferedTypes); + intputResult = bufferedTableOperator.addInput(bufferedVecBatchEof); + assertEquals(decodeAddFlag(intputResult), 2); + + VecBatch streamedVecBatchEof = createBlankVecBatch(streamedTypes); + intputResult = streamedTableOperator.addInput(streamedVecBatchEof); + assertEquals(decodeFetchFlag(intputResult), 5); + + results = bufferedTableOperator.getOutput(); + + while (results.hasNext()) { + resultVecBatch = results.next(); + result.add(resultVecBatch); + } + + int rowCount = 0; + for (int i = 0; i < result.size(); i++) { + rowCount += result.get(i).getRowCount(); + } + + assertEquals(rowCount, 2 * tableSize); + + int maxRowCount = 16384; // 1M / (8 * 8) + int remainCount = tableSize - maxRowCount - 1; + // construct expected data + Object[][] expectedData1 = new Object[8][maxRowCount]; + Object[][] expectedData2 = new Object[8][remainCount]; + Object[][] expectedData3 = new Object[8][maxRowCount]; + Object[][] expectedData4 = new Object[8][remainCount + 1]; + buildExpectedData(expectedData1, expectedData2, expectedData3, expectedData4, maxRowCount, remainCount); + Object[][] expectedData5 = {{41000L}, {41000L}, {41000L}, {41000L}, {41002L}, {41002L}, {41002L}, {41002L}}; + + assertVecBatchEquals(result.get(0), expectedData1); + assertVecBatchEquals(result.get(1), expectedData2); + assertVecBatchEquals(result.get(2), expectedData3); + assertVecBatchEquals(result.get(3), expectedData4); + assertVecBatchEquals(result.get(4), expectedData5); + + for (int i = 0; i < result.size(); i++) { + freeVecBatch(result.get(i)); + } + + bufferedTableOperator.close(); + bufferedWithExprOperatorFactory.close(); + streamedTableOperator.close(); + streamedBuilderWithExprOperatorFactory.close(); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniSortMergeJoinWithExprOperatorsV3Test.java b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniSortMergeJoinWithExprOperatorsV3Test.java new file mode 100644 index 0000000..d24521c --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniSortMergeJoinWithExprOperatorsV3Test.java @@ -0,0 +1,259 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + */ + +package nova.hetu.omniruntime.operator; + +import static nova.hetu.omniruntime.constants.JoinType.OMNI_JOIN_TYPE_INNER; +import static nova.hetu.omniruntime.constants.JoinType.OMNI_JOIN_TYPE_LEFT_ANTI; +import static nova.hetu.omniruntime.constants.JoinType.OMNI_JOIN_TYPE_LEFT_SEMI; +import static nova.hetu.omniruntime.util.TestUtils.assertVecBatchEquals; +import static nova.hetu.omniruntime.util.TestUtils.createVecBatch; +import static nova.hetu.omniruntime.util.TestUtils.freeVecBatch; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonFieldReference; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonLiteral; +import static nova.hetu.omniruntime.util.TestUtils.omniFunctionExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonFourArithmeticExpr; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; + +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.operator.join.OmniSmjBufferedTableWithExprOperatorFactoryV3; +import nova.hetu.omniruntime.operator.join.OmniSmjStreamedTableWithExprOperatorFactoryV3; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DoubleDataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.utils.OmniRuntimeException; +import nova.hetu.omniruntime.vector.VecBatch; + +import org.testng.annotations.Test; + +import java.util.Iterator; +import java.util.Optional; + +/** + * The type Omni sort merge join with expression operators test. + * + * @since 2022-1-10 + */ +public class OmniSortMergeJoinWithExprOperatorsV3Test { + /** + * Test inner hash join one column 1. + */ + @Test + public void testSmjOneTimeEqualCondition() { + DataType[] streamedTypes = {IntDataType.INTEGER, LongDataType.LONG}; + + String[] streamedKeyExps = { + omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonFieldReference(1, 0), getOmniJsonLiteral(1, false, 5))}; + int[] streamedOutputCols = {1}; + OmniSmjStreamedTableWithExprOperatorFactoryV3 streamedBuilderWithExprOperatorFactory = + new OmniSmjStreamedTableWithExprOperatorFactoryV3( + streamedTypes, streamedKeyExps, streamedOutputCols, OMNI_JOIN_TYPE_INNER, Optional.empty()); + + DataType[] bufferedTypes = {LongDataType.LONG, IntDataType.INTEGER}; + + int[] bufferedOutputCols = {0}; + String[] bufferedKeyExps = { + omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonFieldReference(1, 1), getOmniJsonLiteral(1, false, 5))}; + OmniSmjBufferedTableWithExprOperatorFactoryV3 bufferedWithExprOperatorFactory = + new OmniSmjBufferedTableWithExprOperatorFactoryV3( + bufferedTypes, bufferedKeyExps, bufferedOutputCols, streamedBuilderWithExprOperatorFactory); + OmniOperator bufferedTableOperator = bufferedWithExprOperatorFactory.createOperator(); + + // start to add input + Object[][] streamedDatas1 = {{0, 1, 2, 3, 4, 5}, {6600L, 5500L, 4400L, 3300L, 2200L, 1100L}}; + VecBatch streamedVecBatch1 = createVecBatch(streamedTypes, streamedDatas1); + OmniOperator streamedTableOperator = streamedBuilderWithExprOperatorFactory.createOperator(); + streamedTableOperator.addInput(streamedVecBatch1); + + Object[][] bufferedDatas1 = {{6006L, 5005L, 4004L, 3003L, 2002L, 1001L}, {0, 1, 2, 3, 4, 5}}; + VecBatch bufferedVecBatch1 = createVecBatch(bufferedTypes, bufferedDatas1); + bufferedTableOperator.addInput(bufferedVecBatch1); + + Iterator results = bufferedTableOperator.getOutput(); + VecBatch resultVecBatch = results.next(); + + int len = resultVecBatch.getRowCount(); + assertEquals(len, 6); + Object[][] expectedDatas = {{6600L, 5500L, 4400L, 3300L, 2200L, 1100L}, + {6006L, 5005L, 4004L, 3003L, 2002L, 1001L}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + + freeVecBatch(resultVecBatch); + bufferedTableOperator.close(); + bufferedWithExprOperatorFactory.close(); + streamedTableOperator.close(); + streamedBuilderWithExprOperatorFactory.close(); + } + + @Test(expectedExceptions = OmniRuntimeException.class, expectedExceptionsMessageRegExp = + ".*EXPRESSION_NOT_SUPPORT.*") + public void testInvalidStreamedKeys() { + DataType[] streamedTypes = {IntDataType.INTEGER, LongDataType.LONG}; + String[] streamedKeyExps = {omniFunctionExpr("abc", 1, getOmniJsonFieldReference(1, 0))}; + int[] streamedOutputCols = {1}; + OmniSmjStreamedTableWithExprOperatorFactoryV3 streamedBuilderWithExprOperatorFactory = + new OmniSmjStreamedTableWithExprOperatorFactoryV3( + streamedTypes, streamedKeyExps, streamedOutputCols, OMNI_JOIN_TYPE_INNER, Optional.empty()); + } + + @Test(expectedExceptions = OmniRuntimeException.class, expectedExceptionsMessageRegExp = + ".*EXPRESSION_NOT_SUPPORT.*") + public void testInvalidBufferedKeys() { + DataType[] streamedTypes = {IntDataType.INTEGER, LongDataType.LONG}; + String[] streamedKeyExps = { + omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonFieldReference(1, 0), getOmniJsonLiteral(1, false, 5))}; + int[] streamedOutputCols = {1}; + OmniSmjStreamedTableWithExprOperatorFactoryV3 streamedBuilderWithExprOperatorFactory = + new OmniSmjStreamedTableWithExprOperatorFactoryV3( + streamedTypes, streamedKeyExps, streamedOutputCols, OMNI_JOIN_TYPE_INNER, Optional.empty()); + + DataType[] bufferedTypes = {LongDataType.LONG, IntDataType.INTEGER}; + int[] bufferedOutputCols = {0}; + String[] bufferedKeyExps = {omniFunctionExpr("abc", 2, getOmniJsonFieldReference(2, 1))}; + OmniSmjBufferedTableWithExprOperatorFactoryV3 bufferedWithExprOperatorFactory = + new OmniSmjBufferedTableWithExprOperatorFactoryV3( + bufferedTypes, bufferedKeyExps, bufferedOutputCols, streamedBuilderWithExprOperatorFactory); + } + + @Test + public void testFactoryContextEquals() { + DataType[] streamedTypes = {IntDataType.INTEGER, LongDataType.LONG}; + + String[] streamedKeyExps = { + omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonFieldReference(1, 0), getOmniJsonLiteral(1, false, 5))}; + int[] streamedOutputCols = {1}; + OmniSmjStreamedTableWithExprOperatorFactoryV3.FactoryContext streamedBuilderWithExprOperatorFactory1 = + new OmniSmjStreamedTableWithExprOperatorFactoryV3.FactoryContext( + streamedTypes, streamedKeyExps, streamedOutputCols, OMNI_JOIN_TYPE_INNER, Optional.empty(), + new OperatorConfig()); + OmniSmjStreamedTableWithExprOperatorFactoryV3.FactoryContext streamedBuilderWithExprOperatorFactory2 = + new OmniSmjStreamedTableWithExprOperatorFactoryV3.FactoryContext( + streamedTypes, streamedKeyExps, streamedOutputCols, OMNI_JOIN_TYPE_INNER, Optional.empty(), + new OperatorConfig()); + OmniSmjStreamedTableWithExprOperatorFactoryV3.FactoryContext streamedBuilderWithExprOperatorFactory3 = null; + assertEquals(streamedBuilderWithExprOperatorFactory2, streamedBuilderWithExprOperatorFactory1); + assertEquals(streamedBuilderWithExprOperatorFactory1, streamedBuilderWithExprOperatorFactory1); + assertNotEquals(streamedBuilderWithExprOperatorFactory3, streamedBuilderWithExprOperatorFactory1); + + DataType[] bufferedTypes = {LongDataType.LONG, IntDataType.INTEGER}; + + OmniSmjStreamedTableWithExprOperatorFactoryV3 streamedBuilderWithExprOperatorFactory = + new OmniSmjStreamedTableWithExprOperatorFactoryV3( + streamedTypes, streamedKeyExps, streamedOutputCols, OMNI_JOIN_TYPE_INNER, Optional.empty()); + + int[] bufferedOutputCols = {0}; + String[] bufferedKeyExps = { + omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonFieldReference(1, 0), getOmniJsonLiteral(1, false, 5))}; + OmniSmjBufferedTableWithExprOperatorFactoryV3.FactoryContext bufferedWithExprOperatorFactory1 = + new OmniSmjBufferedTableWithExprOperatorFactoryV3.FactoryContext( + bufferedTypes, bufferedKeyExps, bufferedOutputCols, new OperatorConfig(), + streamedBuilderWithExprOperatorFactory); + OmniSmjBufferedTableWithExprOperatorFactoryV3.FactoryContext bufferedWithExprOperatorFactory2 = + new OmniSmjBufferedTableWithExprOperatorFactoryV3.FactoryContext( + bufferedTypes, bufferedKeyExps, bufferedOutputCols, new OperatorConfig(), + streamedBuilderWithExprOperatorFactory); + OmniSmjBufferedTableWithExprOperatorFactoryV3.FactoryContext bufferedWithExprOperatorFactory3 = null; + assertEquals(bufferedWithExprOperatorFactory2, bufferedWithExprOperatorFactory1); + assertEquals(bufferedWithExprOperatorFactory1, bufferedWithExprOperatorFactory1); + assertNotEquals(bufferedWithExprOperatorFactory3, bufferedWithExprOperatorFactory1); + } + + /** + * Test left semi join one column 1. + */ + @Test + public void testSmjOneTimeEqualConditionForLeftSemiJoin() { + DataType[] streamedTypes = {IntDataType.INTEGER, LongDataType.LONG}; + String[] streamedKeyExps = { + omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonFieldReference(1, 0), getOmniJsonLiteral(1, false, 5))}; + int[] streamedOutputCols = {0, 1}; + OmniSmjStreamedTableWithExprOperatorFactoryV3 streamedBuilderWithExprOperatorFactory = + new OmniSmjStreamedTableWithExprOperatorFactoryV3( + streamedTypes, streamedKeyExps, streamedOutputCols, OMNI_JOIN_TYPE_LEFT_SEMI, Optional.empty()); + + DataType[] bufferedTypes = {LongDataType.LONG, IntDataType.INTEGER}; + int[] bufferedOutputCols = {}; + String[] bufferedKeyExps = { + omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonFieldReference(1, 1), getOmniJsonLiteral(1, false, 5))}; + OmniSmjBufferedTableWithExprOperatorFactoryV3 bufferedWithExprOperatorFactory = + new OmniSmjBufferedTableWithExprOperatorFactoryV3( + bufferedTypes, bufferedKeyExps, bufferedOutputCols, streamedBuilderWithExprOperatorFactory); + OmniOperator bufferedTableOperator = bufferedWithExprOperatorFactory.createOperator(); + + // start to add input + Object[][] streamedDatas1 = {{0, 1, 2, 2, 2, 3, 4, 5}, + {8800L, 7700L, 6600L, 5500L, 4400L, 3300L, 2200L, 1100L}}; + VecBatch streamedVecBatch1 = createVecBatch(streamedTypes, streamedDatas1); // 0, 1, 2, 2, 2, 3, 4, 5 + OmniOperator streamedTableOperator = streamedBuilderWithExprOperatorFactory.createOperator(); + streamedTableOperator.addInput(streamedVecBatch1); + + Object[][] bufferedDatas1 = {{8008L, 7007L, 6006L, 5005L, 4004L, 3003L, 2002L, 1001L}, + {0, 1, 2, 2, 3, 3, 4, 5}}; + VecBatch bufferedVecBatch1 = createVecBatch(bufferedTypes, bufferedDatas1); // 0, 1, 2, 2, 3, 3, 4, 5 + bufferedTableOperator.addInput(bufferedVecBatch1); + + Iterator results = bufferedTableOperator.getOutput(); + VecBatch resultVecBatch = results.next(); + + int len = resultVecBatch.getRowCount(); + assertEquals(len, 8); + Object[][] expectedDatas = {{0, 1, 2, 2, 2, 3, 4, 5}, {8800L, 7700L, 6600L, 5500L, 4400L, 3300L, 2200L, 1100L}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + + freeVecBatch(resultVecBatch); + bufferedTableOperator.close(); + bufferedWithExprOperatorFactory.close(); + streamedTableOperator.close(); + streamedBuilderWithExprOperatorFactory.close(); + } + + /** + * Test left anti join one column 1. + */ + @Test + public void testSmjOneTimeEqualConditionForLeftAntiJoin() { + DataType[] streamedTypes = {IntDataType.INTEGER, LongDataType.LONG}; + String[] streamedKeyExps = { + omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonFieldReference(1, 0), getOmniJsonLiteral(1, false, 5))}; + int[] streamedOutputCols = {0, 1}; + OmniSmjStreamedTableWithExprOperatorFactoryV3 streamedBuilderWithExprOperatorFactory = + new OmniSmjStreamedTableWithExprOperatorFactoryV3( + streamedTypes, streamedKeyExps, streamedOutputCols, OMNI_JOIN_TYPE_LEFT_ANTI, Optional.empty()); + + DataType[] bufferedTypes = {IntDataType.INTEGER, DoubleDataType.DOUBLE}; + int[] bufferedOutputCols = {}; + String[] bufferedKeyExps = { + omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonFieldReference(1, 0), getOmniJsonLiteral(1, false, 5))}; + OmniSmjBufferedTableWithExprOperatorFactoryV3 bufferedWithExprOperatorFactory = + new OmniSmjBufferedTableWithExprOperatorFactoryV3( + bufferedTypes, bufferedKeyExps, bufferedOutputCols, streamedBuilderWithExprOperatorFactory); + OmniOperator bufferedTableOperator = bufferedWithExprOperatorFactory.createOperator(); + + // start to add input + Object[][] streamedDatas1 = {{1, 1, 2, 3}, {40L, 25L, 35L, 30L}}; + VecBatch streamedVecBatch1 = createVecBatch(streamedTypes, streamedDatas1); + OmniOperator streamedTableOperator = streamedBuilderWithExprOperatorFactory.createOperator(); + streamedTableOperator.addInput(streamedVecBatch1); + + Object[][] bufferedDatas1 = {{3, 3, 4, 4}, {3.3, 3.5, 4.4, 4.5}}; + VecBatch bufferedVecBatch1 = createVecBatch(bufferedTypes, bufferedDatas1); + bufferedTableOperator.addInput(bufferedVecBatch1); + + Iterator results = bufferedTableOperator.getOutput(); + VecBatch resultVecBatch = results.next(); + + int len = resultVecBatch.getRowCount(); + assertEquals(len, 3); + Object[][] expectedDatas = {{1, 1, 2}, {40L, 25L, 35L}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + + freeVecBatch(resultVecBatch); + bufferedTableOperator.close(); + bufferedWithExprOperatorFactory.close(); + streamedTableOperator.close(); + streamedBuilderWithExprOperatorFactory.close(); + } +} \ No newline at end of file diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniSortOperatorTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniSortOperatorTest.java new file mode 100644 index 0000000..a566ce9 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniSortOperatorTest.java @@ -0,0 +1,635 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.operator; + +import static nova.hetu.omniruntime.util.TestUtils.assertVecBatchEquals; +import static nova.hetu.omniruntime.util.TestUtils.assertVecEquals; +import static nova.hetu.omniruntime.util.TestUtils.createVec; +import static nova.hetu.omniruntime.util.TestUtils.createVecBatch; +import static nova.hetu.omniruntime.util.TestUtils.freeVecBatch; +import static nova.hetu.omniruntime.memory.MemoryManager.clearMemory; +import static nova.hetu.omniruntime.memory.MemoryManager.setGlobalMemoryLimit; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; +import static org.testng.Assert.assertTrue; + +import com.google.common.collect.ImmutableList; + +import nova.hetu.omniruntime.memory.MemoryManager; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.operator.sort.OmniSortOperatorFactory; +import nova.hetu.omniruntime.operator.sort.OmniSortOperatorFactory.FactoryContext; +import nova.hetu.omniruntime.type.CharDataType; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.Date32DataType; +import nova.hetu.omniruntime.type.Decimal128DataType; +import nova.hetu.omniruntime.type.Decimal64DataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.type.VarcharDataType; +import nova.hetu.omniruntime.util.TestUtils; +import nova.hetu.omniruntime.vector.LongVec; +import nova.hetu.omniruntime.vector.Vec; +import nova.hetu.omniruntime.vector.VecBatch; + +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +/** + * The type Omni sort operator test. + * + * @since 2021-5-10 + */ +public class OmniSortOperatorTest { + /** + * The Total page count. + */ + int totalPageCount = 2; + + /** + * The Page distinct count. + */ + int pageDistinctCount = 4; + + /** + * The Page distinct value repeat count. + */ + int pageDistinctValueRepeatCount = 2500; + + /** + * Test sort two columns. + */ + @Test + public void testSortTwoColumns() { + DataType[] sourceTypes = {IntDataType.INTEGER, IntDataType.INTEGER}; + Object[][] sourceDatas = {{5, 3, 2, 6, 1, 4, 7, 8}, {5, 3, 2, 6, 1, 4, 7, 8}}; + VecBatch vecBatch = createVecBatch(sourceTypes, sourceDatas); + + int[] outputCols = {0, 1}; + String[] sortCols = {"#0", "#1"}; + int[] ascendings = {1, 1}; + int[] nullFirsts = {0, 0}; + OmniSortOperatorFactory sortOperatorFactory = new OmniSortOperatorFactory(sourceTypes, outputCols, sortCols, + ascendings, nullFirsts); + OmniOperator sortOperator = sortOperatorFactory.createOperator(); + sortOperator.addInput(vecBatch); + Iterator results = sortOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + int len = resultVecBatch.getRowCount(); + assertEquals(len, sourceDatas[0].length); + + Object[][] expectedDatas = {{1, 2, 3, 4, 5, 6, 7, 8}, {1, 2, 3, 4, 5, 6, 7, 8}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + + freeVecBatch(resultVecBatch); + sortOperator.close(); + sortOperatorFactory.close(); + } + + /** + * test return multi batch + */ + @Test + public void testSortTwoColumnsWithReturnMultiBatch() { + int maxRowCntPerBatch = 131072; // 1M / (4+4) + int totalRowCnt = maxRowCntPerBatch * 3; + DataType[] sourceTypes = {IntDataType.INTEGER, IntDataType.INTEGER}; + Object[][] sourceDatas = new Object[2][]; + sourceDatas[0] = new Object[totalRowCnt]; + sourceDatas[1] = new Object[totalRowCnt]; + for (int i = 0; i < totalRowCnt; i++) { + sourceDatas[0][i] = totalRowCnt - i - 1; + sourceDatas[1][i] = totalRowCnt - i - 1; + } + VecBatch vecBatch = createVecBatch(sourceTypes, sourceDatas); + + int[] outputCols = {0, 1}; + String[] sortCols = {"#0", "#1"}; + int[] ascendings = {1, 1}; + int[] nullFirsts = {0, 0}; + OmniSortOperatorFactory sortOperatorFactory = new OmniSortOperatorFactory(sourceTypes, outputCols, sortCols, + ascendings, nullFirsts); + OmniOperator sortOperator = sortOperatorFactory.createOperator(); + sortOperator.addInput(vecBatch); + Iterator results = sortOperator.getOutput(); + int baseValue = 0; + while (results.hasNext()) { + VecBatch resultVecBatch = results.next(); + assertEquals(resultVecBatch.getRowCount(), maxRowCntPerBatch); + + Object[][] expectedDatas = new Object[2][]; + expectedDatas[0] = new Object[maxRowCntPerBatch]; + expectedDatas[1] = new Object[maxRowCntPerBatch]; + for (int i = 0; i < maxRowCntPerBatch; i++) { + expectedDatas[0][i] = baseValue + i; + expectedDatas[1][i] = baseValue + i; + } + baseValue += maxRowCntPerBatch; + assertVecBatchEquals(resultVecBatch, expectedDatas); + freeVecBatch(resultVecBatch); + } + sortOperator.close(); + sortOperatorFactory.close(); + } + + /** + * Test sort two columns with dictionary vector. + */ + @Test + public void testSortTwoColumnsWithDict() { + DataType[] sourceTypes = {IntDataType.INTEGER, IntDataType.INTEGER}; + Object[][] sourceDatas = {{5, 3, 2, 6, 1, 4, 7, 8}, {5, 3, 2, 6, 1, 4, 7, 8}}; + Vec[] vecs = new Vec[2]; + vecs[0] = TestUtils.createIntVec(sourceDatas[0]); + int[] ids = {0, 1, 2, 3, 4, 5, 6, 7}; + vecs[1] = TestUtils.createDictionaryVec(sourceTypes[1], sourceDatas[1], ids); + VecBatch vecBatch = new VecBatch(vecs); + + int[] outputCols = {0, 1}; + String[] sortCols = {"#0", "#1"}; + int[] ascendings = {1, 1}; + int[] nullFirsts = {0, 0}; + OmniSortOperatorFactory sortOperatorFactory = new OmniSortOperatorFactory(sourceTypes, outputCols, sortCols, + ascendings, nullFirsts); + OmniOperator sortOperator = sortOperatorFactory.createOperator(); + sortOperator.addInput(vecBatch); + Iterator results = sortOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + int len = resultVecBatch.getRowCount(); + assertEquals(len, sourceDatas[0].length); + + Object[][] expectedDatas = {{1, 2, 3, 4, 5, 6, 7, 8}, {1, 2, 3, 4, 5, 6, 7, 8}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + freeVecBatch(resultVecBatch); + sortOperator.close(); + sortOperatorFactory.close(); + } + + /** + * Test sort two varchar columns. + */ + @Test + public void testSortTwoVarcharColumns() { + DataType[] sourceTypes = {new VarcharDataType(1), LongDataType.LONG, new VarcharDataType(3)}; + Object[][] sourceDatas = {{"0", "1", "2", "0", "1", "2"}, {0L, 1L, 2L, 3L, 4L, 5L}, + {"6.6", "5.5", "4.4", "3.3", "2.2", "1.1"}}; + VecBatch vecBatch = createVecBatch(sourceTypes, sourceDatas); + + int[] outputCols = {1, 2}; + String[] sortCols = {"#0", "#2"}; + int[] ascendings = {0, 1}; + int[] nullFirsts = {1, 1}; + OmniSortOperatorFactory sortOperatorFactory = new OmniSortOperatorFactory(sourceTypes, outputCols, sortCols, + ascendings, nullFirsts); + OmniOperator sortOperator = sortOperatorFactory.createOperator(); + sortOperator.addInput(vecBatch); + Iterator results = sortOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + Object[][] expectedDatas = {{5L, 2L, 4L, 1L, 3L, 0L}, {"1.1", "4.4", "2.2", "5.5", "3.3", "6.6"}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + freeVecBatch(resultVecBatch); + sortOperator.close(); + sortOperatorFactory.close(); + } + + /** + * Test sort two char columns. + */ + @Test + public void testSortTwoCharColumns() { + DataType[] sourceTypes = {new CharDataType(1), LongDataType.LONG, new CharDataType(3)}; + Object[][] sourceDatas = {{"0", "1", "2", "0", "1", "2"}, {0L, 1L, 2L, 3L, 4L, 5L}, + {"6.6", "5.5", "4.4", "3.3", "2.2", "1.1"}}; + VecBatch vecBatch = createVecBatch(sourceTypes, sourceDatas); + + int[] outputCols = {1, 2}; + String[] sortCols = {"#0", "#2"}; + int[] ascendings = {0, 1}; + int[] nullFirsts = {1, 1}; + OmniSortOperatorFactory sortOperatorFactory = new OmniSortOperatorFactory(sourceTypes, outputCols, sortCols, + ascendings, nullFirsts); + OmniOperator sortOperator = sortOperatorFactory.createOperator(); + sortOperator.addInput(vecBatch); + Iterator results = sortOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + Object[][] expectedDatas = {{5L, 2L, 4L, 1L, 3L, 0L}, {"1.1", "4.4", "2.2", "5.5", "3.3", "6.6"}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + freeVecBatch(resultVecBatch); + sortOperator.close(); + sortOperatorFactory.close(); + } + + /** + * Test sort two date32 columns. + */ + @Test + public void testSortTwoDate32Columns() { + DataType[] sourceTypes = {new Date32DataType(DataType.DateUnit.DAY), LongDataType.LONG, + new Date32DataType(DataType.DateUnit.MILLI)}; + Object[][] sourceDatas = {{0, 1, 2, 0, 1, 2}, {0L, 1L, 2L, 3L, 4L, 5L}, {66, 55, 44, 33, 22, 11}}; + VecBatch vecBatch = createVecBatch(sourceTypes, sourceDatas); + + int[] outputCols = {1, 2}; + String[] sortCols = {"#0", "#2"}; + int[] ascendings = {0, 1}; + int[] nullFirsts = {1, 1}; + OmniSortOperatorFactory sortOperatorFactory = new OmniSortOperatorFactory(sourceTypes, outputCols, sortCols, + ascendings, nullFirsts); + OmniOperator sortOperator = sortOperatorFactory.createOperator(); + sortOperator.addInput(vecBatch); + Iterator results = sortOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + Object[][] expectedDatas = {{5L, 2L, 4L, 1L, 3L, 0L}, {11, 44, 22, 55, 33, 66}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + freeVecBatch(resultVecBatch); + sortOperator.close(); + sortOperatorFactory.close(); + } + + /** + * Test sort two decimal64 columns. + */ + @Test + public void testSortTwoDecimal64Columns() { + DataType[] sourceTypes = {new Decimal64DataType(1, 0), LongDataType.LONG, new Decimal64DataType(2, 0)}; + Object[][] sourceDatas = {{0L, 1L, 2L, 0L, 1L, 2L}, {0L, 1L, 2L, 3L, 4L, 5L}, {66L, 55L, 44L, 33L, 22L, 11L}}; + VecBatch vecBatch = createVecBatch(sourceTypes, sourceDatas); + + int[] outputCols = {1, 2}; + String[] sortCols = {"#0", "#2"}; + int[] ascendings = {0, 1}; + int[] nullFirsts = {1, 1}; + OmniSortOperatorFactory sortOperatorFactory = new OmniSortOperatorFactory(sourceTypes, outputCols, sortCols, + ascendings, nullFirsts); + OmniOperator sortOperator = sortOperatorFactory.createOperator(); + sortOperator.addInput(vecBatch); + Iterator results = sortOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + Object[][] expectedDatas = {{5L, 2L, 4L, 1L, 3L, 0L}, {11L, 44L, 22L, 55L, 33L, 66L}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + freeVecBatch(resultVecBatch); + sortOperator.close(); + sortOperatorFactory.close(); + } + + /** + * Test sort two decimal128 columns. + */ + @Test + public void testSortTwoDecimal128Columns() { + DataType[] sourceTypes = {new Decimal128DataType(1, 0), LongDataType.LONG, new Decimal128DataType(2, 0)}; + Vec[] vecs = new Vec[sourceTypes.length]; + vecs[0] = createVec(sourceTypes[0], new Object[][]{{0L, 0L}, {1L, 0L}, {2L, 0L}, {0L, 0L}, {1L, 0L}, {2L, 0L}}); + vecs[1] = createVec(sourceTypes[1], new Object[]{0L, 1L, 2L, 3L, 4L, 5L}); + vecs[2] = createVec(sourceTypes[2], + new Object[][]{{66L, 0L}, {55L, 0L}, {44L, 0L}, {33L, 0L}, {22L, 0L}, {11L, 0L}}); + VecBatch vecBatch = new VecBatch(vecs); + + int[] outputCols = {1, 2}; + String[] sortCols = {"#0", "#2"}; + int[] ascendings = {0, 1}; + int[] nullFirsts = {1, 1}; + OmniSortOperatorFactory sortOperatorFactory = new OmniSortOperatorFactory(sourceTypes, outputCols, sortCols, + ascendings, nullFirsts); + OmniOperator sortOperator = sortOperatorFactory.createOperator(); + sortOperator.addInput(vecBatch); + Iterator results = sortOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + assertEquals(resultVecBatch.getVectorCount(), outputCols.length); + assertVecEquals(resultVecBatch.getVectors()[0], new Object[]{5L, 2L, 4L, 1L, 3L, 0L}); + assertVecEquals(resultVecBatch.getVectors()[1], + new Object[][]{{11L, 0L}, {44L, 0L}, {22L, 0L}, {55L, 0L}, {33L, 0L}, {66L, 0L}}); + freeVecBatch(resultVecBatch); + sortOperator.close(); + sortOperatorFactory.close(); + } + + /** + * Test sort with null first. + */ + @Test + public void testSortWithNullFirst() { + DataType[] sourceTypes = {IntDataType.INTEGER, LongDataType.LONG}; + Object[][] sourceDatas = {{4, 3, 2, 1, 0, null}, {0L, 1L, 2L, 3L, 4L, null}}; + VecBatch vecBatch = createVecBatch(sourceTypes, sourceDatas); + + int[] outputCols = {0, 1}; + String[] sortCols = {"#1"}; + int[] ascendings = {0}; + int[] nullFirsts = {1}; + OmniSortOperatorFactory sortOperatorFactory = new OmniSortOperatorFactory(sourceTypes, outputCols, sortCols, + ascendings, nullFirsts); + OmniOperator sortOperator = sortOperatorFactory.createOperator(); + sortOperator.addInput(vecBatch); + Iterator results = sortOperator.getOutput(); + VecBatch resultVecBatch = results.next(); + + Object[][] expectedDatas = {{null, 0, 1, 2, 3, 4}, {null, 4L, 3L, 2L, 1L, 0L}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + freeVecBatch(resultVecBatch); + sortOperator.close(); + sortOperatorFactory.close(); + } + + /** + * Test sort with null last. + */ + @Test + public void testSortWithNullLast() { + DataType[] sourceTypes = {IntDataType.INTEGER, LongDataType.LONG}; + Object[][] sourceDatas = {{4, 3, 2, 1, 0, null}, {0L, 1L, 2L, 3L, 4L, null}}; + VecBatch vecBatch = createVecBatch(sourceTypes, sourceDatas); + + int[] outputCols = {0, 1}; + String[] sortCols = {"#1"}; + int[] ascendings = {0}; + int[] nullFirsts = {0}; + OmniSortOperatorFactory sortOperatorFactory = new OmniSortOperatorFactory(sourceTypes, outputCols, sortCols, + ascendings, nullFirsts); + OmniOperator sortOperator = sortOperatorFactory.createOperator(); + sortOperator.addInput(vecBatch); + Iterator results = sortOperator.getOutput(); + VecBatch resultVecBatch = results.next(); + + Object[][] expectedDatas = {{0, 1, 2, 3, 4, null}, {4L, 3L, 2L, 1L, 0L, null}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + freeVecBatch(resultVecBatch); + sortOperator.close(); + sortOperatorFactory.close(); + } + + /** + * Test sort with multi nulls. + */ + @Test + public void testSortWithMultiNulls() { + DataType[] sourceTypes = {IntDataType.INTEGER, LongDataType.LONG}; + Object[][] sourceDatas = {{4, 3, 2, 1, 0, null}, {0L, 1L, null, null, null, null}}; + VecBatch vecBatch = createVecBatch(sourceTypes, sourceDatas); + + int[] outputCols = {0, 1}; + String[] sortCols = {"#1", "#0"}; + int[] ascendings = {0, 0}; + int[] nullFirsts = {1, 1}; + OmniSortOperatorFactory sortOperatorFactory = new OmniSortOperatorFactory(sourceTypes, outputCols, sortCols, + ascendings, nullFirsts); + OmniOperator sortOperator = sortOperatorFactory.createOperator(); + sortOperator.addInput(vecBatch); + Iterator results = sortOperator.getOutput(); + VecBatch resultVecBatch = results.next(); + + Object[][] expectedDatas = {{null, 2, 1, 0, 3, 4}, {null, null, null, null, 1L, 0L}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + freeVecBatch(resultVecBatch); + sortOperator.close(); + sortOperatorFactory.close(); + } + + /** + * Test sort performance whether with jit or not. + */ + @Test + public void testSortComparePerf() { + DataType[] sourceTypes = {LongDataType.LONG, LongDataType.LONG}; + int[] outputCols = {0, 1}; + String[] sortCols = {"#0", "#1"}; + int[] ascendings = {1, 1}; + int[] nullFirsts = {0, 0}; + + OmniSortOperatorFactory sortOperatorFactoryWithoutJit = new OmniSortOperatorFactory(sourceTypes, outputCols, + sortCols, ascendings, nullFirsts, new OperatorConfig()); + OmniOperator sortOperatorWithoutJit = sortOperatorFactoryWithoutJit.createOperator(); + ImmutableList vecsWithoutJit = buildVecs(); + + long start = System.currentTimeMillis(); + for (VecBatch vec : vecsWithoutJit) { + sortOperatorWithoutJit.addInput(vec); + } + Iterator outputWithoutJit = sortOperatorWithoutJit.getOutput(); + long end = System.currentTimeMillis(); + System.out.println("Sort without jit use " + (end - start) + " ms."); + + OmniSortOperatorFactory sortOperatorFactoryWithJit = new OmniSortOperatorFactory(sourceTypes, outputCols, + sortCols, ascendings, nullFirsts, new OperatorConfig()); + OmniOperator sortOperatorWithJit = sortOperatorFactoryWithJit.createOperator(); + ImmutableList vecsWithJit = buildVecs(); + + start = System.currentTimeMillis(); + for (VecBatch vec : vecsWithJit) { + sortOperatorWithJit.addInput(vec); + } + Iterator outputWithJit = sortOperatorWithJit.getOutput(); + end = System.currentTimeMillis(); + System.out.println("Sort with jit use " + (end - start) + " ms."); + + while (outputWithoutJit.hasNext() && outputWithJit.hasNext()) { + VecBatch resultWithoutJit = outputWithoutJit.next(); + VecBatch resultWithJit = outputWithJit.next(); + assertVecBatchEquals(resultWithoutJit, resultWithJit); + freeVecBatch(resultWithoutJit); + freeVecBatch(resultWithJit); + } + + sortOperatorWithoutJit.close(); + sortOperatorWithJit.close(); + sortOperatorFactoryWithoutJit.close(); + sortOperatorFactoryWithJit.close(); + } + + private VecBatch duplicateVecBatch(VecBatch vecBatch) { + int vecCount = vecBatch.getVectorCount(); + int rowCount = vecBatch.getRowCount(); + Vec[] vecs = new Vec[vecCount]; + for (int i = 0; i < vecCount; i++) { + vecs[i] = vecBatch.getVector(i).slice(0, rowCount); + } + return new VecBatch(vecs); + } + + /** + * Test sort performance when multi threads. + */ + @Test + public void testSortMultiThreadsPerformance() { + DataType[] sourceTypes = {LongDataType.LONG, LongDataType.LONG}; + int[] outputCols = {0, 1}; + String[] sortCols = {"#0", "#1"}; + int[] ascendings = {1, 1}; + int[] nullFirsts = {0, 0}; + OmniSortOperatorFactory sortOperatorFactory = new OmniSortOperatorFactory(sourceTypes, outputCols, sortCols, + ascendings, nullFirsts); + + final int threadNum = 4; + final int corePoolSize = 10; + final int maximumPoolSize = 50; + CountDownLatch countDownLatch = new CountDownLatch(threadNum); + ThreadPoolExecutor threadPool = new ThreadPoolExecutor(corePoolSize, maximumPoolSize, 60L, TimeUnit.SECONDS, + new LinkedBlockingQueue<>(threadNum)); + ImmutableList vecs = buildVecs(); + for (int i = 0; i < threadNum; i++) { + CompletableFuture.runAsync(() -> { + try { + OmniOperator sortOperator = sortOperatorFactory.createOperator(); + for (VecBatch vec : vecs) { + sortOperator.addInput(duplicateVecBatch(vec)); + } + Iterator iterator = sortOperator.getOutput(); + while (iterator.hasNext()) { + VecBatch result = iterator.next(); + freeVecBatch(result); + } + sortOperator.close(); + } finally { + countDownLatch.countDown(); + } + }, threadPool); + } + + // This will wait until all future ready. + try { + countDownLatch.await(); + } catch (InterruptedException e) { + assertTrue(false); + } + + threadPool.shutdown(); + vecs.forEach(TestUtils::freeVecBatch); + sortOperatorFactory.close(); + } + + @Test + public void testFactoryContextEquals() { + DataType[] sourceTypes = {IntDataType.INTEGER, LongDataType.LONG}; + int[] outputCols = {0, 1}; + String[] sortCols = {"#1", "#0"}; + int[] ascendings = {0, 0}; + int[] nullFirsts = {1, 1}; + FactoryContext factory1 = new FactoryContext(sourceTypes, outputCols, sortCols, ascendings, nullFirsts, + new OperatorConfig()); + FactoryContext factory2 = new FactoryContext(sourceTypes, outputCols, sortCols, ascendings, nullFirsts, + new OperatorConfig()); + FactoryContext factory3 = null; + assertEquals(factory2, factory1); + assertEquals(factory1, factory1); + assertNotEquals(factory3, factory1); + } + + private ImmutableList buildVecs() { + ImmutableList.Builder vecBatchList = ImmutableList.builder(); + int positionCount = pageDistinctCount * pageDistinctValueRepeatCount; + List vecs = new ArrayList<>(); + for (int i = 0; i < totalPageCount; i++) { + LongVec longVec1 = new LongVec(positionCount); + LongVec longVec2 = new LongVec(positionCount); + int idx = 0; + for (int j = 0; j < pageDistinctCount; j++) { + for (int k = 0; k < pageDistinctValueRepeatCount; k++) { + longVec1.set(idx, j); + longVec2.set(idx, j); + idx++; + } + } + vecs.add(longVec1); + vecs.add(longVec2); + VecBatch vecBatch = new VecBatch(new Vec[]{longVec1, longVec2}); + vecBatchList.add(vecBatch); + } + return vecBatchList.build(); + } + + @Test + public void testSortMultiThreadsAllocatorStatisticsBasic() { + long limit = 1 << 30; + clearMemory(); + setGlobalMemoryLimit(limit); + + int[] outputCols = {0, 1}; + String[] sortCols = {"#0", "#1"}; + int[] ascendings = {1, 1}; + int[] nullFirsts = {0, 0}; + DataType[] sourceTypes = {LongDataType.LONG, LongDataType.LONG}; + OmniSortOperatorFactory sortOperatorFactory = new OmniSortOperatorFactory(sourceTypes, outputCols, sortCols, + ascendings, nullFirsts); + + ConcurrentHashMap> vecBatchListMap = new ConcurrentHashMap<>(); + + final int threadNum = 4; + ThreadPoolExecutor threadPool = new ThreadPoolExecutor(10, 50, 60L, TimeUnit.SECONDS, + new LinkedBlockingQueue<>(threadNum)); + CountDownLatch countDownLatch = new CountDownLatch(threadNum); + + for (int i = 0; i < threadNum; i++) { + CompletableFuture.runAsync(() -> { + try { + MemoryManager subMemoryManager = new MemoryManager(); + List vecBatchList = new ArrayList<>(); + vecBatchListMap.put(subMemoryManager, vecBatchList); + multiThreadsOpTask(subMemoryManager, sortOperatorFactory, vecBatchListMap); + } finally { + countDownLatch.countDown(); + } + }, threadPool); + } + + // This will wait until all future ready. + try { + countDownLatch.await(); + } catch (InterruptedException e) { + assertTrue(false); + } + + threadPool.shutdown(); + sortOperatorFactory.close(); + } + + private void multiThreadsOpTask(MemoryManager subMemoryManager, OmniSortOperatorFactory sortOperatorFactory, + ConcurrentHashMap> vecBatchListMap) { + OmniOperator sortOperator = sortOperatorFactory.createOperator(); + ImmutableList vecs = buildVecs(); + + // 1048576(4 * 25000 * 8) + 131072(4 * 25000 * 1) + long unitLongVecAllocated = 1179648L; + long initMemory = subMemoryManager.getAllocatedMemory(); + + for (VecBatch vec : vecs) { + sortOperator.addInput(vec); + } + Iterator iterator = sortOperator.getOutput(); + long vecBatchMem = 0L; + while (iterator.hasNext()) { + VecBatch result = iterator.next(); + vecBatchListMap.get(subMemoryManager).add(result); + int rowCount = result.getRowCount(); + int colCount = result.getVectorCount(); + vecBatchMem += (rowCount * (8L + 1L) * colCount); + } + long middleMemory = subMemoryManager.getAllocatedMemory(); + + for (VecBatch vecBatch : vecBatchListMap.get(subMemoryManager)) { + freeVecBatch(vecBatch); + } + + sortOperator.close(); + assertEquals(initMemory, totalPageCount * 2 * unitLongVecAllocated); + assertEquals(middleMemory, vecBatchMem); + assertEquals(subMemoryManager.getAllocatedMemory(), 0); + } +} \ No newline at end of file diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniSortWithExprOperatorTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniSortWithExprOperatorTest.java new file mode 100644 index 0000000..055ac9d --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniSortWithExprOperatorTest.java @@ -0,0 +1,614 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.operator; + +import static nova.hetu.omniruntime.util.TestUtils.assertVecBatchEquals; +import static nova.hetu.omniruntime.util.TestUtils.createVecBatch; +import static nova.hetu.omniruntime.util.TestUtils.freeVecBatch; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonFieldReference; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonLiteral; +import static nova.hetu.omniruntime.util.TestUtils.omniFunctionExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonFourArithmeticExpr; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; +import static org.testng.Assert.assertTrue; + +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.operator.config.OverflowConfig; +import nova.hetu.omniruntime.operator.config.SparkSpillConfig; +import nova.hetu.omniruntime.operator.sort.OmniSortWithExprOperatorFactory; +import nova.hetu.omniruntime.operator.sort.OmniSortWithExprOperatorFactory.FactoryContext; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.type.VarcharDataType; +import nova.hetu.omniruntime.util.TestUtils; +import nova.hetu.omniruntime.utils.OmniRuntimeException; +import nova.hetu.omniruntime.vector.Vec; +import nova.hetu.omniruntime.vector.VecBatch; + +import org.testng.annotations.Test; + +import java.io.File; +import java.nio.file.Paths; +import java.util.Iterator; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +/** + * The type Omni sort with expression operator test. + * + * @since 2021-10-16 + */ +public class OmniSortWithExprOperatorTest { + private final long MAX_SPILL_BYTES = 20 * 1024 * 1024; + + private String generateSpillPath() { + return "/opt" + File.separator + System.currentTimeMillis(); + } + + /** + * Test Sort by zero columns which one with expression + */ + @Test + public void TestSortByZeroColumnWithExpr() { + DataType[] sourceTypes = {IntDataType.INTEGER, LongDataType.LONG}; + Object[][] sourceDatas = {{5, 3, 2, 6, 1, 4, 7, 8}, {5L, 3L, 2L, 6L, 1L, 4L, 7L, 8L}}; + VecBatch vecBatch = createVecBatch(sourceTypes, sourceDatas); + + int[] outputCols = {0, 1}; + String[] sortKeys = {getOmniJsonFieldReference(1, 0), getOmniJsonFieldReference(2, 1)}; + int[] ascendings = {1, 1}; + int[] nullFirsts = {0, 0}; + OmniSortWithExprOperatorFactory sortWithExprOperatorFactory = new OmniSortWithExprOperatorFactory(sourceTypes, + outputCols, sortKeys, ascendings, nullFirsts); + OmniOperator sortWithExprOperator = sortWithExprOperatorFactory.createOperator(); + sortWithExprOperator.addInput(vecBatch); + Iterator results = sortWithExprOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + assertEquals(resultVecBatch.getRowCount(), sourceDatas[0].length); + Object[][] expectedDatas = {{1, 2, 3, 4, 5, 6, 7, 8}, {1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + + freeVecBatch(resultVecBatch); + sortWithExprOperator.close(); + sortWithExprOperatorFactory.close(); + } + + /** + * Test Sort by one columns which one with expression + */ + @Test + public void TestSortByOneColumnWithExpr() { + DataType[] sourceTypes = {IntDataType.INTEGER, LongDataType.LONG}; + Object[][] sourceDatas = {{5, 3, 2, 6, 1, 4, 7, 8}, {5L, 3L, 2L, 6L, 1L, 4L, 7L, 8L}}; + VecBatch vecBatch = createVecBatch(sourceTypes, sourceDatas); + + int[] outputCols = {0, 1}; + String[] sortKeys = { + omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonFieldReference(1, 0), getOmniJsonLiteral(1, false, 5)), + getOmniJsonFieldReference(2, 1)}; + int[] ascendings = {1, 1}; + int[] nullFirsts = {0, 0}; + OmniSortWithExprOperatorFactory sortWithExprOperatorFactory = new OmniSortWithExprOperatorFactory(sourceTypes, + outputCols, sortKeys, ascendings, nullFirsts); + OmniOperator sortWithExprOperator = sortWithExprOperatorFactory.createOperator(); + sortWithExprOperator.addInput(vecBatch); + Iterator results = sortWithExprOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + assertEquals(resultVecBatch.getRowCount(), sourceDatas[0].length); + Object[][] expectedDatas = {{1, 2, 3, 4, 5, 6, 7, 8}, {1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + + freeVecBatch(resultVecBatch); + sortWithExprOperator.close(); + sortWithExprOperatorFactory.close(); + } + + /** + * Test Sort by two columns with expression + */ + @Test + public void TestSortByTwoColumnsWithExpr() { + DataType[] sourceTypes = {IntDataType.INTEGER, IntDataType.INTEGER}; + Object[][] sourceDatas = {{5, 3, 2, 6, 1, 4, 7, 8}, {5, 3, 2, 6, 1, 4, 7, 8}}; + VecBatch vecBatch = createVecBatch(sourceTypes, sourceDatas); + + int[] outputCols = {0, 1}; + String[] sortKeys = { + omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonFieldReference(1, 0), getOmniJsonLiteral(1, false, 5)), + omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonLiteral(1, false, 5), getOmniJsonFieldReference(1, 1))}; + int[] ascendings = {1, 1}; + int[] nullFirsts = {0, 0}; + OmniSortWithExprOperatorFactory sortWithExprOperatorFactory = new OmniSortWithExprOperatorFactory(sourceTypes, + outputCols, sortKeys, ascendings, nullFirsts); + OmniOperator sortWithExprOperator = sortWithExprOperatorFactory.createOperator(); + sortWithExprOperator.addInput(vecBatch); + Iterator results = sortWithExprOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + assertEquals(resultVecBatch.getRowCount(), sourceDatas[0].length); + Object[][] expectedDatas = {{1, 2, 3, 4, 5, 6, 7, 8}, {1, 2, 3, 4, 5, 6, 7, 8}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + + freeVecBatch(resultVecBatch); + sortWithExprOperator.close(); + sortWithExprOperatorFactory.close(); + } + + /** + * Test Sort by two dictionary columns with expression + */ + @Test + public void TestSortByTwoDictionaryWithExpr() { + DataType[] sourceTypes = {IntDataType.INTEGER, IntDataType.INTEGER}; + Object[][] sourceDatas = {{5, 3, 2, 6, 1, 4, 7, 8}, {5, 3, 2, 6, 1, 4, 7, 8}}; + Vec[] vecs = new Vec[2]; + int[] ids = {0, 1, 2, 3, 4, 5, 6, 7}; + vecs[0] = TestUtils.createDictionaryVec(sourceTypes[0], sourceDatas[0], ids); + vecs[1] = TestUtils.createDictionaryVec(sourceTypes[1], sourceDatas[1], ids); + VecBatch vecBatch = new VecBatch(vecs); + + int[] outputCols = {0, 1}; + String[] sortKeys = { + omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonFieldReference(1, 0), getOmniJsonLiteral(1, false, 5)), + omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonLiteral(1, false, 5), getOmniJsonFieldReference(1, 1))}; + int[] ascendings = {1, 1}; + int[] nullFirsts = {0, 0}; + OmniSortWithExprOperatorFactory sortWithExprOperatorFactory = new OmniSortWithExprOperatorFactory(sourceTypes, + outputCols, sortKeys, ascendings, nullFirsts); + OmniOperator sortWithExprOperator = sortWithExprOperatorFactory.createOperator(); + sortWithExprOperator.addInput(vecBatch); + Iterator results = sortWithExprOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + assertEquals(resultVecBatch.getRowCount(), sourceDatas[0].length); + Object[][] expectedDatas = {{1, 2, 3, 4, 5, 6, 7, 8}, {1, 2, 3, 4, 5, 6, 7, 8}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + + freeVecBatch(resultVecBatch); + sortWithExprOperator.close(); + sortWithExprOperatorFactory.close(); + } + + /** + * Test factory context equal + */ + @Test + public void testFactoryContextEquals() { + DataType[] sourceTypes = {IntDataType.INTEGER, IntDataType.INTEGER}; + int[] outputCols = {0, 1}; + String[] sortKeys = { + omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonFieldReference(1, 0), getOmniJsonLiteral(1, false, 5)), + omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonLiteral(1, false, 5), getOmniJsonFieldReference(1, 1))}; + int[] ascendings = {1, 1}; + int[] nullFirsts = {0, 0}; + FactoryContext factory1 = new FactoryContext(sourceTypes, outputCols, sortKeys, ascendings, nullFirsts, + new OperatorConfig()); + FactoryContext factory2 = new FactoryContext(sourceTypes, outputCols, sortKeys, ascendings, nullFirsts, + new OperatorConfig()); + FactoryContext factory3 = null; + assertEquals(factory2, factory1); + assertEquals(factory1, factory1); + assertNotEquals(factory3, factory1); + } + + /** + * Test Sort spill ascending with spill + */ + @Test + public void testSortAscendingWithSpill() { + DataType[] sourceTypes = {IntDataType.INTEGER, IntDataType.INTEGER, IntDataType.INTEGER}; + int[] outputCols = {0, 1, 2}; + String[] sortKeys = {getOmniJsonFieldReference(1, 1), getOmniJsonFieldReference(1, 0)}; + int[] ascendings = {1, 1}; + int[] nullFirsts = {0, 0}; + String spillPath = generateSpillPath(); + OmniSortWithExprOperatorFactory sortWithExprOperatorFactory = new OmniSortWithExprOperatorFactory(sourceTypes, + outputCols, sortKeys, ascendings, nullFirsts, + new OperatorConfig(new SparkSpillConfig(true, spillPath, MAX_SPILL_BYTES, 5))); + OmniOperator sortWithExprOperator = sortWithExprOperatorFactory.createOperator(); + + Object[][] sourceData1 = {{23, 23, 23, 23, 23, 23, 23, 23, 23, 23}, {1, 1, 1, 2, 1, 1, 1, 1, 2, 2}, + {12, 12, 12, 12, 12, 12, 12, 12, 12, 12}}; + VecBatch vecBatch1 = createVecBatch(sourceTypes, sourceData1); + Object[][] sourceData2 = {{45, 45, 45, 45, 45, 45, 45, 45, 45, 45}, {1, 1, 1, 2, 1, 1, 1, 1, 2, 2}, + {24, 24, 24, 24, 24, 24, 24, 24, 24, 24}}; + VecBatch vecBatch2 = createVecBatch(sourceTypes, sourceData2); + Object[][] sourceData3 = {{67, 67, 67, 67, 67, 67, 67, 67, 67, 67}, {1, 1, 1, 2, 1, 1, 1, 1, 2, 2}, + {36, 36, 36, 36, 36, 36, 36, 36, 36, 36}}; + VecBatch vecBatch3 = createVecBatch(sourceTypes, sourceData3); + Object[][] sourceData4 = {{89, 89, 89, 89, 89, 89, 89, 89, 89, 89}, {1, 1, 1, 2, 1, 1, 1, 1, 2, 2}, + {48, 48, 48, 48, 48, 48, 48, 48, 48, 48}}; + VecBatch vecBatch4 = createVecBatch(sourceTypes, sourceData4); + + sortWithExprOperator.addInput(vecBatch1); + sortWithExprOperator.addInput(vecBatch2); + sortWithExprOperator.addInput(vecBatch3); + sortWithExprOperator.addInput(vecBatch4); + Iterator results = sortWithExprOperator.getOutput(); + VecBatch resultVecBatch = results.next(); + + Object[][] expectedDatas = { + {23, 23, 23, 23, 23, 23, 23, 45, 45, 45, 45, 45, 45, 45, 67, 67, 67, 67, 67, 67, 67, 89, 89, 89, 89, 89, + 89, 89, 23, 23, 23, 45, 45, 45, 67, 67, 67, 89, 89, 89}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2}, + {12, 12, 12, 12, 12, 12, 12, 24, 24, 24, 24, 24, 24, 24, 36, 36, 36, 36, 36, 36, 36, 48, 48, 48, 48, 48, + 48, 48, 12, 12, 12, 24, 24, 24, 36, 36, 36, 48, 48, 48}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + + freeVecBatch(resultVecBatch); + long spilledBytes = sortWithExprOperator.getSpilledBytes(); + assertTrue(spilledBytes != 0); + sortWithExprOperator.close(); + sortWithExprOperatorFactory.close(); + File spillDir = new File(spillPath); + spillDir.delete(); + } + + /** + * Test Sort spill descending with spill + */ + @Test + public void testSortDescendingWithSpill() { + DataType[] sourceTypes = {IntDataType.INTEGER, IntDataType.INTEGER, IntDataType.INTEGER}; + int[] outputCols = {0, 1, 2}; + String[] sortKeys = {getOmniJsonFieldReference(1, 2), getOmniJsonFieldReference(1, 1)}; + int[] ascendings = {0, 1}; + int[] nullFirsts = {0, 0}; + String spillPath = generateSpillPath(); + OmniSortWithExprOperatorFactory sortWithExprOperatorFactory = new OmniSortWithExprOperatorFactory(sourceTypes, + outputCols, sortKeys, ascendings, nullFirsts, + new OperatorConfig(new SparkSpillConfig(true, spillPath, MAX_SPILL_BYTES, 5))); + OmniOperator sortWithExprOperator = sortWithExprOperatorFactory.createOperator(); + + Object[][] sourceData1 = {{23, 23, 23, 23, 23, 23, 23, 23, 23, 23}, {1, 1, 1, 2, 1, 1, 1, 1, 2, 2}, + {12, 12, 12, 12, 12, 12, 12, 12, 12, 12}}; + VecBatch vecBatch1 = createVecBatch(sourceTypes, sourceData1); + Object[][] sourceData2 = {{45, 45, 45, 45, 45, 45, 45, 45, 45, 45}, {1, 1, 1, 2, 1, 1, 1, 1, 2, 2}, + {24, 24, 24, 24, 24, 24, 24, 24, 24, 24}}; + VecBatch vecBatch2 = createVecBatch(sourceTypes, sourceData2); + Object[][] sourceData3 = {{67, 67, 67, 67, 67, 67, 67, 67, 67, 67}, {1, 1, 1, 2, 1, 1, 1, 1, 2, 2}, + {36, 36, 36, 36, 36, 36, 36, 36, 36, 36}}; + VecBatch vecBatch3 = createVecBatch(sourceTypes, sourceData3); + Object[][] sourceData4 = {{89, 89, 89, 89, 89, 89, 89, 89, 89, 89}, {1, 1, 1, 2, 1, 1, 1, 1, 2, 2}, + {48, 48, 48, 48, 48, 48, 48, 48, 48, 48}}; + VecBatch vecBatch4 = createVecBatch(sourceTypes, sourceData4); + + sortWithExprOperator.addInput(vecBatch1); + sortWithExprOperator.addInput(vecBatch2); + sortWithExprOperator.addInput(vecBatch3); + sortWithExprOperator.addInput(vecBatch4); + Iterator results = sortWithExprOperator.getOutput(); + VecBatch resultVecBatch = results.next(); + + Object[][] expectedDatas = { + {89, 89, 89, 89, 89, 89, 89, 89, 89, 89, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 45, 45, 45, 45, 45, 45, + 45, 45, 45, 45, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23}, + {1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, + 1, 1, 1, 2, 2, 2}, + {48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + + freeVecBatch(resultVecBatch); + sortWithExprOperator.close(); + sortWithExprOperatorFactory.close(); + File spillDir = new File(spillPath); + spillDir.delete(); + } + + /** + * Test Sort spill with multi records + */ + @Test + public void testSortSpillWithMultiRecords() { + DataType[] sourceTypes = {IntDataType.INTEGER, LongDataType.LONG}; + int[] outputCols = {0, 1}; + String[] sortKeys = {getOmniJsonFieldReference(1, 0), getOmniJsonFieldReference(2, 1)}; + int[] ascendings = {1, 1}; + int[] nullFirsts = {0, 0}; + String spillPath = generateSpillPath(); + OmniSortWithExprOperatorFactory sortWithExprOperatorFactory = new OmniSortWithExprOperatorFactory(sourceTypes, + outputCols, sortKeys, ascendings, nullFirsts, + new OperatorConfig(new SparkSpillConfig(true, spillPath, MAX_SPILL_BYTES, 5))); + OmniOperator sortWithExprOperator = sortWithExprOperatorFactory.createOperator(); + + Object[][] sourceDatas1 = {{5, 3, 2, 6, 1}, {5L, 3L, 2L, 6L, 1L}}; + VecBatch vecBatch1 = createVecBatch(sourceTypes, sourceDatas1); + sortWithExprOperator.addInput(vecBatch1); + + Object[][] sourceDatas2 = {{4}, {4L}}; + VecBatch vecBatch2 = createVecBatch(sourceTypes, sourceDatas2); + sortWithExprOperator.addInput(vecBatch2); + + Object[][] sourceDatas3 = {{15, 13, 12, 16, 11}, {15L, 13L, 12L, 16L, 11L}}; + VecBatch vecBatch3 = createVecBatch(sourceTypes, sourceDatas3); + sortWithExprOperator.addInput(vecBatch3); + Iterator results = sortWithExprOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + assertEquals(resultVecBatch.getRowCount(), + sourceDatas1[0].length + sourceDatas2[0].length + sourceDatas3[0].length); + Object[][] expectedDatas = {{1, 2, 3, 4, 5, 6, 11, 12, 13, 15, 16}, + {1L, 2L, 3L, 4L, 5L, 6L, 11L, 12L, 13L, 15L, 16L}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + + freeVecBatch(resultVecBatch); + sortWithExprOperator.close(); + sortWithExprOperatorFactory.close(); + File spillDir = new File(spillPath); + spillDir.delete(); + } + + /** + * Test Sort spill with return multi batch + */ + @Test + public void testSortSpillWithReturnMultiBatch() { + DataType[] sourceTypes = {IntDataType.INTEGER, IntDataType.INTEGER}; + int[] outputCols = {0, 1}; + String[] sortKeys = {getOmniJsonFieldReference(1, 0), getOmniJsonFieldReference(1, 1)}; + int[] ascendings = {1, 1}; + int[] nullFirsts = {0, 0}; + String spillPath = generateSpillPath(); + OmniSortWithExprOperatorFactory sortWithExprOperatorFactory = new OmniSortWithExprOperatorFactory(sourceTypes, + outputCols, sortKeys, ascendings, nullFirsts, + new OperatorConfig(new SparkSpillConfig(true, spillPath, MAX_SPILL_BYTES, 10000))); + OmniOperator sortWithExprOperator = sortWithExprOperatorFactory.createOperator(); + + int maxRowCntPerBatch = 131072; // 1M / (4+4) + Object[][] sourceDatas1 = new Object[2][]; + sourceDatas1[0] = new Object[maxRowCntPerBatch]; + sourceDatas1[1] = new Object[maxRowCntPerBatch]; + for (int i = 0; i < maxRowCntPerBatch; i++) { + sourceDatas1[0][maxRowCntPerBatch - i - 1] = i; + sourceDatas1[1][maxRowCntPerBatch - i - 1] = i; + } + VecBatch vecBatch1 = createVecBatch(sourceTypes, sourceDatas1); + sortWithExprOperator.addInput(vecBatch1); + + Object[][] sourceDatas2 = new Object[2][]; + sourceDatas2[0] = new Object[maxRowCntPerBatch]; + sourceDatas2[1] = new Object[maxRowCntPerBatch]; + for (int i = 0; i < maxRowCntPerBatch; i++) { + sourceDatas2[0][maxRowCntPerBatch - i - 1] = maxRowCntPerBatch + i; + sourceDatas2[1][maxRowCntPerBatch - i - 1] = maxRowCntPerBatch + i; + } + + VecBatch vecBatch2 = createVecBatch(sourceTypes, sourceDatas2); + sortWithExprOperator.addInput(vecBatch2); + + Object[][] sourceDatas3 = new Object[2][]; + sourceDatas3[0] = new Object[maxRowCntPerBatch]; + sourceDatas3[1] = new Object[maxRowCntPerBatch]; + for (int i = 0; i < maxRowCntPerBatch; i++) { + sourceDatas3[0][maxRowCntPerBatch - i - 1] = maxRowCntPerBatch * 2 + i; + sourceDatas3[1][maxRowCntPerBatch - i - 1] = maxRowCntPerBatch * 2 + i; + } + VecBatch vecBatch3 = createVecBatch(sourceTypes, sourceDatas3); + sortWithExprOperator.addInput(vecBatch3); + Iterator results = sortWithExprOperator.getOutput(); + + int baseValue = 0; + while (results.hasNext()) { + VecBatch resultVecBatch = results.next(); + assertEquals(resultVecBatch.getRowCount(), maxRowCntPerBatch); + + Object[][] expectedDatas = new Object[2][]; + expectedDatas[0] = new Object[maxRowCntPerBatch]; + expectedDatas[1] = new Object[maxRowCntPerBatch]; + for (int i = 0; i < maxRowCntPerBatch; i++) { + expectedDatas[0][i] = baseValue + i; + expectedDatas[1][i] = baseValue + i; + } + baseValue += maxRowCntPerBatch; + assertVecBatchEquals(resultVecBatch, expectedDatas); + freeVecBatch(resultVecBatch); + } + + long spilledBytes = sortWithExprOperator.getSpilledBytes(); + assertTrue(spilledBytes != 0); + sortWithExprOperator.close(); + sortWithExprOperatorFactory.close(); + File spillDir = new File(spillPath); + spillDir.delete(); + } + + /** + * Test Sort spill with one record + */ + @Test + public void testSortSpillWithOneRecord() { + DataType[] sourceTypes = {IntDataType.INTEGER, LongDataType.LONG}; + int[] outputCols = {0, 1}; + String[] sortKeys = {getOmniJsonFieldReference(1, 0), getOmniJsonFieldReference(2, 1)}; + int[] ascendings = {1, 1}; + int[] nullFirsts = {0, 0}; + String spillPath = generateSpillPath(); + OmniSortWithExprOperatorFactory sortWithExprOperatorFactory = new OmniSortWithExprOperatorFactory(sourceTypes, + outputCols, sortKeys, ascendings, nullFirsts, + new OperatorConfig(new SparkSpillConfig(true, spillPath, MAX_SPILL_BYTES, 1))); + OmniOperator sortWithExprOperator = sortWithExprOperatorFactory.createOperator(); + + Object[][] sourceDatas1 = {{5}, {3L}}; + VecBatch vecBatch1 = createVecBatch(sourceTypes, sourceDatas1); + sortWithExprOperator.addInput(vecBatch1); + + Object[][] sourceDatas2 = {{15}, {13L}}; + VecBatch vecBatch2 = createVecBatch(sourceTypes, sourceDatas2); + sortWithExprOperator.addInput(vecBatch2); + + Object[][] sourceDatas3 = {{10}, {8L}}; + VecBatch vecBatch3 = createVecBatch(sourceTypes, sourceDatas3); + sortWithExprOperator.addInput(vecBatch3); + + Iterator results = sortWithExprOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + assertEquals(resultVecBatch.getRowCount(), + sourceDatas1[0].length + sourceDatas2[0].length + sourceDatas3[0].length); + Object[][] expectedDatas = {{5, 10, 15}, {3L, 8L, 13L}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + + freeVecBatch(resultVecBatch); + sortWithExprOperator.close(); + sortWithExprOperatorFactory.close(); + File spillDir = new File(spillPath); + spillDir.delete(); + } + + /** + * Test sort spill with null path + */ + @Test(expectedExceptions = OmniRuntimeException.class, expectedExceptionsMessageRegExp = "Enable spill but do not config spill path.") + public void testSortSpillWithNullPath() { + DataType[] sourceTypes = {IntDataType.INTEGER, LongDataType.LONG}; + int[] outputCols = {0, 1}; + String[] sortKeys = {getOmniJsonFieldReference(1, 0), getOmniJsonFieldReference(2, 1)}; + int[] ascendings = {1, 1}; + int[] nullFirsts = {0, 0}; + + OmniSortWithExprOperatorFactory sortWithExprOperatorFactory = new OmniSortWithExprOperatorFactory(sourceTypes, + outputCols, sortKeys, ascendings, nullFirsts, new OperatorConfig(new SparkSpillConfig(null, 1))); + } + + /** + * Test sort spill with empty path + */ + @Test(expectedExceptions = OmniRuntimeException.class, expectedExceptionsMessageRegExp = "Enable spill but do not config spill path.") + public void testSortSpillWithEmptyPath() { + DataType[] sourceTypes = {IntDataType.INTEGER, LongDataType.LONG}; + int[] outputCols = {0, 1}; + String[] sortKeys = {getOmniJsonFieldReference(1, 0), getOmniJsonFieldReference(2, 1)}; + int[] ascendings = {1, 1}; + int[] nullFirsts = {0, 0}; + + OmniSortWithExprOperatorFactory sortWithExprOperatorFactory = new OmniSortWithExprOperatorFactory(sourceTypes, + outputCols, sortKeys, ascendings, nullFirsts, new OperatorConfig(new SparkSpillConfig("", 1))); + } + + /** + * Test Sort spill with existed path + */ + @Test + public void testSortSpillWithExistedPath() { + DataType[] sourceTypes = {IntDataType.INTEGER, LongDataType.LONG}; + int[] outputCols = {0, 1}; + String[] sortKeys = {getOmniJsonFieldReference(1, 0), getOmniJsonFieldReference(2, 1)}; + int[] ascendings = {1, 1}; + int[] nullFirsts = {0, 0}; + OmniSortWithExprOperatorFactory sortWithExprOperatorFactory = new OmniSortWithExprOperatorFactory(sourceTypes, + outputCols, sortKeys, ascendings, nullFirsts, + new OperatorConfig(new SparkSpillConfig(true, "/opt", MAX_SPILL_BYTES, 1))); + OmniOperator sortWithExprOperator = sortWithExprOperatorFactory.createOperator(); + Object[][] sourceDatas1 = {{5, 3}, {5L, 3L}}; + VecBatch vecBatch = createVecBatch(sourceTypes, sourceDatas1); + sortWithExprOperator.addInput(vecBatch); + Iterator results = sortWithExprOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + Object[][] expectedDatas = {{3, 5}, {3L, 5L}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + + freeVecBatch(resultVecBatch); + sortWithExprOperator.close(); + sortWithExprOperatorFactory.close(); + } + + /** + * Test sort spill with invalid path + */ + @Test + public void testSortSpillWithInvalidPath() { + DataType[] sourceTypes = {IntDataType.INTEGER, LongDataType.LONG}; + int[] outputCols = {0, 1}; + String[] sortKeys = {getOmniJsonFieldReference(1, 0), getOmniJsonFieldReference(2, 1)}; + int[] ascendings = {1, 1}; + int[] nullFirsts = {0, 0}; + String spillPath = "/opt/+-ab23"; + OmniSortWithExprOperatorFactory sortWithExprOperatorFactory = new OmniSortWithExprOperatorFactory(sourceTypes, + outputCols, sortKeys, ascendings, nullFirsts, + new OperatorConfig(new SparkSpillConfig(true, spillPath, MAX_SPILL_BYTES, 1))); + OmniOperator sortWithExprOperator = sortWithExprOperatorFactory.createOperator(); + Object[][] sourceDatas1 = {{5, 3}, {5L, 3L}}; + VecBatch vecBatch = createVecBatch(sourceTypes, sourceDatas1); + sortWithExprOperator.addInput(vecBatch); + Iterator results = sortWithExprOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + Object[][] expectedDatas = {{3, 5}, {3L, 5L}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + + freeVecBatch(resultVecBatch); + sortWithExprOperator.close(); + sortWithExprOperatorFactory.close(); + File spillDir = new File(spillPath); + spillDir.delete(); + } + + /** + * Test sort spill with invalid keys + */ + @Test(expectedExceptions = OmniRuntimeException.class, expectedExceptionsMessageRegExp = ".*EXPRESSION_NOT_SUPPORT.*") + public void testSortWithInvalidKeys() { + DataType[] sourceTypes = {IntDataType.INTEGER, LongDataType.LONG}; + int[] outputCols = {0, 1}; + String[] sortKeys = {omniFunctionExpr("abc", 2, getOmniJsonFieldReference(2, 1))}; + int[] ascendings = {1}; + int[] nullFirsts = {0}; + OmniSortWithExprOperatorFactory sortWithExprOperatorFactory = new OmniSortWithExprOperatorFactory(sourceTypes, + outputCols, sortKeys, ascendings, nullFirsts); + } + + /** + * Test sort multi threads + */ + @Test + public void testSortWithExprOperatorFactoryMultiThreads() { + final int threadNum = 100; + final int corePoolSize = 50; + final int maximumPoolSize = 100; + CountDownLatch countDownLatch = new CountDownLatch(threadNum); + ThreadPoolExecutor threadPool = new ThreadPoolExecutor(corePoolSize, maximumPoolSize, 60L, TimeUnit.SECONDS, + new LinkedBlockingQueue<>(threadNum)); + for (int i = 0; i < threadNum; i++) { + CompletableFuture.runAsync(() -> { + try { + DataType[] sourceTypes = {VarcharDataType.VARCHAR, IntDataType.INTEGER}; + int[] outputCols = {0, 1}; + String[] sortKeys = {getOmniJsonFieldReference(1, 1)}; + int[] ascendings = {1}; + int[] nullFirsts = {0}; + String spillPath = Paths.get("").toAbsolutePath() + File.separator + UUID.randomUUID(); + SparkSpillConfig spillConfig = new SparkSpillConfig(false, spillPath, Long.MAX_VALUE, + Integer.MAX_VALUE); + OmniSortWithExprOperatorFactory sortWithExprOperatorFactory = new OmniSortWithExprOperatorFactory( + sourceTypes, outputCols, sortKeys, ascendings, nullFirsts, + new OperatorConfig(spillConfig, new OverflowConfig(), true)); + sortWithExprOperatorFactory.close(); + } finally { + countDownLatch.countDown(); + } + }, threadPool); + } + + // This will wait until all future ready. + try { + countDownLatch.await(); + } catch (InterruptedException e) { + assertTrue(false); + } + + threadPool.shutdown(); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniTopNOperatorTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniTopNOperatorTest.java new file mode 100644 index 0000000..ac7bb44 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniTopNOperatorTest.java @@ -0,0 +1,345 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator; + +import static nova.hetu.omniruntime.util.TestUtils.assertVecBatchEquals; +import static nova.hetu.omniruntime.util.TestUtils.freeVecBatch; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; + +import com.google.common.collect.ImmutableList; + +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.operator.topn.OmniTopNOperatorFactory; +import nova.hetu.omniruntime.operator.topn.OmniTopNOperatorFactory.FactoryContext; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DoubleDataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.util.TestUtils; +import nova.hetu.omniruntime.vector.DoubleVec; +import nova.hetu.omniruntime.vector.IntVec; +import nova.hetu.omniruntime.vector.LongVec; +import nova.hetu.omniruntime.vector.Vec; +import nova.hetu.omniruntime.vector.VecBatch; + +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; + +/** + * The type Omni TopN operator test. + * + * @since 2021-7-31 + */ +public class OmniTopNOperatorTest { + /** + * The Total page count. + */ + int totalPageCount = 10; + + /** + * The Page distinct count. + */ + int pageDistinctCount = 4; + + /** + * The Page distinct value repeat count. + */ + int pageDistinctValueRepeatCount = 2500; + + /** + * test topN performance whether with jit or not. + */ + @Test + public void testTopNComparePref() { + DataType[] sourceTypes = {LongDataType.LONG, LongDataType.LONG}; + int limitN = 10; + int[] outputCols = {0, 1}; + String[] sortCols = {"#0", "#1"}; + int[] sotrAscendings = {1, 1}; + int[] sortNullFirsts = {0, 0}; + + OmniTopNOperatorFactory topNOperatorFactoryWithoutJit = new OmniTopNOperatorFactory(sourceTypes, limitN, + sortCols, sotrAscendings, sortNullFirsts, new OperatorConfig()); + OmniOperator topNOperatorWithoutJit = topNOperatorFactoryWithoutJit.createOperator(); + ImmutableList vecsWithoutJit = buildVecs(); + + long start = System.currentTimeMillis(); + for (VecBatch vec : vecsWithoutJit) { + topNOperatorWithoutJit.addInput(vec); + } + Iterator outputWithoutJit = topNOperatorWithoutJit.getOutput(); + long end = System.currentTimeMillis(); + System.out.println("TopN without jit use " + (end - start) + " ms."); + + OmniTopNOperatorFactory topNOperatorFactoryWithJit = new OmniTopNOperatorFactory(sourceTypes, limitN, sortCols, + sotrAscendings, sortNullFirsts, new OperatorConfig()); + OmniOperator topNOperatorWithJit = topNOperatorFactoryWithJit.createOperator(); + ImmutableList vecsWithJit = buildVecs(); + + start = System.currentTimeMillis(); + for (VecBatch vec : vecsWithJit) { + topNOperatorWithJit.addInput(vec); + } + Iterator outputWithJit = topNOperatorWithJit.getOutput(); + end = System.currentTimeMillis(); + System.out.println("TopN with jit use " + (end - start) + " ms."); + + while (outputWithoutJit.hasNext()) { + VecBatch resultWithoutJit = outputWithoutJit.next(); + VecBatch resultWithJit = outputWithJit.next(); + assertVecBatchEquals(resultWithoutJit, resultWithJit); + freeVecBatch(resultWithoutJit); + freeVecBatch(resultWithJit); + } + + topNOperatorWithoutJit.close(); + topNOperatorWithJit.close(); + topNOperatorFactoryWithoutJit.close(); + topNOperatorFactoryWithJit.close(); + } + + @Test + public void testOneColumn() { + int rowSize = 6; + int expectedRowSize = 5; + long[] rawData = {0, 1, 2, 0, 1, 2}; + LongVec longVec = new LongVec(6); + longVec.put(rawData, 0, 0, rowSize); + ArrayList longVecs = new ArrayList<>(); + longVecs.add(longVec); + + DataType[] sourceTypes = {LongDataType.LONG}; + String[] sortCols = {"#0"}; + int[] sortAsc = {0}; + int[] nullFirst = {0}; + OmniTopNOperatorFactory omniTopNOperatorFactory = new OmniTopNOperatorFactory(sourceTypes, expectedRowSize, + sortCols, sortAsc, nullFirst); + OmniOperator operator = omniTopNOperatorFactory.createOperator(); + operator.addInput(new VecBatch(longVecs)); + Iterator output = operator.getOutput(); + VecBatch result = output.next(); + assertEquals(result.getRowCount(), expectedRowSize); + Vec vector = result.getVectors()[0]; + long[] resultArray = new long[expectedRowSize]; + for (int i = 0; i < vector.getSize(); i++) { + resultArray[i] = ((LongVec) vector).get(i); + } + long[] expectedArray = {2, 2, 1, 1, 0}; + assertEquals(resultArray, expectedArray); + + TestUtils.freeVecBatch(result); + + operator.close(); + omniTopNOperatorFactory.close(); + } + + @Test + public void testMultipleColumns() { + int rowSize = 6; + int[] rawData1 = {0, 1, 2, 0, 1, 2}; + long[] rawData2 = {0, 1, 2, 3, 4, 5}; + double[] rawData3 = {6.6, 5.5, 4.4, 3.3, 2.2, 1.1}; + IntVec vec1 = new IntVec(6); + LongVec vec2 = new LongVec(6); + DoubleVec vec3 = new DoubleVec(6); + vec1.put(rawData1, 0, 0, rowSize); + vec2.put(rawData2, 0, 0, rowSize); + vec3.put(rawData3, 0, 0, rowSize); + ArrayList longVecs = new ArrayList<>(); + longVecs.add(vec1); + longVecs.add(vec2); + longVecs.add(vec3); + + DataType[] sourceTypes = {IntDataType.INTEGER, LongDataType.LONG, DoubleDataType.DOUBLE}; + String[] sortCols = {"#0", "#1"}; + int[] sortAsc = {1, 1}; + int[] nullFirst = {0, 0}; + int expectedRowSize = 5; + OmniTopNOperatorFactory omniTopNOperatorFactory = new OmniTopNOperatorFactory(sourceTypes, expectedRowSize, + sortCols, sortAsc, nullFirst); + OmniOperator operator = omniTopNOperatorFactory.createOperator(); + operator.addInput(new VecBatch(longVecs)); + Iterator output = operator.getOutput(); + VecBatch result = output.next(); + assertEquals(result.getRowCount(), expectedRowSize); + Vec[] vector = result.getVectors(); + int[] resultArray1 = new int[expectedRowSize]; + long[] resultArray2 = new long[expectedRowSize]; + double[] resultArray3 = new double[expectedRowSize]; + for (int i = 0; i < vector[0].getSize(); i++) { + resultArray1[i] = ((IntVec) vector[0]).get(i); + } + for (int i = 0; i < vector[1].getSize(); i++) { + resultArray2[i] = ((LongVec) vector[1]).get(i); + } + for (int i = 0; i < vector[2].getSize(); i++) { + resultArray3[i] = ((DoubleVec) vector[2]).get(i); + } + int[] expectedArray1 = {0, 0, 1, 1, 2}; + long[] expectedArray2 = {0, 3, 1, 4, 2}; + double[] expectedArray3 = {6.6, 3.3, 5.5, 2.2, 4.4}; + assertEquals(resultArray1, expectedArray1); + assertEquals(resultArray2, expectedArray2); + assertEquals(resultArray3, expectedArray3); + + TestUtils.freeVecBatch(result); + + operator.close(); + omniTopNOperatorFactory.close(); + } + + @Test + public void testTopNWithOffset() { + int rowSize = 6; + int[] rawData1 = {0, 1, 2, 0, 1, 2}; + double[] rawData2 = {6.6, 5.5, 4.4, 3.3, 2.2, 1.1}; + IntVec vec1 = new IntVec(6); + DoubleVec vec2 = new DoubleVec(6); + vec1.put(rawData1, 0, 0, rowSize); + vec2.put(rawData2, 0, 0, rowSize); + ArrayList longVecs = new ArrayList<>(); + longVecs.add(vec1); + longVecs.add(vec2); + + DataType[] sourceTypes = {IntDataType.INTEGER, DoubleDataType.DOUBLE}; + String[] sortCols = {"#0", "#1"}; + int[] sortAsc = {1, 0}; + int[] nullFirst = {0, 0}; + int limitSize = 5; + int expectedRowSize = limitSize - 2; + OmniTopNOperatorFactory omniTopNOperatorFactory = new OmniTopNOperatorFactory(sourceTypes, limitSize, 2, + sortCols, sortAsc, nullFirst); + OmniOperator operator = omniTopNOperatorFactory.createOperator(); + operator.addInput(new VecBatch(longVecs)); + Iterator output = operator.getOutput(); + VecBatch result = output.next(); + assertEquals(result.getRowCount(), expectedRowSize); + Vec[] vector = result.getVectors(); + int[] resultArray1 = new int[expectedRowSize]; + double[] resultArray2 = new double[expectedRowSize]; + for (int i = 0; i < vector[0].getSize(); i++) { + if (vector[0] instanceof IntVec) { + resultArray1[i] = ((IntVec) vector[0]).get(i); + } + } + for (int i = 0; i < vector[1].getSize(); i++) { + if (vector[1] instanceof DoubleVec) { + resultArray2[i] = ((DoubleVec) vector[1]).get(i); + } + } + + int[] expectedArray1 = {1, 1, 2}; + double[] expectedArray2 = {5.5, 2.2, 4.4}; + assertEquals(resultArray1, expectedArray1); + assertEquals(resultArray2, expectedArray2); + + TestUtils.freeVecBatch(result); + + operator.close(); + omniTopNOperatorFactory.close(); + } + + @Test + public void testTopNDescMultiColumnSortColumn1() { + int rowSize = 6; + int[] rawData1 = {0, 1, 2, 0, 1, 2}; + long[] rawData2 = {0, 1, 2, 3, 4, 5}; + double[] rawData3 = {6.6, 5.5, 4.4, 3.3, 2.2, 1.1}; + IntVec vec1 = new IntVec(6); + LongVec vec2 = new LongVec(6); + DoubleVec vec3 = new DoubleVec(6); + vec1.put(rawData1, 0, 0, rowSize); + vec2.put(rawData2, 0, 0, rowSize); + vec3.put(rawData3, 0, 0, rowSize); + ArrayList longVecs = new ArrayList<>(); + longVecs.add(vec1); + longVecs.add(vec2); + longVecs.add(vec3); + + DataType[] sourceTypes = {IntDataType.INTEGER, LongDataType.LONG, DoubleDataType.DOUBLE}; + String[] sortCols = {"#1"}; + int[] sortAsc = {0}; + int[] nullFirst = {0}; + int expectedRowSize = 5; + OmniTopNOperatorFactory omniTopNOperatorFactory = new OmniTopNOperatorFactory(sourceTypes, expectedRowSize, + sortCols, sortAsc, nullFirst); + OmniOperator operator = omniTopNOperatorFactory.createOperator(); + operator.addInput(new VecBatch(longVecs)); + Iterator output = operator.getOutput(); + VecBatch result = output.next(); + assertEquals(result.getRowCount(), expectedRowSize); + Vec[] vector = result.getVectors(); + List resultList1 = new ArrayList<>(); + List resultList2 = new ArrayList<>(); + List resultList3 = new ArrayList<>(); + + for (int i = 0; i < vector[0].getSize(); i++) { + resultList1.add(((IntVec) vector[0]).get(i)); + } + for (int i = 0; i < vector[1].getSize(); i++) { + resultList2.add(((LongVec) vector[1]).get(i)); + } + for (int i = 0; i < vector[2].getSize(); i++) { + resultList3.add(((DoubleVec) vector[2]).get(i)); + } + + ArrayList expectList1 = new ArrayList<>(Arrays.asList(2, 1, 0, 2, 1)); + ArrayList expectList2 = new ArrayList<>(Arrays.asList(5L, 4L, 3L, 2L, 1L)); + ArrayList expectList3 = new ArrayList<>(Arrays.asList(1.1, 2.2, 3.3, 4.4, 5.5)); + assertEquals(expectList1, resultList1); + assertEquals(expectList2, resultList2); + assertEquals(expectList3, resultList3); + + TestUtils.freeVecBatch(result); + + operator.close(); + omniTopNOperatorFactory.close(); + } + + @Test + public void testFactoryContextEquals() { + DataType[] sourceTypes = {IntDataType.INTEGER, LongDataType.LONG, DoubleDataType.DOUBLE}; + String[] sortCols = {"#1"}; + int[] sortAsc = {0}; + int[] nullFirst = {0}; + int expectedRowSize = 5; + FactoryContext factory1 = new FactoryContext(sourceTypes, expectedRowSize, 0, sortCols, sortAsc, nullFirst, + new OperatorConfig()); + FactoryContext factory2 = new FactoryContext(sourceTypes, expectedRowSize, 0, sortCols, sortAsc, nullFirst, + new OperatorConfig()); + FactoryContext factory3 = null; + assertEquals(factory2, factory1); + assertEquals(factory1, factory1); + assertNotEquals(factory3, factory1); + } + + private ImmutableList buildVecs() { + ImmutableList.Builder vecBatchList = ImmutableList.builder(); + int positionCount = pageDistinctCount * pageDistinctValueRepeatCount; + List vecs = new ArrayList<>(); + for (int i = 0; i < totalPageCount; i++) { + LongVec longVec1 = new LongVec(positionCount); + LongVec longVec2 = new LongVec(positionCount); + int idx = 0; + for (int j = 0; j < pageDistinctCount; j++) { + for (int k = 0; k < pageDistinctValueRepeatCount; k++) { + longVec1.set(idx, j); + longVec2.set(idx, j); + idx++; + } + } + vecs.add(longVec1); + vecs.add(longVec2); + VecBatch vecBatch = new VecBatch(new Vec[]{longVec1, longVec2}); + vecBatchList.add(vecBatch); + } + return vecBatchList.build(); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniTopNSortWithExprOperatorTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniTopNSortWithExprOperatorTest.java new file mode 100644 index 0000000..6b943aa --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniTopNSortWithExprOperatorTest.java @@ -0,0 +1,136 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + */ + +package nova.hetu.omniruntime.operator; + +import static nova.hetu.omniruntime.util.TestUtils.assertVecBatchEquals; +import static nova.hetu.omniruntime.util.TestUtils.createVecBatch; +import static nova.hetu.omniruntime.util.TestUtils.freeVecBatch; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonFieldReference; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; + +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.operator.topnsort.OmniTopNSortWithExprOperatorFactory; +import nova.hetu.omniruntime.operator.topnsort.OmniTopNSortWithExprOperatorFactory.FactoryContext; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.type.VarcharDataType; +import nova.hetu.omniruntime.vector.VecBatch; + +import org.testng.annotations.Test; + +import java.util.Iterator; + +/** + * The omni TopNSort with expression operator test. + * + * @since 2021-7-31 + */ +public class OmniTopNSortWithExprOperatorTest { + @Test + public void testTopNSortDescNullLast() { + DataType[] sourceTypes = {new VarcharDataType(10), LongDataType.LONG, LongDataType.LONG}; + Object[][] sourceDatas = {{"hi", "hi", "hi", "bye", "bye", "bye", "bye", "bye"}, + {2L, 5L, 3L, 11L, 4L, 3L, 0L, 23L}, {3L, 5L, 8L, 3L, 5L, 3L, 4L, 3L}}; + VecBatch vecBatch = createVecBatch(sourceTypes, sourceDatas); + + String[] partitionKeys = {getOmniJsonFieldReference(15, 0)}; + String[] sortKeys = {getOmniJsonFieldReference(2, 2)}; + int[] ascendings = {0}; + int[] nullFirsts = {0}; + OmniTopNSortWithExprOperatorFactory topNSortOperatorFactory = new OmniTopNSortWithExprOperatorFactory( + sourceTypes, 3, false, partitionKeys, sortKeys, ascendings, nullFirsts); + OmniOperator topNSortOperator = topNSortOperatorFactory.createOperator(); + topNSortOperator.addInput(vecBatch); + Iterator results = topNSortOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + assertEquals(resultVecBatch.getRowCount(), sourceDatas[0].length); + Object[][] expectedDatas = {{"bye", "bye", "bye", "bye", "bye", "hi", "hi", "hi"}, + {4L, 0L, 11L, 3L, 23L, 3L, 5L, 2L}, {5L, 4L, 3L, 3L, 3L, 8L, 5L, 3L}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + + freeVecBatch(resultVecBatch); + topNSortOperator.close(); + topNSortOperatorFactory.close(); + } + + @Test + public void testTopNSortAscNullLast() { + DataType[] sourceTypes = {new VarcharDataType(10), LongDataType.LONG, LongDataType.LONG}; + Object[][] sourceDatas = {{"hi", "hi", "hi", "bye", "bye", "bye", "bye", "bye"}, + {2L, 5L, 3L, 11L, 4L, 3L, 0L, 23L}, {5L, 3L, 8L, 3L, 6L, 6L, 4L, 6L}}; + VecBatch vecBatch = createVecBatch(sourceTypes, sourceDatas); + + String[] partitionKeys = {getOmniJsonFieldReference(15, 0)}; + String[] sortKeys = {getOmniJsonFieldReference(2, 2)}; + int[] ascendings = {1}; + int[] nullFirsts = {0}; + OmniTopNSortWithExprOperatorFactory topNSortOperatorFactory = new OmniTopNSortWithExprOperatorFactory( + sourceTypes, 3, false, partitionKeys, sortKeys, ascendings, nullFirsts); + OmniOperator topNSortOperator = topNSortOperatorFactory.createOperator(); + topNSortOperator.addInput(vecBatch); + Iterator results = topNSortOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + assertEquals(resultVecBatch.getRowCount(), sourceDatas[0].length); + Object[][] expectedDatas = {{"bye", "bye", "bye", "bye", "bye", "hi", "hi", "hi"}, + {11L, 0L, 4L, 3L, 23L, 5L, 2L, 3L}, {3L, 4L, 6L, 6L, 6L, 3L, 5L, 8L}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + + freeVecBatch(resultVecBatch); + topNSortOperator.close(); + topNSortOperatorFactory.close(); + } + + @Test + public void testTopNSortAscNullFirst() { + DataType[] sourceTypes = {new VarcharDataType(10), LongDataType.LONG, LongDataType.LONG}; + Object[][] sourceDatas = {{"hi", "hi", "hi", "bye", "bye", "bye", "bye", "bye"}, + {2L, 5L, 3L, 11L, 3L, 3L, 0L, 3L}, {5L, 3L, 8L, 3L, 6L, 6L, 4L, 6L}}; + VecBatch vecBatch = createVecBatch(sourceTypes, sourceDatas); + + String[] partitionKeys = {getOmniJsonFieldReference(15, 0)}; + String[] sortKeys = {getOmniJsonFieldReference(2, 2), getOmniJsonFieldReference(2, 1)}; + int[] ascendings = {1, 1}; + int[] nullFirsts = {0, 0}; + OmniTopNSortWithExprOperatorFactory topNSortOperatorFactory = new OmniTopNSortWithExprOperatorFactory( + sourceTypes, 3, false, partitionKeys, sortKeys, ascendings, nullFirsts, new OperatorConfig()); + OmniOperator topNSortOperator = topNSortOperatorFactory.createOperator(); + topNSortOperator.addInput(vecBatch); + Iterator results = topNSortOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + assertEquals(resultVecBatch.getRowCount(), sourceDatas[0].length); + Object[][] expectedDatas = {{"bye", "bye", "bye", "bye", "bye", "hi", "hi", "hi"}, + {11L, 0L, 3L, 3L, 3L, 5L, 2L, 3L}, {3L, 4L, 6L, 6L, 6L, 3L, 5L, 8L}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + + freeVecBatch(resultVecBatch); + topNSortOperator.close(); + topNSortOperatorFactory.close(); + } + + @Test + public void testFactoryContextEquals() { + DataType[] sourceTypes = {IntDataType.INTEGER, LongDataType.LONG}; + int limitN = 10; + boolean isStrictTopN = false; + String[] partitionKeys = {getOmniJsonFieldReference(1, 0)}; + String[] sortKeys = {getOmniJsonFieldReference(2, 1)}; + int[] sortAscendings = {1}; + int[] sortNullFirsts = {1}; + OperatorConfig operatorConfig = new OperatorConfig(); + FactoryContext factory1 = new FactoryContext(sourceTypes, limitN, isStrictTopN, partitionKeys, sortKeys, + sortAscendings, sortNullFirsts, operatorConfig); + FactoryContext factory2 = new FactoryContext(sourceTypes, limitN, isStrictTopN, partitionKeys, sortKeys, + sortAscendings, sortNullFirsts, operatorConfig); + FactoryContext factory3 = null; + assertEquals(factory2, factory1); + assertEquals(factory1, factory1); + assertNotEquals(factory3, factory1); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniTopNWithExprOperatorTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniTopNWithExprOperatorTest.java new file mode 100644 index 0000000..d47853a --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniTopNWithExprOperatorTest.java @@ -0,0 +1,301 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator; + +import static nova.hetu.omniruntime.util.TestUtils.assertVecBatchEquals; +import static nova.hetu.omniruntime.util.TestUtils.createVecBatch; +import static nova.hetu.omniruntime.util.TestUtils.freeVecBatch; +import static nova.hetu.omniruntime.util.TestUtils.freeVecBatches; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonFieldReference; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonLiteral; +import static nova.hetu.omniruntime.util.TestUtils.omniFunctionExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonFourArithmeticExpr; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; + +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.operator.topn.OmniTopNWithExprOperatorFactory; +import nova.hetu.omniruntime.operator.topn.OmniTopNWithExprOperatorFactory.FactoryContext; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.utils.OmniRuntimeException; +import nova.hetu.omniruntime.vector.VecBatch; + +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Arrays; + +/** + * The type Omni TopN with expression operator test. + * + * @since 2021-11-11 + */ +public class OmniTopNWithExprOperatorTest { + @Test + public void testTopNWithAllExpr() { + DataType[] sourceTypes = {IntDataType.INTEGER, LongDataType.LONG, LongDataType.LONG}; + String[] sortKeys = { + omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonFieldReference(1, 0), getOmniJsonLiteral(1, false, 5)), + omniJsonFourArithmeticExpr("MODULUS", 2, getOmniJsonFieldReference(2, 2), + getOmniJsonLiteral(2, false, 3))}; + int[] sortAsc = {0, 1}; + int[] nullFirst = {0, 0}; + + int expectedRowSize = 5; + + OmniTopNWithExprOperatorFactory omniTopNOperatorFactory = new OmniTopNWithExprOperatorFactory(sourceTypes, + expectedRowSize, sortKeys, sortAsc, nullFirst); + OmniOperator operator = omniTopNOperatorFactory.createOperator(); + + Object[][] sourceDatas = {{5, 8, 8, 6, 8, 4, 13, 15}, {2L, 5L, 3L, 11L, 4L, 3L, 0L, 23L}, + {5L, 3L, 2L, 6L, 1L, 4L, 7L, 8L}}; + VecBatch vecBatch = createVecBatch(sourceTypes, sourceDatas); + + operator.addInput(vecBatch); + Iterator output = operator.getOutput(); + + assertEquals(output.hasNext(), true); + VecBatch resultVecBatch = output.next(); + assertEquals(output.hasNext(), false); + assertEquals(resultVecBatch.getRowCount(), expectedRowSize); + assertEquals(resultVecBatch.getVectorCount(), sourceTypes.length + sortKeys.length); + + Object[][] expectedDatas = {{15, 13, 8, 8, 8}, {23L, 0L, 5L, 4L, 3L}, {8L, 7L, 3L, 1L, 2L}, + {20, 18, 13, 13, 13}, {2L, 1L, 0L, 1L, 2L}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + + freeVecBatch(resultVecBatch); + operator.close(); + omniTopNOperatorFactory.close(); + } + + @Test + public void testTopNWithPartialExpr() { + DataType[] sourceTypes = {IntDataType.INTEGER, LongDataType.LONG, LongDataType.LONG}; + String[] sortKeys = {getOmniJsonFieldReference(1, 0), omniJsonFourArithmeticExpr("MODULUS", 2, + getOmniJsonFieldReference(2, 2), getOmniJsonLiteral(2, false, 3))}; + int[] sortAsc = {0, 1}; + int[] nullFirst = {0, 0}; + + int expectedRowSize = 5; + + OmniTopNWithExprOperatorFactory omniTopNOperatorFactory = new OmniTopNWithExprOperatorFactory(sourceTypes, + expectedRowSize, sortKeys, sortAsc, nullFirst); + OmniOperator operator = omniTopNOperatorFactory.createOperator(); + + Object[][] sourceDatas = {{5, 8, 8, 6, 8, 4, 13, 15}, {2L, 5L, 3L, 11L, 4L, 3L, 0L, 23L}, + {5L, 3L, 2L, 6L, 1L, 4L, 7L, 8L}}; + VecBatch vecBatch = createVecBatch(sourceTypes, sourceDatas); + + operator.addInput(vecBatch); + Iterator output = operator.getOutput(); + + assertEquals(output.hasNext(), true); + VecBatch resultVecBatch = output.next(); + assertEquals(output.hasNext(), false); + assertEquals(resultVecBatch.getRowCount(), expectedRowSize); + assertEquals(resultVecBatch.getVectorCount(), 4); + + Object[][] expectedDatas = {{15, 13, 8, 8, 8}, {23L, 0L, 5L, 4L, 3L}, {8L, 7L, 3L, 1L, 2L}, + {2L, 1L, 0L, 1L, 2L}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + + freeVecBatch(resultVecBatch); + operator.close(); + omniTopNOperatorFactory.close(); + } + + @Test + public void testTopNWithNoExpr() { + DataType[] sourceTypes = {IntDataType.INTEGER, LongDataType.LONG, LongDataType.LONG}; + String[] sortKeys = {getOmniJsonFieldReference(1, 0), getOmniJsonFieldReference(2, 2)}; + int[] sortAsc = {0, 1}; + int[] nullFirst = {0, 0}; + + int expectedRowSize = 5; + + OmniTopNWithExprOperatorFactory omniTopNOperatorFactory = new OmniTopNWithExprOperatorFactory(sourceTypes, + expectedRowSize, sortKeys, sortAsc, nullFirst); + OmniOperator operator = omniTopNOperatorFactory.createOperator(); + + Object[][] sourceDatas = {{5, 8, 8, 6, 8, 4, 13, 15}, {2L, 5L, 3L, 11L, 4L, 3L, 0L, 23L}, + {5L, 3L, 2L, 6L, 1L, 4L, 7L, 8L}}; + VecBatch vecBatch = createVecBatch(sourceTypes, sourceDatas); + + operator.addInput(vecBatch); + Iterator output = operator.getOutput(); + + assertEquals(output.hasNext(), true); + VecBatch resultVecBatch = output.next(); + assertEquals(output.hasNext(), false); + assertEquals(resultVecBatch.getRowCount(), expectedRowSize); + assertEquals(resultVecBatch.getVectorCount(), 3); + + Object[][] expectedDatas = {{15, 13, 8, 8, 8}, {23L, 0L, 4L, 3L, 5L}, {8L, 7L, 1L, 2L, 3L}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + + freeVecBatch(resultVecBatch); + operator.close(); + omniTopNOperatorFactory.close(); + } + + @Test(expectedExceptions = OmniRuntimeException.class, expectedExceptionsMessageRegExp = ".*EXPRESSION_NOT_SUPPORT.*") + public void testTopNWithInvalidKeys() { + DataType[] sourceTypes = {IntDataType.INTEGER, LongDataType.LONG, LongDataType.LONG}; + String[] sortKeys = {omniFunctionExpr("abc", 2, getOmniJsonFieldReference(2, 1))}; + int[] sortAsc = {0}; + int[] nullFirst = {0}; + int expectedRowSize = 5; + + OmniTopNWithExprOperatorFactory omniTopNOperatorFactory = new OmniTopNWithExprOperatorFactory(sourceTypes, + expectedRowSize, sortKeys, sortAsc, nullFirst); + } + + @Test + public void testFactoryContextEquals() { + DataType[] sourceTypes = {IntDataType.INTEGER, LongDataType.LONG, LongDataType.LONG}; + String[] sortKeys = {getOmniJsonFieldReference(1, 0), getOmniJsonFieldReference(2, 2)}; + int[] sortAsc = {0, 1}; + int[] nullFirst = {0, 0}; + int expectedRowSize = 5; + FactoryContext factory1 = new FactoryContext(sourceTypes, expectedRowSize, 0, sortKeys, sortAsc, nullFirst, + new OperatorConfig()); + FactoryContext factory2 = new FactoryContext(sourceTypes, expectedRowSize, 0, sortKeys, sortAsc, nullFirst, + new OperatorConfig()); + FactoryContext factory3 = null; + assertEquals(factory2, factory1); + assertEquals(factory1, factory1); + assertNotEquals(factory3, factory1); + } + + private void buildIterativeExpectedData(Object[][] expectedData1, Object[][] expectedData2, int sourceDataSize, + int maxRowCount, int expectedRowSize) { + for (int i = 0; i < maxRowCount; i++) { + if (i % 2 == 0) { + expectedData1[0][i] = sourceDataSize - 1L - i; + expectedData1[1][i] = sourceDataSize - 1L - i; + } else { + expectedData1[0][i] = expectedData1[0][i - 1]; + expectedData1[1][i] = (long) expectedData1[1][i - 1] + 1L; + } + expectedData1[2][i] = (long) expectedData1[0][i] + 5L; + expectedData1[3][i] = (long) expectedData1[1][i] + 2L; + } + + for (int i = 0; i < expectedRowSize - maxRowCount; i++) { + if (i % 2 == 0) { + expectedData2[0][i] = sourceDataSize - maxRowCount - 1L - i; + expectedData2[1][i] = sourceDataSize - maxRowCount - 1L - i; + } else { + expectedData2[0][i] = expectedData2[0][i - 1]; + expectedData2[1][i] = (long) expectedData2[1][i - 1] + 1L; + } + expectedData2[2][i] = (long) expectedData2[0][i] + 5L; + expectedData2[3][i] = (long) expectedData2[1][i] + 2L; + } + } + + @Test + public void testTopNWithExprAndOffset() { + DataType[] sourceTypes = {IntDataType.INTEGER, LongDataType.LONG, LongDataType.LONG}; + String[] sortKeys = {getOmniJsonFieldReference(1, 0), getOmniJsonFieldReference(2, 2)}; + int[] sortAsc = {0, 1}; + int[] nullFirst = {0, 0}; + + int limitSize = 5; + int offsetSize = 2; + int expectedRowSize = limitSize - offsetSize; + + OmniTopNWithExprOperatorFactory omniTopNOperatorFactory = new OmniTopNWithExprOperatorFactory(sourceTypes, + limitSize, offsetSize, sortKeys, sortAsc, nullFirst); + OmniOperator operator = omniTopNOperatorFactory.createOperator(); + + List> sourceDatas = Arrays.asList(Arrays.asList(5, 8, 8, 6, 8, 4, 13, 15), + Arrays.asList(2L, 5L, 3L, 11L, 4L, 3L, 0L, 23L), + Arrays.asList(5L, 3L, 2L, 6L, 1L, 4L, 7L, 8L)); + VecBatch vecBatch = createVecBatch(sourceTypes, sourceDatas); + + operator.addInput(vecBatch); + Iterator output = operator.getOutput(); + + assertEquals(output.hasNext(), true); + VecBatch resultVecBatch = output.next(); + assertEquals(output.hasNext(), false); + assertEquals(resultVecBatch.getRowCount(), expectedRowSize); + assertEquals(resultVecBatch.getVectorCount(), 3); + + List> expectedDatas = new ArrayList<>(); + expectedDatas.add(Arrays.asList(8, 8, 8)); + expectedDatas.add(Arrays.asList(4L, 3L, 5L)); + expectedDatas.add(Arrays.asList(1L, 2L, 3L)); + + VecBatch expectedVecBatch = createVecBatch(sourceTypes, expectedDatas); + assertVecBatchEquals(resultVecBatch, expectedVecBatch); + + freeVecBatch(resultVecBatch); + operator.close(); + omniTopNOperatorFactory.close(); + } + + @Test + public void testTopNIterativeGetOutput() { + DataType[] sourceTypes = {LongDataType.LONG, LongDataType.LONG}; + String[] sortKeys = { + omniJsonFourArithmeticExpr("ADD", 2, getOmniJsonFieldReference(2, 0), getOmniJsonLiteral(2, false, 5)), + omniJsonFourArithmeticExpr("ADD", 2, getOmniJsonFieldReference(2, 1), getOmniJsonLiteral(2, false, 2))}; + int[] sortAsc = {0, 1}; + int[] nullFirst = {0, 0}; + + int sourceDataSize = 33000; + int expectedRowSize = 32800; + + OmniTopNWithExprOperatorFactory omniTopNOperatorFactory = new OmniTopNWithExprOperatorFactory(sourceTypes, + expectedRowSize, sortKeys, sortAsc, nullFirst); + OmniOperator operator = omniTopNOperatorFactory.createOperator(); + + Object[][] sourceData = new Object[2][sourceDataSize]; + for (int i = 0; i < sourceDataSize; i++) { + if (i % 2 == 0) { + sourceData[0][i] = i + 1L; + } else { + sourceData[0][i] = sourceData[0][i - 1]; + } + sourceData[1][i] = i + 1L; + } + VecBatch vecBatch = createVecBatch(sourceTypes, sourceData); + + operator.addInput(vecBatch); + Iterator topNIterator = operator.getOutput(); + List resultList = new ArrayList<>(); + while (topNIterator.hasNext()) { + resultList.add(topNIterator.next()); + } + + int resultRowCount = 0; + int resultVectorCount = 0; + for (int i = 0; i < resultList.size(); i++) { + resultRowCount += resultList.get(i).getRowCount(); + resultVectorCount = resultList.get(i).getVectorCount(); + } + assertEquals(resultRowCount, expectedRowSize); + assertEquals(resultVectorCount, sourceTypes.length + sortKeys.length); + + int maxRowCount = 32768; // 1M / (4 * 8) + Object[][] expectedData1 = new Object[resultVectorCount][maxRowCount]; + Object[][] expectedData2 = new Object[resultVectorCount][expectedRowSize - maxRowCount]; + buildIterativeExpectedData(expectedData1, expectedData2, sourceDataSize, maxRowCount, expectedRowSize); + + assertVecBatchEquals(resultList.get(0), expectedData1); + assertVecBatchEquals(resultList.get(1), expectedData2); + + freeVecBatches(resultList); + operator.close(); + omniTopNOperatorFactory.close(); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniUnionOperatorTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniUnionOperatorTest.java new file mode 100644 index 0000000..5263002 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniUnionOperatorTest.java @@ -0,0 +1,170 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator; + +import static nova.hetu.omniruntime.util.TestUtils.assertVecBatchEquals; +import static nova.hetu.omniruntime.util.TestUtils.createDictionaryVec; +import static nova.hetu.omniruntime.util.TestUtils.createLongVec; +import static nova.hetu.omniruntime.util.TestUtils.createVecBatch; +import static nova.hetu.omniruntime.util.TestUtils.freeVecBatch; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; + +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.operator.union.OmniUnionOperatorFactory; +import nova.hetu.omniruntime.operator.union.OmniUnionOperatorFactory.FactoryContext; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DoubleDataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.vector.Vec; +import nova.hetu.omniruntime.vector.VecBatch; + +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +/** + * The type Omni union operator test. + * + * @since 2021-8-11 + */ +public class OmniUnionOperatorTest { + /** + * Test the correctness of Omni union operator. + */ + @Test + public void testUnionByTwoCols() { + DataType[] sourceTypes = {IntDataType.INTEGER, DoubleDataType.DOUBLE}; + Object[][] sourceDatas1 = {{5, 3, 2, 6, 1, 4, 7, 8}, {5.0, 3.0, 2.0, 6.0, 1.0, 4.0, 7.0, 8.0}}; + VecBatch vecBatch1 = createVecBatch(sourceTypes, sourceDatas1); + + Object[][] sourceDatas2 = {{15, 13, 12, 16, 11, 14, 17, 18}, {15.0, 13.0, 12.0, 16.0, 11.0, 14.0, 17.0, 18.0}}; + VecBatch vecBatch2 = createVecBatch(sourceTypes, sourceDatas2); + + OmniUnionOperatorFactory unionOperatorFactory = new OmniUnionOperatorFactory(sourceTypes, false); + OmniOperator unionOperator = unionOperatorFactory.createOperator(); + unionOperator.addInput(vecBatch1); + unionOperator.addInput(vecBatch2); + Iterator results = unionOperator.getOutput(); + + Object[][] expectedDatas1 = {{5, 3, 2, 6, 1, 4, 7, 8}, {5.0, 3.0, 2.0, 6.0, 1.0, 4.0, 7.0, 8.0}}; + Object[][] expectedDatas2 = {{15, 13, 12, 16, 11, 14, 17, 18}, + {15.0, 13.0, 12.0, 16.0, 11.0, 14.0, 17.0, 18.0}}; + + List resultList = new ArrayList<>(); + while (results.hasNext()) { + resultList.add(results.next()); + } + + assertEquals(resultList.size(), 2); + assertVecBatchEquals(resultList.get(0), expectedDatas1); + assertVecBatchEquals(resultList.get(1), expectedDatas2); + + for (int i = 0; i < resultList.size(); i++) { + freeVecBatch(resultList.get(i)); + } + + unionOperator.close(); + unionOperatorFactory.close(); + } + + /** + * Test the correctness of Omni union operator when the data has null value. + */ + @Test + public void testUnionByTwoColsWithNulls() { + DataType[] sourceTypes = {IntDataType.INTEGER, DoubleDataType.DOUBLE}; + Object[][] sourceDatas1 = {{null, 3, 2, 6, 1, 4, 7, 8}, {5.0, 3.0, 2.0, 6.0, 1.0, 4.0, null, 8.0}}; + VecBatch vecBatch1 = createVecBatch(sourceTypes, sourceDatas1); + + Object[][] sourceDatas2 = {{15, 13, null, 16, 11, 14, 17, 18}, + {15.0, null, 12.0, 16.0, 11.0, 14.0, 17.0, 18.0}}; + VecBatch vecBatch2 = createVecBatch(sourceTypes, sourceDatas2); + + OmniUnionOperatorFactory unionOperatorFactory = new OmniUnionOperatorFactory(sourceTypes, false); + OmniOperator unionOperator = unionOperatorFactory.createOperator(); + unionOperator.addInput(vecBatch1); + unionOperator.addInput(vecBatch2); + Iterator results = unionOperator.getOutput(); + + List resultList = new ArrayList<>(); + while (results.hasNext()) { + resultList.add(results.next()); + } + + Object[][] expectedDatas1 = {{null, 3, 2, 6, 1, 4, 7, 8}, {5.0, 3.0, 2.0, 6.0, 1.0, 4.0, null, 8.0}}; + Object[][] expectedDatas2 = {{15, 13, null, 16, 11, 14, 17, 18}, + {15.0, null, 12.0, 16.0, 11.0, 14.0, 17.0, 18.0}}; + + assertEquals(resultList.size(), 2); + assertVecBatchEquals(resultList.get(0), expectedDatas1); + assertVecBatchEquals(resultList.get(1), expectedDatas2); + + for (int i = 0; i < resultList.size(); i++) { + freeVecBatch(resultList.get(i)); + } + unionOperator.close(); + unionOperatorFactory.close(); + } + + /** + * Test the correctness of Omni union operator when the data has dictionary + * type. + */ + @Test + public void testUnionWithDictionaryType() { + DataType[] sourceTypes = {LongDataType.LONG, LongDataType.LONG}; + int[] ids = {0, 1, 2, 3}; + + Object[][] sourceDatas1 = {{1L, null, 3L, null}, {111L, 11L, 333L, 33L}}; + Vec[] vecs1 = new Vec[2]; + vecs1[0] = createLongVec(sourceDatas1[0]); + vecs1[1] = createDictionaryVec(sourceTypes[1], sourceDatas1[1], ids); + VecBatch vecBatch1 = new VecBatch(vecs1); + + Object[][] sourceDatas2 = {{null, 2L, null, 4L}, {11L, 22L, 33L, 44L}}; + Vec[] vecs2 = new Vec[2]; + vecs2[0] = createLongVec(sourceDatas2[0]); + vecs2[1] = createDictionaryVec(sourceTypes[1], sourceDatas2[1], ids); + VecBatch vecBatch2 = new VecBatch(vecs2); + + OmniUnionOperatorFactory unionOperatorFactory = new OmniUnionOperatorFactory(sourceTypes, false); + OmniOperator unionOperator = unionOperatorFactory.createOperator(); + unionOperator.addInput(vecBatch1); + unionOperator.addInput(vecBatch2); + Iterator results = unionOperator.getOutput(); + List resultList = new ArrayList<>(); + while (results.hasNext()) { + resultList.add(results.next()); + } + + Object[][] expectedDatas1 = {{1L, null, 3L, null}, {111L, 11L, 333L, 33L}}; + Object[][] expectedDatas2 = {{null, 2L, null, 4L}, {11L, 22L, 33L, 44L}}; + + assertEquals(resultList.size(), 2); + assertVecBatchEquals(resultList.get(0), expectedDatas1); + assertVecBatchEquals(resultList.get(1), expectedDatas2); + + for (int i = 0; i < resultList.size(); i++) { + freeVecBatch(resultList.get(i)); + } + unionOperator.close(); + unionOperatorFactory.close(); + } + + @Test + public void testFactoryContextEquals() { + DataType[] sourceTypes = {LongDataType.LONG, LongDataType.LONG}; + FactoryContext factory1 = new FactoryContext(sourceTypes, false, new OperatorConfig()); + FactoryContext factory2 = new FactoryContext(sourceTypes, false, new OperatorConfig()); + FactoryContext factory3 = null; + assertEquals(factory2, factory1); + assertEquals(factory1, factory1); + assertNotEquals(factory3, factory1); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniWindowGroupLimitWithExprOperatorTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniWindowGroupLimitWithExprOperatorTest.java new file mode 100644 index 0000000..5c3f193 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniWindowGroupLimitWithExprOperatorTest.java @@ -0,0 +1,178 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ + +package nova.hetu.omniruntime.operator; + +import static nova.hetu.omniruntime.util.TestUtils.assertVecBatchEquals; +import static nova.hetu.omniruntime.util.TestUtils.createVecBatch; +import static nova.hetu.omniruntime.util.TestUtils.freeVecBatch; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonFieldReference; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; + +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.operator.window.OmniWindowGroupLimitWithExprOperatorFactory; +import nova.hetu.omniruntime.operator.window.OmniWindowGroupLimitWithExprOperatorFactory.FactoryContext; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.type.VarcharDataType; +import nova.hetu.omniruntime.vector.VecBatch; + +import org.testng.annotations.Test; + +import java.util.Iterator; + +/** + * The omni WindowGroupLimit with expression operator test. + * + * @since 2025-1-23 + */ +public class OmniWindowGroupLimitWithExprOperatorTest { + // rank + Desc + NullLast + @Test + public void testWindowGroupLimitRankDescNullLast() { + DataType[] sourceTypes = {new VarcharDataType(10), LongDataType.LONG, LongDataType.LONG}; + Object[][] sourceDatas = {{"hi", "hi", "hi", "bye", "bye", "bye", "bye", "bye"}, + {2L, 5L, 3L, 11L, 4L, 3L, 0L, 23L}, {3L, 5L, 8L, 3L, 5L, 3L, 4L, 3L}}; + VecBatch vecBatch = createVecBatch(sourceTypes, sourceDatas); + + String[] partitionKeys = {getOmniJsonFieldReference(15, 0)}; + String[] sortKeys = {getOmniJsonFieldReference(2, 2)}; + int[] ascendings = {0}; + int[] nullFirsts = {0}; + OmniWindowGroupLimitWithExprOperatorFactory windowGroupLimitOperatorFactory = + new OmniWindowGroupLimitWithExprOperatorFactory(sourceTypes, 3, "rank", partitionKeys, + sortKeys, ascendings, nullFirsts); + OmniOperator windowGroupLimitOperator = windowGroupLimitOperatorFactory.createOperator(); + windowGroupLimitOperator.addInput(vecBatch); + Iterator results = windowGroupLimitOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + assertEquals(resultVecBatch.getRowCount(), sourceDatas[0].length); + Object[][] expectedDatas = {{"bye", "bye", "bye", "bye", "bye", "hi", "hi", "hi"}, + {4L, 0L, 11L, 3L, 23L, 3L, 5L, 2L}, {5L, 4L, 3L, 3L, 3L, 8L, 5L, 3L}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + + freeVecBatch(resultVecBatch); + windowGroupLimitOperator.close(); + windowGroupLimitOperatorFactory.close(); + } + + // rank + Asc + NullLast + @Test + public void testWindowGroupLimitRankAscNullLast() { + DataType[] sourceTypes = {new VarcharDataType(10), LongDataType.LONG, LongDataType.LONG}; + Object[][] sourceDatas = {{"hi", "hi", "hi", "bye", "bye", "bye", "bye", "bye"}, + {2L, 5L, 3L, 11L, 4L, 3L, 0L, 23L}, {5L, 3L, 8L, 3L, 6L, 6L, 4L, 6L}}; + VecBatch vecBatch = createVecBatch(sourceTypes, sourceDatas); + + String[] partitionKeys = {getOmniJsonFieldReference(15, 0)}; + String[] sortKeys = {getOmniJsonFieldReference(2, 2)}; + int[] ascendings = {1}; + int[] nullFirsts = {0}; + OmniWindowGroupLimitWithExprOperatorFactory windowGroupLimitOperatorFactory = + new OmniWindowGroupLimitWithExprOperatorFactory(sourceTypes, 3, "rank", partitionKeys, + sortKeys, ascendings, nullFirsts); + OmniOperator windowGroupLimitOperator = windowGroupLimitOperatorFactory.createOperator(); + windowGroupLimitOperator.addInput(vecBatch); + Iterator results = windowGroupLimitOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + assertEquals(resultVecBatch.getRowCount(), sourceDatas[0].length); + Object[][] expectedDatas = {{"bye", "bye", "bye", "bye", "bye", "hi", "hi", "hi"}, + {11L, 0L, 4L, 3L, 23L, 5L, 2L, 3L}, {3L, 4L, 6L, 6L, 6L, 3L, 5L, 8L}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + + freeVecBatch(resultVecBatch); + windowGroupLimitOperator.close(); + windowGroupLimitOperatorFactory.close(); + } + + // Row_number + Desc + NullLast + @Test + public void testWindowGroupLimitRowNumberDescNullLast() { + DataType[] sourceTypes = {new VarcharDataType(10), LongDataType.LONG, LongDataType.LONG}; + Object[][] sourceDatas = {{"hi", "hi", "hi", "bye", "bye", "bye", "bye", "bye"}, + {2L, 5L, 3L, 11L, 4L, 3L, 0L, 23L}, {3L, 5L, 8L, 3L, 5L, 3L, 4L, 3L}}; + VecBatch vecBatch = createVecBatch(sourceTypes, sourceDatas); + + String[] partitionKeys = {getOmniJsonFieldReference(15, 0)}; + String[] sortKeys = {getOmniJsonFieldReference(2, 2)}; + int[] ascendings = {0}; + int[] nullFirsts = {0}; + OmniWindowGroupLimitWithExprOperatorFactory windowGroupLimitOperatorFactory = + new OmniWindowGroupLimitWithExprOperatorFactory(sourceTypes, 3, "row_number", partitionKeys, + sortKeys, ascendings, nullFirsts); + OmniOperator windowGroupLimitOperator = windowGroupLimitOperatorFactory.createOperator(); + windowGroupLimitOperator.addInput(vecBatch); + Iterator results = windowGroupLimitOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + assertEquals(resultVecBatch.getRowCount(), 6); + Object[][] expectedDatas = {{"bye", "bye", "bye", "hi", "hi", "hi"}, {4L, 0L, 11L, 3L, 5L, 2L}, + {5L, 4L, 3L, 8L, 5L, 3L}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + + freeVecBatch(resultVecBatch); + windowGroupLimitOperator.close(); + windowGroupLimitOperatorFactory.close(); + } + + // Row_number + Asc + NullLast + @Test + public void testWindowGroupLimitRowNumberAscNullLast() { + DataType[] sourceTypes = {new VarcharDataType(10), LongDataType.LONG, LongDataType.LONG}; + Object[][] sourceDatas = {{"hi", "hi", "hi", "bye", "bye", "bye", "bye", "bye"}, + {2L, 5L, 3L, 11L, 4L, 3L, 0L, 23L}, {5L, 3L, 8L, 3L, 6L, 6L, 4L, 6L}}; + VecBatch vecBatch = createVecBatch(sourceTypes, sourceDatas); + + String[] partitionKeys = {getOmniJsonFieldReference(15, 0)}; + String[] sortKeys = {getOmniJsonFieldReference(2, 2)}; + int[] ascendings = {1}; + int[] nullFirsts = {0}; + OmniWindowGroupLimitWithExprOperatorFactory windowGroupLimitOperatorFactory = new + OmniWindowGroupLimitWithExprOperatorFactory(sourceTypes, 3, "row_number", partitionKeys, sortKeys, + ascendings, nullFirsts); + OmniOperator windowGroupLimitOperator = windowGroupLimitOperatorFactory.createOperator(); + windowGroupLimitOperator.addInput(vecBatch); + Iterator results = windowGroupLimitOperator.getOutput(); + + VecBatch resultVecBatch = results.next(); + assertEquals(resultVecBatch.getRowCount(), 6); + Object[][] expectedDatas = {{"bye", "bye", "bye", "hi", "hi", "hi"}, {11L, 0L, 4L, 5L, 2L, 3L}, + {3L, 4L, 6L, 3L, 5L, 8L}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + + freeVecBatch(resultVecBatch); + windowGroupLimitOperator.close(); + windowGroupLimitOperatorFactory.close(); + } + + @Test + public void testFactoryContextEquals() { + DataType[] sourceTypes = {IntDataType.INTEGER, LongDataType.LONG}; + String[] partitionKeys = {getOmniJsonFieldReference(1, 0)}; + String[] sortKeys = {getOmniJsonFieldReference(2, 1)}; + int[] sortAscendings = {1}; + int[] sortNullFirsts = {0}; + OperatorConfig operatorConfig = new OperatorConfig(); + FactoryContext factory1 = new FactoryContext(sourceTypes, 10, "rank", partitionKeys, sortKeys, sortAscendings, + sortNullFirsts, operatorConfig); + FactoryContext factory2 = new FactoryContext(sourceTypes, 10, "rank", partitionKeys, sortKeys, sortAscendings, + sortNullFirsts, operatorConfig); + FactoryContext factory3 = null; + assertEquals(factory2, factory1); + assertEquals(factory1, factory1); + assertNotEquals(factory3, factory1); + FactoryContext factory4 = new FactoryContext(sourceTypes, 10, "row_number", partitionKeys, sortKeys, + sortAscendings, sortNullFirsts, operatorConfig); + FactoryContext factory5 = new FactoryContext(sourceTypes, 10, "row_number", partitionKeys, sortKeys, + sortAscendings, sortNullFirsts, operatorConfig); + assertEquals(factory4, factory5); + assertEquals(factory4, factory4); + assertNotEquals(factory4, factory3); + assertNotEquals(factory4, factory2); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniWindowOperatorTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniWindowOperatorTest.java new file mode 100644 index 0000000..e171643 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniWindowOperatorTest.java @@ -0,0 +1,315 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator; + +import static nova.hetu.omniruntime.constants.OmniWindowFrameBoundType.OMNI_FRAME_BOUND_CURRENT_ROW; +import static nova.hetu.omniruntime.constants.OmniWindowFrameBoundType.OMNI_FRAME_BOUND_UNBOUNDED_PRECEDING; +import static nova.hetu.omniruntime.constants.OmniWindowFrameType.OMNI_FRAME_TYPE_RANGE; +import static nova.hetu.omniruntime.util.TestUtils.assertVecBatchEquals; +import static nova.hetu.omniruntime.util.TestUtils.freeVecBatch; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; + +import com.google.common.collect.ImmutableList; + +import nova.hetu.omniruntime.constants.FunctionType; +import nova.hetu.omniruntime.constants.OmniWindowFrameBoundType; +import nova.hetu.omniruntime.constants.OmniWindowFrameType; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.operator.window.OmniWindowOperatorFactory; +import nova.hetu.omniruntime.operator.window.OmniWindowOperatorFactory.FactoryContext; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.vector.LongVec; +import nova.hetu.omniruntime.vector.Vec; +import nova.hetu.omniruntime.vector.VecBatch; + +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +/** + * The type Omni window operator test. + * + * @since 2021-6-4 + */ +public class OmniWindowOperatorTest { + /** + * The Total page count. + */ + int totalPageCount = 1; + + /** + * The Page distinct count. + */ + int pageDistinctCount = 4; + + /** + * The Page distinct value repeat count. + */ + int pageDistinctValueRepeatCount = 5000; + + /** + * test window performance whether with jit or not. + */ + @Test + public void testWindowComparePref() { + DataType[] sourceTypes = {LongDataType.LONG, LongDataType.LONG}; + int[] outputChannels = {0, 1}; + FunctionType[] windowFunction = {FunctionType.OMNI_AGGREGATION_TYPE_COUNT_COLUMN, + FunctionType.OMNI_AGGREGATION_TYPE_COUNT_ALL}; + OmniWindowFrameType[] windowFrameTypes = {OMNI_FRAME_TYPE_RANGE, OMNI_FRAME_TYPE_RANGE, OMNI_FRAME_TYPE_RANGE}; + OmniWindowFrameBoundType[] windowFrameStartTypes = {OMNI_FRAME_BOUND_UNBOUNDED_PRECEDING, + OMNI_FRAME_BOUND_UNBOUNDED_PRECEDING, OMNI_FRAME_BOUND_UNBOUNDED_PRECEDING}; + int[] winddowFrameStartChannels = {-1, -1, -1}; + OmniWindowFrameBoundType[] windowFrameEndTypes = {OMNI_FRAME_BOUND_CURRENT_ROW, OMNI_FRAME_BOUND_CURRENT_ROW, + OMNI_FRAME_BOUND_CURRENT_ROW}; + int[] winddowFrameEndChannels = {-1, -1, -1}; + int[] partitionChannels = {0}; + int[] preGroupedChannels = {}; + int[] sortChannels = {1}; + int[] sortOrder = {1}; + int[] sortNullFirsts = {0}; + int preSortedChannelPrefix = 0; + int expectedPositions = 10000; + int[] argumentChannels = {1, -1}; + DataType[] windowFunctionReturnType = {LongDataType.LONG, LongDataType.LONG}; + + OmniWindowOperatorFactory windowOperatorFactoryWithoutJit = new OmniWindowOperatorFactory(sourceTypes, + outputChannels, windowFunction, partitionChannels, preGroupedChannels, sortChannels, sortOrder, + sortNullFirsts, preSortedChannelPrefix, expectedPositions, argumentChannels, windowFunctionReturnType, + windowFrameTypes, windowFrameStartTypes, winddowFrameStartChannels, windowFrameEndTypes, + winddowFrameEndChannels, new OperatorConfig()); + OmniOperator windowOperatorWithoutJit = windowOperatorFactoryWithoutJit.createOperator(); + ImmutableList vecsWithoutJit = buildVecs(); + + long start = System.currentTimeMillis(); + for (VecBatch vec : vecsWithoutJit) { + windowOperatorWithoutJit.addInput(vec); + } + Iterator outputWithoutJit = windowOperatorWithoutJit.getOutput(); + long end = System.currentTimeMillis(); + System.out.println("Window without jit use " + (end - start) + " ms."); + + OmniWindowOperatorFactory windowOperatorFactoryWithJit = new OmniWindowOperatorFactory(sourceTypes, + outputChannels, windowFunction, partitionChannels, preGroupedChannels, sortChannels, sortOrder, + sortNullFirsts, preSortedChannelPrefix, expectedPositions, argumentChannels, windowFunctionReturnType, + windowFrameTypes, windowFrameStartTypes, winddowFrameStartChannels, windowFrameEndTypes, + winddowFrameEndChannels, new OperatorConfig()); + OmniOperator windowOperatorWithJit = windowOperatorFactoryWithJit.createOperator(); + ImmutableList vecsWithJit = buildVecs(); + + start = System.currentTimeMillis(); + for (VecBatch vec : vecsWithJit) { + windowOperatorWithJit.addInput(vec); + } + Iterator outputWithJit = windowOperatorWithJit.getOutput(); + end = System.currentTimeMillis(); + System.out.println("Window with jit use " + (end - start) + " ms."); + + while (outputWithoutJit.hasNext() && outputWithJit.hasNext()) { + VecBatch resultWithoutJit = outputWithoutJit.next(); + VecBatch resultWithJit = outputWithJit.next(); + assertVecBatchEquals(resultWithoutJit, resultWithJit); + freeVecBatch(resultWithoutJit); + freeVecBatch(resultWithJit); + } + + windowOperatorWithoutJit.close(); + windowOperatorWithJit.close(); + windowOperatorFactoryWithoutJit.close(); + windowOperatorFactoryWithJit.close(); + } + + /** + * Test rank. + */ + @Test + public void testRank() { + DataType[] sourceTypes = {LongDataType.LONG, LongDataType.LONG}; + int[] outputChannels = {0, 1}; + FunctionType[] windowFunction = {FunctionType.OMNI_WINDOW_TYPE_RANK}; + OmniWindowFrameType[] windowFrameTypes = {OMNI_FRAME_TYPE_RANGE}; + OmniWindowFrameBoundType[] windowFrameStartTypes = {OMNI_FRAME_BOUND_UNBOUNDED_PRECEDING}; + int[] winddowFrameStartChannels = {-1}; + OmniWindowFrameBoundType[] windowFrameEndTypes = {OMNI_FRAME_BOUND_CURRENT_ROW}; + int[] winddowFrameEndChannels = {-1}; + int[] partitionChannels = {0}; + int[] preGroupedChannels = {}; + int[] sortChannels = {1}; + int[] sortOrder = {1}; + int[] sortNullFirsts = {0}; + int preSortedChannelPrefix = 0; + int expectedPositions = 10000; + int[] argumentChannels = {}; + DataType[] windowFunctionReturnType = {LongDataType.LONG}; + OmniWindowOperatorFactory omniWindowOperatorFactory = new OmniWindowOperatorFactory(sourceTypes, outputChannels, + windowFunction, partitionChannels, preGroupedChannels, sortChannels, sortOrder, sortNullFirsts, + preSortedChannelPrefix, expectedPositions, argumentChannels, windowFunctionReturnType, windowFrameTypes, + windowFrameStartTypes, winddowFrameStartChannels, windowFrameEndTypes, winddowFrameEndChannels); + OmniOperator omniOperator = omniWindowOperatorFactory.createOperator(); + + VecBatch vecBatch = buildData(); + + omniOperator.addInput(vecBatch); + Iterator output = omniOperator.getOutput(); + if (output.hasNext()) { + VecBatch outputVecBatch = output.next(); + Vec[] vectors = outputVecBatch.getVectors(); + assertEquals(((LongVec) vectors[0]).get(0), 1); + assertEquals(((LongVec) vectors[0]).get(1), 1); + assertEquals(((LongVec) vectors[0]).get(2), 1); + assertEquals(((LongVec) vectors[0]).get(3), 2); + assertEquals(((LongVec) vectors[0]).get(4), 2); + assertEquals(((LongVec) vectors[1]).get(0), 2); + assertEquals(((LongVec) vectors[1]).get(1), 4); + assertEquals(((LongVec) vectors[1]).get(2), 6); + assertEquals(((LongVec) vectors[1]).get(3), -1); + assertEquals(((LongVec) vectors[1]).get(4), 5); + assertEquals(((LongVec) vectors[2]).get(0), 1); + assertEquals(((LongVec) vectors[2]).get(1), 2); + assertEquals(((LongVec) vectors[2]).get(2), 3); + assertEquals(((LongVec) vectors[2]).get(3), 1); + assertEquals(((LongVec) vectors[2]).get(4), 2); + freeVecBatch(outputVecBatch); + } + + omniOperator.close(); + omniWindowOperatorFactory.close(); + } + + /** + * Test count. + */ + @Test + public void testCount() { + DataType[] sourceTypes = {LongDataType.LONG, LongDataType.LONG}; + int[] outputChannels = {0, 1}; + FunctionType[] windowFunction = {FunctionType.OMNI_AGGREGATION_TYPE_COUNT_COLUMN, + FunctionType.OMNI_AGGREGATION_TYPE_COUNT_ALL}; + OmniWindowFrameType[] windowFrameTypes = {OMNI_FRAME_TYPE_RANGE, OMNI_FRAME_TYPE_RANGE}; + OmniWindowFrameBoundType[] windowFrameStartTypes = {OMNI_FRAME_BOUND_UNBOUNDED_PRECEDING, + OMNI_FRAME_BOUND_UNBOUNDED_PRECEDING}; + int[] winddowFrameStartChannels = {-1, -1}; + OmniWindowFrameBoundType[] windowFrameEndTypes = {OMNI_FRAME_BOUND_CURRENT_ROW, OMNI_FRAME_BOUND_CURRENT_ROW}; + int[] winddowFrameEndChannels = {-1, -1}; + int[] partitionChannels = {0}; + int[] preGroupedChannels = {}; + int[] sortChannels = {1}; + int[] sortOrder = {1}; + int[] sortNullFirsts = {0}; + int preSortedChannelPrefix = 0; + int expectedPositions = 10000; + int[] argumentChannels = {1, -1}; + DataType[] windowFunctionReturnType = {LongDataType.LONG, LongDataType.LONG}; + OmniWindowOperatorFactory omniWindowOperatorFactory = new OmniWindowOperatorFactory(sourceTypes, outputChannels, + windowFunction, partitionChannels, preGroupedChannels, sortChannels, sortOrder, sortNullFirsts, + preSortedChannelPrefix, expectedPositions, argumentChannels, windowFunctionReturnType, windowFrameTypes, + windowFrameStartTypes, winddowFrameStartChannels, windowFrameEndTypes, winddowFrameEndChannels); + OmniOperator omniOperator = omniWindowOperatorFactory.createOperator(); + + VecBatch vecBatch = buildData(); + vecBatch.getVectors()[1].setNull(2); + + omniOperator.addInput(vecBatch); + Iterator output = omniOperator.getOutput(); + if (output.hasNext()) { + VecBatch outputVecBatch = output.next(); + Object[][] expectedData = {{1L, 1L, 1L, 2L, 2L}, {2L, 6L, null, -1L, 5L}, {1L, 2L, 2L, 1L, 2L}, + {1L, 2L, 3L, 1L, 2L}}; + + assertEquals(outputVecBatch.getRowCount(), 5); + assertEquals(outputVecBatch.getVectors().length, 4); + assertVecBatchEquals(outputVecBatch, expectedData); + + freeVecBatch(outputVecBatch); + } + + omniOperator.close(); + omniWindowOperatorFactory.close(); + } + + @Test + public void testFactoryContextEquals() { + FunctionType[] windowFunction = {FunctionType.OMNI_AGGREGATION_TYPE_COUNT_COLUMN, + FunctionType.OMNI_AGGREGATION_TYPE_COUNT_ALL}; + OmniWindowFrameType[] windowFrameTypes = {OMNI_FRAME_TYPE_RANGE, OMNI_FRAME_TYPE_RANGE}; + OmniWindowFrameBoundType[] windowFrameStartTypes = {OMNI_FRAME_BOUND_UNBOUNDED_PRECEDING, + OMNI_FRAME_BOUND_UNBOUNDED_PRECEDING}; + int[] winddowFrameStartChannels = {-1, -1}; + OmniWindowFrameBoundType[] windowFrameEndTypes = {OMNI_FRAME_BOUND_CURRENT_ROW, OMNI_FRAME_BOUND_CURRENT_ROW}; + int[] winddowFrameEndChannels = {-1, -1}; + int[] partitionChannels = {0}; + int[] preGroupedChannels = {}; + int[] sortChannels = {1}; + int[] sortOrder = {1}; + int[] sortNullFirsts = {0}; + int preSortedChannelPrefix = 0; + int expectedPositions = 10000; + int[] argumentChannels = {1, -1}; + DataType[] windowFunctionReturnType = {LongDataType.LONG, LongDataType.LONG}; + DataType[] sourceTypes = {LongDataType.LONG, LongDataType.LONG}; + int[] outputChannels = {0, 1}; + FactoryContext factory1 = new FactoryContext(sourceTypes, outputChannels, windowFunction, partitionChannels, + preGroupedChannels, sortChannels, sortOrder, sortNullFirsts, preSortedChannelPrefix, expectedPositions, + argumentChannels, windowFunctionReturnType, windowFrameTypes, windowFrameStartTypes, + winddowFrameStartChannels, windowFrameEndTypes, winddowFrameEndChannels, new OperatorConfig()); + FactoryContext factory2 = new FactoryContext(sourceTypes, outputChannels, windowFunction, partitionChannels, + preGroupedChannels, sortChannels, sortOrder, sortNullFirsts, preSortedChannelPrefix, expectedPositions, + argumentChannels, windowFunctionReturnType, windowFrameTypes, windowFrameStartTypes, + winddowFrameStartChannels, windowFrameEndTypes, winddowFrameEndChannels, new OperatorConfig()); + FactoryContext factory3 = null; + assertEquals(factory2, factory1); + assertEquals(factory1, factory1); + assertNotEquals(factory3, factory1); + } + + private VecBatch buildData() { + int rowNum = 5; + LongVec longVec1 = new LongVec(rowNum); + LongVec longVec2 = new LongVec(rowNum); + longVec1.set(0, 2); + longVec1.set(1, 1); + longVec1.set(2, 1); + longVec1.set(3, 2); + longVec1.set(4, 1); + longVec2.set(0, -1); + longVec2.set(1, 2); + longVec2.set(2, 4); + longVec2.set(3, 5); + longVec2.set(4, 6); + + List columns = new ArrayList<>(); + columns.add(longVec1); + columns.add(longVec2); + return new VecBatch(columns, rowNum); + } + + private ImmutableList buildVecs() { + ImmutableList.Builder vecBatchList = ImmutableList.builder(); + int positionCount = pageDistinctCount * pageDistinctValueRepeatCount; + List vecs = new ArrayList<>(); + for (int i = 0; i < totalPageCount; i++) { + LongVec longVec1 = new LongVec(positionCount); + LongVec longVec2 = new LongVec(positionCount); + int idx = 0; + for (int j = 0; j < pageDistinctCount; j++) { + for (int k = 0; k < pageDistinctValueRepeatCount; k++) { + longVec1.set(idx, j); + longVec2.set(idx, j); + idx++; + } + } + vecs.add(longVec1); + vecs.add(longVec2); + VecBatch vecBatch = new VecBatch(new Vec[]{longVec1, longVec2}); + vecBatchList.add(vecBatch); + } + return vecBatchList.build(); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniWindowWithExprOperatorTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniWindowWithExprOperatorTest.java new file mode 100644 index 0000000..ba7b00e --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/operator/OmniWindowWithExprOperatorTest.java @@ -0,0 +1,333 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.operator; + +import static nova.hetu.omniruntime.constants.OmniWindowFrameBoundType.OMNI_FRAME_BOUND_CURRENT_ROW; +import static nova.hetu.omniruntime.constants.OmniWindowFrameBoundType.OMNI_FRAME_BOUND_UNBOUNDED_FOLLOWING; +import static nova.hetu.omniruntime.constants.OmniWindowFrameBoundType.OMNI_FRAME_BOUND_UNBOUNDED_PRECEDING; +import static nova.hetu.omniruntime.constants.OmniWindowFrameType.OMNI_FRAME_TYPE_RANGE; +import static nova.hetu.omniruntime.constants.OmniWindowFrameType.OMNI_FRAME_TYPE_ROWS; +import static nova.hetu.omniruntime.util.TestUtils.assertVecBatchEquals; +import static nova.hetu.omniruntime.util.TestUtils.freeVecBatch; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonFieldReference; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonLiteral; +import static nova.hetu.omniruntime.util.TestUtils.omniFunctionExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonFourArithmeticExpr; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; + +import nova.hetu.omniruntime.constants.FunctionType; +import nova.hetu.omniruntime.constants.OmniWindowFrameBoundType; +import nova.hetu.omniruntime.constants.OmniWindowFrameType; +import nova.hetu.omniruntime.operator.config.OperatorConfig; +import nova.hetu.omniruntime.operator.window.OmniWindowWithExprOperatorFactory; +import nova.hetu.omniruntime.operator.window.OmniWindowWithExprOperatorFactory.FactoryContext; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DoubleDataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.utils.OmniRuntimeException; +import nova.hetu.omniruntime.vector.DoubleVec; +import nova.hetu.omniruntime.vector.IntVec; +import nova.hetu.omniruntime.vector.LongVec; +import nova.hetu.omniruntime.vector.Vec; +import nova.hetu.omniruntime.vector.VecBatch; + +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +/** + * The type Omni window operator test. + * + * @since 2021-11-3 + */ +public class OmniWindowWithExprOperatorTest { + /** + * Test iterative output + */ + @Test + public void testOutputMultiVecBatch() { + DataType[] sourceTypes = {IntDataType.INTEGER, LongDataType.LONG, DoubleDataType.DOUBLE}; + int[] outputChannels = {0, 1, 2}; + FunctionType[] windowFunction = {FunctionType.OMNI_AGGREGATION_TYPE_MAX}; + OmniWindowFrameType[] windowFrameTypes = {OMNI_FRAME_TYPE_RANGE}; + OmniWindowFrameBoundType[] windowFrameStartTypes = {OMNI_FRAME_BOUND_UNBOUNDED_PRECEDING}; + int[] winddowFrameStartChannels = {-1}; + OmniWindowFrameBoundType[] windowFrameEndTypes = {OMNI_FRAME_BOUND_CURRENT_ROW}; + int[] winddowFrameEndChannels = {-1}; + int[] partitionChannels = {0}; + int[] preGroupedChannels = {}; + int[] sortChannels = {1}; + int[] sortOrder = {0}; + int[] sortNullFirsts = {0}; + int preSortedChannelPrefix = 0; + String[] argumentKeys = {omniJsonFourArithmeticExpr("ADD", 3, getOmniJsonFieldReference(3, 2), + getOmniJsonLiteral(3, false, 50.0))}; + DataType[] windowFunctionReturnType = {DoubleDataType.DOUBLE}; + OmniWindowWithExprOperatorFactory omniWindowOperatorFactory = new OmniWindowWithExprOperatorFactory(sourceTypes, + outputChannels, windowFunction, partitionChannels, preGroupedChannels, sortChannels, sortOrder, + sortNullFirsts, preSortedChannelPrefix, 10000, argumentKeys, windowFunctionReturnType, windowFrameTypes, + windowFrameStartTypes, winddowFrameStartChannels, windowFrameEndTypes, winddowFrameEndChannels); + OmniOperator omniOperator = omniWindowOperatorFactory.createOperator(); + + int column = 4; + int rowNum = 30000; + VecBatch vecBatch = new VecBatch(buildDataForOutputMultiVectorBatch(rowNum)); + omniOperator.addInput(vecBatch); + + // the value rowsPerBatch = (1M / 36) + 1 + int rowsPerBatch = 30000; + Object[][] expectedData1 = new Object[column][rowsPerBatch]; + Object[][] expectedData2 = new Object[column][rowNum - rowsPerBatch]; + buildIterativeExpectedData(expectedData1, expectedData2, rowsPerBatch, rowNum); + Iterator outputVecBatch = omniOperator.getOutput(); + List resultList = new ArrayList<>(); + while (outputVecBatch.hasNext()) { + resultList.add(outputVecBatch.next()); + } + + int totalRowCount = 0; + for (int i = 0; i < resultList.size(); i++) { + totalRowCount += resultList.get(i).getRowCount(); + } + + assertEquals(totalRowCount, rowNum); + assertVecBatchEquals(resultList.get(0), expectedData1); + + for (int i = 0; i < resultList.size(); i++) { + freeVecBatch(resultList.get(i)); + } + + omniOperator.close(); + omniWindowOperatorFactory.close(); + } + + /** + * Test max. + */ + @Test + public void testMax() { + DataType[] sourceTypes = {IntDataType.INTEGER, LongDataType.LONG, DoubleDataType.DOUBLE}; + int[] outputChannels = {0, 1, 2}; + FunctionType[] windowFunction = {FunctionType.OMNI_AGGREGATION_TYPE_MAX}; + OmniWindowFrameType[] windowFrameTypes = {OMNI_FRAME_TYPE_RANGE}; + OmniWindowFrameBoundType[] windowFrameStartTypes = {OMNI_FRAME_BOUND_UNBOUNDED_PRECEDING}; + int[] winddowFrameStartChannels = {-1}; + OmniWindowFrameBoundType[] windowFrameEndTypes = {OMNI_FRAME_BOUND_CURRENT_ROW}; + int[] winddowFrameEndChannels = {-1}; + int[] partitionChannels = {0}; + int[] preGroupedChannels = {}; + int[] sortChannels = {1}; + int[] sortOrder = {0}; + int[] sortNullFirsts = {0}; + int preSortedChannelPrefix = 0; + String[] argumentKeys = {omniJsonFourArithmeticExpr("ADD", 3, getOmniJsonFieldReference(3, 2), + getOmniJsonLiteral(3, false, 50))}; + DataType[] windowFunctionReturnType = {DoubleDataType.DOUBLE}; + OmniWindowWithExprOperatorFactory omniWindowOperatorFactory = new OmniWindowWithExprOperatorFactory(sourceTypes, + outputChannels, windowFunction, partitionChannels, preGroupedChannels, sortChannels, sortOrder, + sortNullFirsts, preSortedChannelPrefix, 10000, argumentKeys, windowFunctionReturnType, windowFrameTypes, + windowFrameStartTypes, winddowFrameStartChannels, windowFrameEndTypes, winddowFrameEndChannels); + OmniOperator omniOperator = omniWindowOperatorFactory.createOperator(); + + VecBatch vecBatch = buildData(); + + omniOperator.addInput(vecBatch); + Iterator output = omniOperator.getOutput(); + VecBatch outputVecBatch = output.next(); + Object[][] expectedDatas = {{0, 0, 1, 1, 2, 2}, {8L, 8L, 4L, 1L, 5L, 2L}, {6.6D, 3.3D, 2.2D, 5.5D, 1.1D, 4.4D}, + {56.6D, 56.6D, 52.2D, 55.5D, 51.1D, 54.4D}}; + assertVecBatchEquals(outputVecBatch, expectedDatas); + freeVecBatch(outputVecBatch); + + omniOperator.close(); + omniWindowOperatorFactory.close(); + } + + @Test(expectedExceptions = OmniRuntimeException.class, expectedExceptionsMessageRegExp = ".*EXPRESSION_NOT_SUPPORT.*") + public void testWindowWithInvalidKeys() { + DataType[] sourceTypes = {IntDataType.INTEGER, LongDataType.LONG, DoubleDataType.DOUBLE}; + int[] outputChannels = {0, 1, 2}; + FunctionType[] windowFunction = {FunctionType.OMNI_AGGREGATION_TYPE_MAX}; + OmniWindowFrameType[] windowFrameTypes = {OMNI_FRAME_TYPE_RANGE}; + OmniWindowFrameBoundType[] windowFrameStartTypes = {OMNI_FRAME_BOUND_UNBOUNDED_PRECEDING}; + int[] winddowFrameStartChannels = {-1}; + OmniWindowFrameBoundType[] windowFrameEndTypes = {OMNI_FRAME_BOUND_CURRENT_ROW}; + int[] winddowFrameEndChannels = {-1}; + int[] partitionChannels = {0}; + int[] preGroupedChannels = {}; + int[] sortChannels = {1}; + int[] sortOrder = {0}; + int[] sortNullFirsts = {0}; + int preSortedChannelPrefix = 0; + String[] argumentKeys = {omniFunctionExpr("abc", 3, getOmniJsonFieldReference(3, 2))}; + DataType[] windowFunctionReturnType = {DoubleDataType.DOUBLE}; + OmniWindowWithExprOperatorFactory omniWindowOperatorFactory = new OmniWindowWithExprOperatorFactory(sourceTypes, + outputChannels, windowFunction, partitionChannels, preGroupedChannels, sortChannels, sortOrder, + sortNullFirsts, preSortedChannelPrefix, 10000, argumentKeys, windowFunctionReturnType, windowFrameTypes, + windowFrameStartTypes, winddowFrameStartChannels, windowFrameEndTypes, winddowFrameEndChannels); + } + + @Test + public void testFactoryContextEquals() { + DataType[] sourceTypes = {IntDataType.INTEGER, LongDataType.LONG, DoubleDataType.DOUBLE}; + int[] outputChannels = {0, 1, 2}; + FunctionType[] windowFunction = {FunctionType.OMNI_AGGREGATION_TYPE_MAX}; + OmniWindowFrameType[] windowFrameTypes = {OMNI_FRAME_TYPE_RANGE}; + OmniWindowFrameBoundType[] windowFrameStartTypes = {OMNI_FRAME_BOUND_UNBOUNDED_PRECEDING}; + int[] winddowFrameStartChannels = {-1}; + OmniWindowFrameBoundType[] windowFrameEndTypes = {OMNI_FRAME_BOUND_CURRENT_ROW}; + int[] winddowFrameEndChannels = {-1}; + int[] partitionChannels = {0}; + int[] preGroupedChannels = {}; + int[] sortChannels = {1}; + int[] sortOrder = {0}; + int[] sortNullFirsts = {0}; + int preSortedChannelPrefix = 0; + String[] argumentKeys = {omniJsonFourArithmeticExpr("ADD", 3, getOmniJsonFieldReference(3, 2), + getOmniJsonLiteral(3, false, 50))}; + DataType[] windowFunctionReturnType = {DoubleDataType.DOUBLE}; + FactoryContext factory1 = new FactoryContext(sourceTypes, outputChannels, windowFunction, partitionChannels, + preGroupedChannels, sortChannels, sortOrder, sortNullFirsts, preSortedChannelPrefix, 10000, + argumentKeys, windowFunctionReturnType, windowFrameTypes, windowFrameStartTypes, + winddowFrameStartChannels, windowFrameEndTypes, winddowFrameEndChannels, new OperatorConfig()); + FactoryContext factory2 = new FactoryContext(sourceTypes, outputChannels, windowFunction, partitionChannels, + preGroupedChannels, sortChannels, sortOrder, sortNullFirsts, preSortedChannelPrefix, 10000, + argumentKeys, windowFunctionReturnType, windowFrameTypes, windowFrameStartTypes, + winddowFrameStartChannels, windowFrameEndTypes, winddowFrameEndChannels, new OperatorConfig()); + FactoryContext factory3 = null; + assertEquals(factory2, factory1); + assertEquals(factory1, factory1); + assertNotEquals(factory3, factory1); + } + + @Test + public void testWindowFunctionMix() { + DataType[] sourceTypes = {IntDataType.INTEGER, LongDataType.LONG, DoubleDataType.DOUBLE}; + int[] outputChannels = {0, 1, 2}; + FunctionType[] windowFunction = {FunctionType.OMNI_WINDOW_TYPE_RANK, FunctionType.OMNI_AGGREGATION_TYPE_AVG}; + OmniWindowFrameType[] windowFrameTypes = {OMNI_FRAME_TYPE_ROWS, OMNI_FRAME_TYPE_ROWS}; + OmniWindowFrameBoundType[] windowFrameStartTypes = {OMNI_FRAME_BOUND_UNBOUNDED_PRECEDING, + OMNI_FRAME_BOUND_UNBOUNDED_PRECEDING}; + int[] winddowFrameStartChannels = {-1, -1}; + OmniWindowFrameBoundType[] windowFrameEndTypes = {OMNI_FRAME_BOUND_UNBOUNDED_FOLLOWING, + OMNI_FRAME_BOUND_UNBOUNDED_FOLLOWING}; + int[] winddowFrameEndChannels = {-1, -1}; + int[] partitionChannels = {0}; + int[] preGroupedChannels = {}; + int[] sortChannels = {2}; + int[] sortOrder = {1}; + int[] sortNullFirsts = {0}; + int preSortedChannelPrefix = 0; + String[] argumentKeys = {omniFunctionExpr("abs", 2, getOmniJsonFieldReference(2, 1))}; + DataType[] windowFunctionReturnType = {IntDataType.INTEGER, DoubleDataType.DOUBLE}; + OmniWindowWithExprOperatorFactory omniWindowOperatorFactory = new OmniWindowWithExprOperatorFactory(sourceTypes, + outputChannels, windowFunction, partitionChannels, preGroupedChannels, sortChannels, sortOrder, + sortNullFirsts, preSortedChannelPrefix, 10000, argumentKeys, windowFunctionReturnType, windowFrameTypes, + windowFrameStartTypes, winddowFrameStartChannels, windowFrameEndTypes, winddowFrameEndChannels); + OmniOperator omniOperator = omniWindowOperatorFactory.createOperator(); + + VecBatch vecBatch = buildData(); + + omniOperator.addInput(vecBatch); + Iterator output = omniOperator.getOutput(); + VecBatch outputVecBatch = output.next(); + + Object[][] expectedDatas = {{0, 0, 1, 1, 2, 2}, {8L, 8L, 4L, 1L, 5L, 2L}, {3.3D, 6.6D, 2.2D, 5.5D, 1.1D, 4.4D}, + {1, 2, 1, 2, 1, 2}, {8.0D, 8.0D, 2.5D, 2.5D, 3.5D, 3.5D}}; + assertVecBatchEquals(outputVecBatch, expectedDatas); + freeVecBatch(outputVecBatch); + + omniOperator.close(); + omniWindowOperatorFactory.close(); + } + + private VecBatch buildData() { + int rowNum = 6; + IntVec vec1 = new IntVec(rowNum); + vec1.set(0, 0); + vec1.set(1, 1); + vec1.set(2, 2); + vec1.set(3, 0); + vec1.set(4, 1); + vec1.set(5, 2); + LongVec vec2 = new LongVec(rowNum); + vec2.set(0, 8); + vec2.set(1, 1); + vec2.set(2, 2); + vec2.set(3, 8); + vec2.set(4, 4); + vec2.set(5, 5); + DoubleVec vec3 = new DoubleVec(rowNum); + vec3.set(0, 6.6); + vec3.set(1, 5.5); + vec3.set(2, 4.4); + vec3.set(3, 3.3); + vec3.set(4, 2.2); + vec3.set(5, 1.1); + List columns = new ArrayList<>(); + columns.add(vec1); + columns.add(vec2); + columns.add(vec3); + return new VecBatch(columns); + } + + private List buildDataForOutputMultiVectorBatch(int rowNum) { + IntVec c1 = new IntVec(rowNum); + for (int i = 0; i < rowNum / 2; i++) { + c1.set(i, i); + } + + for (int i = rowNum / 2; i < rowNum; i++) { + c1.set(i, i - rowNum / 2); + } + + LongVec c2 = new LongVec(rowNum); + DoubleVec c3 = new DoubleVec(rowNum); + for (int i = 0; i < rowNum; i++) { + c2.set(i, i); + c3.set(i, i); + } + + List columns = new ArrayList<>(); + columns.add(c1); + columns.add(c2); + columns.add(c3); + + return columns; + } + + private void buildIterativeExpectedData(Object[][] expectedData1, Object[][] expectedData2, int maxRowCount, + int expectedRowSize) { + int offset1 = maxRowCount / 2; + int offset2 = expectedRowSize / 2; + for (int i = 0; i < offset1; i++) { + expectedData1[0][i * 2] = i; + expectedData1[0][i * 2 + 1] = i; + expectedData1[1][i * 2] = (long) (i + offset2); + expectedData1[1][i * 2 + 1] = (long) i; + expectedData1[2][i * 2] = (double) (i + offset2); + expectedData1[2][i * 2 + 1] = (double) i; + expectedData1[3][i * 2] = (double) (i + offset2 + 50); + expectedData1[3][i * 2 + 1] = (double) (i + offset2 + 50); + } + + int offset3 = offset1 + offset2; + int offset4 = (expectedRowSize - maxRowCount) / 2; + for (int i = 0; i < offset4; i++) { + expectedData2[0][i * 2] = i + offset1; + expectedData2[0][i * 2 + 1] = i + offset1; + expectedData2[1][i * 2] = (long) (i + offset3); + expectedData2[1][i * 2 + 1] = (long) (i + offset1); + expectedData2[2][i * 2] = (double) (i + offset3); + expectedData2[2][i * 2 + 1] = (double) (i + offset1); + expectedData2[3][i * 2] = (double) (i + 50 + offset3); + expectedData2[3][i * 2 + 1] = (double) (i + 50 + offset3); + } + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql10ForOmniFilterOperatorTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql10ForOmniFilterOperatorTest.java new file mode 100644 index 0000000..eeea83a --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql10ForOmniFilterOperatorTest.java @@ -0,0 +1,240 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.tensql; + +import static nova.hetu.omniruntime.util.TestUtils.filterOperatorMatch; +import static nova.hetu.omniruntime.util.TestUtils.filterOperatorMatchWithJson; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonFieldReference; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonLiteral; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonAndExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonEqualExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonGreaterThanOrEqualExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonIsNotNullExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonLessThanOrEqualExpr; + +import com.google.common.collect.ImmutableList; + +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.type.VarcharDataType; + +import org.testng.annotations.Test; + +import java.util.List; + +/** + * sql10 for OmniFilter operator test. + * + * @since 2022-03-31 + */ +public class Sql10ForOmniFilterOperatorTest { + private static final String MUST_TEST_EXP1 = "AND:4(AND:4(AND:4(AND:4(AND:4(IS_NOT_NULL:4(#0) , " + + "$operator$GREATER_THAN_OR_EQUAL:4(#0 , 2451484:2)) , $operator$LESS_THAN_OR_EQUAL:4(#0 , 2451513:2)) , " + + "IS_NOT_NULL:4(#1)) , IS_NOT_NULL:4(#2)) , IS_NOT_NULL:4(#3))"; + private static final String MUST_TEST_EXP2 = "AND:4(AND:4(AND:4(AND:4(AND:4(AND:4(IS_NOT_NULL:4(#4) , " + + "IS_NOT_NULL:4(#5)) , $operator$EQUAL:4(#4 , 11:1)) , $operator$EQUAL:4(#5 , 1999:1)) , " + + "$operator$GREATER_THAN_OR_EQUAL:4(#6 , 2451484:2)) , $operator$LESS_THAN_OR_EQUAL:4(#6 , 2451513:2)) , " + + "IS_NOT_NULL:4(#6))"; + private static final String MUST_TEST_EXP3 = "AND:4(AND:4(IS_NOT_NULL:4(#7) , $operator$EQUAL:4(#7 , 7:1)) , " + + "IS_NOT_NULL:4(#8))"; + private static final String MUST_TEST_EXP4 = "AND:4(IS_NOT_NULL:4(#9), IS_NOT_NULL:4(#10))"; + private static final String MUST_TEST_EXP5 = "AND:4(IS_NOT_NULL:4(#11), IS_NOT_NULL:4(#12))"; + private static final String MUST_TEST_EXP6 = "AND:4(IS_NOT_NULL:4(#13), IS_NOT_NULL:4(#14))"; + + private static final String MUST_TEST_EXP1_JSON = omniJsonAndExpr( + omniJsonAndExpr( + omniJsonAndExpr( + omniJsonAndExpr( + omniJsonAndExpr(omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 0)), + omniJsonGreaterThanOrEqualExpr(getOmniJsonFieldReference(2, 0), + getOmniJsonLiteral(2, false, 2451484))), + omniJsonLessThanOrEqualExpr(getOmniJsonFieldReference(2, 0), + getOmniJsonLiteral(2, false, 2451513))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 1))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 2))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 3))); + private static final String MUST_TEST_EXP2_JSON = omniJsonAndExpr( + omniJsonAndExpr( + omniJsonAndExpr( + omniJsonAndExpr( + omniJsonAndExpr( + omniJsonAndExpr(omniJsonIsNotNullExpr(getOmniJsonFieldReference(1, 4)), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(1, 5))), + omniJsonEqualExpr(getOmniJsonFieldReference(1, 4), + getOmniJsonLiteral(1, false, 11))), + omniJsonEqualExpr(getOmniJsonFieldReference(1, 5), + getOmniJsonLiteral(1, false, 1999))), + omniJsonGreaterThanOrEqualExpr(getOmniJsonFieldReference(2, 6), + getOmniJsonLiteral(2, false, 2451484))), + omniJsonLessThanOrEqualExpr(getOmniJsonFieldReference(2, 6), + getOmniJsonLiteral(2, false, 2451513))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 6))); + private static final String MUST_TEST_EXP3_JSON = omniJsonAndExpr( + omniJsonAndExpr(omniJsonIsNotNullExpr(getOmniJsonFieldReference(1, 7)), + omniJsonEqualExpr(getOmniJsonFieldReference(1, 7), getOmniJsonLiteral(1, false, 7))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 8))); + private static final String MUST_TEST_EXP4_JSON = omniJsonAndExpr( + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 9)), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 10))); + private static final String MUST_TEST_EXP5_JSON = omniJsonAndExpr( + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 11)), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(15, 12))); + private static final String MUST_TEST_EXP6_JSON = omniJsonAndExpr( + omniJsonIsNotNullExpr(getOmniJsonFieldReference(15, 13)), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 14))); + + Object[][] dataSourceValue = {{1214L, 2451484L, 2451484L, 2451489L, 2451513L, 2452000L}, // #0 ss_sold_date_sk #long + {1L, 2L, 3L, 4L, 5L, 6L}, // #1 ss_item_sk #long + {1L, 2L, 3L, 4L, 5L, 6L}, // #2 ss_customer_sk #long + {1L, 2L, 3L, 4L, 5L, 6L}, // #3 ss_store_sk #long + {-1, 11, 11, 11, 11, 15}, // #4 d_moy #int + {1000, 1999, 1999, 1999, 2000, 2001}, // d_year #5 #int + {1214L, 2451484L, 2451484L, 2451489L, 2451513L, 2452000L}, // d_date_sk #6 #long + {-1, 7, 7, 7, 7, 10}, // #7 i_manager_id #int + {1L, 2L, 3L, 4L, 5L, 6L}, // #8 i_item_sk #long + {1L, 2L, 3L, 4L, 5L, 6L}, // #9 c_customer_sk #long + {1L, 2L, 3L, 4L, 5L, 6L}, // #10 c_current_addr_sk #long + {1L, 2L, 3L, 4L, 5L, 6L}, // #11 ca_address_sk #long + {"a", "ab", " ", "", " ", " ab "}, // #12 ca_zip #char(10) + {"a", " ab", "ab ", "abc", " ", ""}, // s_zip #13 char(10) + {1L, 2L, 3L, 4L, 5L, 6L} // s_store_sk #14 #long + }; + + Object[][] dataSourceValueWithNull = {{1214L, null, 2451484L, 2451489L, 2451513L, 2452000L}, + // #0 ss_sold_date_sk #long + {null, 2L, 3L, 4L, 5L, null}, // #1 ss_item_sk #long + {null, null, 3L, 4L, 5L, 6L}, // #2 ss_customer_sk #long + {1L, 2L, 3L, null, 5L, 6L}, // #3 ss_store_sk #long + {null, 11, 11, 11, 11, 15}, // #4 d_moy #int + {1000, null, 1999, 1999, 2000, 2001}, // d_year #5 #int + {1214L, 2451484L, null, 2451489L, 2451513L, 2452000L}, // d_date_sk #6 #long + {-1, null, 7, 7, 7, 10}, // #7 i_manager_id #int + {1L, 2L, null, 4L, 5L, 6L}, // #8 i_item_sk #long + {null, null, 3L, 4L, 5L, 6L}, // #9 c_customer_sk #long + {null, 2L, 3L, 4L, 5L, null}, // #10 c_current_addr_sk #long + {1L, null, 3L, 4L, 5L, 6L}, // #11 ca_address_sk #long + {null, "ab", null, null, " ", " ab "}, // #12 ca_zip #char(10) + {null, " ab", "ab ", "abc", " ", null}, // s_zip #13 char(10) + {1L, null, null, 4L, 5L, 6L} // s_store_sk #14 #long + }; + + DataType[] dataSourceTypes = {LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, + IntDataType.INTEGER, IntDataType.INTEGER, LongDataType.LONG, IntDataType.INTEGER, LongDataType.LONG, + LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, VarcharDataType.VARCHAR, VarcharDataType.VARCHAR, + LongDataType.LONG}; + + List dataSourceProjects = ImmutableList.of("#0", "#1", "#2", "#3", "#4", "#5", "#6", "#7", "#8", "#9", + "#10", "#11", "#12", "#13", "#14"); + + List dataSourceProjectionJson = ImmutableList.of( + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":0}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":1}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":2}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":3}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":4}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":5}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":6}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":7}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":8}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":9}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":10}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":11}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":12,\"width\":50}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":13,\"width\":50}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":14}"); + + @Test + public void testMustExpJson() { + int[] resultKeepRowIdxForEXP1 = {2, 4}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP1, MUST_TEST_EXP1_JSON, 1); + + int[] resultKeepRowIdxForEXP2WithNull = {3}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP2WithNull, MUST_TEST_EXP2_JSON, 1); + + int[] resultKeepRowIdxForEXP3 = {3, 4}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP3, MUST_TEST_EXP3_JSON, 1); + + int[] resultKeepRowIdxForEXP4 = {2, 3, 4}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP4, MUST_TEST_EXP4_JSON, 1); + + int[] resultKeepRowIdxForEXP5 = {4, 5}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP5, MUST_TEST_EXP5_JSON, 1); + + int[] resultKeepRowIdxForEXP6 = {3, 4}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP6, MUST_TEST_EXP6_JSON, 1); + } + + @Test + public void testMustExp() { + int[] resultKeepRowIdxForEXP1 = {2, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP1, + MUST_TEST_EXP1); + + int[] resultKeepRowIdxForEXP2WithNull = {3}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, + resultKeepRowIdxForEXP2WithNull, MUST_TEST_EXP2); + + int[] resultKeepRowIdxForEXP3 = {3, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP3, + MUST_TEST_EXP3); + + int[] resultKeepRowIdxForEXP4 = {2, 3, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP4, + MUST_TEST_EXP4); + + int[] resultKeepRowIdxForEXP5 = {4, 5}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP5, + MUST_TEST_EXP5); + + int[] resultKeepRowIdxForEXP6 = {3, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP6, + MUST_TEST_EXP6); + } + + @Test + public void testForStoreSalesTable() { + testForStoreSalesTableWithNull(); + testForStoreSalesTableWithNotNull(); + } + + private void testForStoreSalesTableWithNull() { + String expNotNull1 = "IS_NOT_NULL:4(#0)"; + int[] resultKeepRowIdxForNotNull1 = {0, 2, 3, 4, 5}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull1, + expNotNull1); + + String expNotNull2 = "IS_NOT_NULL:4(#1)"; + int[] resultKeepRowIdxForNotNull2 = {1, 2, 3, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull2, + expNotNull2); + + String expNotNull3 = "IS_NOT_NULL:4(#2)"; + int[] resultKeepRowIdxForNotNull3 = {2, 3, 4, 5}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull3, + expNotNull3); + + String expNotNull4 = "IS_NOT_NULL:4(#3)"; + int[] resultKeepRowIdxForNotNull4 = {0, 1, 2, 4, 5}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull4, + expNotNull4); + } + + private void testForStoreSalesTableWithNotNull() { + String expGe = "$operator$GREATER_THAN_OR_EQUAL:4(#0 , 2451484:2)"; + int[] resultKeepRowIdxForGe = {1, 2, 3, 4, 5}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForGe, expGe); + + String expLe = "$operator$LESS_THAN_OR_EQUAL:4(#0 , 2451513:2)"; + int[] resultKeepRowIdxForLe = {0, 1, 2, 3, 4}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForLe, expLe); + } +} \ No newline at end of file diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql1ForOmniFilterOperatorTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql1ForOmniFilterOperatorTest.java new file mode 100644 index 0000000..9babd32 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql1ForOmniFilterOperatorTest.java @@ -0,0 +1,314 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.tensql; + +import static nova.hetu.omniruntime.util.TestUtils.filterOperatorMatch; +import static nova.hetu.omniruntime.util.TestUtils.filterOperatorMatchWithJson; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonFieldReference; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonLiteral; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonAndExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonGreaterThanOrEqualExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonInExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonIsNotNullExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonLessThanOrEqualExpr; + +import com.google.common.collect.ImmutableList; + +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.Date32DataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; + +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.List; + +/** + * sql1 for OmniFilter operator test. + * + * @since 2022-03-31 + */ +public class Sql1ForOmniFilterOperatorTest { + private static final String MUST_TEST_EXP1 = "AND:4(AND:4(AND:4(IS_NOT_NULL:4(#3) , " + + "$operator$GREATER_THAN_OR_EQUAL:4(#3 , 10406:8)) , $operator$LESS_THAN_OR_EQUAL:4(#3 , 10467:8)) , " + + "IS_NOT_NULL:4(#4))"; + + private static final String MUST_TEST_EXP2 = "AND:4(AND:4(AND:4(AND:4(IS_NOT_NULL:4(#5) , " + + "$operator$GREATER_THAN_OR_EQUAL:4(#5 , 100:1)) , $operator$LESS_THAN_OR_EQUAL:4(#5 , 500:1)) , " + + "IS_NOT_NULL:4(#6)) , IS_NOT_NULL:4(#7))"; + + private static final String MUST_TEST_EXP3 = "AND:4(AND:4(AND:4(AND:4(IS_NOT_NULL:4(#0) , " + + "$operator$GREATER_THAN_OR_EQUAL:4(#0 , 76:1)) , $operator$LESS_THAN_OR_EQUAL:4(#0 , 106:1)) , " + + "IN:4(#1,512:1,409:1,677:1,16:1)) , IS_NOT_NULL:4(#1))"; + + private static final String MUST_TEST_EXP4 = "IS_NOT_NULL:4(#8)"; + + private static final String MUST_TEST_EXP1_JSON = omniJsonAndExpr( + omniJsonAndExpr( + omniJsonAndExpr(omniJsonIsNotNullExpr(getOmniJsonFieldReference(8, 3)), + omniJsonGreaterThanOrEqualExpr(getOmniJsonFieldReference(8, 3), + getOmniJsonLiteral(8, false, 10406))), + omniJsonLessThanOrEqualExpr(getOmniJsonFieldReference(8, 3), getOmniJsonLiteral(8, false, 10467))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 4))); + + private static final String MUST_TEST_EXP2_JSON = omniJsonAndExpr( + omniJsonAndExpr( + omniJsonAndExpr( + omniJsonAndExpr(omniJsonIsNotNullExpr(getOmniJsonFieldReference(1, 5)), + omniJsonGreaterThanOrEqualExpr(getOmniJsonFieldReference(1, 5), + getOmniJsonLiteral(1, false, 100))), + omniJsonLessThanOrEqualExpr(getOmniJsonFieldReference(1, 5), + getOmniJsonLiteral(1, false, 500))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 6))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 7))); + + private static final String MUST_TEST_EXP3_JSON = omniJsonAndExpr( + omniJsonAndExpr( + omniJsonAndExpr( + omniJsonAndExpr(omniJsonIsNotNullExpr(getOmniJsonFieldReference(1, 0)), + omniJsonGreaterThanOrEqualExpr(getOmniJsonFieldReference(1, 0), + getOmniJsonLiteral(1, false, 76))), + omniJsonLessThanOrEqualExpr(getOmniJsonFieldReference(1, 0), + getOmniJsonLiteral(1, false, 106))), + omniJsonInExpr(1, 1, Arrays.asList(512, 409, 677, 16))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(1, 1))); + + private static final String MUST_TEST_EXP4_JSON = omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 8)); + + /* + * 涉及表:item,date_dim,inventory,store_sales + */ + Object[][] dataSource = {{50, 76, 80, 106, 200}, // #0 i_current_price + {16, 30, 409, 512, 677}, // #1 i_manufact_id #int + {0L, 5L, 10L, 15L, 20L}, // #2 i_item_sk #long + {10405, 10406, 10450, 10467, 10500}, // #3 d_date #date + {1L, 2L, 3L, 4L, 5L}, // #4 d_date_sk #long + {10, 100, 200, 500, 1000}, // #5 inv_quantity_on_hand #int + {0L, 5L, 20L, 30L, 50L}, // #6 inv_item_sk #long + {0L, 1L, 20L, 40L, 60L}, // #7 inv_date_sk #long + {2L, 4L, 6L, 8L, 10L} // #8 ss_item_sk #long + }; + + Object[][] dataSourceWithNull = {{50, 76, null, 106, 200}, // #0 i_current_price + {16, 30, 409, 512, null}, // #1 i_manufact_id #int + {0L, null, 10L, null, 20L}, // #2 i_item_sk #long + {10405, 10406, 10450, null, 10500}, // #3 d_date #date + {1L, 2L, null, 4L, 5L}, // #4 d_date_sk #long + {10, 100, 200, null, 1000}, // #5 inv_quantity_on_hand #int + {0L, 5L, 20L, null, 50L}, // #6 inv_item_sk #long + {0L, 1L, null, null, 60L}, // #7 inv_date_sk #long + {2L, null, null, 8L, 10L} // #8 ss_item_sk #long + }; + + // 数据源类型 + DataType[] dataSourceTypes = {IntDataType.INTEGER, IntDataType.INTEGER, LongDataType.LONG, Date32DataType.DATE32, + LongDataType.LONG, IntDataType.INTEGER, LongDataType.LONG, LongDataType.LONG, LongDataType.LONG}; + + List dataSourceProjections = ImmutableList.of("#0", "#1", "#2", "#3", "#4", "#5", "#6", "#7", "#8"); + + List dataSourceProjectionJson = ImmutableList.of( + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":1}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":2}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":8,\"colVal\":3}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":4}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":5}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":6}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":7}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":8}"); + + @Test + public void testMustExpJson() { + int[] resultKeepRowIdxForEXP1 = {1}; + filterOperatorMatchWithJson(dataSourceWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP1, MUST_TEST_EXP1_JSON, 1); + + int[] resultKeepRowIdxForEXP2 = {1}; + filterOperatorMatchWithJson(dataSourceWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP2, MUST_TEST_EXP2_JSON, 1); + + int[] resultKeepRowIdxForEXP3 = {3}; + filterOperatorMatchWithJson(dataSourceWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP3, MUST_TEST_EXP3_JSON, 1); + + int[] resultKeepRowIdxForEXP4 = {0, 3, 4}; + filterOperatorMatchWithJson(dataSourceWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP4, MUST_TEST_EXP4_JSON, 1); + } + + @Test + public void testMustExp() { + int[] resultKeepRowIdxForEXP1 = {1}; + filterOperatorMatch(dataSourceWithNull, dataSourceTypes, dataSourceProjections, resultKeepRowIdxForEXP1, + MUST_TEST_EXP1); + + int[] resultKeepRowIdxForEXP2 = {1}; + filterOperatorMatch(dataSourceWithNull, dataSourceTypes, dataSourceProjections, resultKeepRowIdxForEXP2, + MUST_TEST_EXP2); + + int[] resultKeepRowIdxForEXP3 = {3}; + filterOperatorMatch(dataSourceWithNull, dataSourceTypes, dataSourceProjections, resultKeepRowIdxForEXP3, + MUST_TEST_EXP3); + + int[] resultKeepRowIdxForEXP4 = {0, 3, 4}; + filterOperatorMatch(dataSourceWithNull, dataSourceTypes, dataSourceProjections, resultKeepRowIdxForEXP4, + MUST_TEST_EXP4); + } + + @Test + public void testForItemTableFilter() { + testForItemTableWithNotNull(); + testForItemTableWithNull(); + } + + @Test + public void testForDatedimTableFilter() { + testForDateDimTableWithNotNull(); + testForDateDimTableWithNull(); + } + + @Test + public void testForInventoryTableFilter() { + testForInventoryTableWithNotNull(); + testForInventoryTableWithNull(); + } + + private void testForInventoryTableWithNotNull() { + String expLe = "$operator$LESS_THAN_OR_EQUAL:4(#5 , 500:1)"; + int[] resultKeepRowIdxForLe = {0, 1, 2, 3}; + filterOperatorMatch(dataSource, dataSourceTypes, dataSourceProjections, resultKeepRowIdxForLe, expLe); + + String expGe = "$operator$GREATER_THAN_OR_EQUAL:4(#5 , 100:1)"; + int[] resultKeepRowIdxForGe = {1, 2, 3, 4}; + filterOperatorMatch(dataSource, dataSourceTypes, dataSourceProjections, resultKeepRowIdxForGe, expGe); + + String expLeAndGe = "AND:4($operator$LESS_THAN_OR_EQUAL:4(#5 , 500:1), " + + "$operator$GREATER_THAN_OR_EQUAL:4(#5 , 100:1))"; + int[] resultKeepRowIdxForLeAndGe = {1, 2, 3}; + filterOperatorMatch(dataSource, dataSourceTypes, dataSourceProjections, resultKeepRowIdxForLeAndGe, expLeAndGe); + } + + private void testForInventoryTableWithNull() { + String expNotNull1 = "IS_NOT_NULL:4(#5)"; + int[] resultKeepRowIdxForNotNull1 = {0, 1, 2, 4}; + filterOperatorMatch(dataSourceWithNull, dataSourceTypes, dataSourceProjections, resultKeepRowIdxForNotNull1, + expNotNull1); + + String expNotNull2 = "IS_NOT_NULL:4(#6)"; + int[] resultKeepRowIdxForNotNull2 = {0, 1, 2, 4}; + filterOperatorMatch(dataSourceWithNull, dataSourceTypes, dataSourceProjections, resultKeepRowIdxForNotNull2, + expNotNull2); + + String expNotNull3 = "IS_NOT_NULL:4(#7)"; + int[] resultKeepRowIdxForNotNull3 = {0, 1, 4}; + filterOperatorMatch(dataSourceWithNull, dataSourceTypes, dataSourceProjections, resultKeepRowIdxForNotNull3, + expNotNull3); + + String expGe = "AND:4(IS_NOT_NULL:4(#5) , $operator$GREATER_THAN_OR_EQUAL:4(#5 , 100:1))"; + int[] resultKeepRowIdxForGe = {1, 2, 4}; + filterOperatorMatch(dataSourceWithNull, dataSourceTypes, dataSourceProjections, resultKeepRowIdxForGe, expGe); + + String expLe = "AND:4(IS_NOT_NULL:4(#5) , $operator$LESS_THAN_OR_EQUAL:4(#5 , 500:1))"; + int[] resultKeepRowIdxForLe = {0, 1, 2}; + filterOperatorMatch(dataSourceWithNull, dataSourceTypes, dataSourceProjections, resultKeepRowIdxForLe, expLe); + + String expLeAndGe1 = "AND:4(AND:4(IS_NOT_NULL:4(#5) , $operator$GREATER_THAN_OR_EQUAL:4(#5 , 100:1)) ," + + " $operator$LESS_THAN_OR_EQUAL:4(#5 , 500:1))"; + int[] resultKeepRowIdxForLeAndGe1 = {1, 2}; + filterOperatorMatch(dataSourceWithNull, dataSourceTypes, dataSourceProjections, resultKeepRowIdxForLeAndGe1, + expLeAndGe1); + + String expLeAndGe2 = "AND:4(AND:4(AND:4(IS_NOT_NULL:4(#5) , $operator$GREATER_THAN_OR_EQUAL:4(#5 , 100:1)) ," + + " $operator$LESS_THAN_OR_EQUAL:4(#5 , 500:1)) , IS_NOT_NULL:4(#6))"; + int[] resultKeepRowIdxForLeAndGe2 = {1, 2}; + filterOperatorMatch(dataSourceWithNull, dataSourceTypes, dataSourceProjections, resultKeepRowIdxForLeAndGe2, + expLeAndGe2); + } + + private void testForDateDimTableWithNotNull() { + String expGe = "$operator$GREATER_THAN_OR_EQUAL:4(#3 , 10406:8)"; + int[] resultKeepRowIdxForGe = {1, 2, 3, 4}; + filterOperatorMatch(dataSource, dataSourceTypes, dataSourceProjections, resultKeepRowIdxForGe, expGe); + + String expLe = "$operator$LESS_THAN_OR_EQUAL:4(#3 , 10467:8)"; + int[] resultKeepRowIdxForLe = {0, 1, 2, 3}; + filterOperatorMatch(dataSource, dataSourceTypes, dataSourceProjections, resultKeepRowIdxForLe, expLe); + + String expLeAndGe = "AND:4($operator$GREATER_THAN_OR_EQUAL:4(#3 , 10406:8) , " + + "$operator$LESS_THAN_OR_EQUAL:4(#3 , 10467:8))"; + int[] resultKeepRowIdxForLeAndGe = {1, 2, 3}; + filterOperatorMatch(dataSource, dataSourceTypes, dataSourceProjections, resultKeepRowIdxForLeAndGe, expLeAndGe); + } + + private void testForDateDimTableWithNull() { + String expNotNull1 = "IS_NOT_NULL:4(#3)"; + int[] resultKeepRowIdxForNotNull1 = {0, 1, 2, 4}; + filterOperatorMatch(dataSourceWithNull, dataSourceTypes, dataSourceProjections, resultKeepRowIdxForNotNull1, + expNotNull1); + + String expNotNull2 = "IS_NOT_NULL:4(#4)"; + int[] resultKeepRowIdxForNotNull2 = {0, 1, 3, 4}; + filterOperatorMatch(dataSourceWithNull, dataSourceTypes, dataSourceProjections, resultKeepRowIdxForNotNull2, + expNotNull2); + + String expGe = "AND:4(IS_NOT_NULL:4(#3) , $operator$GREATER_THAN_OR_EQUAL:4(#3 , 10406:8))"; + int[] resultKeepRowIdxForGe = {1, 2, 4}; + filterOperatorMatch(dataSourceWithNull, dataSourceTypes, dataSourceProjections, resultKeepRowIdxForGe, expGe); + + String expLe = "AND:4(IS_NOT_NULL:4(#3) , $operator$LESS_THAN_OR_EQUAL:4(#3 , 10467:8))"; + int[] resultKeepRowIdxForLe = {0, 1, 2}; + filterOperatorMatch(dataSourceWithNull, dataSourceTypes, dataSourceProjections, resultKeepRowIdxForLe, expLe); + + String expLeAndGe1 = "AND:4(AND:4(IS_NOT_NULL:4(#3) , " + + "$operator$GREATER_THAN_OR_EQUAL:4(#3 , 10406:8)) , $operator$LESS_THAN_OR_EQUAL:4(#3 , 10467:8))"; + int[] resultKeepRowIdxForLeAndGe1 = {1, 2}; + filterOperatorMatch(dataSourceWithNull, dataSourceTypes, dataSourceProjections, resultKeepRowIdxForLeAndGe1, + expLeAndGe1); + } + + private void testForItemTableWithNotNull() { + String expGe = "$operator$GREATER_THAN_OR_EQUAL:4(#0 , 76:1)"; + int[] resultKeepRowIdxForGe = {1, 2, 3, 4}; + filterOperatorMatch(dataSource, dataSourceTypes, dataSourceProjections, resultKeepRowIdxForGe, expGe); + + String expLe = "$operator$LESS_THAN_OR_EQUAL:4(#0 , 106:1)"; + int[] resultKeepRowIdxForLe = {0, 1, 2, 3}; + filterOperatorMatch(dataSource, dataSourceTypes, dataSourceProjections, resultKeepRowIdxForLe, expLe); + + String expIn = "IN:4(#1,512:1,409:1,677:1,16:1)"; + int[] resultKeepRowIdxForIn = {0, 2, 3, 4}; + filterOperatorMatch(dataSource, dataSourceTypes, dataSourceProjections, resultKeepRowIdxForIn, expIn); + + String expLeAndGe = "AND:4(AND:4(AND:4(IS_NOT_NULL:4(#0) , $operator$GREATER_THAN_OR_EQUAL:4(#0 , 76:1)) , " + + "$operator$LESS_THAN_OR_EQUAL:4(#0 , 106:1)) , IN:4(#1,512:1,409:1,677:1,16:1))"; + int[] resultKeepRowIdxForLeAndGe = {2, 3}; + filterOperatorMatch(dataSource, dataSourceTypes, dataSourceProjections, resultKeepRowIdxForLeAndGe, expLeAndGe); + } + + private void testForItemTableWithNull() { + String expIn1 = "IS_NOT_NULL:4(#0)"; + int[] resultKeepRowIdxForIn1 = {0, 1, 3, 4}; + filterOperatorMatch(dataSourceWithNull, dataSourceTypes, dataSourceProjections, resultKeepRowIdxForIn1, expIn1); + + String expIn2 = "IS_NOT_NULL:4(#2)"; + int[] resultKeepRowIdxForIn2 = {0, 2, 4}; + filterOperatorMatch(dataSourceWithNull, dataSourceTypes, dataSourceProjections, resultKeepRowIdxForIn2, expIn2); + + String expIn3 = "AND:4(IS_NOT_NULL:4(#0) , $operator$GREATER_THAN_OR_EQUAL:4(#0 , 76:1))"; + int[] resultKeepRowIdxForIn3 = {1, 3, 4}; + filterOperatorMatch(dataSourceWithNull, dataSourceTypes, dataSourceProjections, resultKeepRowIdxForIn3, expIn3); + + String expIn4 = "AND:4(AND:4(IS_NOT_NULL:4(#0) , $operator$GREATER_THAN_OR_EQUAL:4(#0 , 76:1)) , " + + "$operator$LESS_THAN_OR_EQUAL:4(#0 , 106:1))"; + int[] resultKeepRowIdxForIn4 = {1, 3}; + filterOperatorMatch(dataSourceWithNull, dataSourceTypes, dataSourceProjections, resultKeepRowIdxForIn4, expIn4); + + String expIn5 = "AND:4(AND:4(AND:4(IS_NOT_NULL:4(#0) , $operator$GREATER_THAN_OR_EQUAL:4(#0 , 76:1))," + + "$operator$LESS_THAN_OR_EQUAL:4(#0 , 106:1)) , IN:4(#1,512:1,409:1,677:1,16:1))"; + int[] resultKeepRowIdxForIn5 = {3}; + filterOperatorMatch(dataSourceWithNull, dataSourceTypes, dataSourceProjections, resultKeepRowIdxForIn5, expIn5); + } +} \ No newline at end of file diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql2ForOmniFilterOperatorTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql2ForOmniFilterOperatorTest.java new file mode 100644 index 0000000..b0939f7 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql2ForOmniFilterOperatorTest.java @@ -0,0 +1,242 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.tensql; + +import static nova.hetu.omniruntime.util.TestUtils.filterOperatorMatch; +import static nova.hetu.omniruntime.util.TestUtils.filterOperatorMatchWithJson; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonFieldReference; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonLiteral; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonAbsExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonAndExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonCastExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonEqualExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonFourArithmeticExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonGreaterThanExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonIfExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonIsNotNullExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonOrExpr; + +import com.google.common.collect.ImmutableList; + +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DoubleDataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.type.VarcharDataType; + +import org.testng.annotations.Test; + +import java.util.List; + +/** + * sql2 for OmniFilter operator test. + * + * @since 2022-03-31 + */ +public class Sql2ForOmniFilterOperatorTest { + private static final String MUST_TEST_EXP1 = "AND:4(AND:4(IS_NOT_NULL:4(#0) , IS_NOT_NULL:4(#2)) , " + + "IS_NOT_NULL:4(#1))"; + private static final String MUST_TEST_EXP2 = "AND:4(AND:4(IS_NOT_NULL:4(#3) , IS_NOT_NULL:4(#5)) , " + + "IS_NOT_NULL:4(#4))"; + private static final String MUST_TEST_EXP3 = "AND:4(OR:4(OR:4($operator$EQUAL:4(#7 , 2000:1) , " + + "AND:4($operator$EQUAL:4(#7 , 1999:1) , $operator$EQUAL:4(#8 , 12:1))) , " + + "AND:4($operator$EQUAL:4(#7 , 2001:1) , $operator$EQUAL:4(#8 , 1:1))) , IS_NOT_NULL:4(#6))"; + private static final String MUST_TEST_EXP4 = "AND:4(AND:4(IS_NOT_NULL:4(#9) , IS_NOT_NULL:4(#10)) , " + + "IS_NOT_NULL:4(#11))"; + private static final String MUST_TEST_EXP5 = "IS_NOT_NULL:4(#14)"; + private static final String MUST_TEST_EXP6 = "AND:4(AND:4(AND:4(AND:4(AND:4(IS_NOT_NULL:4(#7) , " + + "IS_NOT_NULL:4(#13)) , $operator$EQUAL:4(#7 , 2000:1)) , $operator$GREATER_THAN:4(#13 , 0.0:3)) , " + + "$operator$GREATER_THAN:4(IF:3($operator$GREATER_THAN:4(#13 , 0.0:3), " + + "$operator$DIVIDE:3(abs:3($operator$SUBTRACT:3(CAST:3(#12) , #13)) , #13), null:3) , 0.1:3)) , " + + "IS_NOT_NULL:4(#14))"; + + private static final String MUST_TEST_EXP1_JSON = omniJsonAndExpr( + omniJsonAndExpr(omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 0)), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(15, 2))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(15, 1))); + private static final String MUST_TEST_EXP2_JSON = omniJsonAndExpr( + omniJsonAndExpr(omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 3)), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 5))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 4))); + private static final String MUST_TEST_EXP3_JSON = omniJsonAndExpr( + omniJsonOrExpr(omniJsonOrExpr( + omniJsonEqualExpr(getOmniJsonFieldReference(1, 7), getOmniJsonLiteral(1, false, 2000)), + omniJsonAndExpr( + omniJsonEqualExpr(getOmniJsonFieldReference(1, 7), getOmniJsonLiteral(1, false, 1999)), + omniJsonEqualExpr(getOmniJsonFieldReference(1, 8), getOmniJsonLiteral(1, false, 12)))), + omniJsonAndExpr( + omniJsonEqualExpr(getOmniJsonFieldReference(1, 7), getOmniJsonLiteral(1, false, 2001)), + omniJsonEqualExpr(getOmniJsonFieldReference(1, 8), getOmniJsonLiteral(1, false, 1)))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 6))); + private static final String MUST_TEST_EXP4_JSON = omniJsonAndExpr( + omniJsonAndExpr(omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 9)), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(15, 10))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(15, 11))); + private static final String MUST_TEST_EXP5_JSON = omniJsonIsNotNullExpr(getOmniJsonFieldReference(1, 14)); + private static final String MUST_TEST_EXP6_JSON = omniJsonAndExpr(omniJsonAndExpr( + omniJsonAndExpr(omniJsonAndExpr( + omniJsonAndExpr(omniJsonIsNotNullExpr(getOmniJsonFieldReference(1, 7)), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(3, 13))), + omniJsonEqualExpr(getOmniJsonFieldReference(1, 7), getOmniJsonLiteral(1, false, 2000))), + omniJsonGreaterThanExpr(getOmniJsonFieldReference(3, 13), getOmniJsonLiteral(3, false, 0.0))), + omniJsonGreaterThanExpr(omniJsonIfExpr( + omniJsonGreaterThanExpr(getOmniJsonFieldReference(3, 13), getOmniJsonLiteral(3, false, 0.0)), 3, + omniJsonFourArithmeticExpr("DIVIDE", 3, omniJsonAbsExpr(3, omniJsonFourArithmeticExpr("SUBTRACT", 3, + omniJsonCastExpr(3, getOmniJsonFieldReference(2, 12)), getOmniJsonFieldReference(3, 13))), + getOmniJsonFieldReference(3, 13)), + getOmniJsonFieldReference(3, 13)), getOmniJsonLiteral(3, false, 0.1))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(1, 14))); + + Object[][] dataSourceValue = {{1L, 2L, 3L, 4L, 5L, 6L}, // i_item_sk #0 #long + {"ab", "", " ", " ab", "ab ", "a b c"}, // i_brand #1 char(50) + {"a", " ", "abccscss", " ab ", "ab ab", "abc"}, // i_brand #2 char(50) + {1L, 2L, 3L, 4L, 5L, 6L}, // ss_item_sk #3 long + {1L, 2L, 3L, 4L, 5L, 6L}, // ss_store_sk #4 #long + {1L, 2L, 3L, 4L, 5L, 6L}, // ss_sold_date_sk #5 #long + {1L, 2L, 3L, 4L, 5L, 6L}, // d_date_sk #6 #long + {-10, 1999, 2000, 2000, 1999, 2001}, // d_year #7 #int + {-2, 0, 1, 1, 12, 1}, // d_moy #8 #int + {1L, 2L, 3L, 4L, 5L, 6L}, // s_store_sk #9 long + {"abab", "", " ", " ab", "ab ", "a b c"}, // s_store_name #10 char(50) + {"ab", "", " ", " ab", "ab ", "a b c"}, // s_company_name #11 char(50) + {1L, 2L, 3L, 4L, 5L, 6L}, // sum_sales #12 long + {-2.0, 0.0, 1.0, 1.5, 2.0, 3.0}, // avg_monthly_sales #13 double + {-3, 12, 1215, 1223, 1217, 1214} // rn #14 #int + }; + + Object[][] dataSourceValueWithNull = {{null, 2L, 3L, 4L, 5L, 6L}, // i_item_sk #0 #long + {"ab", null, " ", " ab", "ab ", "a b c"}, // i_brand #1 char(50) + {"a", " ", null, " ab ", "ab ab", "abc"}, // i_brand #2 char(50) + {null, 2L, 3L, 4L, 5L, null}, // ss_item_sk #3 long + {1L, null, 3L, 4L, 5L, 6L}, // ss_store_sk #4 #long + {1L, 2L, null, null, 5L, 6L}, // ss_sold_date_sk #5 #long + {1L, null, 3L, 4L, 5L, 6L}, // d_date_sk #6 #long + {null, 1999, 2000, 2000, 1999, 2001}, // d_year #7 #int + {-2, 0, 1, 1, 12, 1}, // d_moy #8 #int + {null, 2L, 3L, 4L, 5L, 6L}, // s_store_sk #9 long + {"abab", "", " ", " ab", "ab ", null}, // s_store_name #10 char(50) + {"ab", "", " ", " ab", null, "a b c"}, // s_company_name #11 char(50) + {1L, 2L, 3L, 4L, 5L, 6L}, // sum_sales #12 long + {-2.0, 0.0, 1.0, 1.5, null, 3.0}, // avg_monthly_sales #13 double + {12, null, 1223, 1217, 1214, 1224} // rn #14 #int + }; + + DataType[] dataSourceTypes = {LongDataType.LONG, VarcharDataType.VARCHAR, VarcharDataType.VARCHAR, + LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, IntDataType.INTEGER, + IntDataType.INTEGER, LongDataType.LONG, VarcharDataType.VARCHAR, VarcharDataType.VARCHAR, LongDataType.LONG, + DoubleDataType.DOUBLE, IntDataType.INTEGER}; + + List dataSourceProjects = ImmutableList.of("#0", "#1", "#2", "#3", "#4", "#5", "#6", "#7", "#8", "#9", + "#10", "#11", "#12", "#13", "#14"); + + List dataSourceProjectionJson = ImmutableList.of( + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":0}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":1,\"width\":50}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":50}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":3}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":4}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":5}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":6}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":7}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":8}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":9}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":10,\"width\":50}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":11,\"width\":50}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":12}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":3,\"colVal\":13}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":14}"); + + @Test + public void testMustExpJson() { + int[] resultKeepRowIdxForEXP1 = {3, 4, 5}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP1, MUST_TEST_EXP1_JSON, 1); + + int[] resultKeepRowIdxForEXP2 = {4}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP2, MUST_TEST_EXP2_JSON, 1); + + int[] resultKeepRowIdxForEXP3 = {2, 3, 4, 5}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP3, MUST_TEST_EXP3_JSON, 1); + + int[] resultKeepRowIdxForEXP4 = {1, 2, 3}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP4, MUST_TEST_EXP4_JSON, 1); + + int[] resultKeepRowIdxForEXP5 = {0, 2, 3, 4, 5}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP5, MUST_TEST_EXP5_JSON, 1); + + int[] resultKeepRowIdxForEXP6 = {2, 3}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP6, MUST_TEST_EXP6_JSON, 1); + } + + @Test + public void testMustExp() { + int[] resultKeepRowIdxForEXP1 = {3, 4, 5}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP1, + MUST_TEST_EXP1); + + int[] resultKeepRowIdxForEXP2 = {4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP2, + MUST_TEST_EXP2); + + int[] resultKeepRowIdxForEXP3 = {2, 3, 4, 5}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP3, + MUST_TEST_EXP3); + + int[] resultKeepRowIdxForEXP4 = {1, 2, 3}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP4, + MUST_TEST_EXP4); + + int[] resultKeepRowIdxForEXP5 = {0, 2, 3, 4, 5}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP5, + MUST_TEST_EXP5); + + int[] resultKeepRowIdxForEXP6 = {2, 3}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP6, + MUST_TEST_EXP6); + } + + @Test + public void testForDateDimTable() { + testForDateDimTableWithNotNull(); + testForDateDimTableWithNull(); + } + + private void testForDateDimTableWithNotNull() { + String expEq1 = "$operator$EQUAL:4(#7 , 2000:1)"; + int[] resultKeepRowIdxForEq = {2, 3}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEq, expEq1); + + String expEqAnd1 = "AND:4($operator$EQUAL:4(#7 , 1999:1) , $operator$EQUAL:4(#8 , 12:1))"; + int[] resultKeepRowIdxForEqAnd1 = {4}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEqAnd1, expEqAnd1); + + String expEq2 = "$operator$EQUAL:4(#7 , 2001:1)"; + int[] resultKeepRowIdxForEq2 = {5}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEq2, expEq2); + + String expEqAnd2 = "AND:4($operator$EQUAL:4(#7 , 2001:1) , $operator$EQUAL:4(#8 , 1:1))"; + int[] resultKeepRowIdxForEqAnd2 = {5}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEqAnd2, expEqAnd2); + + // expEq1 or expEqAnd1 or expEqAnd2 + String expOr = "OR:4(OR:4($operator$EQUAL:4(#7 , 2000:1) , AND:4($operator$EQUAL:4(#7 , 1999:1) , " + + "$operator$EQUAL:4(#8 , 12:1))) , AND:4($operator$EQUAL:4(#7 , 2001:1) , " + + "$operator$EQUAL:4(#8 , 1:1)))"; + int[] resultKeepRowIdxForOr = {2, 3, 4, 5}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForOr, expOr); + } + + private void testForDateDimTableWithNull() { + String expNotNull = "IS_NOT_NULL:4(#6)"; + int[] resultKeepRowIdxForNotNull = {0, 2, 3, 4, 5}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull, + expNotNull); + } +} \ No newline at end of file diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql3ForOmniFilterOperatorTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql3ForOmniFilterOperatorTest.java new file mode 100644 index 0000000..bde33db --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql3ForOmniFilterOperatorTest.java @@ -0,0 +1,303 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.tensql; + +import static nova.hetu.omniruntime.util.TestUtils.filterOperatorMatch; +import static nova.hetu.omniruntime.util.TestUtils.filterOperatorMatchWithJson; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonFieldReference; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonLiteral; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonAbsExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonAndExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonCastExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonFourArithmeticExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonGreaterThanExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonIfExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonInExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonIsNotNullExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonOrExpr; + +import com.google.common.collect.ImmutableList; + +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DoubleDataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.type.VarcharDataType; + +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.List; + +/** + * sql3 for OmniFilter operator test. + * + * @since 2022-03-31 + */ +public class Sql3ForOmniFilterOperatorTest { + private static final String MUST_TEST_EXP1 = "IS_NOT_NULL:4(#9)"; + private static final String MUST_TEST_EXP2 = "$operator$GREATER_THAN:4(IF:3($operator$GREATER_THAN:4(#11 , 0.0:3), " + + "$operator$DIVIDE:3(abs:3($operator$SUBTRACT:3(CAST:3(#10) , #11)) , #11), null:3) , 0.1:3)"; + private static final String MUST_TEST_EXP3 = "AND:4(IN:4(#7,1216:1,1213:1,1219:1,1221:1,1215:1,1223:1,1212:1," + + "1218:1,1222:1,1217:1,1214:1,1220:1) , IS_NOT_NULL:4(#8))"; + private static final String MUST_TEST_EXP4 = "AND:4(OR:4(AND:4(AND:4(" + + "IN:4(#0,'Books':15,'Children':15,'Electronics':15) , " + + "IN:4(#1,'personal':15,'portable':15,'reference':15,'self-help':15)) , " + + "IN:4(#2,'scholaramalgamalg #14':15,'scholaramalgamalg #7':15,'exportiunivamalg #9':15," + + "'scholaramalgamalg #9':15)) , AND:4(AND:4(IN:4(#0,'Women':15,'Music':15,'Men':15) , " + + "IN:4(#1,'accessories':15,'classical':15,'fragrances':15,'pants':15)) , " + + "IN:4(#2,'amalgimporto #1':15,'edu packscholar #1':15,'exportiimporto #1':15,'importoamalg #1':15))) , " + + "IS_NOT_NULL:4(#3))"; + + private static final String MUST_TEST_EXP1_JSON = omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 9)); + private static final String MUST_TEST_EXP2_JSON = omniJsonGreaterThanExpr( + omniJsonIfExpr(omniJsonGreaterThanExpr(getOmniJsonFieldReference(3, 11), getOmniJsonLiteral(3, false, 0.0)), + 3, + omniJsonFourArithmeticExpr("DIVIDE", 3, omniJsonAbsExpr(3, omniJsonFourArithmeticExpr("SUBTRACT", 3, + omniJsonCastExpr(3, getOmniJsonFieldReference(2, 10)), getOmniJsonFieldReference(3, 11))), + getOmniJsonFieldReference(3, 11)), + getOmniJsonFieldReference(3, 11)), + getOmniJsonLiteral(3, false, 0.1)); + private static final String MUST_TEST_EXP3_JSON = omniJsonAndExpr( + omniJsonInExpr(1, 7, Arrays.asList(1216, 1213, 1219, 1221, 1215, 1223, 1212, 1218, 1222, 1217, 1214, 1220)), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 8))); + private static final String MUST_TEST_EXP4_JSON = omniJsonAndExpr( + omniJsonOrExpr( + omniJsonAndExpr( + omniJsonAndExpr( + omniJsonInExpr(15, 0, + Arrays.asList("Books", "Children", "Electronics", "personal", "portable", + "reference", "self-help")), + omniJsonInExpr(15, 1, + Arrays.asList("personal", "portable", "reference", "self-help"))), + omniJsonInExpr(15, 2, + Arrays.asList("scholaramalgamalg #14", "scholaramalgamalg #7", + "exportiunivamalg #9", "scholaramalgamalg #9"))), + omniJsonAndExpr( + omniJsonAndExpr(omniJsonInExpr(15, 0, Arrays.asList("Women", "Music", "Men")), + omniJsonInExpr(15, 1, + Arrays.asList("accessories", "classical", "fragrances", "pants"))), + omniJsonInExpr(15, 2, + Arrays.asList("amalgimporto #1", "edu packscholar #1", "exportiimporto #1", + "importoamalg #1")))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 3))); + + /** + * i_category #0 char(50) + * i_class #1 char(50) + * i_brand #2 char(50) + * i_item_sk #3 long + * ss_item_sk #4 #long + * ss_sold_date_sk #5 #long + * ss_store_sk #6 #long + * d_month_seq #7 #int + * d_date_sk #8 #long + * s_store_sk #9 long + * sum_sales #10 long + * avg_quarterly_sales #11 double + */ + Object[][] dataSourceValue = {{"Books", "Children", "Books", "Women", "Men", "", "", ""}, + {"personal", "personal", "reference", "self-help", "accessories", "classical", "fragrances", "pants"}, + {"scholaramalgamalg #14", "scholaramalgamalg #7", "exportiunivamalg #9", "scholaramalgamalg #9", + "amalgimporto #1", "edu packscholar #1", "exportiimporto #1", "importoamalg #1"}, + {1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L}, {1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L}, {1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L}, + {1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L}, {12, 1215, 1223, 1217, 1214, 1219, 1213, 22144}, + {1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L}, {1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L}, {1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L}, + {-2.0, -1.0, 0.0, 1.5, 2.0, 3.0, 4.0, 5.0}}; + + /** + * i_category #0 char(50) + * i_class #1 char(50) + * i_brand #2 char(50) + * i_item_sk #3 long + * ss_item_sk #4 #long + * ss_sold_date_sk #5 #long + * ss_store_sk #6 #long + * d_month_seq #7 #int + * d_date_sk #8 #long + * s_store_sk #9 long + * sum_sales #10 long + * avg_quarterly_sales #11 double + */ + Object[][] dataSourceValueWithNull = {{"Books", "Children", null, "Women", "Men", "", "", ""}, + {"personal", "personal", "reference", "self-help", "accessories", "classical", "fragrances", "pants"}, + {"scholaramalgamalg #14", "scholaramalgamalg #7", "exportiunivamalg #9", "scholaramalgamalg #9", + "amalgimporto #1", "edu packscholar #1", "exportiimporto #1", "importoamalg #1"}, + {null, 2L, 3L, 4L, 5L, 6L, 7L, 8L}, {null, 2L, 3L, 4L, 5L, 6L, 7L, 8L}, {1L, null, 3L, 4L, 5L, 6L, 7L, 8L}, + {1L, 2L, null, 4L, 5L, 6L, 7L, 8L}, {12, 1215, 1223, 1217, 1214, 1219, 1213, 22144}, + {1L, null, null, 4L, 5L, 6L, 7L, null}, {null, 2L, 3L, 4L, 5L, null, 7L, null}, + {1L, 2L, 3L, null, 5L, 6L, 7L, 8L}, {-2.0, -1.0, 0.0, 1.5, null, 3.0, 4.0, 5.0}}; + + DataType[] dataSourceTypes = {VarcharDataType.VARCHAR, VarcharDataType.VARCHAR, VarcharDataType.VARCHAR, + LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, IntDataType.INTEGER, + LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, DoubleDataType.DOUBLE}; + + List dataSourceProjects = ImmutableList.of("#0", "#1", "#2", "#3", "#4", "#5", "#6", "#7", "#8", "#9", + "#10", "#11"); + + List dataSourceProjectionJson = ImmutableList.of( + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":0,\"width\":50}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":1,\"width\":50}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":50}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":3}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":4}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":5}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":6}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":7}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":8}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":9}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":10}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":3,\"colVal\":11}"); + + @Test + public void testMustExpJson() { + int[] resultKeepRowIdxForEXP1 = {1, 2, 3, 4, 6}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP1, MUST_TEST_EXP1_JSON, 1); + + int[] resultKeepRowIdxForEXP2WithNull = {5, 6, 7}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP2WithNull, MUST_TEST_EXP2_JSON, 1); + int[] resultKeepRowIdxForEXP2WithNotNull = {3, 4, 5, 6, 7}; + filterOperatorMatchWithJson(dataSourceValue, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP2WithNotNull, MUST_TEST_EXP2_JSON, 1); + + int[] resultKeepRowIdxForEXP3 = {3, 4, 5, 6}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP3, MUST_TEST_EXP3_JSON, 1); + + int[] resultKeepRowIdxForEXP4 = {1, 4}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP4, MUST_TEST_EXP4_JSON, 1); + } + + @Test + public void testMustExp() { + int[] resultKeepRowIdxForEXP1 = {1, 2, 3, 4, 6}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP1, + MUST_TEST_EXP1); + + int[] resultKeepRowIdxForEXP2WithNull = {5, 6, 7}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, + resultKeepRowIdxForEXP2WithNull, MUST_TEST_EXP2); + int[] resultKeepRowIdxForEXP2WithNotNull = {3, 4, 5, 6, 7}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP2WithNotNull, + MUST_TEST_EXP2); + + int[] resultKeepRowIdxForEXP3 = {3, 4, 5, 6}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP3, + MUST_TEST_EXP3); + + int[] resultKeepRowIdxForEXP4 = {1, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP4, + MUST_TEST_EXP4); + } + + @Test + public void testForItemTable() { + testForItemTableWithNotNull(); + } + + @Test + public void testForStoreSalesTable() { + testForStoreSalesTableWithNull(); + } + + @Test + public void testForDateDimTable() { + testForDateDimTableWithNotNull(); + testForDateDimTableWithNull(); + } + + @Test + public void testForStoreTable() { + testForStoreTableWithNull(); + } + + private void testForItemTableWithNotNull() { + String expIn1 = "IN:4(#0,'Books':15,'Children':15," + "'Electronics':15)"; + int[] resultKeepRowIdxForIn1 = {0, 1, 2}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForIn1, expIn1); + + String expIn2 = "IN:4 (#1,'personal':15,'portable':15," + "'reference':15,'self-help':15)"; + int[] resultKeepRowIdxForIn2 = {0, 1, 2, 3}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForIn2, expIn2); + + String expIn3 = "IN:4(#2,'scholaramalgamalg #14':15,'scholaramalgamalg #7':15," + + "'exportiunivamalg #9':15,'scholaramalgamalg #9':15)"; + int[] resultKeepRowIdxForIn3 = {0, 1, 2, 3}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForIn3, expIn3); + + String expIn4 = "IN:4 (#0,'Women':15," + "'Music':15,'Men':15)"; + int[] resultKeepRowIdxForIn4 = {3, 4}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForIn4, expIn4); + + String expIn5 = "IN:4(#1,'accessories':15,'classical':15," + "'fragrances':15,'pants':15)"; + int[] resultKeepRowIdxForIn5 = {4, 5, 6, 7}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForIn5, expIn5); + + String expIn6 = "IN:4(#2,'amalgimporto #1':15,'edu packscholar #1':15," + + "'exportiimporto #1':15,'importoamalg #1':15)"; + int[] resultKeepRowIdxForIn6 = {4, 5, 6, 7}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForIn6, expIn6); + + String expAnd1 = "AND:4(AND:4(IN:4(#0,'Books':15,'Children':15,'Electronics':15)," + + "IN:4(#1,'personal':15,'portable':15,'reference':15,'self-help':15))," + + "IN:4(#2,'scholaramalgamalg #14':15,'scholaramalgamalg #7':15,'exportiunivamalg #9':15," + + "'scholaramalgamalg #9':15))"; + int[] resultKeepRowIdxForAnd1 = {0, 1, 2}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForAnd1, expAnd1); + + String expAnd2 = "AND:4(AND:4(IN:4(#0,'Women':15,'Music':15,'Men':15)," + + "IN(#1,'accessories':15,'classical':15,'fragrances':15,'pants':15))," + + "IN(#2,'amalgimporto #1':15,'edu packscholar #1':15,'exportiimporto #1':15,'importoamalg #1':15))"; + int[] resultKeepRowIdxForAnd2 = {4}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForAnd2, expAnd2); + + // (In1 and In2 and In3) Or (In4 and Int5 and Int6) = {0, 1, 2, 4} + String expMixedOr = "OR:4(AND:4(AND:4(IN:4(#0,'Books':15,'Children':15,'Electronics':15)," + + "IN:4(#1,'personal':15,'portable':15,'reference':15,'self-help':15)), " + + "IN:4(#2,'scholaramalgamalg #14':15,'scholaramalgamalg #7':15,'exportiunivamalg #9':15," + + "'scholaramalgamalg #9':15)), AND:4(AND:4(IN:4 (#0, 'Women':15,'Music ':15,'Men':15), " + + "IN:4(#1,'accessories':15,'classical':15,'fragrances':15,'pants':15)), " + + "IN:4(#2,'amalgimporto #1':15,'edu packscholar #1':15,'exportiimporto #1':15,'importoamalg #1':15)))"; + int[] resultKeepRowIdxForMixedOr = {0, 1, 2, 4}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForMixedOr, + expMixedOr); + } + + private void testForStoreSalesTableWithNull() { + String expNotNull = "AND:4(AND:4(IS_NOT_NULL:4(#4) , IS_NOT_NULL:4(#5)) , IS_NOT_NULL:4(#6))"; + int[] resultKeepRowIdxForNotNull = {3, 4, 5, 6, 7}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull, + expNotNull); + } + + private void testForDateDimTableWithNotNull() { + String expIn = "IN:4(#7,1222:1,1215:1,1223:1,1217:1,1214:1,1219:1,1213:1,1218:1,1220:1,1221:1,1216:1,1212:1)"; + int[] resultKeepRowIdxForIn = {1, 2, 3, 4, 5, 6}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForIn, expIn); + } + + private void testForDateDimTableWithNull() { + String expNotNull = "IS_NOT_NULL:4(#8)"; + int[] resultKeepRowIdxForNotNull = {0, 3, 4, 5, 6}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull, + expNotNull); + + String expMixed = "AND:4(IN:4(#7,1222:1,1215:1,1223:1,1217:1,1214:1,1219:1,1213:1,1218:1,1220:1,1221:1,1216:1," + + "1212:1), IS_NOT_NULL:4(#8))"; + int[] resultKeepRowIdxForMixed = {3, 4, 5, 6}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForMixed, + expMixed); + } + + private void testForStoreTableWithNull() { + String expNotNull = "IS_NOT_NULL:4(#9)"; + int[] resultKeepRowIdxForNotNull = {1, 2, 3, 4, 6}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull, + expNotNull); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql4ForOmniFilterOperatorTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql4ForOmniFilterOperatorTest.java new file mode 100644 index 0000000..314888b --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql4ForOmniFilterOperatorTest.java @@ -0,0 +1,267 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.tensql; + +import static nova.hetu.omniruntime.util.TestUtils.filterOperatorMatch; +import static nova.hetu.omniruntime.util.TestUtils.filterOperatorMatchWithJson; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonFieldReference; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonLiteral; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonAndExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonEqualExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonIsNotNullExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonOrExpr; + +import com.google.common.collect.ImmutableList; + +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.type.VarcharDataType; + +import org.testng.annotations.Test; + +import java.util.List; + +/** + * sql4 for OmniFilter operator test. + * + * @since 2022-03-31 + */ +public class Sql4ForOmniFilterOperatorTest { + private static final String MUST_TEST_EXP1 = "AND:4(AND:4(IS_NOT_NULL:4(#0) , $operator$EQUAL:4(#0 , 1:1)) , " + + "IS_NOT_NULL:4(#1))"; + private static final String MUST_TEST_EXP2 = "AND:4(AND:4(IS_NOT_NULL:4(#2) , IS_NOT_NULL:4(#4)) , " + + "IS_NOT_NULL:4(#3))"; + private static final String MUST_TEST_EXP3 = "AND:4(AND:4(AND:4(AND:4(IS_NOT_NULL:4(#5) , IS_NOT_NULL:4(#6)) , " + + "$operator$EQUAL:4(#5 , 12:1)) , $operator$EQUAL:4(#6 , 2001:1)) , IS_NOT_NULL:4(#7))"; + private static final String MUST_TEST_EXP4 = "AND:4(AND:4(IS_NOT_NULL:4(#8) , IS_NOT_NULL:4(#9)) , " + + "IS_NOT_NULL:4(#10))"; + private static final String MUST_TEST_EXP5 = "AND:4(AND:4(IS_NOT_NULL:4(#11) , IS_NOT_NULL:4(#13)) , " + + "IS_NOT_NULL:4(#12))"; + private static final String MUST_TEST_EXP6 = "AND:4(OR:4($operator$EQUAL:4(#14 , 'breakfast ':15) , " + + "$operator$EQUAL:4(#14 , 'dinner ':15)) , IS_NOT_NULL:4(#15))"; + + private static final String MUST_TEST_EXP1_JSON = omniJsonAndExpr( + omniJsonAndExpr(omniJsonIsNotNullExpr(getOmniJsonFieldReference(1, 0)), + omniJsonEqualExpr(getOmniJsonFieldReference(1, 0), getOmniJsonLiteral(1, false, 1))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 1))); + private static final String MUST_TEST_EXP2_JSON = omniJsonAndExpr( + omniJsonAndExpr(omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 2)), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 4))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 3))); + private static final String MUST_TEST_EXP3_JSON = omniJsonAndExpr( + omniJsonAndExpr( + omniJsonAndExpr( + omniJsonAndExpr(omniJsonIsNotNullExpr(getOmniJsonFieldReference(1, 5)), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(1, 6))), + omniJsonEqualExpr(getOmniJsonFieldReference(1, 5), getOmniJsonLiteral(1, false, 12))), + omniJsonEqualExpr(getOmniJsonFieldReference(1, 6), getOmniJsonLiteral(1, false, 2001))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 7))); + private static final String MUST_TEST_EXP4_JSON = omniJsonAndExpr( + omniJsonAndExpr(omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 8)), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 9))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 10))); + private static final String MUST_TEST_EXP5_JSON = omniJsonAndExpr( + omniJsonAndExpr(omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 11)), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 13))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 12))); + private static final String MUST_TEST_EXP6_JSON = omniJsonAndExpr(omniJsonOrExpr( + omniJsonEqualExpr(getOmniJsonFieldReference(15, 14), getOmniJsonLiteral(15, false, "breakfast ")), + omniJsonEqualExpr(getOmniJsonFieldReference(15, 14), + getOmniJsonLiteral(15, false, "dinner "))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 15))); + + DataType[] dataSourceTypes = {IntDataType.INTEGER, LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, + LongDataType.LONG, IntDataType.INTEGER, IntDataType.INTEGER, LongDataType.LONG, LongDataType.LONG, + LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, + VarcharDataType.VARCHAR, LongDataType.LONG}; + + Object[][] dataSourceValue = {{1, 1, 1, 4, 5}, // #0 i_manager_id #int + {1L, 2L, 3L, 4L, 5L}, // #1 item_sk #long + {1L, 2L, 3L, 4L, 5L}, // #2 ws_sold_date_sk #long + {1L, 2L, 3L, 4L, 5L}, // #3 ws_item_sk #long + {1L, 2L, 3L, 4L, 5L}, // #4 ws_sold_time_sk #long + {12, 12, 12, 12, 40}, // #5 d_moy #int + {2001, 2001, 2001, 2001, 2021}, // #6 d_year #int + {1L, 2L, 3L, 4L, 5L}, // #7 d_date_sk #long + {1L, 2L, 3L, 4L, 5L}, // #8 cs_sold_date_sk #long + {1L, 2L, 3L, 4L, 5L}, // #9 cs_item_sk #long + {1L, 2L, 3L, 4L, 5L}, // #10 cs_sold_time_sk #long + {1L, 2L, 3L, 4L, 5L}, // #11 ss_sold_date_sk #long + {1L, 2L, 3L, 4L, 5L}, // #12 ss_item_sk #long + {1L, 2L, 3L, 4L, 5L}, // #13 ss_sold_time_sk #long + // #14 t_meal_time char(20) + {"breakfast ", "dinner", "dinner ", " dinner", " dinner "}, + {1L, 2L, 3L, 4L, 5L} // #15 t_time_sk #long + }; + + Object[][] dataSourceValueWithNull = {{1, 1, null, 4, 5}, // #0 i_manager_id #int + {1L, null, 3L, 4L, 5L}, // #1 item_sk #long + {null, 2L, 3L, 4L, 5L}, // #2 ws_sold_date_sk #long + {1L, null, 3L, 4L, 5L}, // #3 ws_item_sk #long + {1L, 2L, null, 4L, 5L}, // #4 ws_sold_time_sk #long + {12, null, 12, 12, 40}, // #5 d_moy #int + {2001, 2001, null, 2001, 2021}, // #6 d_year #int + {1L, 2L, 3L, null, 5L}, // #7 d_date_sk #long + {1L, null, 3L, 4L, 5L}, // #8 cs_sold_date_sk #long + {1L, 2L, null, 4L, 5L}, // #9 cs_item_sk #long + {1L, 2L, 3L, null, 5L}, // #10 cs_sold_time_sk #long + {1L, null, 3L, 4L, 5L}, // #11 ss_sold_date_sk #long + {1L, 2L, null, 4L, 5L}, // #12 ss_item_sk #long + {1L, 2L, 3L, null, 5L}, // #13 ss_sold_time_sk #long + {"breakfast ", null, "dinner ", "", " "}, // #14 t_meal_time char(20) + {1L, 2L, null, 4L, 5L} // #15 t_time_sk #long + }; + + List dataSourceProjects = ImmutableList.of("#0", "#1", "#2", "#3", "#4", "#5", "#6", "#7", "#8", "#9", + "#10", "#11", "#12", "#13", "#14", "#15"); + + List dataSourceProjectionJson = ImmutableList.of( + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":1}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":2}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":3}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":4}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":5}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":6}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":7}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":8}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":9}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":10}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":11}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":12}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":13}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":14,\"width\":50}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":15}"); + + @Test + public void testMustExpJson() { + int[] resultKeepRowIdxForEXP1 = {0}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP1, MUST_TEST_EXP1_JSON, 1); + + int[] resultKeepRowIdxForEXP2 = {3, 4}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP2, MUST_TEST_EXP2_JSON, 1); + + int[] resultKeepRowIdxForEXP3 = {0}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP3, MUST_TEST_EXP3_JSON, 1); + + int[] resultKeepRowIdxForEXP4 = {0, 4}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP4, MUST_TEST_EXP4_JSON, 1); + + int[] resultKeepRowIdxForEXP5 = {0, 4}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP5, MUST_TEST_EXP5_JSON, 1); + + int[] resultKeepRowIdxForEXP6 = {0}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP6, MUST_TEST_EXP6_JSON, 1); + } + + @Test + public void testMustExp() { + int[] resultKeepRowIdxForEXP1 = {0}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP1, + MUST_TEST_EXP1); + + int[] resultKeepRowIdxForEXP2 = {3, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP2, + MUST_TEST_EXP2); + + int[] resultKeepRowIdxForEXP3 = {0}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP3, + MUST_TEST_EXP3); + + int[] resultKeepRowIdxForEXP4 = {0, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP4, + MUST_TEST_EXP4); + + int[] resultKeepRowIdxForEXP5 = {0, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP5, + MUST_TEST_EXP5); + + int[] resultKeepRowIdxForEXP6 = {0}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP6, + MUST_TEST_EXP6); + } + + @Test + public void testForItemTable() { + testForItemTableWithNotNull(); + testForItemTableWithNull(); + } + + @Test + public void testForDatedimTable() { + testForDatedimTableWithNotNull(); + testForDatedimTableWithNull(); + } + + @Test + public void testForTimedimTable() { + testForTimedimTableWithNotNull(); + testForTimedimTableWithNull(); + } + + private void testForTimedimTableWithNull() { + String expNotNull = "IS_NOT_NULL:4(#15)"; + int[] resultKeepRowIdxForNotNull = {0, 1, 3, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull, + expNotNull); + } + + private void testForTimedimTableWithNotNull() { + String expEq1 = "$operator$EQUAL:4(#14 , 'dinner ':15)"; + int[] resultKeepRowIdxForEq1 = {2}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEq1, expEq1); + + String expEq2 = "$operator$EQUAL:4(#14 , 'breakfast ':15)"; + int[] resultKeepRowIdxForEq2 = {0}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEq2, expEq2); + } + + private void testForDatedimTableWithNotNull() { + String expEq1 = "$operator$EQUAL:4(#6 , 2001:1)"; + int[] resultKeepRowIdxForEq1 = {0, 1, 2, 3}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEq1, expEq1); + + String expEq2 = "$operator$EQUAL:4(#5 , 12:1)"; + int[] resultKeepRowIdxForEq2 = {0, 1, 2, 3}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEq2, expEq2); + } + + private void testForDatedimTableWithNull() { + String expNotNull1 = "IS_NOT_NULL:4(#5)"; + int[] resultKeepRowIdxForNotNull1 = {0, 2, 3, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull1, + expNotNull1); + + String expNotNull2 = "IS_NOT_NULL:4(#6)"; + int[] resultKeepRowIdxForNotNull2 = {0, 1, 3, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull2, + expNotNull2); + } + + private void testForItemTableWithNull() { + String expNotNull1 = "IS_NOT_NULL:4(#0)"; + int[] resultKeepRowIdxForNotNull1 = {0, 1, 3, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull1, + expNotNull1); + + String expNotNull2 = "IS_NOT_NULL:4(#1)"; + int[] resultKeepRowIdxForNotNull2 = {0, 2, 3, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull2, + expNotNull2); + } + + private void testForItemTableWithNotNull() { + String expEq = "$operator$EQUAL:4(#0 , 1:1)"; + int[] resultKeepRowIdxForEq = {0, 1, 2}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEq, expEq); + } +} \ No newline at end of file diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql5ForOmniFilterOperatorTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql5ForOmniFilterOperatorTest.java new file mode 100644 index 0000000..1607724 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql5ForOmniFilterOperatorTest.java @@ -0,0 +1,253 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.tensql; + +import static nova.hetu.omniruntime.util.TestUtils.filterOperatorMatch; +import static nova.hetu.omniruntime.util.TestUtils.filterOperatorMatchWithJson; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonFieldReference; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonLiteral; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonAndExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonEqualExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonGreaterThanOrEqualExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonIsNotNullExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonLessThanOrEqualExpr; + +import com.google.common.collect.ImmutableList; + +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.type.VarcharDataType; + +import org.testng.annotations.Test; + +import java.util.List; + +/** + * sql5 for OmniFilter operator test. + * + * @since 2022-03-31 + */ +public class Sql5ForOmniFilterOperatorTest { + private static final String MUST_TEST_EXP1 = "AND:4(AND:4(IS_NOT_NULL:4(#3) , " + + "$operator$EQUAL:4(#3 , 'Hopewell':15)) , IS_NOT_NULL:4(#4))"; + private static final String MUST_TEST_EXP2 = "IS_NOT_NULL:4(#10)"; + private static final String MUST_TEST_EXP3 = "AND:4(IS_NOT_NULL:4(#5) , IS_NOT_NULL:4(#6))"; + private static final String MUST_TEST_EXP4 = "AND:4(AND:4(AND:4(AND:4(IS_NOT_NULL:4(#7) , IS_NOT_NULL:4(#8)) , " + + "$operator$GREATER_THAN_OR_EQUAL:4(#7 , 32287:1)) , $operator$LESS_THAN_OR_EQUAL:4(#8 , 82287:1)) , " + + "IS_NOT_NULL:4(#9))"; + + private static final String MUST_TEST_EXP1_JSON = omniJsonAndExpr( + omniJsonAndExpr(omniJsonIsNotNullExpr(getOmniJsonFieldReference(15, 3)), + omniJsonEqualExpr(getOmniJsonFieldReference(15, 3), getOmniJsonLiteral(15, false, "Hopewell"))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 4))); + private static final String MUST_TEST_EXP2_JSON = omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 10)); + private static final String MUST_TEST_EXP3_JSON = omniJsonAndExpr( + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 5)), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 6))); + private static final String MUST_TEST_EXP4_JSON = omniJsonAndExpr( + omniJsonAndExpr( + omniJsonAndExpr( + omniJsonAndExpr(omniJsonIsNotNullExpr(getOmniJsonFieldReference(1, 7)), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(1, 8))), + omniJsonGreaterThanOrEqualExpr(getOmniJsonFieldReference(1, 7), + getOmniJsonLiteral(1, false, 32287))), + omniJsonLessThanOrEqualExpr(getOmniJsonFieldReference(1, 8), getOmniJsonLiteral(1, false, 82287))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 9))); + + List dataSourceProjects = ImmutableList.of("#0", "#1", "#2", "#3", "#4", "#5", "#6", "#7", "#8", "#9", + "#10"); + + List dataSourceProjectionJson = ImmutableList.of( + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":0}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":1}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":2}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":3,\"width\":50}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":4}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":5}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":6}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":7}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":8}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":9}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":10}"); + + DataType[] dataSourceTypes = {LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, VarcharDataType.VARCHAR, + LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, IntDataType.INTEGER, IntDataType.INTEGER, + LongDataType.LONG, LongDataType.LONG}; + + Object[][] dataSourceValue = {{1L, 2L, 3L, 4L, 5L}, // c_current_addr_sk #0 #long + {1L, 2L, 3L, 4L, 5L}, // c_current_cdemo_sk #1 #long + {1L, 2L, 3L, 4L, 5L}, // c_current_hdemo_sk #2 #long + {"Hopewell", "Hopewell", "Hopewell ", " ", " Hopewell"}, // ca_city #3 #char(10) + {1L, 2L, 3L, 4L, 5L}, // ca_address_sk #4 #long + {1L, 2L, 3L, 4L, 5L}, // hd_demo_sk #5 #long + {1L, 2L, 3L, 4L, 5L}, // hd_income_band_sk #6 #long + {0, 32287, 33300, 82287, 90000}, // ib_lower_bound #7 #int + {210, 50000, 82287, 82287, 95000}, // ib_upper_bound #8 #int + {1L, 2L, 3L, 4L, 5L}, // ib_income_band_sk #9 #long + {1L, 2L, 3L, 4L, 5L} // sr_cdemo_sk #10 #long + }; + + Object[][] dataSourceValueWithNull = {{null, 2L, 3L, 4L, 5L}, // c_current_addr_sk #0 #long + {1L, null, 3L, 4L, 5L}, // c_current_cdemo_sk #1 #long + {1L, 2L, null, 4L, 5L}, // c_current_hdemo_sk #2 #long + {"Hopewell", "Hopewell", "Hopewell ", null, " Hopewell"}, // ca_city #3 #char(10) + {null, 2L, 3L, 4L, 5L}, // ca_address_sk #4 #long + {null, 2L, 3L, 4L, 5L}, // hd_demo_sk #5 #long + {1L, null, 3L, 4L, 5L}, // hd_income_band_sk #6 #long + {0, null, 33300, 82200, 82287}, // ib_lower_bound #7 #int + {210, 50000, 82287, null, 82287}, // ib_upper_bound #8 #int + {1L, 2L, 3L, 4L, null}, // ib_income_band_sk #9 #long + {null, null, 3L, 4L, 5L} // sr_cdemo_sk #10 #long + }; + + @Test + public void testMustExpJson() { + int[] resultKeepRowIdxForEXP1 = {1}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP1, MUST_TEST_EXP1_JSON, 1); + + int[] resultKeepRowIdxForEXP2 = {2, 3, 4}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP2, MUST_TEST_EXP2_JSON, 1); + + int[] resultKeepRowIdxForEXP3 = {2, 3, 4}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP3, MUST_TEST_EXP3_JSON, 1); + + int[] resultKeepRowIdxForEXP4 = {2}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP4, MUST_TEST_EXP4_JSON, 1); + } + + @Test + public void testMustExp() { + int[] resultKeepRowIdxForEXP1 = {1}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP1, + MUST_TEST_EXP1); + + int[] resultKeepRowIdxForEXP2 = {2, 3, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP2, + MUST_TEST_EXP2); + + int[] resultKeepRowIdxForEXP3 = {2, 3, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP3, + MUST_TEST_EXP3); + + int[] resultKeepRowIdxForEXP4 = {2}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP4, + MUST_TEST_EXP4); + } + + @Test + public void testForCustomerTable() { + testForCustomerTableWithNull(); + } + + @Test + public void testForCustomerAddressTable() { + testForCustomeraddressTableWithNotNull(); + testForCustomeraddressTableWithNull(); + } + + @Test + public void testForHouseholdDemographicsTable() { + testForHouseholdDemographicsTableWithNull(); + } + + @Test + public void testForIncomeBandTable() { + testForIncomeBandTableWithNotNull(); + testForIncomeBandTableWithNull(); + } + + @Test + public void testForStoreReturnsTable() { + testForStoreReturnsTableWithNull(); + } + + private void testForCustomerTableWithNull() { + String expNotNull = "AND:4(AND:4(IS_NOT_NULL:4(#0),IS_NOT_NULL:4(#1)), IS_NOT_NULL:4(#2))"; + int[] resultKeepRowIdxForNotNull = {3, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull, + expNotNull); + } + + private void testForCustomeraddressTableWithNotNull() { + String expEq = "$operator$EQUAL:4(#3 ,'Hopewell':15)"; + int[] resultKeepRowIdxForEq = {0, 1}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEq, expEq); + } + + private void testForCustomeraddressTableWithNull() { + String expNotNull1 = "IS_NOT_NULL:4(#3)"; + int[] resultKeepRowIdxForNotNull1 = {0, 1, 2, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull1, + expNotNull1); + + String expNotNull2 = "IS_NOT_NULL:4(#4)"; + int[] resultKeepRowIdxForNotNull2 = {1, 2, 3, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull2, + expNotNull2); + + String expMixed = "AND:4(AND:4(IS_NOT_NULL:4(#3), $operator$EQUAL:4(#3 ,'Hopewell':15)), IS_NOT_NULL:4(#4))"; + int[] resultKeepRowIdxForMixed = {1}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForMixed, + expMixed); + } + + private void testForHouseholdDemographicsTableWithNull() { + String expNotNull1 = "IS_NOT_NULL:4(#5)"; + int[] resultKeepRowIdxForNotNull1 = {1, 2, 3, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull1, + expNotNull1); + + String expNotNull2 = "IS_NOT_NULL:4(#6)"; + int[] resultKeepRowIdxForNotNull2 = {0, 2, 3, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull2, + expNotNull2); + + String expMixed = "AND:4(IS_NOT_NULL:4(#5), IS_NOT_NULL:4(#6))"; + int[] resultKeepRowIdxForMixed = {2, 3, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForMixed, + expMixed); + } + + private void testForIncomeBandTableWithNotNull() { + String expGe = "$operator$GREATER_THAN_OR_EQUAL:4(#7,32287:1)"; + int[] resultKeepRowIdxForGe = {1, 2, 3, 4}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForGe, expGe); + + String expLe = "$operator$LESS_THAN_OR_EQUAL:4(#8, 82287:1)"; + int[] resultKeepRowIdxForLe = {0, 1, 2, 3}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForLe, expLe); + } + + private void testForIncomeBandTableWithNull() { + String expNotNull1 = "IS_NOT_NULL:4(#7)"; + int[] resultKeepRowIdxForNotNull1 = {0, 2, 3, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull1, + expNotNull1); + + String expNotNull2 = "IS_NOT_NULL:4(#8)"; + int[] resultKeepRowIdxForNotNull2 = {0, 1, 2, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull2, + expNotNull2); + + String expMixed = "AND:4(AND:4(AND:4(AND:4(IS_NOT_NULL:4(#7), IS_NOT_NULL:4(#8)), " + + "$operator$GREATER_THAN_OR_EQUAL:4(#7,32287:1)), $operator$LESS_THAN_OR_EQUAL:4(#8, 82287:1)), " + + "IS_NOT_NULL:4(#9))"; + int[] resultKeepRowIdxForMixed = {2}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForMixed, + expMixed); + } + + private void testForStoreReturnsTableWithNull() { + String expNotNull = "IS_NOT_NULL:4(#10)"; + int[] resultKeepRowIdxForNotNull = {2, 3, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull, + expNotNull); + } +} \ No newline at end of file diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql6ForOmniFilterOperatorTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql6ForOmniFilterOperatorTest.java new file mode 100644 index 0000000..6b03b13 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql6ForOmniFilterOperatorTest.java @@ -0,0 +1,437 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.tensql; + +import static nova.hetu.omniruntime.util.TestUtils.filterOperatorMatch; +import static nova.hetu.omniruntime.util.TestUtils.filterOperatorMatchWithJson; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonFieldReference; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonLiteral; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonAbsExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonAndExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonCastExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonFourArithmeticExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonGreaterThanExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonGreaterThanOrEqualExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonIfExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonInExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonIsNotNullExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonLessThanOrEqualExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonOrExpr; + +import com.google.common.collect.ImmutableList; + +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DoubleDataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.type.VarcharDataType; + +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.List; + +/** + * sql6 for OmniFilter operator test. + * + * @since 2022-03-21 + */ +public class Sql6ForOmniFilterOperatorTest { + private static final String MUST_TEST_EXP1 = "AND:4(OR:4(AND:4(AND:4(IN:4(#0," + + "'Books ':15," + + "'Children ':15," + + "'Electronics ':15) , " + + "IN:4(#1,'personal ':15," + + "'portable ':15," + + "'reference ':15," + + "'self-help ':15)) , " + + "IN:4(#2,'scholaramalgamalg #14 ':15," + + "'scholaramalgamalg #7 ':15," + + "'exportiunivamalg #9 ':15," + + "'scholaramalgamalg #9 ':15)) , " + + "AND:4(AND:4(IN:4(#0,'Women ':15," + + "'Music ':15," + + "'Men ':15) , " + + "IN:4(#1,'accessories ':15," + + "'classical ':15," + + "'fragrances ':15," + + "'pants ':15)) , " + + "IN:4(#2,'amalgimporto #1 ':15," + + "'edu packscholar #1 ':15," + + "'exportiimporto #1 ':15," + + "'importoamalg #1 ':15))) , IS_NOT_NULL:4(#3))"; + private static final String MUST_TEST_EXP2 = "AND:4(AND:4(AND:4(AND:4(IS_NOT_NULL:4(#4) , " + + "$operator$GREATER_THAN_OR_EQUAL:4(#5 , 2452123:2)) , $operator$LESS_THAN_OR_EQUAL:4(#5 , 2452487:2)) , " + + "IS_NOT_NULL:4(#5)) , IS_NOT_NULL:4(#6))"; + private static final String MUST_TEST_EXP3 = "AND:4(AND:4(AND:4(" + + "IN:4(#7,1227:1,1224:1,1219:1,1221:1,1226:1,1223:1,1229:1,1230:1,1222:1,1228:1,1225:1,1220:1) , " + + "$operator$GREATER_THAN_OR_EQUAL:4(#8 , 2452123:2)) , $operator$LESS_THAN_OR_EQUAL:4(#8 , 2452487:2)) , " + + "IS_NOT_NULL:4(#8))"; + private static final String MUST_TEST_EXP4 = "IS_NOT_NULL:4(#9)"; + private static final String MUST_TEST_EXP5 = "$operator$GREATER_THAN:4(IF:3($operator$GREATER_THAN:4(#11 , 0.0:3), " + + "$operator$DIVIDE:3(abs:3($operator$SUBTRACT:3(CAST:3(#10) , #11)) , #11), null:3) , 0.1:3)"; + + private static final String MUST_TEST_EXP1_JSON = omniJsonAndExpr( + omniJsonOrExpr( + omniJsonAndExpr( + omniJsonAndExpr( + omniJsonInExpr(15, 0, + Arrays.asList("Books ", + "Children ", + "Electronics ")), + omniJsonInExpr(15, 1, + Arrays.asList("personal ", + "portable ", + "reference ", + "self-help "))), + omniJsonInExpr(15, 2, + Arrays.asList("scholaramalgamalg #14 ", + "scholaramalgamalg #7 ", + "exportiunivamalg #9 ", + "scholaramalgamalg #9 "))), + omniJsonAndExpr( + omniJsonAndExpr( + omniJsonInExpr(15, 0, + Arrays.asList("Women ", + "Music ", + "Men ")), + omniJsonInExpr(15, 1, + Arrays.asList("accessories ", + "classical ", + "fragrances ", + "pants "))), + omniJsonInExpr(15, 2, + Arrays.asList("amalgimporto #1 ", + "edu packscholar #1 ", + "exportiimporto #1 ", + "importoamalg #1 ")))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 3))); + private static final String MUST_TEST_EXP2_JSON = omniJsonAndExpr( + omniJsonAndExpr( + omniJsonAndExpr( + omniJsonAndExpr(omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 4)), + omniJsonGreaterThanOrEqualExpr(getOmniJsonFieldReference(2, 5), + getOmniJsonLiteral(2, false, 2452123))), + omniJsonLessThanOrEqualExpr(getOmniJsonFieldReference(2, 5), + getOmniJsonLiteral(2, false, 2452487))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 5))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 6))); + private static final String MUST_TEST_EXP3_JSON = omniJsonAndExpr( + omniJsonAndExpr( + omniJsonAndExpr( + omniJsonInExpr(1, 7, + Arrays.asList(1227, 1224, 1219, 1221, 1226, 1223, 1229, 1230, 1222, 1228, 1225, + 1220)), + omniJsonGreaterThanOrEqualExpr(getOmniJsonFieldReference(2, 8), + getOmniJsonLiteral(2, false, 2452123))), + omniJsonLessThanOrEqualExpr(getOmniJsonFieldReference(2, 8), + getOmniJsonLiteral(2, false, 2452487))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 8))); + private static final String MUST_TEST_EXP4_JSON = omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 9)); + private static final String MUST_TEST_EXP5_JSON = omniJsonGreaterThanExpr( + omniJsonIfExpr(omniJsonGreaterThanExpr(getOmniJsonFieldReference(3, 11), getOmniJsonLiteral(3, false, 0.0)), + 3, + omniJsonFourArithmeticExpr("DIVIDE", 3, omniJsonAbsExpr(3, omniJsonFourArithmeticExpr("SUBTRACT", 3, + omniJsonCastExpr(3, getOmniJsonFieldReference(2, 10)), getOmniJsonFieldReference(3, 11))), + getOmniJsonFieldReference(3, 11)), + getOmniJsonFieldReference(3, 11)), + getOmniJsonLiteral(3, false, 0.1)); + + Object[][] dataSourceValue = { + {"Books ", "Children ", + "Books ", + "Women ", + "Men ", " ", " ", ""}, // i_category #0 char(50) + {"reference ", "reference ", + "reference ", + "self-help ", + "accessories ", + "classical ", + "fragrances ", + "pants "}, // i_class #1 char(50) + {"scholaramalgamalg #14 ", "scholaramalgamalg #7 ", + "exportiunivamalg #9 ", + "scholaramalgamalg #9 ", + "amalgimporto #1 ", + "edu packscholar #1 ", + "exportiimporto #1 ", + "importoamalg #1 "}, // i_brand #2 char(50) + {1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L}, // i_item_sk #3 long + {1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L}, // ss_item_sk #4 #long + {-12234L, 2452121L, 2452123L, 2452200L, 2452486L, 2452487L, 2452487L, 2455000L}, // ss_sold_date_sk #5 #long + {1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L}, // ss_store_sk #6 #long + {-12, 12, 1228, 1223, 1227, 1219, 1226, 22144}, // d_month_seq #7 #int + {-12234L, 2452121L, 2452123L, 2452200L, 2452486L, 2452487L, 2452487L, 2455000L}, // d_date_sk #8 #long + {1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L}, // s_store_sk #9 long + {1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L}, // sum_sales #10 long + {-2.0, -1.0, 0.0, 1.5, 2.0, 3.0, 4.0, 5.0} // avg_monthly_sales #11 double + }; + + Object[][] dataSourceValueWithNull = { + {"Books ", "Children ", + null, "Women ", + "Men ", " ", " ", ""}, // i_category #0 char(50) + {"personal ", "personal ", + "reference ", + "self-help ", + "accessories ", + "classical ", + "fragrances ", + "pants "}, // i_class #1 char(50) + {"scholaramalgamalg #14 ", "scholaramalgamalg #7 ", + "exportiunivamalg #9 ", + "scholaramalgamalg #9 ", + "amalgimporto #1 ", + "edu packscholar #1 ", + "exportiimporto #1 ", + "importoamalg #1 "}, // i_brand #2 char(50) + {null, 2L, 3L, 4L, 5L, 6L, 7L, 8L}, // i_item_sk #3 long + {null, 2L, 3L, 4L, 5L, 6L, 7L, 8L}, // ss_item_sk #4 #long + {-12234L, null, 2452123L, 2452200L, 2452486L, 2452487L, 2452487L, 2455000L}, // ss_sold_date_sk #5 #long + {1L, 2L, null, 4L, 5L, 6L, 7L, 8L}, // ss_store_sk #6 #long + {-12, 12, 1228, 1223, 1227, 1219, 1226, 22144}, // d_month_seq #7 #int + {-12234L, null, null, 2452200L, 2452486L, 2452487L, 2452487L, null}, // d_date_sk #8 #long + {null, 2L, 3L, 4L, 5L, null, 7L, null}, // s_store_sk #9 long + {1L, 2L, 3L, 4L, null, 6L, 7L, 8L}, // sum_sales #10 long + {-2.0, -1.0, 0.0, null, 2.0, 3.0, 4.0, 5.0} // avg_monthly_sales #11 double + }; + + DataType[] dataSourceTypes = {VarcharDataType.VARCHAR, VarcharDataType.VARCHAR, VarcharDataType.VARCHAR, + LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, IntDataType.INTEGER, + LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, DoubleDataType.DOUBLE}; + + List dataSourceProjects = ImmutableList.of("#0", "#1", "#2", "#3", "#4", "#5", "#6", "#7", "#8", "#9", + "#10", "#11"); + + List dataSourceProjectionJson = ImmutableList.of( + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":0,\"width\":50}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":1,\"width\":50}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":50}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":3}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":4}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":5}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":6}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":7}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":8}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":9}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":10}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":3,\"colVal\":11}"); + + @Test + public void testMustExpJson() { + int[] resultKeepRowIdxForEXP1 = {1, 4}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP1, MUST_TEST_EXP1_JSON, 1); + + int[] resultKeepRowIdxForEXP2 = {3, 4, 5, 6}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP2, MUST_TEST_EXP2_JSON, 1); + + int[] resultKeepRowIdxForEXP3 = {3, 4, 5, 6}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP3, MUST_TEST_EXP3_JSON, 1); + + int[] resultKeepRowIdxForEXP4 = {1, 2, 3, 4, 6}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP4, MUST_TEST_EXP4_JSON, 1); + + int[] resultKeepRowIdxForEXP5 = {5, 6, 7}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP5, MUST_TEST_EXP5_JSON, 1); + } + + @Test + public void testMustExp() { + int[] resultKeepRowIdxForEXP1 = {1, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP1, + MUST_TEST_EXP1); + + int[] resultKeepRowIdxForEXP2WithNull = {3, 4, 5, 6}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, + resultKeepRowIdxForEXP2WithNull, MUST_TEST_EXP2); + + int[] resultKeepRowIdxForEXP3 = {3, 4, 5, 6}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP3, + MUST_TEST_EXP3); + + int[] resultKeepRowIdxForEXP4 = {1, 2, 3, 4, 6}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP4, + MUST_TEST_EXP4); + + int[] resultKeepRowIdxForEXP5 = {5, 6, 7}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP5, + MUST_TEST_EXP5); + } + + @Test + public void testForItemTable() { + testForItemTableWithNotNull(); + } + + @Test + public void testForStoreSalesTable() { + testForStoreSalesTableWithNull(); + testForStoreSalesTableWithNotNull(); + } + + @Test + public void testForDateDimTable() { + testForDateDimTableWithNotNull(); + testForDateDimTableWithNull(); + } + + @Test + public void testForStoreTable() { + testForStoreTableWithNull(); + } + + private void testForItemTableWithNotNull() { + String expIn1 = "IN:4(#0,'Books ':15," + + "'Children ':15," + + "'Electronics ':15)"; + int[] resultKeepRowIdxForIn1 = {0, 1, 2}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForIn1, expIn1); + + String expIn2 = "IN:4 (#1,'personal ':15," + + "'portable ':15," + + "'reference ':15," + + "'self-help ':15)"; + int[] resultKeepRowIdxForIn2 = {0, 1, 2, 3}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForIn2, expIn2); + + String expIn3 = "IN:4(#2,'scholaramalgamalg #14 ':15," + + "'scholaramalgamalg #7 ':15," + + "'exportiunivamalg #9 ':15," + + "'scholaramalgamalg #9 ':15)"; + int[] resultKeepRowIdxForIn3 = {0, 1, 2, 3}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForIn3, expIn3); + + String expIn4 = "IN:4 (#0,'Women ':15," + "" + + "'Music ':15," + + "'Men ':15)"; + int[] resultKeepRowIdxForIn4 = {3, 4}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForIn4, expIn4); + + String expIn5 = "IN:4(#1,'accessories ':15," + + "'classical ':15," + + "'fragrances ':15," + + "'pants ':15)"; + int[] resultKeepRowIdxForIn5 = {4, 5, 6, 7}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForIn5, expIn5); + + String expIn6 = "IN:4(#2,'amalgimporto #1 ':15," + + "'edu packscholar #1 ':15," + + "'exportiimporto #1 ':15," + + "'importoamalg #1 ':15)"; + int[] resultKeepRowIdxForIn6 = {4, 5, 6, 7}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForIn6, expIn6); + + String expAnd1 = "AND:4(AND:4(IN:4(#0,'Books ':15," + + "'Children ':15," + + "'Electronics ':15)," + + "IN:4(#1,'personal ':15," + + "'portable ':15," + + "'reference ':15," + + "'self-help ':15))," + + "IN:4(#2,'scholaramalgamalg #14 ':15," + + "'scholaramalgamalg #7 ':15," + + "'exportiunivamalg #9 ':15," + + "'scholaramalgamalg #9 ':15))"; + int[] resultKeepRowIdxForAnd1 = {0, 1, 2}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForAnd1, expAnd1); + + String expAnd2 = "AND:4(AND:4(IN:4(#0,'Women ':15," + + "'Music ':15," + + "'Men ':15)," + + "IN(#1,'accessories ':15," + + "'classical ':15," + + "'fragrances ':15," + + "'pants ':15))," + + "IN(#2,'amalgimporto #1 ':15," + + "'edu packscholar #1 ':15," + + "'exportiimporto #1 ':15," + + "'importoamalg #1 ':15))"; + int[] resultKeepRowIdxForAnd2 = {4}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForAnd2, expAnd2); + + // (In1 and In2 and In3) Or (In4 and Int5 and Int6) = {0, 1, 2, 4} + String expMixedOr = "OR:4(AND:4(AND:4(IN:4(#0,'Books ':15," + + "'Children ':15," + + "'Electronics ':15)," + + "IN:4(#1,'personal ':15," + + "'portable ':15," + + "'reference ':15," + + "'self-help ':15)), " + + "IN:4(#2,'scholaramalgamalg #14 ':15," + + "'scholaramalgamalg #7 ':15," + + "'exportiunivamalg #9 ':15," + + "'scholaramalgamalg #9 ':15)), " + + "AND:4(AND:4(IN:4 (#0, 'Women ':15," + + "'Music ':15," + + "'Men ':15), " + + "IN:4(#1,'accessories ':15," + + "'classical ':15," + + "'fragrances ':15," + + "'pants ':15)), " + + "IN:4(#2,'amalgimporto #1 ':15," + + "'edu packscholar #1 ':15," + + "'exportiimporto #1 ':15," + + "'importoamalg #1 ':15)))"; + int[] resultKeepRowIdxForMixedOr = {0, 1, 2, 4}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForMixedOr, + expMixedOr); + } + + private void testForStoreSalesTableWithNull() { + String expNotNull = "AND:4(AND:4(IS_NOT_NULL:4(#4) , IS_NOT_NULL:4(#5)) , IS_NOT_NULL:4(#6))"; + int[] resultKeepRowIdxForNotNull = {3, 4, 5, 6, 7}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull, + expNotNull); + } + + private void testForStoreSalesTableWithNotNull() { + String expGe = "$operator$GREATER_THAN_OR_EQUAL:4(#5 , 2452123:2)"; + int[] resultKeepRowIdxForGe = {2, 3, 4, 5, 6, 7}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForGe, expGe); + + String expLe = "$operator$LESS_THAN_OR_EQUAL:4(#5 , 2452487:2)"; + int[] resultKeepRowIdxForLe = {0, 1, 2, 3, 4, 5, 6}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForLe, expLe); + } + + private void testForDateDimTableWithNotNull() { + String expIn = "IN:4(#7,1227:1,1224:1,1219:1,1221:1,1226:1,1223:1,1229:1,1230:1,1222:1,1228:1,1225:1,1220:1)"; + int[] resultKeepRowIdxForIn = {2, 3, 4, 5, 6}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForIn, expIn); + + String expGe = "$operator$GREATER_THAN_OR_EQUAL:4(#5 , 2452123:2)"; + int[] resultKeepRowIdxForGe = {2, 3, 4, 5, 6, 7}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForGe, expGe); + + String expLe = "$operator$LESS_THAN_OR_EQUAL:4(#5 , 2452487:2)"; + int[] resultKeepRowIdxForLe = {0, 1, 2, 3, 4, 5, 6}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForLe, expLe); + } + + private void testForDateDimTableWithNull() { + String expNotNull = "IS_NOT_NULL:4(#8)"; + int[] resultKeepRowIdxForNotNull = {0, 3, 4, 5, 6}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull, + expNotNull); + + String expMixed = "AND:4(IN:4(#7,1227:1,1224:1,1219:1,1221:1,1226:1,1223:1,1229:1,1230:1,1222:1,1228:1,1225:1," + + "1220:1), IS_NOT_NULL:4(#8))"; + int[] resultKeepRowIdxForMixed = {3, 4, 5, 6}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForMixed, + expMixed); + } + + private void testForStoreTableWithNull() { + String expNotNull = "IS_NOT_NULL:4(#9)"; + int[] resultKeepRowIdxForNotNull = {1, 2, 3, 4, 6}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull, + expNotNull); + } +} \ No newline at end of file diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql7ForOmniFilterOperatorTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql7ForOmniFilterOperatorTest.java new file mode 100644 index 0000000..c8d1858 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql7ForOmniFilterOperatorTest.java @@ -0,0 +1,167 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.tensql; + +import static nova.hetu.omniruntime.util.TestUtils.filterOperatorMatch; +import static nova.hetu.omniruntime.util.TestUtils.filterOperatorMatchWithJson; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonFieldReference; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonLiteral; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonAndExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonGreaterThanOrEqualExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonIsNotNullExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonLessThanOrEqualExpr; + +import com.google.common.collect.ImmutableList; + +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; + +import org.testng.annotations.Test; + +import java.util.List; + +/** + * sql7 for OmniFilter operator test. + * + * @since 2022-03-31 + */ +public class Sql7ForOmniFilterOperatorTest { + private static final String MUST_TEST_EXP1 = "AND:4(AND:4(AND:4(IS_NOT_NULL:4(#7) , " + + "$operator$GREATER_THAN_OR_EQUAL:4(#7 , 1202:1)) , $operator$LESS_THAN_OR_EQUAL:4(#7 , 1213:1)) , " + + "IS_NOT_NULL:4(#8))"; + + private static final String MUST_TEST_EXP1_JSON = omniJsonAndExpr( + omniJsonAndExpr( + omniJsonAndExpr(omniJsonIsNotNullExpr(getOmniJsonFieldReference(1, 7)), + omniJsonGreaterThanOrEqualExpr(getOmniJsonFieldReference(1, 7), + getOmniJsonLiteral(1, false, 1202))), + omniJsonLessThanOrEqualExpr(getOmniJsonFieldReference(1, 7), getOmniJsonLiteral(1, false, 1213))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 8))); + + DataType[] dataSourceTypes = {LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, + LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, IntDataType.INTEGER, LongDataType.LONG}; + + Object[][] dataSourceValue = {{1L, 2L, 3L, 4L, 5L}, // cs_warehouse_sk #0 long + {1L, 2L, 3L, 4L, 5L}, // cs_ship_mode_sk #1 long + {1L, 2L, 3L, 4L, 5L}, // cs_call_center_sk #2 long + {1L, 2L, 3L, 4L, 5L}, // cs_ship_date_sk #3 long + {1L, 2L, 3L, 4L, 5L}, // w_warehouse_sk #4 long + {1L, 2L, 3L, 4L, 5L}, // sm_ship_mode_sk #5 long + {1L, 2L, 3L, 4L, 5L}, // cc_call_center_sk #6 long + {-3, 1202, 1205, 1213, 1213}, // d_month_seq #7 int + {1L, 2L, 3L, 4L, 5L} // d_date_sk #8 long + }; + + Object[][] dataSourceValueWithNull = {{null, 2L, 3L, 4L, 5L}, // cs_warehouse_sk #0 long + {1L, null, 3L, 4L, 5L}, // cs_ship_mode_sk #1 long + {1L, 2L, null, 4L, 5L}, // cs_call_center_sk #2 long + {1L, 2L, 3L, null, 5L}, // cs_ship_date_sk #3 long + {null, null, 3L, 4L, 5L}, // w_warehouse_sk #4 long + {null, 2L, 3L, 4L, null}, // sm_ship_mode_sk #5 long + {1L, 2L, null, null, 5L}, // cc_call_center_sk #6 long + {-3, null, 1205, 1213, 1213}, // d_month_seq #7 int + {1L, 2L, 3L, null, 5L} // d_date_sk #8 long + }; + + List dataSourceProjects = ImmutableList.of("#0", "#1", "#2", "#3", "#4", "#5", "#6", "#7", "#8"); + + List dataSourceProjectionJson = ImmutableList.of( + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":0}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":1}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":2}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":3}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":4}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":5}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":6}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":7}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":8}"); + + @Test + public void testMustExpJson() { + int[] resultKeepRowIdxForEXP1 = {2, 4}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP1, MUST_TEST_EXP1_JSON, 1); + } + + @Test + public void testMustExp() { + int[] resultKeepRowIdxForEXP1 = {2, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP1, + MUST_TEST_EXP1); + } + + @Test + public void testForCatalogSalesTable() { + testForCatalogSalesTableWithNull(); + } + + @Test + public void testForWarehouseTable() { + testForWarehouseTableWithNull(); + } + + @Test + public void testForShipModeTable() { + testForShipModeTableWithNull(); + } + + @Test + public void testForCallCenterTable() { + testForCallCenterTableWithNull(); + } + + @Test + public void testForDateDimTable() { + testForDateDimTableWithNull(); + } + + private void testForCatalogSalesTableWithNull() { + String expNotNull = "AND:4(AND:4(AND:4(IS_NOT_NULL:4(#0), IS_NOT_NULL:4(#1)), IS_NOT_NULL:4(#2)), " + + "IS_NOT_NULL:4(#3))"; + int[] resultKeepRowIdxForNotNull1 = {4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull1, + expNotNull); + } + + private void testForWarehouseTableWithNull() { + String expNotNull = "IS_NOT_NULL:4(#4)"; + int[] resultKeepRowIdxForNotNull = {2, 3, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull, + expNotNull); + } + + private void testForShipModeTableWithNull() { + String expNotNull = "IS_NOT_NULL:4(#5)"; + int[] resultKeepRowIdxForNotNull = {1, 2, 3}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull, + expNotNull); + } + + private void testForCallCenterTableWithNull() { + String expNotNull = "IS_NOT_NULL:4(#6)"; + int[] resultKeepRowIdxForNotNull = {0, 1, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull, + expNotNull); + } + + private void testForDateDimTableWithNull() { + String expNotNull1 = "IS_NOT_NULL:4(#7)"; + int[] resultKeepRowIdxForNotNull1 = {0, 2, 3, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull1, + expNotNull1); + + String expNotNull2 = "IS_NOT_NULL:4(#8)"; + int[] resultKeepRowIdxForNotNull2 = {0, 1, 2, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull2, + expNotNull2); + + String expMixed = "AND:4(AND:4(AND:4(IS_NOT_NULL:4(#7), $operator$GREATER_THAN_OR_EQUAL:4(#7 , 1202:1)), " + + "$operator$LESS_THAN_OR_EQUAL:4(#7 , 1213:1)), IS_NOT_NULL:4(#8))"; + int[] resultKeepRowIdxForMixed = {2, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForMixed, + expMixed); + } +} \ No newline at end of file diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql8ForOmniFilterOperatorTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql8ForOmniFilterOperatorTest.java new file mode 100644 index 0000000..5036e5e --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql8ForOmniFilterOperatorTest.java @@ -0,0 +1,351 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.tensql; + +import static nova.hetu.omniruntime.util.TestUtils.filterOperatorMatch; +import static nova.hetu.omniruntime.util.TestUtils.filterOperatorMatchWithJson; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonFieldReference; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonLiteral; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonAbsExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonAndExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonCastExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonEqualExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonFourArithmeticExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonGreaterThanExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonGreaterThanOrEqualExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonIfExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonInExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonIsNotNullExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonLessThanOrEqualExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonNotEqualExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonOrExpr; + +import com.google.common.collect.ImmutableList; + +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.DoubleDataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.type.VarcharDataType; + +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.List; + +/** + * sql8 for OmniFilter operator test. + * + * @since 2022-03-31 + */ +public class Sql8ForOmniFilterOperatorTest { + private static final String MUST_TEST_EXP1 = "AND:4(OR:4(AND:4(IN:4(#0," + + "'Home ':15," + + "'Books ':15," + + "'Electronics ':15) , " + + "IN:4(#1,'wallpaper ':15," + + "'parenting ':15," + + "'musical ':15)) , " + + "AND:4(IN:4(#0,'Shoes ':15," + + "'Jewelry ':15," + + "'Men ':15) , " + + "IN:4(#1,'womens ':15," + + "'birdal ':15," + + "'pants ':15))) , IS_NOT_NULL:4(#2))"; + private static final String MUST_TEST_EXP2 = "AND:4(AND:4(AND:4(AND:4(IS_NOT_NULL:4(#3) , " + + "$operator$EQUAL:4(#3 , 2000:1)) , $operator$GREATER_THAN_OR_EQUAL:4(#4 , 2451545:2)) , " + + "$operator$LESS_THAN_OR_EQUAL:4(#4 , 2451910:2)) , IS_NOT_NULL:4(#4))"; + private static final String MUST_TEST_EXP3 = "AND:4(AND:4(AND:4(AND:4(IS_NOT_NULL:4(#5) , " + + "$operator$GREATER_THAN_OR_EQUAL:4(#5 , 2451545:2)) , $operator$LESS_THAN_OR_EQUAL:4(#5 , 2451910:2)) , " + + "IS_NOT_NULL:4(#6)) , IS_NOT_NULL:4(#7))"; + private static final String MUST_TEST_EXP4 = "IS_NOT_NULL:4(#8)"; + private static final String MUST_TEST_EXP5 = "$operator$GREATER_THAN:4(IF:3(not:4($operator$EQUAL:4(#10 , 0.0:3)), " + + "$operator$DIVIDE:3(abs:3($operator$SUBTRACT:3(CAST:3(#9) , #10)) , #10), null:3) , 0.1:3)"; + + private static final String MUST_TEST_EXP1_JSON = omniJsonAndExpr( + omniJsonOrExpr( + omniJsonAndExpr( + omniJsonInExpr(15, 0, + Arrays.asList("Home ", + "Books ", + "Electronics ")), + omniJsonInExpr(15, 1, + Arrays.asList("wallpaper ", + "parenting ", + "musical "))), + omniJsonAndExpr( + omniJsonInExpr(15, 0, + Arrays.asList("Shoes ", + "Jewelry ", + "Men ")), + omniJsonInExpr(15, 1, + Arrays.asList("womens ", + "birdal ", + "pants ")))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 2))); + private static final String MUST_TEST_EXP2_JSON = omniJsonAndExpr( + omniJsonAndExpr( + omniJsonAndExpr( + omniJsonAndExpr(omniJsonIsNotNullExpr(getOmniJsonFieldReference(1, 3)), + omniJsonEqualExpr(getOmniJsonFieldReference(1, 3), + getOmniJsonLiteral(1, false, 2000))), + omniJsonGreaterThanOrEqualExpr(getOmniJsonFieldReference(2, 4), + getOmniJsonLiteral(2, false, 2451545))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 6))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 4))); + private static final String MUST_TEST_EXP3_JSON = omniJsonAndExpr( + omniJsonAndExpr( + omniJsonAndExpr( + omniJsonAndExpr(omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 5)), + omniJsonGreaterThanOrEqualExpr(getOmniJsonFieldReference(2, 5), + getOmniJsonLiteral(2, false, 2451545))), + omniJsonLessThanOrEqualExpr(getOmniJsonFieldReference(2, 5), + getOmniJsonLiteral(2, false, 2451910))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 6))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 7))); + private static final String MUST_TEST_EXP4_JSON = omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 8)); + private static final String MUST_TEST_EXP5_JSON = omniJsonGreaterThanExpr( + omniJsonIfExpr(omniJsonNotEqualExpr(getOmniJsonFieldReference(3, 10), getOmniJsonLiteral(3, false, 0.0)), 3, + omniJsonFourArithmeticExpr("DIVIDE", 3, omniJsonAbsExpr(3, omniJsonFourArithmeticExpr("SUBTRACT", 3, + omniJsonCastExpr(3, getOmniJsonFieldReference(2, 9)), getOmniJsonFieldReference(3, 10))), + getOmniJsonFieldReference(3, 10)), + getOmniJsonFieldReference(3, 10)), + getOmniJsonLiteral(3, false, 0.1)); + + Object[][] dataSourceValue = {{"Home ", + "Books ", "Books ", + "Shoes ", "Men ", + "Jewelry ", " ", ""}, // i_category #0 char(50) + {"wallpaper ", "parenting ", + "musical ", + "self-help ", + "womens ", + "birdal ", + "pants ", + "other "}, // i_class #1 char(50) + {1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L}, // i_item_sk #2 long + {1000, 2000, 2000, 2000, 2001, 2010, 2020, 2021}, // d_year #3 int + {2246L, 2451545L, 2451545L, 2451560L, 2451910L, 2451910L, 2452000L, 2452000L}, // d_date_sk #4 long + {2246L, 2451545L, 2451545L, 2451560L, 2451910L, 2451910L, 2452000L, 2452000L}, // ss_sold_date_sk #5 #long + {1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L}, // ss_item_sk #6 #long + {1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L}, // ss_store_sk #7 #long + {1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L}, // s_store_sk #8 long + {1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L}, // sum_sales #9 long + {-2.0, -1.0, 0.0, 1.5, 2.0, 3.0, 4.0, 5.0} // avg_monthly_sales #10 double + }; + + Object[][] dataSourceValueWithNull = {{"Home ", + "Books ", "Books ", + "Shoes ", "Men ", + "Jewelry ", null, ""}, // i_category #0 char(50) + {"wallpaper ", null, + "musical ", + "self-help ", + "womens ", + "birdal ", + "pants ", + "pants "}, // i_class #1 char(50) + {null, 2L, null, 4L, 5L, 6L, 7L, 8L}, // i_item_sk #2 long + {1000, null, 2000, 2000, 2001, 2010, 2020, 2021}, // d_year #3 int + {2246L, 2451545L, null, 2451560L, 2451910L, 2451910L, 2452000L, 2452000L}, // d_date_sk #4 long + {2246L, null, null, 2451560L, 2451910L, 2451910L, 2452000L, 2452000L}, // ss_sold_date_sk #5 #long + {1L, 2L, 3L, 4L, null, 6L, 7L, 8L}, // ss_item_sk #6 #long + {1L, 2L, 3L, 4L, 5L, null, 7L, 8L}, // ss_store_sk #7 #long + {null, 2L, 3L, 4L, null, 6L, 7L, null}, // s_store_sk #8 long + {1L, 2L, 3L, 4L, null, 6L, 7L, 8L}, // sum_sales #9 long + {-2.0, -1.0, 0.0, null, 2.0, 3.0, 4.0, 5.0} // avg_monthly_sales #10 double + }; + + DataType[] dataSourceTypes = {VarcharDataType.VARCHAR, VarcharDataType.VARCHAR, LongDataType.LONG, + IntDataType.INTEGER, LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, + LongDataType.LONG, LongDataType.LONG, DoubleDataType.DOUBLE}; + + List dataSourceProjects = ImmutableList.of("#0", "#1", "#2", "#3", "#4", "#5", "#6", "#7", "#8", "#9", + "#10"); + + List dataSourceProjectionJson = ImmutableList.of( + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":0,\"width\":50}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":1,\"width\":50}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":2}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":3}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":4}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":5}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":6}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":7}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":8}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":9}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":3,\"colVal\":10}"); + + @Test + public void testMustExpJson() { + int[] resultKeepRowIdxForEXP1 = {4, 5}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP1, MUST_TEST_EXP1_JSON, 1); + + int[] resultKeepRowIdxForEXP2WithNull = {3}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP2WithNull, MUST_TEST_EXP2_JSON, 1); + + int[] resultKeepRowIdxForEXP3 = {3}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP3, MUST_TEST_EXP3_JSON, 1); + + int[] resultKeepRowIdxForEXP4 = {1, 2, 3, 5, 6}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP4, MUST_TEST_EXP4_JSON, 1); + + int[] resultKeepRowIdxForEXP5 = {5, 6, 7}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP5, MUST_TEST_EXP5_JSON, 1); + } + + @Test + public void testMustExp() { + int[] resultKeepRowIdxForEXP1 = {4, 5}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP1, + MUST_TEST_EXP1); + + int[] resultKeepRowIdxForEXP2WithNull = {3}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, + resultKeepRowIdxForEXP2WithNull, MUST_TEST_EXP2); + + int[] resultKeepRowIdxForEXP3 = {3}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP3, + MUST_TEST_EXP3); + + int[] resultKeepRowIdxForEXP4 = {1, 2, 3, 5, 6}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP4, + MUST_TEST_EXP4); + + int[] resultKeepRowIdxForEXP5 = {5, 6, 7}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP5, + MUST_TEST_EXP5); + } + + @Test + public void testForItemTable() { + testForItemTableWithNotNull(); + } + + @Test + public void testForDateDimTable() { + testForDateDimTableWithNotNull(); + testForDateDimTableWithNull(); + } + + @Test + public void testForStoreSalesTable() { + testForStoreSalesTableWithNull(); + testForStoreSalesTableWithNotNull(); + } + + @Test + public void testForStoreTable() { + testForStoreTableWithNull(); + } + + private void testForItemTableWithNotNull() { + String expIn1 = "IN:4(#0,'Home ':15," + + "'Books ':15," + + "'Electronics ':15)"; + int[] resultKeepRowIdxForIn1 = {0, 1, 2}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForIn1, expIn1); + + String expIn2 = "IN:4(#1,'wallpaper ':15," + + "'parenting ':15," + + "'musical ':15)"; + int[] resultKeepRowIdxForIn2 = {0, 1, 2}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForIn2, expIn2); + + String expIn3 = "IN:4(#0,'Shoes ':15," + + "'Jewelry ':15," + + "'Men ':15)"; + int[] resultKeepRowIdxForIn3 = {3, 4, 5}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForIn3, expIn3); + + String expIn4 = "IN:4(#1,'womens ':15," + + "'birdal ':15," + + "'pants ':15)"; + int[] resultKeepRowIdxForIn4 = {4, 5, 6}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForIn4, expIn4); + + // (In1 and In2 ) Or (In3 and Int4) = {0, 1, 2, 4, 5} + String expMixedOr = "OR:4(AND:4(IN:4(#0,'Home ':15," + + "'Books ':15," + + "'Electronics ':15) , " + + "IN:4(#1,'wallpaper ':15," + + "'parenting ':15," + + "'musical ':15)) , " + + "AND:4(IN:4(#0,'Shoes ':15," + + "'Jewelry ':15," + + "'Men ':15) , " + + "IN:4(#1,'womens ':15," + + "'birdal ':15," + + "'pants ':15)))"; + int[] resultKeepRowIdxForMixedOr = {0, 1, 2, 4, 5}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForMixedOr, + expMixedOr); + } + + private void testForStoreSalesTableWithNotNull() { + String expGe = "$operator$GREATER_THAN_OR_EQUAL:4(#5 , 2451545:2)"; + int[] resultKeepRowIdxForGe = {1, 2, 3, 4, 5, 6, 7}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForGe, expGe); + + String expLe = "$operator$LESS_THAN_OR_EQUAL:4(#5 , 2451910:2)"; + int[] resultKeepRowIdxForLe = {0, 1, 2, 3, 4, 5}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForLe, expLe); + } + + private void testForStoreSalesTableWithNull() { + String expNotNull1 = "IS_NOT_NULL:4(#5)"; + int[] resultKeepRowIdxForNotNull1 = {0, 3, 4, 5, 6, 7}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull1, + expNotNull1); + + String expNotNull2 = "IS_NOT_NULL:4(#6)"; + int[] resultKeepRowIdxForNotNull2 = {0, 1, 2, 3, 5, 6, 7}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull2, + expNotNull2); + + String expNotNull3 = "IS_NOT_NULL:4(#7)"; + int[] resultKeepRowIdxForNotNull3 = {0, 1, 2, 3, 4, 6, 7}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull3, + expNotNull3); + } + + private void testForDateDimTableWithNotNull() { + String expEq = "$operator$EQUAL:4(#3, 2000:1)"; + int[] resultKeepRowIdxForEq = {1, 2, 3}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEq, expEq); + + String expGe = "$operator$GREATER_THAN_OR_EQUAL:4(#4 , 2451545:2)"; + int[] resultKeepRowIdxForGe = {1, 2, 3, 4, 5, 6, 7}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForGe, expGe); + + String expLe = "$operator$LESS_THAN_OR_EQUAL:4(#4 , 2451910:2)"; + int[] resultKeepRowIdxForLe = {0, 1, 2, 3, 4, 5}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForLe, expLe); + } + + private void testForDateDimTableWithNull() { + String expNotNull1 = "IS_NOT_NULL:4(#3)"; + int[] resultKeepRowIdxForNotNull = {0, 2, 3, 4, 5, 6, 7}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull, + expNotNull1); + + String expNotNull2 = "IS_NOT_NULL:4(#4)"; + int[] resultKeepRowIdxForNotNull2 = {0, 1, 3, 4, 5, 6, 7}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull2, + expNotNull2); + } + + private void testForStoreTableWithNull() { + String expNotNull = "IS_NOT_NULL:4(#8)"; + int[] resultKeepRowIdxForNotNull = {1, 2, 3, 5, 6}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull, + expNotNull); + } +} \ No newline at end of file diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql9ForOmniFilterOperatorTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql9ForOmniFilterOperatorTest.java new file mode 100644 index 0000000..dac7eaa --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/tensql/Sql9ForOmniFilterOperatorTest.java @@ -0,0 +1,303 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + */ + +package nova.hetu.omniruntime.tensql; + +import static nova.hetu.omniruntime.util.TestUtils.filterOperatorMatch; +import static nova.hetu.omniruntime.util.TestUtils.filterOperatorMatchWithJson; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonFieldReference; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonLiteral; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonAndExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonEqualExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonGreaterThanExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonGreaterThanOrEqualExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonInExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonIsNotNullExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonLessThanOrEqualExpr; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonOrExpr; + +import com.google.common.collect.ImmutableList; + +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; + +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.List; + +/** + * sql9 for OmniFilter operator test. + * + * @since 2022-03-31 + */ +public class Sql9ForOmniFilterOperatorTest { + private static final String MUST_TEST_EXP1 = "AND:4(AND:4(AND:4(AND:4(AND:4(IS_NOT_NULL:4(#0), " + + "$operator$GREATER_THAN_OR_EQUAL:4(#0 , 2450819:2)), $operator$LESS_THAN_OR_EQUAL:4(#0 , 2451904:2)), " + + "IS_NOT_NULL:4(#1)), IS_NOT_NULL:4(#2)), IS_NOT_NULL:4(#3))"; + private static final String MUST_TEST_EXP2 = "AND:4(AND:4(AND:4(AND:4(AND:4(IS_NOT_NULL:4(#4) , " + + "$operator$EQUAL:4(#4 , 1:1)) , IN:4(#5,1998:1,1999:1,2000:1)) , " + + "$operator$GREATER_THAN_OR_EQUAL:4(#6 , 2450819:2)) , " + + "$operator$LESS_THAN_OR_EQUAL:4(#6 , 2451904:2)) , IS_NOT_NULL:4(#6))"; + private static final String MUST_TEST_EXP3 = "AND:4(AND:4(AND:4(IS_NOT_NULL:4(#7) , " + + "$operator$GREATER_THAN_OR_EQUAL:4(#7 , 200:1)) , " + + "$operator$LESS_THAN_OR_EQUAL:4(#7 , 295:1)) , IS_NOT_NULL:4(#8))"; + private static final String MUST_TEST_EXP4 = "AND:4(OR:4($operator$EQUAL:4(#9 , 8:1) , " + + "$operator$GREATER_THAN:4(#10 , 0:1)) , IS_NOT_NULL:4(#11))"; + private static final String MUST_TEST_EXP5 = "IS_NOT_NULL:4(#12)"; + + private static final String MUST_TEST_EXP1_JSON = omniJsonAndExpr( + omniJsonAndExpr( + omniJsonAndExpr( + omniJsonAndExpr( + omniJsonAndExpr(omniJsonIsNotNullExpr(getOmniJsonFieldReference(1, 4)), + omniJsonGreaterThanOrEqualExpr(getOmniJsonFieldReference(2, 0), + getOmniJsonLiteral(2, false, 2450819))), + omniJsonLessThanOrEqualExpr(getOmniJsonFieldReference(2, 0), + getOmniJsonLiteral(2, false, 2451904))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 1))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 2))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 3))); + private static final String MUST_TEST_EXP2_JSON = omniJsonAndExpr( + omniJsonAndExpr( + omniJsonAndExpr( + omniJsonAndExpr( + omniJsonAndExpr(omniJsonIsNotNullExpr(getOmniJsonFieldReference(1, 4)), + omniJsonEqualExpr(getOmniJsonFieldReference(1, 4), + getOmniJsonLiteral(1, false, 1))), + omniJsonInExpr(1, 5, Arrays.asList(1998, 1999, 2000))), + omniJsonGreaterThanOrEqualExpr(getOmniJsonFieldReference(2, 6), + getOmniJsonLiteral(2, false, 2450819))), + omniJsonLessThanOrEqualExpr(getOmniJsonFieldReference(2, 6), + getOmniJsonLiteral(2, false, 2451904))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 6))); + private static final String MUST_TEST_EXP3_JSON = omniJsonAndExpr( + omniJsonAndExpr( + omniJsonAndExpr(omniJsonIsNotNullExpr(getOmniJsonFieldReference(1, 7)), + omniJsonGreaterThanOrEqualExpr(getOmniJsonFieldReference(1, 7), + getOmniJsonLiteral(1, false, 200))), + omniJsonLessThanOrEqualExpr(getOmniJsonFieldReference(1, 7), getOmniJsonLiteral(1, false, 295))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 8))); + private static final String MUST_TEST_EXP4_JSON = omniJsonAndExpr( + omniJsonOrExpr(omniJsonEqualExpr(getOmniJsonFieldReference(1, 9), getOmniJsonLiteral(1, false, 8)), + omniJsonGreaterThanExpr(getOmniJsonFieldReference(1, 10), getOmniJsonLiteral(1, false, 0))), + omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 11))); + private static final String MUST_TEST_EXP5_JSON = omniJsonIsNotNullExpr(getOmniJsonFieldReference(2, 12)); + + Object[][] dataSourceValue = {{1214L, 2450819L, 2450820L, 2451904L, 2451904L, 2455000L}, // #0 ss_sold_date_sk #long + {1L, 2L, 3L, 4L, 5L, 6L}, // #1 ss_store_sk #long + {1L, 2L, 3L, 4L, 5L, 6L}, // #2 ss_hdemo_sk #long + {1L, 2L, 3L, 4L, 5L, 6L}, // #3 ss_customer_sk #long + {-1, 0, 1, 1, 1, 4}, // #4 d_dow #int + {1000, 1998, 1999, 2000, 2000, 2001}, // #5 d_year #int + {1214L, 2450819L, 2450820L, 2451904L, 2451904L, 2455000L}, // #6 d_date_sk #long + {-10, 100, 200, 210, 295, 300}, // #7 s_number_employees #int + {1L, 2L, 3L, 4L, 5L, 6L}, // #8 s_store_sk #long + {-2, 4, 8, 8, 8, 20}, // #9 hd_dep_count #int + {-3, 0, 1, 2, 4, 10}, // #10 hd_vehicle_count #int + {1L, 2L, 3L, 4L, 5L, 6L}, // #11 hd_demo_sk #long + {1L, 2L, 3L, 4L, 5L, 6L}, // #12 c_customer_sk #long + }; + + Object[][] dataSourceValueWithNull = {{1214L, null, 2450820L, 2451904L, 2451904L, 2455000L}, + // #0 ss_sold_date_sk #long + {null, 2L, 3L, 4L, 5L, null}, // #1 ss_store_sk #long + {null, null, 3L, 4L, 5L, 6L}, // #2 ss_hdemo_sk #long + {1L, 2L, 3L, null, 5L, 6L}, // #3 ss_customer_sk #long + {-1, null, 1, 1, 1, 4}, // #4 d_dow #int + {1000, 1998, 1999, 2000, 2000, 2001}, // #5 d_year #int + {1214L, null, 2450820L, 2451904L, 2451904L, null}, // #6 d_date_sk #long + {-10, 100, null, 210, 295, 300}, // #7 s_number_employees #int + {1L, 2L, 3L, null, 5L, 6L}, // #8 s_store_sk #long + {-2, null, 8, 8, 8, 20}, // #9 hd_dep_count #int + {-3, 0, 1, 2, 4, 10}, // #10 hd_vehicle_count #int + {1L, 2L, null, 4L, 5L, null}, // #11 hd_demo_sk #long + {null, null, 3L, 4L, 5L, null}, // #12 c_customer_sk #long + }; + + DataType[] dataSourceTypes = {LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, LongDataType.LONG, + IntDataType.INTEGER, IntDataType.INTEGER, LongDataType.LONG, IntDataType.INTEGER, LongDataType.LONG, + IntDataType.INTEGER, IntDataType.INTEGER, LongDataType.LONG, LongDataType.LONG}; + + List dataSourceProjects = ImmutableList.of("#0", "#1", "#2", "#3", "#4", "#5", "#6", "#7", "#8", "#9", + "#10", "#11", "#12"); + + List dataSourceProjectionJson = ImmutableList.of( + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":0}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":1}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":2}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":3}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":4}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":5}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":6}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":7}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":8}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":9}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":10}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":11}", + "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":2,\"colVal\":12}"); + + @Test + public void testMustExpJson() { + int[] resultKeepRowIdxForEXP1 = {2, 4}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP1, MUST_TEST_EXP1_JSON, 1); + + int[] resultKeepRowIdxForEXP2WithNull = {2, 3, 4}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP2WithNull, MUST_TEST_EXP2_JSON, 1); + + int[] resultKeepRowIdxForEXP3 = {4}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP3, MUST_TEST_EXP3_JSON, 1); + + int[] resultKeepRowIdxForEXP4 = {3, 4}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP4, MUST_TEST_EXP4_JSON, 1); + + int[] resultKeepRowIdxForEXP5 = {2, 3, 4}; + filterOperatorMatchWithJson(dataSourceValueWithNull, dataSourceTypes, dataSourceProjectionJson, + resultKeepRowIdxForEXP5, MUST_TEST_EXP5_JSON, 1); + } + + @Test + public void testMustExp() { + int[] resultKeepRowIdxForEXP1 = {2, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP1, + MUST_TEST_EXP1); + + int[] resultKeepRowIdxForEXP2WithNull = {2, 3, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, + resultKeepRowIdxForEXP2WithNull, MUST_TEST_EXP2); + + int[] resultKeepRowIdxForEXP3 = {4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP3, + MUST_TEST_EXP3); + + int[] resultKeepRowIdxForEXP4 = {3, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP4, + MUST_TEST_EXP4); + + int[] resultKeepRowIdxForEXP5 = {2, 3, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEXP5, + MUST_TEST_EXP5); + } + + @Test + public void testForStoreSalesTable() { + testForStoreSalesTableWithNull(); + testForStoreSalesTableWithNotNull(); + } + + @Test + public void testForDateDimTable() { + testForDateDimTableWithNotNull(); + testForDateDimTableWithNull(); + } + + @Test + public void testForStoreTable() { + testForStoreTableWithNull(); + testForStoreTableWithNotNull(); + } + + @Test + public void testForHouseholdDemographicsTable() { + testForHouseholdDemographicsTableWithNotNull(); + testForHouseholdDemographicsTableWithNull(); + } + + private void testForHouseholdDemographicsTableWithNull() { + String expNotNull1 = "IS_NOT_NULL:4(#11)"; + int[] resultKeepRowIdxForNotNull1 = {0, 1, 3, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull1, + expNotNull1); + } + + private void testForHouseholdDemographicsTableWithNotNull() { + String expEq = "$operator$EQUAL:4(#9 , 8:1)"; + int[] resultKeepRowIdxForEq = {2, 3, 4}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEq, expEq); + + String expGt = "$operator$GREATER_THAN:4(#10, 0:1)"; + int[] resultKeepRowIdxForGt = {2, 3, 4, 5}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForGt, expGt); + } + + private void testForStoreTableWithNull() { + String expNotNull1 = "IS_NOT_NULL:4(#7)"; + int[] resultKeepRowIdxForNotNull1 = {0, 1, 3, 4, 5}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull1, + expNotNull1); + + String expNotNull2 = "IS_NOT_NULL:4(#8)"; + int[] resultKeepRowIdxForNotNull2 = {0, 1, 2, 4, 5}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull2, + expNotNull2); + } + + private void testForStoreTableWithNotNull() { + String expGe = "$operator$GREATER_THAN_OR_EQUAL:4(#7 , 200:1)"; + int[] resultKeepRowIdxForGe = {2, 3, 4, 5}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForGe, expGe); + + String expLe = "$operator$LESS_THAN_OR_EQUAL:4(#7 , 295:1)"; + int[] resultKeepRowIdxForLe = {0, 1, 2, 3, 4}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForLe, expLe); + } + + private void testForStoreSalesTableWithNull() { + String expNotNull1 = "IS_NOT_NULL:4(#0)"; + int[] resultKeepRowIdxForNotNull1 = {0, 2, 3, 4, 5}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull1, + expNotNull1); + + String expNotNull2 = "IS_NOT_NULL:4(#1)"; + int[] resultKeepRowIdxForNotNull2 = {1, 2, 3, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull2, + expNotNull2); + + String expNotNull3 = "IS_NOT_NULL:4(#2)"; + int[] resultKeepRowIdxForNotNull3 = {2, 3, 4, 5}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull3, + expNotNull3); + + String expNotNull4 = "IS_NOT_NULL:4(#3)"; + int[] resultKeepRowIdxForNotNull4 = {0, 1, 2, 4, 5}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull4, + expNotNull4); + } + + private void testForStoreSalesTableWithNotNull() { + String expGe = "$operator$GREATER_THAN_OR_EQUAL:4(#0 , 2450819:2)"; + int[] resultKeepRowIdxForGe = {1, 2, 3, 4, 5}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForGe, expGe); + + String expLe = "$operator$LESS_THAN_OR_EQUAL:4(#0 , 2451904:2)"; + int[] resultKeepRowIdxForLe = {0, 1, 2, 3, 4}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForLe, expLe); + } + + private void testForDateDimTableWithNotNull() { + String expEq = "$operator$EQUAL:4(#4 , 1:1)"; + int[] resultKeepRowIdxForEq = {2, 3, 4}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForEq, expEq); + + String expGe = "$operator$GREATER_THAN_OR_EQUAL:4(#6 , 2450819:2)"; + int[] resultKeepRowIdxForGe = {1, 2, 3, 4, 5}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForGe, expGe); + + String expLe = "$operator$LESS_THAN_OR_EQUAL:4(#6 , 2451904:2)"; + int[] resultKeepRowIdxForLe = {0, 1, 2, 3, 4}; + filterOperatorMatch(dataSourceValue, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForLe, expLe); + } + + private void testForDateDimTableWithNull() { + String expNotNull = "IS_NOT_NULL:4(#6)"; + int[] resultKeepRowIdxForNotNull = {0, 2, 3, 4}; + filterOperatorMatch(dataSourceValueWithNull, dataSourceTypes, dataSourceProjects, resultKeepRowIdxForNotNull, + expNotNull); + } +} \ No newline at end of file diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/type/BenchmarkDataTypeSerializer.java b/bindings/java/src/test/java/nova/hetu/omniruntime/type/BenchmarkDataTypeSerializer.java new file mode 100644 index 0000000..0d28c59 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/type/BenchmarkDataTypeSerializer.java @@ -0,0 +1,67 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.type; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; + +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import org.openjdk.jmh.runner.options.VerboseMode; + +import java.util.ArrayList; +import java.util.List; + +/** + * DataType serialize benchmark + * + * @since 2022-5-18 + */ +@State(Scope.Thread) +@OutputTimeUnit(MILLISECONDS) +@Fork(1) +@Warmup(iterations = 5, batchSize = 1) +@Measurement(iterations = 20, batchSize = 1) +@BenchmarkMode(Mode.AverageTime) +public class BenchmarkDataTypeSerializer { + @Param("1000") + int times = 1000; + + /** + * create benchmark for datatype serialize + * + * @return Deserialized dataType list + */ + @Benchmark + public List dataTypeSerializeBenchmark() { + DataType[] dataTypes = new DataType[]{BooleanDataType.BOOLEAN, ShortDataType.SHORT, IntDataType.INTEGER, + LongDataType.LONG, DoubleDataType.DOUBLE, Decimal64DataType.DECIMAL64, Decimal128DataType.DECIMAL128, + Date32DataType.DATE32, Date64DataType.DATE64, VarcharDataType.VARCHAR, CharDataType.CHAR, + TimestampDataType.TIMESTAMP}; + List list = new ArrayList<>(); + for (int i = 0; i < times; i++) { + String allSerializedTypes = DataTypeSerializer.serialize(dataTypes); + DataType[] allDeserializedTypes = DataTypeSerializer.deserialize(allSerializedTypes); + list.add(allDeserializedTypes); + } + return list; + } + + public static void main(String[] args) throws Throwable { + Options options = new OptionsBuilder().verbosity(VerboseMode.NORMAL) + .include(".*" + BenchmarkDataTypeSerializer.class.getSimpleName() + ".*").build(); + new Runner(options).run(); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/type/TestDataType.java b/bindings/java/src/test/java/nova/hetu/omniruntime/type/TestDataType.java new file mode 100644 index 0000000..2e51fa4 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/type/TestDataType.java @@ -0,0 +1,65 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.type; + +import static org.testng.Assert.assertEquals; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; + +import nova.hetu.omniruntime.type.BooleanDataType; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.Decimal128DataType; +import nova.hetu.omniruntime.type.DoubleDataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.utils.OmniErrorType; +import nova.hetu.omniruntime.utils.OmniRuntimeException; + +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.List; + +/** + * test vec type + * + * @since 2021-7-2 + */ +public class TestDataType { + /** + * test vec type + */ + @Test + public void testDataType() { + DataType type = getDataTypeFromBase("BIGINT"); + assertEquals(type, LongDataType.LONG); + } + + private DataType getDataTypeFromBase(String base) { + switch (base) { + case "INT": + case "DATE": + return IntDataType.INTEGER; + case "BIGINT": + return LongDataType.LONG; + case "DOUBLE": + return DoubleDataType.DOUBLE; + case "BOOLEAN": + return BooleanDataType.BOOLEAN; + default: + throw new OmniRuntimeException(OmniErrorType.OMNI_UNDEFINED, "Not support Type " + base); + } + } + + @Test + public void testSerialization() throws JsonProcessingException { + ObjectMapper map = new ObjectMapper(); + List types = new ArrayList<>(); + types.add(LongDataType.LONG); + types.add(new Decimal128DataType(1, 2)); + assertEquals(map.writeValueAsString(types), "[{\"id\":2},{\"precision\":1,\"scale\":2,\"id\":7}]"); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/type/TestDataTypeSerializer.java b/bindings/java/src/test/java/nova/hetu/omniruntime/type/TestDataTypeSerializer.java new file mode 100644 index 0000000..0928da4 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/type/TestDataTypeSerializer.java @@ -0,0 +1,35 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.type; + +import static org.testng.AssertJUnit.assertEquals; + +import org.testng.annotations.Test; + +/** + * Data type serializer test + * + * @since 2022-2-17 + */ +public class TestDataTypeSerializer { + @Test + public void testAllTypes() { + DataType[] types = new DataType[]{BooleanDataType.BOOLEAN, ShortDataType.SHORT, IntDataType.INTEGER, + LongDataType.LONG, DoubleDataType.DOUBLE, Decimal64DataType.DECIMAL64, Decimal128DataType.DECIMAL128, + Date32DataType.DATE32, Date64DataType.DATE64, VarcharDataType.VARCHAR, CharDataType.CHAR, + new ContainerDataType(new DataType[]{IntDataType.INTEGER, CharDataType.CHAR}), InvalidDataType.INVALID, + NoneDataType.NONE, TimestampDataType.TIMESTAMP}; + String[] serializeds = new String[types.length]; + for (int i = 0; i < types.length; i++) { + serializeds[i] = DataTypeSerializer.serializeSingle(types[i]); + } + String serializedAll = DataTypeSerializer.serialize(types); + DataType[] allDeserilizedDataTypes = DataTypeSerializer.deserialize(serializedAll); + for (int i = 0; i < types.length; i++) { + assertEquals(DataTypeSerializer.deserializeSingle(serializeds[i]), types[i]); + assertEquals(allDeserilizedDataTypes[i], types[i]); + } + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/util/TestJsonUtils.java b/bindings/java/src/test/java/nova/hetu/omniruntime/util/TestJsonUtils.java new file mode 100644 index 0000000..20972ba --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/util/TestJsonUtils.java @@ -0,0 +1,66 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.util; + +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonFieldReference; +import static nova.hetu.omniruntime.util.TestUtils.getOmniJsonLiteral; +import static nova.hetu.omniruntime.util.TestUtils.omniJsonFourArithmeticExpr; +import static org.testng.Assert.assertTrue; + +import nova.hetu.omniruntime.utils.JsonUtils; + +import org.testng.annotations.Test; + +/** + * Test json serialization/deserialization + * + * @since 2022-9-22 + */ +public class TestJsonUtils { + private static boolean compareStringArray(String[] arr1, String[] arr2) { + if (arr1.length != arr2.length) { + return false; + } + for (int i = 0; i < arr1.length; i++) { + if (!arr1[i].equals(arr2[i])) { + return false; + } + } + return true; + } + + @Test + public void testJsonStringArray() { + String[] src = { + omniJsonFourArithmeticExpr("MODULUS", 2, getOmniJsonFieldReference(2, 0), + getOmniJsonLiteral(2, false, 3)), + omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonFieldReference(1, 2), getOmniJsonLiteral(1, false, 5))}; + + String json = JsonUtils.jsonStringArray(src); + String[] deserializeJsons = JsonUtils.deserializeJson(json); + assertTrue(compareStringArray(src, deserializeJsons)); + } + + @Test + public void testJsonMultiDimStringArray() { + String[][] strings = new String[3][]; + String[] arr1 = { + omniJsonFourArithmeticExpr("MODULUS", 2, getOmniJsonFieldReference(2, 0), + getOmniJsonLiteral(2, false, 3)), + omniJsonFourArithmeticExpr("ADD", 1, getOmniJsonFieldReference(1, 2), getOmniJsonLiteral(1, false, 5))}; + String[] arr2 = {getOmniJsonFieldReference(2, 3)}; + String[] arr3 = {omniJsonFourArithmeticExpr("MULTIPLY", 2, getOmniJsonFieldReference(2, 1), + getOmniJsonLiteral(2, false, 5)), getOmniJsonFieldReference(1, 3)}; + strings[0] = arr1; + strings[1] = arr2; + strings[2] = arr3; + + String[] json = JsonUtils.jsonStringArray(strings); + String[][] deserializeJsons = JsonUtils.deserializeJson(json); + for (int i = 0; i < 3; i++) { + assertTrue(compareStringArray(strings[i], deserializeJsons[i])); + } + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/util/TestUtils.java b/bindings/java/src/test/java/nova/hetu/omniruntime/util/TestUtils.java new file mode 100644 index 0000000..1bfd309 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/util/TestUtils.java @@ -0,0 +1,1593 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.util; + +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DATE32; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DECIMAL128; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DOUBLE; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_INT; +import static nova.hetu.omniruntime.vector.VecEncoding.OMNI_VEC_ENCODING_DICTIONARY; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +import com.alibaba.fastjson.JSONArray; +import com.alibaba.fastjson.JSONObject; + +import nova.hetu.omniruntime.operator.OmniOperator; +import nova.hetu.omniruntime.operator.filter.OmniFilterAndProjectOperatorFactory; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.VarcharDataType; +import nova.hetu.omniruntime.vector.BooleanVec; +import nova.hetu.omniruntime.vector.Decimal128Vec; +import nova.hetu.omniruntime.vector.DictionaryVec; +import nova.hetu.omniruntime.vector.DoubleVec; +import nova.hetu.omniruntime.vector.IntVec; +import nova.hetu.omniruntime.vector.LongVec; +import nova.hetu.omniruntime.vector.ShortVec; +import nova.hetu.omniruntime.vector.VarcharVec; +import nova.hetu.omniruntime.vector.Vec; +import nova.hetu.omniruntime.vector.VecBatch; +import nova.hetu.omniruntime.vector.VecEncoding; + +import org.testng.AssertJUnit; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Locale; + +/** + * Test utils for data generate + * + * @since 2021-8-10 + */ +public class TestUtils { + /** + * Create data for blank vec batch + * + * @param types dataType + * @return VecBatch + */ + public static VecBatch createBlankVecBatch(DataType[] types) { + Object[] data = {}; + Vec[] vecs = new Vec[types.length]; + for (int i = 0; i < types.length; i++) { + vecs[i] = createVec(types[i], data); + } + return new VecBatch(vecs, 0); + } + + /** + * Create vec batch data + * + * @param types dataType + * @param datas data + * @return VecBatch + */ + public static VecBatch createVecBatch(DataType[] types, Object[][] datas) { + Vec[] vecs = new Vec[types.length]; + for (int i = 0; i < types.length; i++) { + vecs[i] = createVec(types[i], datas[i]); + } + return new VecBatch(vecs); + } + + /** + * Create vec batch data + * + * @param types dataType + * @param list dataList + * @return VecBatch + */ + public static VecBatch createVecBatch(DataType[] types, List> list) { + Vec[] vecs = new Vec[types.length]; + for (int i = 0; i < types.length; i++) { + vecs[i] = createVec(types[i], list.get(i).toArray()); + } + return new VecBatch(vecs); + } + + /** + * Create vec + * + * @param type dataType + * @param data data + * @return Vec + */ + public static Vec createVec(DataType type, Object[] data) { + switch (type.getId()) { + case OMNI_INT: + case OMNI_DATE32: + return createIntVec(data); + case OMNI_SHORT: + return createShortVec(data); + case OMNI_LONG: + case OMNI_TIMESTAMP: + case OMNI_DECIMAL64: + return createLongVec(data); + case OMNI_DOUBLE: + return createDoubleVec(data); + case OMNI_BOOLEAN: + return createBooleanVec(data); + case OMNI_VARCHAR: + case OMNI_CHAR: + return createVarcharVec((VarcharDataType) type, data); + default: + throw new UnsupportedOperationException("Unsupported type : " + type.getId()); + } + } + + /** + * Create vec for decimal128 + * + * @param type dataType + * @param data data + * @return Vec + */ + public static Vec createVec(DataType type, Object[][] data) { + switch (type.getId()) { + case OMNI_DECIMAL128: + return createDecimal128Vec(data); + default: + throw new UnsupportedOperationException("Unsupported type : " + type.getId()); + } + } + + /** + * Create short vec + * + * @param data data + * @return ShortVec + */ + public static ShortVec createShortVec(Object[] data) { + ShortVec result = new ShortVec(data.length); + for (int j = 0; j < data.length; j++) { + if (data[j] == null) { + result.setNull(j); + } else { + result.set(j, (short) data[j]); + } + } + return result; + } + + /** + * Create int vec + * + * @param data data + * @return IntVec + */ + public static IntVec createIntVec(Object[] data) { + IntVec result = new IntVec(data.length); + for (int j = 0; j < data.length; j++) { + if (data[j] == null) { + result.setNull(j); + } else { + result.set(j, (int) data[j]); + } + } + return result; + } + + /** + * Create long vec + * + * @param data data + * @return LongVec + */ + public static LongVec createLongVec(Object[] data) { + LongVec result = new LongVec(data.length); + for (int j = 0; j < data.length; j++) { + if (data[j] == null) { + result.setNull(j); + } else { + result.set(j, (long) data[j]); + } + } + return result; + } + + /** + * Create Double vec + * + * @param data data + * @return DoubleVec + */ + public static DoubleVec createDoubleVec(Object[] data) { + DoubleVec result = new DoubleVec(data.length); + for (int j = 0; j < data.length; j++) { + if (data[j] == null) { + result.setNull(j); + } else { + result.set(j, (double) data[j]); + } + } + return result; + } + + /** + * Create Boolean vec + * + * @param data data + * @return BooleanVec + */ + public static BooleanVec createBooleanVec(Object[] data) { + BooleanVec result = new BooleanVec(data.length); + for (int i = 0; i < data.length; i++) { + if (data[i] == null) { + result.setNull(i); + } else { + result.set(i, (boolean) data[i]); + } + } + return result; + } + + /** + * Create Varchar vec + * + * @param varcharVecType varchar vec type + * @param data data + * @return VarcharVec + */ + public static VarcharVec createVarcharVec(VarcharDataType varcharVecType, Object[] data) { + VarcharVec result = new VarcharVec(data.length); + for (int j = 0; j < data.length; j++) { + if (data[j] == null) { + result.setNull(j); + } else { + result.set(j, ((String) data[j]).getBytes(StandardCharsets.UTF_8)); + } + } + return result; + } + + /** + * Create Decimal128 vec + * + * @param data data + * @return Decimal128Vec + */ + public static Decimal128Vec createDecimal128Vec(Object[][] data) { + Decimal128Vec result = new Decimal128Vec(data.length); + for (int i = 0; i < data.length; i++) { + if (data[i] == null) { + result.setNull(i); + } else { + result.set(i, new long[]{(long) data[i][0], (long) data[i][1]}); + } + } + return result; + } + + /** + * Create Dictionary vec + * + * @param dataType dataType + * @param data input data + * @param ids id array + * @return DictionaryVec + */ + public static DictionaryVec createDictionaryVec(DataType dataType, Object[] data, int[] ids) { + Vec dictionary = createVec(dataType, data); + DictionaryVec dictionaryVec = new DictionaryVec(dictionary, ids); + dictionary.close(); + return dictionaryVec; + } + + /** + * Vec Batch equals + * + * @param vecBatch vecBatch + * @param expectedDatas data + */ + public static void assertVecBatchEqualsIgnoreOrder(VecBatch vecBatch, Object[][] expectedDatas) { + int vectorCount = vecBatch.getVectorCount(); + assertEquals(vectorCount, expectedDatas.length); + + Vec[] vecs = vecBatch.getVectors(); + for (int i = 0; i < vectorCount; i++) { + Vec vec = vecs[i]; + assertEquals(vec.getSize(), expectedDatas[i].length); + VecEncoding vecEncoding = vec.getEncoding(); + if (vecEncoding.equals(OMNI_VEC_ENCODING_DICTIONARY)) { + assertTrue(assertDictionaryVecEqualsIgnoreOrder((DictionaryVec) vec, expectedDatas[i])); + continue; + } + assertTrue(assertVecEqualsIgnoreOrder(vec, expectedDatas[i])); + } + } + + /** + * Vec Batch equals + * + * @param vecBatch vecBatch + * @param expectedDatas data + */ + public static void assertVecBatchEquals(VecBatch vecBatch, Object[][] expectedDatas) { + int vectorCount = vecBatch.getVectorCount(); + assertEquals(vectorCount, expectedDatas.length); + + Vec[] vecs = vecBatch.getVectors(); + for (int i = 0; i < vectorCount; i++) { + Vec vec = vecs[i]; + assertEquals(vec.getSize(), expectedDatas[i].length); + VecEncoding vecEncoding = vec.getEncoding(); + if (vecEncoding.equals(OMNI_VEC_ENCODING_DICTIONARY)) { + assertDictionaryVecEquals((DictionaryVec) vec, expectedDatas[i]); + continue; + } + assertVecEquals(vec, expectedDatas[i]); + } + } + + /** + * Vec Batch equals + * + * @param vecBatch vecBatch + * @param expectedVecBatch expectedVecBatch + */ + public static void assertVecBatchEqualsIgnoreOrder(VecBatch vecBatch, VecBatch expectedVecBatch) { + int vectorCount = vecBatch.getVectorCount(); + assertEquals(vectorCount, expectedVecBatch.getVectorCount()); + Vec[] vecs = vecBatch.getVectors(); + Vec[] expectedVecs = expectedVecBatch.getVectors(); + for (int i = 0; i < vectorCount; i++) { + Vec vec = vecs[i]; + Vec expectedVec = expectedVecs[i]; + assertEquals(vec.getSize(), expectedVec.getSize()); + VecEncoding vecEncoding = vec.getEncoding(); + if (vecEncoding.equals(OMNI_VEC_ENCODING_DICTIONARY)) { + assertTrue(assertDictionaryVecEqualsIgnoreOrder((DictionaryVec) vec, (DictionaryVec) expectedVec)); + continue; + } + assertTrue(assertVecEqualsIgnoreOrder(vec, expectedVec)); + } + } + + /** + * Vec Batch equals + * + * @param vecBatch vecBatch + * @param expectedVecBatch expectedVecBatch + */ + public static void assertVecBatchEquals(VecBatch vecBatch, VecBatch expectedVecBatch) { + int vectorCount = vecBatch.getVectorCount(); + assertEquals(vectorCount, expectedVecBatch.getVectorCount()); + Vec[] vecs = vecBatch.getVectors(); + Vec[] expectedVecs = expectedVecBatch.getVectors(); + for (int i = 0; i < vectorCount; i++) { + Vec vec = vecs[i]; + Vec expectedVec = expectedVecs[i]; + assertEquals(vec.getSize(), expectedVec.getSize()); + VecEncoding vecEncoding = vec.getEncoding(); + if (vecEncoding.equals(OMNI_VEC_ENCODING_DICTIONARY)) { + assertDictionaryVecEquals((DictionaryVec) vec, (DictionaryVec) expectedVec); + continue; + } + assertVecEquals(vec, expectedVec); + } + } + + /** + * Vec equals + * + * @param vec vec + * @param expectedData data + * @return boolean + */ + public static boolean assertDecimal128VecEqualsIgnoreOrder(Vec vec, Long[][] expectedData) { + int rowCount = vec.getSize(); + int resNUllCount = 0; + int expNullCount = 0; + for (int i = 0; i < rowCount; i++) { + if (vec.isNull(i)) { + resNUllCount++; + } + if (expectedData[i] == null) { + expNullCount++; + } + } + if (resNUllCount != expNullCount) { + return false; + } + long[][] resArr = new long[rowCount - resNUllCount][2]; + long[][] expArr = new long[rowCount - resNUllCount][2]; + for (int i = 0, j = 0, k = 0; i < rowCount; i++) { + if (!vec.isNull(i)) { + resArr[j++] = ((Decimal128Vec) vec).get(i); + } + if (expectedData[i] != null) { + expArr[k][0] = (long) expectedData[i][0]; + expArr[k][1] = (long) expectedData[i][1]; + k++; + } + } + Arrays.sort(resArr, new Comparator() { + @Override + public int compare(long[] arr1, long[] arr2) { + if (arr1[0] == arr2[0]) { + return Long.compare(arr1[1], arr2[1]); + } + return Long.compare(arr1[0], arr2[0]); + } + }); + Arrays.sort(expArr, new Comparator() { + @Override + public int compare(long[] arr1, long[] arr2) { + if (arr1[0] == arr2[0]) { + return Long.compare(arr1[1], arr2[1]); + } + return Long.compare(arr1[0], arr2[0]); + } + }); + for (int i = 0; i < resArr.length; i++) { + if (resArr[i][0] != expArr[i][0] || resArr[i][1] != expArr[i][1]) { + return false; + } + } + return true; + } + + /** + * Vec equals + * + * @param vec vec + * @param expectedData data + * @return boolean + */ + public static boolean assertVecEqualsIgnoreOrder(Vec vec, Object[] expectedData) { + int rowCount = vec.getSize(); + int resNUllCount = 0; + int expNullCount = 0; + List resList = new ArrayList<>(); + List expList = new ArrayList<>(); + for (int i = 0; i < rowCount; i++) { + if (vec.isNull(i)) { + resNUllCount++; + } + if (expectedData[i] == null) { + expNullCount++; + } else { + expList.add(expectedData[i]); + } + } + + if (resNUllCount != expNullCount) { + return false; + } + switch (vec.getType().getId()) { + case OMNI_INT: + case OMNI_DATE32: + for (int i = 0; i < rowCount; i++) { + if (!vec.isNull(i)) { + resList.add(((IntVec) vec).get(i)); + } + } + break; + case OMNI_SHORT: + for (int i = 0; i < rowCount; i++) { + if (!vec.isNull(i)) { + resList.add(((ShortVec) vec).get(i)); + } + } + break; + case OMNI_LONG: + case OMNI_TIMESTAMP: + case OMNI_DECIMAL64: + for (int i = 0; i < rowCount; i++) { + if (!vec.isNull(i)) { + resList.add(((LongVec) vec).get(i)); + } + } + break; + case OMNI_DOUBLE: + for (int i = 0; i < rowCount; i++) { + if (!vec.isNull(i)) { + resList.add(((DoubleVec) vec).get(i)); + } + } + break; + case OMNI_BOOLEAN: + for (int i = 0; i < rowCount; i++) { + if (!vec.isNull(i)) { + resList.add(((BooleanVec) vec).get(i)); + } + } + break; + case OMNI_VARCHAR: + case OMNI_CHAR: + for (int i = 0; i < rowCount; i++) { + if (!vec.isNull(i)) { + resList.add(new String(((VarcharVec) vec).get(i))); + } + } + break; + default: + throw new UnsupportedOperationException("Unsupported type : " + vec.getType().getId()); + } + Object[] resArr = resList.toArray(); + Object[] expArr = expList.toArray(); + Arrays.sort(resArr); + Arrays.sort(expArr); + if (vec.getType().getId() == OMNI_DOUBLE) { + for (int i = 0; i < resArr.length; i++) { + if (Double.compare((double) resArr[i], (double) expArr[i]) != 0) { + return false; + } + } + } else { + for (int i = 0; i < resArr.length; i++) { + if (!resArr[i].equals(expArr[i])) { + return false; + } + } + } + return true; + } + + /** + * Vec equals + * + * @param vec vec + * @param expectedData data + */ + public static void assertVecEquals(Vec vec, Object[] expectedData) { + for (int i = 0; i < vec.getSize(); i++) { + if (vec.isNull(i)) { + assertEquals(null, expectedData[i]); + continue; + } + switch (vec.getType().getId()) { + case OMNI_INT: + case OMNI_DATE32: + assertEquals(((IntVec) vec).get(i), expectedData[i]); + break; + case OMNI_SHORT: + assertEquals(((ShortVec) vec).get(i), expectedData[i]); + break; + case OMNI_LONG: + case OMNI_TIMESTAMP: + case OMNI_DECIMAL64: + assertEquals(((LongVec) vec).get(i), expectedData[i]); + break; + case OMNI_DOUBLE: + assertTrue(Double.compare(((DoubleVec) vec).get(i), (Double) expectedData[i]) == 0); + break; + case OMNI_BOOLEAN: + assertEquals(((BooleanVec) vec).get(i), expectedData[i]); + break; + case OMNI_VARCHAR: + case OMNI_CHAR: + assertEquals(new String(((VarcharVec) vec).get(i)), expectedData[i]); + break; + default: + throw new UnsupportedOperationException("Unsupported type : " + vec.getType().getId()); + } + } + } + + /** + * Vec equals + * + * @param vec vec + * @param expectedVec expectedVec + * @return boolean + */ + public static boolean assertDecimal128VecEqualsIgnoreOrder(Vec vec, Vec expectedVec) { + int rowCount = vec.getSize(); + int resNUllCount = 0; + int expNullCount = 0; + for (int i = 0; i < rowCount; i++) { + if (vec.isNull(i)) { + resNUllCount++; + } + if (expectedVec.isNull(i)) { + expNullCount++; + } + } + if (resNUllCount != expNullCount) { + return false; + } + long[][] resArr = new long[rowCount - resNUllCount][2]; + long[][] expArr = new long[rowCount - resNUllCount][2]; + for (int i = 0, j = 0, k = 0; i < rowCount; i++) { + if (!vec.isNull(i)) { + resArr[j++] = ((Decimal128Vec) vec).get(i); + } + if (expectedVec.isNull(i)) { + expArr[k++] = ((Decimal128Vec) expectedVec).get(i); + } + } + Arrays.sort(resArr, new Comparator() { + @Override + public int compare(long[] arr1, long[] arr2) { + if (arr1[0] == arr2[0]) { + return Long.compare(arr1[1], arr2[1]); + } + return Long.compare(arr1[0], arr2[0]); + } + }); + Arrays.sort(expArr, new Comparator() { + @Override + public int compare(long[] arr1, long[] arr2) { + if (arr1[0] == arr2[0]) { + return Long.compare(arr1[1], arr2[1]); + } + return Long.compare(arr1[0], arr2[0]); + } + }); + for (int i = 0; i < resArr.length; i++) { + if (resArr[i][0] != expArr[i][0] || resArr[i][1] != expArr[i][1]) { + return false; + } + } + return true; + } + + /** + * Vec equals + * + * @param vec vec + * @param expectedVec expectedVec + * @return boolean + */ + public static boolean assertVecEqualsIgnoreOrder(Vec vec, Vec expectedVec) { + int rowCount = vec.getSize(); + int resNUllCount = 0; + int expNullCount = 0; + List resList = new ArrayList<>(); + List expList = new ArrayList<>(); + for (int i = 0; i < rowCount; i++) { + if (vec.isNull(i)) { + resNUllCount++; + } + if (expectedVec.isNull(i)) { + expNullCount++; + } + } + + if (resNUllCount != expNullCount) { + return false; + } + switch (vec.getType().getId()) { + case OMNI_INT: + case OMNI_DATE32: + for (int i = 0; i < rowCount; i++) { + if (!vec.isNull(i)) { + resList.add(((IntVec) vec).get(i)); + } + if (!expectedVec.isNull(i)) { + expList.add(((IntVec) expectedVec).get(i)); + } + } + break; + case OMNI_SHORT: + for (int i = 0; i < rowCount; i++) { + if (!vec.isNull(i)) { + resList.add(((ShortVec) vec).get(i)); + } + if (!expectedVec.isNull(i)) { + expList.add(((ShortVec) expectedVec).get(i)); + } + } + break; + case OMNI_LONG: + case OMNI_TIMESTAMP: + case OMNI_DECIMAL64: + for (int i = 0; i < rowCount; i++) { + if (!vec.isNull(i)) { + resList.add(((LongVec) vec).get(i)); + } + if (!expectedVec.isNull(i)) { + expList.add(((LongVec) expectedVec).get(i)); + } + } + break; + case OMNI_DOUBLE: + for (int i = 0; i < rowCount; i++) { + if (!vec.isNull(i)) { + resList.add(((DoubleVec) vec).get(i)); + } + if (!expectedVec.isNull(i)) { + expList.add(((DoubleVec) expectedVec).get(i)); + } + } + break; + case OMNI_BOOLEAN: + for (int i = 0; i < rowCount; i++) { + if (!vec.isNull(i)) { + resList.add(((BooleanVec) vec).get(i)); + } + if (!expectedVec.isNull(i)) { + expList.add(((BooleanVec) expectedVec).get(i)); + } + } + break; + case OMNI_VARCHAR: + case OMNI_CHAR: + for (int i = 0; i < rowCount; i++) { + if (!vec.isNull(i)) { + resList.add(new String(((VarcharVec) vec).get(i))); + } + if (!expectedVec.isNull(i)) { + expList.add(new String(((VarcharVec) expectedVec).get(i))); + } + } + break; + default: + throw new UnsupportedOperationException("Unsupported type : " + vec.getType().getId()); + } + Object[] resArr = resList.toArray(); + Object[] expArr = expList.toArray(); + Arrays.sort(resArr); + Arrays.sort(expArr); + if (vec.getType().getId() == OMNI_DOUBLE) { + for (int i = 0; i < resArr.length; i++) { + if (Double.compare((double) resArr[i], (double) expArr[i]) != 0) { + return false; + } + } + } else { + for (int i = 0; i < resArr.length; i++) { + if (!resArr[i].equals(expArr[i])) { + return false; + } + } + } + return true; + } + + /** + * Vec equals + * + * @param vec vec + * @param expectedVec expectedVec + */ + public static void assertVecEquals(Vec vec, Vec expectedVec) { + for (int i = 0; i < vec.getSize(); i++) { + if (vec.isNull(i) && expectedVec.isNull(i)) { + continue; + } + switch (vec.getType().getId()) { + case OMNI_INT: + case OMNI_DATE32: + assertEquals(((IntVec) vec).get(i), ((IntVec) expectedVec).get(i)); + break; + case OMNI_SHORT: + assertEquals(((ShortVec) vec).get(i), ((ShortVec) expectedVec).get(i)); + break; + case OMNI_LONG: + case OMNI_TIMESTAMP: + case OMNI_DECIMAL64: + assertEquals(((LongVec) vec).get(i), ((LongVec) expectedVec).get(i)); + break; + case OMNI_DOUBLE: + assertTrue(Double.compare(((DoubleVec) vec).get(i), ((DoubleVec) expectedVec).get(i)) == 0); + break; + case OMNI_BOOLEAN: + assertEquals(((BooleanVec) vec).get(i), ((BooleanVec) expectedVec).get(i)); + break; + case OMNI_VARCHAR: + case OMNI_CHAR: + assertEquals(new String(((VarcharVec) vec).get(i)), new String(((VarcharVec) expectedVec).get(i))); + break; + case OMNI_DECIMAL128: + assertEquals(((Decimal128Vec) vec).get(i), ((Decimal128Vec) expectedVec).get(i)); + break; + default: + throw new UnsupportedOperationException("Unsupported type : " + vec.getType().getId()); + } + } + } + + /** + * Vec equals + * + * @param vec vec + * @param expectedData data + */ + public static void assertVecEquals(Vec vec, Object[][] expectedData) { + for (int i = 0; i < vec.getSize(); i++) { + if (vec.isNull(i)) { + assertEquals(null, expectedData[i]); + continue; + } + switch (vec.getType().getId()) { + case OMNI_DECIMAL128: + assertEquals(((Decimal128Vec) vec).get(i), + new long[]{(long) expectedData[i][0], (long) expectedData[i][1]}); + break; + default: + throw new UnsupportedOperationException("Unsupported type : " + vec.getType().getId()); + } + } + } + + /** + * Dictionary vec equals + * + * @param vec dictionary vec + * @param expectedData data + * @return boolean + */ + public static boolean assertDecimal128DictionaryVecEqualsIgnoreOrder(DictionaryVec vec, Long[][] expectedData) { + int rowCount = vec.getSize(); + int resNUllCount = 0; + int expNullCount = 0; + for (int i = 0; i < rowCount; i++) { + if (vec.isNull(i)) { + resNUllCount++; + } + if (expectedData[i] == null) { + expNullCount++; + } + } + if (resNUllCount != expNullCount) { + return false; + } + long[][] resArr = new long[rowCount - resNUllCount][2]; + long[][] expArr = new long[rowCount - resNUllCount][2]; + for (int i = 0, j = 0, k = 0; i < rowCount; i++) { + if (!vec.isNull(i)) { + resArr[j++] = vec.getDecimal128(i); + } + if (expectedData[i] != null) { + expArr[k][0] = (long) expectedData[i][0]; + expArr[k][1] = (long) expectedData[i][1]; + k++; + } + } + Arrays.sort(resArr, new Comparator() { + @Override + public int compare(long[] arr1, long[] arr2) { + if (arr1[0] == arr2[0]) { + return Long.compare(arr1[1], arr2[1]); + } + return Long.compare(arr1[0], arr2[0]); + } + }); + Arrays.sort(expArr, new Comparator() { + @Override + public int compare(long[] arr1, long[] arr2) { + if (arr1[0] == arr2[0]) { + return Long.compare(arr1[1], arr2[1]); + } + return Long.compare(arr1[0], arr2[0]); + } + }); + for (int i = 0; i < resArr.length; i++) { + if (resArr[i][0] != expArr[i][0] || resArr[i][1] != expArr[i][1]) { + return false; + } + } + return true; + } + + /** + * Dictionary vec equals + * + * @param vec dictionary vec + * @param expectedData data + * @return boolean + */ + public static boolean assertDictionaryVecEqualsIgnoreOrder(DictionaryVec vec, Object[] expectedData) { + int rowCount = vec.getSize(); + int resNUllCount = 0; + int expNullCount = 0; + List resList = new ArrayList<>(); + List expList = new ArrayList<>(); + for (int i = 0; i < rowCount; i++) { + if (vec.isNull(i)) { + resNUllCount++; + } + if (expectedData[i] == null) { + expNullCount++; + } else { + expList.add(expectedData[i]); + } + } + + if (resNUllCount != expNullCount) { + return false; + } + switch (vec.getType().getId()) { + case OMNI_INT: + case OMNI_DATE32: + for (int i = 0; i < rowCount; i++) { + if (!vec.isNull(i)) { + resList.add(vec.getInt(i)); + } + } + break; + case OMNI_SHORT: + for (int i = 0; i < rowCount; i++) { + if (!vec.isNull(i)) { + resList.add(vec.getShort(i)); + } + } + break; + case OMNI_LONG: + case OMNI_TIMESTAMP: + case OMNI_DECIMAL64: + for (int i = 0; i < rowCount; i++) { + if (!vec.isNull(i)) { + resList.add(vec.getLong(i)); + } + } + break; + case OMNI_DOUBLE: + for (int i = 0; i < rowCount; i++) { + if (!vec.isNull(i)) { + resList.add(vec.getDouble(i)); + } + } + break; + case OMNI_BOOLEAN: + for (int i = 0; i < rowCount; i++) { + if (!vec.isNull(i)) { + resList.add(vec.getBoolean(i)); + } + } + break; + case OMNI_VARCHAR: + case OMNI_CHAR: + for (int i = 0; i < rowCount; i++) { + if (!vec.isNull(i)) { + resList.add(new String(vec.getBytes(i))); + } + } + break; + default: + throw new UnsupportedOperationException("Unsupported type : " + vec.getType().getId()); + } + Object[] resArr = resList.toArray(); + Object[] expArr = expList.toArray(); + Arrays.sort(resArr); + Arrays.sort(expArr); + if (vec.getType().getId() == OMNI_DOUBLE) { + for (int i = 0; i < resArr.length; i++) { + if (Double.compare((double) resArr[i], (double) expArr[i]) != 0) { + return false; + } + } + } else { + for (int i = 0; i < resArr.length; i++) { + if (!resArr[i].equals(expArr[i])) { + return false; + } + } + } + + return true; + } + + /** + * Dictionary vec equals + * + * @param vec dictionary vec + * @param expectedData data + */ + public static void assertDictionaryVecEquals(DictionaryVec vec, Object[] expectedData) { + DataType.DataTypeId typeId = vec.getType().getId(); + + for (int i = 0; i < vec.getSize(); i++) { + if (vec.isNull(i)) { + assertEquals(null, expectedData[i]); + continue; + } + switch (typeId) { + case OMNI_INT: + case OMNI_DATE32: + assertEquals(vec.getInt(i), expectedData[i]); + break; + case OMNI_SHORT: + assertEquals(vec.getShort(i), expectedData[i]); + break; + case OMNI_LONG: + case OMNI_TIMESTAMP: + case OMNI_DECIMAL64: + assertEquals(vec.getLong(i), expectedData[i]); + break; + case OMNI_BOOLEAN: + assertEquals(vec.getBoolean(i), expectedData[i]); + break; + case OMNI_DOUBLE: + assertEquals(Double.compare(vec.getDouble(i), (Double) expectedData[i]), 0); + break; + case OMNI_VARCHAR: + case OMNI_CHAR: + assertEquals(vec.getBytes(i), ((String) (expectedData[i])).getBytes(StandardCharsets.UTF_8)); + break; + case OMNI_DECIMAL128: + assertEquals(vec.getDecimal128(i), + new long[]{(long) expectedData[i * 2], (long) expectedData[i * 2 + 1]}); + break; + default: + throw new UnsupportedOperationException("Unsupported type : " + typeId); + } + } + } + + /** + * Dictionary vec equals + * + * @param vec dictionary vec + * @param expectedVec expectedVec + * @return boolean + */ + public static boolean assertDecimal128DictionaryVecEqualsIgnoreOrder(DictionaryVec vec, + DictionaryVec expectedVec) { + int rowCount = vec.getSize(); + int resNUllCount = 0; + int expNullCount = 0; + for (int i = 0; i < rowCount; i++) { + if (vec.isNull(i)) { + resNUllCount++; + } + if (expectedVec.isNull(i)) { + expNullCount++; + } + } + if (resNUllCount != expNullCount) { + return false; + } + long[][] resArr = new long[rowCount - resNUllCount][2]; + long[][] expArr = new long[rowCount - resNUllCount][2]; + for (int i = 0, j = 0, k = 0; i < rowCount; i++) { + if (!vec.isNull(i)) { + resArr[j++] = vec.getDecimal128(i); + } + if (expectedVec.isNull(i)) { + expArr[k++] = expectedVec.getDecimal128(i); + } + } + Arrays.sort(resArr, new Comparator() { + @Override + public int compare(long[] arr1, long[] arr2) { + if (arr1[0] == arr2[0]) { + return Long.compare(arr1[1], arr2[1]); + } + return Long.compare(arr1[0], arr2[0]); + } + }); + Arrays.sort(expArr, new Comparator() { + @Override + public int compare(long[] arr1, long[] arr2) { + if (arr1[0] == arr2[0]) { + return Long.compare(arr1[1], arr2[1]); + } + return Long.compare(arr1[0], arr2[0]); + } + }); + for (int i = 0; i < resArr.length; i++) { + if (resArr[i][0] != expArr[i][0] || resArr[i][1] != expArr[i][1]) { + return false; + } + } + return true; + } + + /** + * Dictionary vec equals + * + * @param vec dictionary vec + * @param expectedVec expectedVec + * @return boolean + */ + public static boolean assertDictionaryVecEqualsIgnoreOrder(DictionaryVec vec, DictionaryVec expectedVec) { + int rowCount = vec.getSize(); + int resNUllCount = 0; + int expNullCount = 0; + List resList = new ArrayList<>(); + List expList = new ArrayList<>(); + for (int i = 0; i < rowCount; i++) { + if (vec.isNull(i)) { + resNUllCount++; + } + if (expectedVec.isNull(i)) { + expNullCount++; + } + } + + if (resNUllCount != expNullCount) { + return false; + } + switch (vec.getType().getId()) { + case OMNI_INT: + case OMNI_DATE32: + for (int i = 0; i < rowCount; i++) { + if (!vec.isNull(i)) { + resList.add(vec.getInt(i)); + } + if (!expectedVec.isNull(i)) { + expList.add(expectedVec.getInt(i)); + } + } + break; + case OMNI_SHORT: + for (int i = 0; i < rowCount; i++) { + if (!vec.isNull(i)) { + resList.add(vec.getShort(i)); + } + if (!expectedVec.isNull(i)) { + expList.add(expectedVec.getShort(i)); + } + } + break; + case OMNI_LONG: + case OMNI_TIMESTAMP: + case OMNI_DECIMAL64: + for (int i = 0; i < rowCount; i++) { + if (!vec.isNull(i)) { + resList.add(vec.getLong(i)); + } + if (!expectedVec.isNull(i)) { + expList.add(expectedVec.getLong(i)); + } + } + break; + case OMNI_DOUBLE: + for (int i = 0; i < rowCount; i++) { + if (!vec.isNull(i)) { + resList.add(vec.getDouble(i)); + } + if (!expectedVec.isNull(i)) { + expList.add(expectedVec.getDouble(i)); + } + } + break; + case OMNI_BOOLEAN: + for (int i = 0; i < rowCount; i++) { + if (!vec.isNull(i)) { + resList.add(vec.getBoolean(i)); + } + if (!expectedVec.isNull(i)) { + expList.add(expectedVec.getBoolean(i)); + } + } + break; + case OMNI_VARCHAR: + case OMNI_CHAR: + for (int i = 0; i < rowCount; i++) { + if (!vec.isNull(i)) { + resList.add(new String(vec.getBytes(i))); + } + if (!expectedVec.isNull(i)) { + expList.add(new String(expectedVec.getBytes(i))); + } + } + break; + case OMNI_DECIMAL128: + for (int i = 0; i < rowCount; i++) { + if (!vec.isNull(i)) { + resList.add(vec.getDecimal128(i)); + } + if (!expectedVec.isNull(i)) { + expList.add(expectedVec.getDecimal128(i)); + } + } + break; + default: + throw new UnsupportedOperationException("Unsupported type : " + vec.getType().getId()); + } + Object[] resArr = resList.toArray(); + Object[] expArr = expList.toArray(); + Arrays.sort(resArr); + Arrays.sort(expArr); + if (vec.getType().getId() == OMNI_DOUBLE) { + for (int i = 0; i < resArr.length; i++) { + if (Double.compare((double) resArr[i], (double) expArr[i]) != 0) { + return false; + } + } + } else { + for (int i = 0; i < resArr.length; i++) { + if (!resArr[i].equals(expArr[i])) { + return false; + } + } + } + return true; + } + + /** + * Dictionary vec equals + * + * @param vec dictionary vec + * @param expectedVec expectedVec + */ + public static void assertDictionaryVecEquals(DictionaryVec vec, DictionaryVec expectedVec) { + DataType.DataTypeId typeId = vec.getType().getId(); + + for (int i = 0; i < vec.getSize(); i++) { + if (vec.isNull(i) && expectedVec.isNull(i)) { + continue; + } + switch (typeId) { + case OMNI_INT: + case OMNI_DATE32: + assertEquals(vec.getInt(i), expectedVec.getInt(i)); + break; + case OMNI_SHORT: + assertEquals(vec.getShort(i), expectedVec.getShort(i)); + break; + case OMNI_LONG: + case OMNI_TIMESTAMP: + case OMNI_DECIMAL64: + assertEquals(vec.getLong(i), expectedVec.getLong(i)); + break; + case OMNI_BOOLEAN: + assertEquals(vec.getBoolean(i), expectedVec.getBoolean(i)); + break; + case OMNI_DOUBLE: + assertEquals(Double.compare(vec.getDouble(i), expectedVec.getDouble(i)), 0); + break; + case OMNI_VARCHAR: + case OMNI_CHAR: + assertEquals(vec.getBytes(i), expectedVec.getBytes(i)); + break; + case OMNI_DECIMAL128: + assertEquals(vec.getDecimal128(i), expectedVec.getDecimal128(i)); + break; + default: + throw new UnsupportedOperationException("Unsupported type : " + typeId); + } + } + } + + /** + * free the vecBatches. + * + * @param vecBatches vecBatch list. + */ + public static void freeVecBatches(List vecBatches) { + for (int i = 0; i < vecBatches.size(); i++) { + freeVecBatch(vecBatches.get(i)); + } + } + + /** + * Vec batch free + * + * @param vecBatch vecBatch + */ + public static void freeVecBatch(VecBatch vecBatch) { + vecBatch.releaseAllVectors(); + vecBatch.close(); + } + + /** + * match filter operator with json. + * + * @param inputData input data + * @param types the type array of input data + * @param projectIndices the list of project indices + * @param resultKeepRowIdxs the row ids of result + * @param filterExpression filter expression + * @param parFormat the format to parse expression + */ + public static void filterOperatorMatchWithJson(Object[][] inputData, DataType[] types, List projectIndices, + int[] resultKeepRowIdxs, String filterExpression, int parFormat) { + VecBatch vecBatch = createVecBatch(types, inputData); + OmniFilterAndProjectOperatorFactory factory = new OmniFilterAndProjectOperatorFactory(filterExpression, types, + projectIndices, parFormat); + OmniOperator op = factory.createOperator(); + op.addInput(vecBatch); + Iterator results = op.getOutput(); + + if (resultKeepRowIdxs.length == 0) { + assertFalse(results.hasNext()); + return; + } + + AssertJUnit.assertTrue(results.hasNext()); + + List keepedColumns = new ArrayList<>(); + for (int resultKeepRowIdx : resultKeepRowIdxs) { + keepedColumns.add(resultKeepRowIdx); + } + Object[][] expectedDatas = new Object[inputData.length][resultKeepRowIdxs.length]; + for (int i = 0; i < inputData.length; i++) { + for (int j = 0, m = 0; j < inputData[0].length && m < resultKeepRowIdxs.length; j++) { + if (keepedColumns.contains(j)) { + expectedDatas[i][m] = inputData[i][j]; + m++; + } + } + } + + VecBatch resultVecBatch = results.next(); + assertVecBatchEquals(resultVecBatch, expectedDatas); + op.close(); + factory.close(); + freeVecBatch(resultVecBatch); + } + + /** + * match filter operator. + * + * @param inputData input data + * @param types the type array of input data + * @param projectIndices the list of project indices + * @param resultKeepRowIdxs the row ids of result + * @param filterExpression filter expression + */ + public static void filterOperatorMatch(Object[][] inputData, DataType[] types, List projectIndices, + int[] resultKeepRowIdxs, String filterExpression) { + filterOperatorMatchWithJson(inputData, types, projectIndices, resultKeepRowIdxs, filterExpression, 0); + } + + /** + * generating "is not null" json expression. + * + * @param arguments arguments + * @return the formatted "is not null" json expression + */ + public static String omniJsonIsNotNullExpr(String arguments) { + return String.format(Locale.ROOT, "{\"exprType\":\"UNARY\",\"returnType\":4,\"operator\":\"not\"," + + "\"expr\":{\"exprType\":\"IS_NULL\",\"returnType\":4," + "\"arguments\":[%s]}}", arguments); + } + + /** + * generating "less than or equal" json expression. + * + * @param left the left argument + * @param right the right argument + * @return the formatted "less than or equal" json expression + */ + public static String omniJsonLessThanOrEqualExpr(String left, String right) { + return String.format(Locale.ROOT, "{\"exprType\":\"BINARY\",\"returnType\":4," + + "\"operator\":\"LESS_THAN_OR_EQUAL\",\"left\":%s,\"right\":%s}", left, right); + } + + /** + * generating "equal" json expression. + * + * @param left the left argument + * @param right the right argument + * @return the formatted "equal" json expression + */ + public static String omniJsonEqualExpr(String left, String right) { + return String.format(Locale.ROOT, + "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"EQUAL\",\"left\":%s,\"right\":%s}", left, + right); + } + + /** + * generating "not equal" json expression. + * + * @param left the left argument + * @param right the right argument + * @param the generic parameter + * @return the formatted "not equal" json expression + */ + public static String omniJsonNotEqualExpr(String left, String right) { + return String.format(Locale.ROOT, + "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":%s,\"right\":%s}", left, + right); + } + + /** + * generating "greater than or equal" json expression. + * + * @param left the left argument + * @param right the right argument + * @return the formatted "greater than or equal" json expression + */ + public static String omniJsonGreaterThanOrEqualExpr(String left, String right) { + return String.format(Locale.ROOT, "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":" + + "\"GREATER_THAN_OR_EQUAL\",\"left\":%s,\"right\":%s}", left, right); + } + + /** + * generating "greater than" json expression. + * + * @param left the left argument + * @param right the right argument + * @return the formatted "greater than" json expression + */ + public static String omniJsonGreaterThanExpr(String left, String right) { + return String.format(Locale.ROOT, "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"GREATER_THAN\"," + + "\"left\":%s,\"right\":%s}", left, right); + } + + /** + * generating "and" json expression. + * + * @param left the left argument + * @param right the right argument + * @return the formatted "and" json expression + */ + public static String omniJsonAndExpr(String left, String right) { + return String.format(Locale.ROOT, + "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"AND\",\"left\":%s,\"right\":%s}", left, + right); + } + + /** + * generating "or" json expression. + * + * @param left the left argument + * @param right the right argument + * @return the formatted "or" json expression + */ + public static String omniJsonOrExpr(String left, String right) { + return String.format(Locale.ROOT, + "{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"OR\",\"left\":%s,\"right\":%s}", left, right); + } + + /** + * generating "if" json expression. + * + * @param condition the condition + * @param returnType the type of result + * @param trueOpr the operator of true + * @param falseOpr the operator of false + * @return the formatted "if" json expression + */ + public static String omniJsonIfExpr(String condition, int returnType, String trueOpr, String falseOpr) { + return String.format(Locale.ROOT, + "{\"exprType\":\"IF\",\"returnType\":%d,\"condition\":%s,\"if_true\":%s, \"if_false\":%s}", returnType, + condition, trueOpr, falseOpr); + } + + /** + * get the formatted field reference. + * + * @param dataType the type of data + * @param colVal the column + * @return the formatted field reference + */ + public static String getOmniJsonFieldReference(int dataType, int colVal) { + if (dataType == 15) { + return String.format(Locale.ROOT, + "{\"exprType\": \"FIELD_REFERENCE\",\"dataType\": %d,\"colVal\": %d,\"width\":50}", dataType, + colVal); + } + return String.format(Locale.ROOT, "{\"exprType\": \"FIELD_REFERENCE\",\"dataType\": %d,\"colVal\": %d}", + dataType, colVal); + } + + /** + * get the formatted literal. + * + * @param dataType the type of data + * @param isNull whether is null + * @param value the value + * @param the generic parameter + * @return the formatted literal + */ + public static String getOmniJsonLiteral(int dataType, boolean isNull, T value) { + if (value instanceof String) { + return String.format(Locale.ROOT, + "{\"exprType\":\"LITERAL\",\"dataType\":%d,\"isNull\":%b,\"value\":\"%s\",\"width\":50}", dataType, + isNull, value); + } + return String.format(Locale.ROOT, "{\"exprType\":\"LITERAL\",\"dataType\":%d,\"isNull\":%b,", dataType, isNull) + + "\"value\":" + value + "}"; + } + + /** + * generating "arithmetic" json expression. + * + * @param opr the operator + * @param returnType the type of result + * @param left the left argument + * @param right the right argument + * @return the formatted "arithmetic" json expression + */ + public static String omniJsonFourArithmeticExpr(String opr, int returnType, String left, String right) { + return String.format(Locale.ROOT, + "{\"exprType\":\"BINARY\",\"returnType\":%d,\"operator\":\"%s\",\"left\":%s,\"right\":%s}", returnType, + opr, left, right); + } + + /** + * generating "in" json expression. + * + * @param argDataType the type of data + * @param colVal the column + * @param values the list of value + * @param the generic parameter + * @return the formatted "in" json expression + */ + public static String omniJsonInExpr(int argDataType, int colVal, List values) { + JSONArray jsonArray = new JSONArray(); + JSONObject jsonObject = new JSONObject(); + jsonObject.put("exprType", "FIELD_REFERENCE"); + jsonObject.put("dataType", argDataType); + jsonObject.put("colVal", colVal); + if (argDataType == 15) { + jsonObject.put("width", 50); + } + jsonArray.add(jsonObject); + for (T value : values) { + JSONObject object = new JSONObject(); + object.put("exprType", "LITERAL"); + object.put("dataType", argDataType); + object.put("isNull", false); + object.put("value", value); + if (argDataType == 15) { + object.put("width", 50); + } + jsonArray.add(object); + } + String argString = jsonArray.toJSONString(); + return String.format(Locale.ROOT, "{\"exprType\":\"IN\",\"returnType\":4,\"arguments\":%s}", argString); + } + + /** + * generating "abs" json expression. + * + * @param dataType the type of data + * @param arguments the arguments + * @return the formatted "abs" json expression + */ + public static String omniJsonAbsExpr(int dataType, String arguments) { + return String.format(Locale.ROOT, + "{\"exprType\":\"FUNCTION\",\"returnType\":%d,\"function_name\":\"abs\",\"arguments\":[%s]}", dataType, + arguments); + } + + /** + * generating "cast" json expression. + * + * @param returnType the type of result + * @param arguments the arguments + * @return the formatted "cast" json expression + */ + public static String omniJsonCastExpr(int returnType, String arguments) { + if (returnType == 15) { + return String.format(Locale.ROOT, "{\"exprType\":\"FUNCTION\",\"returnType\":%d,\"width\":50," + + "\"function_name\":\"CAST\",\"arguments\":[%s]}", returnType, arguments); + } + return String.format(Locale.ROOT, + "{\"exprType\":\"FUNCTION\",\"returnType\":%d,\"function_name\":\"CAST\",\"arguments\":[%s]}", + returnType, arguments); + } + + /** + * generating general function json expression. + * + * @param function the name of function + * @param returnType the type of result + * @param arguments the arguments + * @return the formatted "cast" json expression + */ + public static String omniFunctionExpr(String function, int returnType, String arguments) { + if (returnType == 15) { + return String.format(Locale.ROOT, "{\"exprType\":\"FUNCTION\",\"returnType\":%d,\"width\":50," + + "\"function_name\":\"%s\",\"arguments\":[%s]}", returnType, function, arguments); + } + return String.format(Locale.ROOT, + "{\"exprType\":\"FUNCTION\",\"returnType\":%d,\"function_name\":\"%s\",\"arguments\":[%s]}", returnType, + function, arguments); + } + + /** + * get the corresponding addFlag enumerated value in resultCode. + * + * @param resultCode the resultCode + * @return the corresponding addFlag enumerated value in resultCode + */ + public static int decodeAddFlag(int resultCode) { + return resultCode >> 16; + } + + /** + * get the corresponding fetchFlag enumerated value in resultCode. + * + * @param resultCode the resultCode + * @return the corresponding fetchFlag enumerated value in resultCode + */ + public static int decodeFetchFlag(int resultCode) { + return resultCode & Short.MAX_VALUE; + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/vector/BenchmarkDecimal128Vec.java b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/BenchmarkDecimal128Vec.java new file mode 100644 index 0000000..3f1588e --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/BenchmarkDecimal128Vec.java @@ -0,0 +1,352 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import static java.util.concurrent.TimeUnit.MICROSECONDS; + +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.DecimalVector; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import org.openjdk.jmh.runner.options.VerboseMode; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.math.MathContext; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +/** + * Decimal128 vec benchmark + * + * @since 2022-4-9 + */ +@State(Scope.Thread) +@OutputTimeUnit(MICROSECONDS) +@Fork(1) +@Warmup(iterations = 1, batchSize = 1) +@Measurement(iterations = 10, batchSize = 1) +@BenchmarkMode(Mode.AverageTime) +public class BenchmarkDecimal128Vec { + private static final int ALLOCATOR_CAPACITY = 1024 * 1024; + private static final int COUNT = 1000; + private static final int LONG_COUNT = 2; // count of long values per decimal128 value + private static final int BITS_OF_DECIMAL = 128; + private static final int PRECISION = 38; + private static final int SCALE = 3; + private static final long MIN_VALUE = 100L; + + static { + // this parameter affects arrow get performance + System.setProperty("arrow.enable_null_check_for_get", "false"); + // this parameter affects arrow set performance + System.setProperty("arrow.enable_unsafe_memory_access", "true"); + } + + // arrow get + RootAllocator allocator2; + DecimalVector arrowDecimal128VecGet; + + @Param({"1024"}) + private int rows = 1024; + + // Decimal128Vec set/put + private Decimal128Vec vecPutData; + private Decimal128Vec vecSetData; + private long[] randomData; + + // Decimal128Vec get + private Decimal128Vec vecGetData; + private Decimal128Vec vecGetDatas; + private long[] results; + private int[] positions; + + // arrow set + private RootAllocator allocator1; + private DecimalVector arrowDecimal128VecSet; + + private final Random random = new Random(0); + + /** + * init + */ + @Setup(Level.Iteration) + public void init() { + // for vec set/put + vecPutData = new Decimal128Vec(rows); + vecSetData = new Decimal128Vec(rows); + randomData = new long[rows * LONG_COUNT]; + initValues(randomData, rows * LONG_COUNT); + + // for vec get + vecGetData = new Decimal128Vec(rows); + initValues(vecGetData, rows); + vecGetDatas = new Decimal128Vec(rows); + initValues(vecGetDatas, rows); + results = new long[rows * LONG_COUNT]; + + positions = new int[rows / 2]; + for (int i = 0; i < rows / 2; i++) { + positions[i] = i; + } + + // arrow set + allocator1 = new RootAllocator(ALLOCATOR_CAPACITY); + arrowDecimal128VecSet = new DecimalVector("decimal128Vector1", allocator1, PRECISION, SCALE); + arrowDecimal128VecSet.allocateNew(rows); + arrowDecimal128VecSet.setValueCount(rows); + + // arrow get + allocator2 = new RootAllocator(ALLOCATOR_CAPACITY); + arrowDecimal128VecGet = new DecimalVector("decimal128Vector2", allocator2, PRECISION, SCALE); + arrowDecimal128VecGet.allocateNew(rows); + initValues(arrowDecimal128VecGet, rows); + arrowDecimal128VecGet.setValueCount(rows); + } + + /** + * close + */ + @TearDown(Level.Iteration) + public void tearDown() { + vecPutData.close(); + vecSetData.close(); + vecGetData.close(); + + arrowDecimal128VecSet.close(); + allocator1.close(); + + arrowDecimal128VecGet.close(); + allocator2.close(); + } + + private void initValues(Decimal128Vec vec, int rowCount) { + for (int i = 0; i < rowCount; i++) { + long[] val = {random.nextLong(), random.nextLong()}; + vec.set(i, val); + } + } + + private void initValues(long[] array, int rowCount) { + for (int i = 0; i < rowCount; i++) { + array[i] = random.nextLong(); + } + } + + private BigDecimal getOneBigDecimalValue() { + MathContext mc = new MathContext(PRECISION); + BigInteger bigInteger; + while (true) { + bigInteger = new BigInteger(String.valueOf(random.nextLong())); + if (bigInteger.compareTo(BigInteger.valueOf(MIN_VALUE)) > 0) { + break; + } + } + + return new BigDecimal(bigInteger, SCALE, mc); + } + + private void initValues(DecimalVector arrayVec, int rowCount) { + for (int i = 0; i < rowCount; i++) { + arrayVec.set(i, getOneBigDecimalValue()); + } + } + + /** + * Create long vec benchmark + * + * @param blackhole blackhole + */ + @Benchmark + public void createDecimal128VecBenchmark(Blackhole blackhole) { + List vecs = new ArrayList<>(); + for (int i = 0; i < COUNT; i++) { + Decimal128Vec vec = new Decimal128Vec(rows); + blackhole.consume(vec); + vecs.add(vec); + } + closeVec(vecs); + } + + /** + * create benchmark for arrow vector + * + * @param blackhole blackhole + */ + @Benchmark + public void createArrowVecBenchmark(Blackhole blackhole) { + List vecs = new ArrayList<>(); + RootAllocator rootAllocator = new RootAllocator(Long.MAX_VALUE); + for (int i = 0; i < COUNT; i++) { + DecimalVector arrowVec = new DecimalVector("key", rootAllocator, PRECISION, SCALE); + arrowVec.allocateNew(rows); + blackhole.consume(arrowVec); + vecs.add(arrowVec); + } + closeArrowVec(vecs); + rootAllocator.close(); + } + + /** + * create long array benchmark + * + * @param benchmarkData benchmark data for test + */ + @Benchmark + public void createDecimal128ArrayBenchmark(Blackhole benchmarkData) { + for (int i = 0; i < COUNT; i++) { + benchmarkData.consume(new long[rows]); + } + } + + /** + * put data for decimal128 vector + */ + @Benchmark + public void setDecimal128VecBenchmark() { + for (int i = 0; i < rows; i++) { + long[] valArray = {randomData[i], randomData[i] / 2}; + vecSetData.set(i, valArray); + } + } + + /** + * copy benchmark data + */ + @Benchmark + public void copyArrayBenchmark() { + long[] data = new long[rows]; + System.arraycopy(randomData, 0, data, 0, rows); + } + + /** + * new decimal128 vec put benchmark + */ + @Benchmark + public void newPutReleaseDecimal128VecBenchmark() { + Decimal128Vec vec = new Decimal128Vec(rows); + vec.put(randomData, 0, 0, randomData.length); + vec.close(); + } + + /** + * Decimal128 vec put benchmark + */ + @Benchmark + public void putDecimal128VecBenchmark() { + vecPutData.put(randomData, 0, 0, randomData.length); + } + + /** + * Arrow vec set benchmark + */ + @Benchmark + public void setArrowVecBenchmark() { + for (int i = 0; i < rows; i++) { + arrowDecimal128VecSet.set(i, getOneBigDecimalValue()); + } + } + + /** + * Decimal128 vec get benchmark + * + * @return sum value + */ + @Benchmark + public long getDecimal128VecBenchmark() { + long sum = 0L; + for (int i = 0; i < rows; i++) { + sum += vecGetData.get(i)[0]; + } + return sum; + } + + /** + * Decimal128 vecs get benchmark + * + * @return sum value + */ + @Benchmark + public long getsDecimal128VecBenchmark() { + long sum = 0L; + long[] result = vecGetDatas.get(0, rows); + for (long datum : result) { + sum += datum; + } + return sum; + } + + /** + * Arrow vec get benchmark + * + * @return sum value + */ + @Benchmark + public long getArrowVecBenchmark() { + long sum = 0L; + int rowsTemp = this.rows; + for (int i = 0; i < rowsTemp; i++) { + sum += arrowDecimal128VecGet.get(i).getLong(0); + } + return sum; + } + + /** + * Get decimal128 vec slice size + * + * @return slice size + */ + @Benchmark + public int sliceDecimal128VecBenchmark() { + Decimal128Vec slice = vecGetData.slice(2, rows / 2); + return slice.getSize(); + } + + /** + * Copy position of decimal128 vec benchmark + * + * @return position size + */ + @Benchmark + public int copyPositionDecimal128VecBenchmark() { + Decimal128Vec copyPosition = vecGetData.copyPositions(positions, 0, positions.length); + return copyPosition.getSize(); + } + + private void closeArrowVec(List vecs) { + for (DecimalVector vec : vecs) { + vec.close(); + } + } + + private void closeVec(List vecs) { + for (Vec vec : vecs) { + vec.close(); + } + } + + public static void main(String[] args) throws Throwable { + Options options = new OptionsBuilder().verbosity(VerboseMode.NORMAL) + .include(".*" + BenchmarkDecimal128Vec.class.getSimpleName() + ".*").build(); + + new Runner(options).run(); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/vector/BenchmarkDoubleVec.java b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/BenchmarkDoubleVec.java new file mode 100644 index 0000000..4aebb3d --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/BenchmarkDoubleVec.java @@ -0,0 +1,359 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import static java.util.concurrent.TimeUnit.MICROSECONDS; + +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.Float8Vector; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import org.openjdk.jmh.runner.options.VerboseMode; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +/** + * Double vec benchmark + * + * @since 2022-4-9 + */ +@State(Scope.Thread) +@OutputTimeUnit(MICROSECONDS) +@Fork(1) +@Warmup(iterations = 1, batchSize = 1) +@Measurement(iterations = 10, batchSize = 1) +@BenchmarkMode(Mode.AverageTime) +public class BenchmarkDoubleVec { + private static final int ALLOCATOR_CAPACITY = 1024 * 1024; + private static final int COUNT = 1000; + + static { + // this parameter affects arrow get performance + System.setProperty("arrow.enable_null_check_for_get", "false"); + // this parameter affects arrow set performance + System.setProperty("arrow.enable_unsafe_memory_access", "true"); + } + + // arrow get + RootAllocator allocator2; + Float8Vector arrowDoubleVecGet; + + @Param({"1024"}) + private int rows = 1024; + + // doubleVec set/put + private DoubleVec vecPutData; + private DoubleVec vecSetData; + private double[] arraySetData; + private double[] randomData; + + // doubleVec get + private DoubleVec vecGetData; + private DoubleVec vecGetDatas; + private double[] results; + private double[] arrayGetData; + private int[] positions; + + // arrow set + private RootAllocator allocator1; + private Float8Vector arrowDoubleVecSet; + + private final Random random = new Random(0); + + /** + * init + */ + @Setup(Level.Iteration) + public void init() { + // for vec set/put + vecPutData = new DoubleVec(rows); + vecSetData = new DoubleVec(rows); + arraySetData = new double[rows]; + randomData = new double[rows]; + initValues(randomData, rows); + + // for vec get + vecGetData = new DoubleVec(rows); + initValues(vecGetData, rows); + vecGetDatas = new DoubleVec(rows); + initValues(vecGetDatas, rows); + results = new double[rows]; + + positions = new int[rows / 2]; + for (int i = 0; i < rows / 2; i++) { + positions[i] = i; + } + + // arrow set + allocator1 = new RootAllocator(ALLOCATOR_CAPACITY); + arrowDoubleVecSet = new Float8Vector("doubleVector1", allocator1); + arrowDoubleVecSet.allocateNew(rows); + arrowDoubleVecSet.setValueCount(rows); + + // arrow get + allocator2 = new RootAllocator(ALLOCATOR_CAPACITY); + arrowDoubleVecGet = new Float8Vector("doubleVector2", allocator2); + arrowDoubleVecGet.allocateNew(rows); + initValues(arrowDoubleVecGet, rows); + arrowDoubleVecGet.setValueCount(rows); + + arrayGetData = new double[rows]; + initValues(arrayGetData, rows); + } + + /** + * close + */ + @TearDown(Level.Iteration) + public void tearDown() { + vecPutData.close(); + vecSetData.close(); + vecGetData.close(); + + arrowDoubleVecSet.close(); + allocator1.close(); + + arrowDoubleVecGet.close(); + allocator2.close(); + } + + private void initValues(DoubleVec vec, int rowCount) { + for (int i = 0; i < rowCount; i++) { + vec.set(i, random.nextDouble()); + } + } + + private void initValues(double[] array, int rowCount) { + for (int i = 0; i < rowCount; i++) { + array[i] = random.nextDouble(); + } + } + + private void initValues(Float8Vector arrayVec, int rowCount) { + for (int i = 0; i < rowCount; i++) { + arrayVec.set(i, random.nextDouble()); + } + } + + /** + * Create double vec benchmark + * + * @param blackhole blackhole + */ + @Benchmark + public void createDoubleVecBenchmark(Blackhole blackhole) { + List vecs = new ArrayList<>(); + for (int i = 0; i < COUNT; i++) { + DoubleVec vec = new DoubleVec(rows); + blackhole.consume(vec); + vecs.add(vec); + } + closeVec(vecs); + } + + /** + * create benchmark for arrow vector + * + * @param blackhole blackhole + */ + @Benchmark + public void createArrowVecBenchmark(Blackhole blackhole) { + List vecs = new ArrayList<>(); + RootAllocator rootAllocator = new RootAllocator(Long.MAX_VALUE); + for (int i = 0; i < COUNT; i++) { + Float8Vector arrowVec = new Float8Vector("key", rootAllocator); + arrowVec.allocateNew(rows); + blackhole.consume(arrowVec); + vecs.add(arrowVec); + } + closeArrowVec(vecs); + rootAllocator.close(); + } + + /** + * create double array benchmark + * + * @param benchmarkData benchmark data for test + */ + @Benchmark + public void createDoubleArrayBenchmark(Blackhole benchmarkData) { + for (int i = 0; i < COUNT; i++) { + benchmarkData.consume(new double[rows]); + } + } + + /** + * put data for double vector + */ + @Benchmark + public void setDoubleVecBenchmark() { + for (int i = 0; i < rows; i++) { + vecSetData.set(i, randomData[i]); + } + } + + /** + * copy benchmark data + */ + @Benchmark + public void copyArrayBenchmark() { + double[] data = new double[rows]; + System.arraycopy(randomData, 0, data, 0, rows); + } + + /** + * new double vec put benchmark + */ + @Benchmark + public void newPutReleaseDoubleVecBenchmark() { + DoubleVec vec = new DoubleVec(rows); + vec.put(randomData, 0, 0, randomData.length); + vec.close(); + } + + /** + * Double vec put benchmark + */ + @Benchmark + public void putDoubleVecBenchmark() { + vecPutData.put(randomData, 0, 0, randomData.length); + } + + /** + * Double array set benchmark + */ + @Benchmark + public void setDoubleArrayBenchmark() { + if (rows >= 0) { + System.arraycopy(randomData, 0, arraySetData, 0, rows); + } + } + + /** + * Arrow vec set benchmark + */ + @Benchmark + public void setArrowVecBenchmark() { + for (int i = 0; i < rows; i++) { + arrowDoubleVecSet.set(i, randomData[i]); + } + } + + /** + * Double array get benchmark + * + * @return sum value + */ + @Benchmark + public double getDoubleArrayBenchmark() { + double sum = 0.0d; + for (int i = 0; i < rows; i++) { + sum += arrayGetData[i]; + } + return sum; + } + + /** + * Double vec get benchmark + * + * @return sum value + */ + @Benchmark + public double getDoubleVecBenchmark() { + double sum = 0.0d; + for (int i = 0; i < rows; i++) { + sum += vecGetData.get(i); + } + return sum; + } + + /** + * Double vecs get benchmark + * + * @return sum value + */ + @Benchmark + public double getsDoubleVecBenchmark() { + double sum = 0.0d; + double[] result = vecGetDatas.get(0, rows); + for (double datum : result) { + sum += datum; + } + return sum; + } + + /** + * Arrow vec get benchmark + * + * @return sum value + */ + @Benchmark + public double getArrowVecBenchmark() { + double sum = 0.0d; + int rowsTemp = this.rows; + for (int i = 0; i < rowsTemp; i++) { + sum += arrowDoubleVecGet.get(i); + } + return sum; + } + + /** + * Get double vec slice size + * + * @return slice size + */ + @Benchmark + public int sliceDoubleVecBenchmark() { + DoubleVec slice = vecGetData.slice(2, rows / 2); + return slice.getSize(); + } + + /** + * Copy position of double vec benchmark + * + * @return position size + */ + @Benchmark + public int copyPositionDoubleVecBenchmark() { + DoubleVec copyPosition = vecGetData.copyPositions(positions, 0, positions.length); + return copyPosition.getSize(); + } + + private void closeArrowVec(List vecs) { + for (Float8Vector vec : vecs) { + vec.close(); + } + } + + private void closeVec(List vecs) { + for (Vec vec : vecs) { + vec.close(); + } + } + + public static void main(String[] args) throws Throwable { + Options options = new OptionsBuilder().verbosity(VerboseMode.NORMAL) + .include(".*" + BenchmarkDoubleVec.class.getSimpleName() + ".*").build(); + + new Runner(options).run(); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/vector/BenchmarkIntVec.java b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/BenchmarkIntVec.java new file mode 100644 index 0000000..9bed96d --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/BenchmarkIntVec.java @@ -0,0 +1,359 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import static java.util.concurrent.TimeUnit.MICROSECONDS; + +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import org.openjdk.jmh.runner.options.VerboseMode; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +/** + * Int vec benchmark + * + * @since 2022-4-9 + */ +@State(Scope.Thread) +@OutputTimeUnit(MICROSECONDS) +@Fork(1) +@Warmup(iterations = 1, batchSize = 1) +@Measurement(iterations = 10, batchSize = 1) +@BenchmarkMode(Mode.AverageTime) +public class BenchmarkIntVec { + private static final int ALLOCATOR_CAPACITY = 1024 * 1024; + private static final int COUNT = 1000; + + static { + // this parameter affects arrow get performance + System.setProperty("arrow.enable_null_check_for_get", "false"); + // this parameter affects arrow set performance + System.setProperty("arrow.enable_unsafe_memory_access", "true"); + } + + // arrow get + RootAllocator allocator2; + IntVector arrowIntVecGet; + + @Param({"1024"}) + private int rows = 1024; + + // intVec set/put + private IntVec vecPutData; + private IntVec vecSetData; + private int[] arraySetData; + private int[] randomData; + + // intVec get + private IntVec vecGetData; + private IntVec vecGetDatas; + private int[] results; + private int[] arrayGetData; + private int[] positions; + + // arrow set + private RootAllocator allocator1; + private IntVector arrowIntVecSet; + + private final Random random = new Random(0); + + /** + * init + */ + @Setup(Level.Iteration) + public void init() { + // for vec set/put + vecPutData = new IntVec(rows); + vecSetData = new IntVec(rows); + arraySetData = new int[rows]; + randomData = new int[rows]; + initValues(randomData, rows); + + // for vec get + vecGetData = new IntVec(rows); + initValues(vecGetData, rows); + vecGetDatas = new IntVec(rows); + initValues(vecGetDatas, rows); + results = new int[rows]; + + positions = new int[rows / 2]; + for (int i = 0; i < rows / 2; i++) { + positions[i] = i; + } + + // arrow set + allocator1 = new RootAllocator(ALLOCATOR_CAPACITY); + arrowIntVecSet = new IntVector("IntVector1", allocator1); + arrowIntVecSet.allocateNew(rows); + arrowIntVecSet.setValueCount(rows); + + // arrow get + allocator2 = new RootAllocator(ALLOCATOR_CAPACITY); + arrowIntVecGet = new IntVector("IntVector2", allocator2); + arrowIntVecGet.allocateNew(rows); + initValues(arrowIntVecGet, rows); + arrowIntVecGet.setValueCount(rows); + + arrayGetData = new int[rows]; + initValues(arrayGetData, rows); + } + + /** + * close + */ + @TearDown(Level.Iteration) + public void tearDown() { + vecPutData.close(); + vecSetData.close(); + vecGetData.close(); + + arrowIntVecSet.close(); + allocator1.close(); + + arrowIntVecGet.close(); + allocator2.close(); + } + + private void initValues(IntVec vec, int rowCount) { + for (int i = 0; i < rowCount; i++) { + vec.set(i, random.nextInt()); + } + } + + private void initValues(int[] array, int rowCount) { + for (int i = 0; i < rowCount; i++) { + array[i] = random.nextInt(); + } + } + + private void initValues(IntVector arrayVec, int rowCount) { + for (int i = 0; i < rowCount; i++) { + arrayVec.set(i, random.nextInt()); + } + } + + /** + * Create int vec benchmark + * + * @param blackhole blackhole + */ + @Benchmark + public void createIntVecBenchmark(Blackhole blackhole) { + List vecs = new ArrayList<>(); + for (int i = 0; i < COUNT; i++) { + IntVec vec = new IntVec(rows); + blackhole.consume(vec); + vecs.add(vec); + } + closeVec(vecs); + } + + /** + * create benchmark for arrow vector + * + * @param blackhole blackhole + */ + @Benchmark + public void createArrowVecBenchmark(Blackhole blackhole) { + List vecs = new ArrayList<>(); + RootAllocator rootAllocator = new RootAllocator(Long.MAX_VALUE); + for (int i = 0; i < COUNT; i++) { + IntVector arrowVec = new IntVector("key", rootAllocator); + arrowVec.allocateNew(rows); + blackhole.consume(arrowVec); + vecs.add(arrowVec); + } + closeArrowVec(vecs); + rootAllocator.close(); + } + + /** + * create int array benchmark + * + * @param benchmarkData benchmark data for test + */ + @Benchmark + public void createIntArrayBenchmark(Blackhole benchmarkData) { + for (int i = 0; i < COUNT; i++) { + benchmarkData.consume(new int[rows]); + } + } + + /** + * put data for int vector + */ + @Benchmark + public void setIntVecBenchmark() { + for (int i = 0; i < rows; i++) { + vecSetData.set(i, randomData[i]); + } + } + + /** + * copy benchmark data + */ + @Benchmark + public void copyArrayBenchmark() { + int[] data = new int[rows]; + System.arraycopy(randomData, 0, data, 0, rows); + } + + /** + * new int vec put benchmark + */ + @Benchmark + public void newPutReleaseIntVecBenchmark() { + IntVec vec = new IntVec(rows); + vec.put(randomData, 0, 0, randomData.length); + vec.close(); + } + + /** + * Int vec put benchmark + */ + @Benchmark + public void putIntVecBenchmark() { + vecPutData.put(randomData, 0, 0, randomData.length); + } + + /** + * Int array set benchmark + */ + @Benchmark + public void setIntArrayBenchmark() { + if (rows >= 0) { + System.arraycopy(randomData, 0, arraySetData, 0, rows); + } + } + + /** + * Arrow vec set benchmark + */ + @Benchmark + public void setArrowVecBenchmark() { + for (int i = 0; i < rows; i++) { + arrowIntVecSet.set(i, randomData[i]); + } + } + + /** + * Int array get benchmark + * + * @return sum value + */ + @Benchmark + public long getIntArrayBenchmark() { + long sum = 0L; + for (int i = 0; i < rows; i++) { + sum += arrayGetData[i]; + } + return sum; + } + + /** + * Int vec get benchmark + * + * @return sum value + */ + @Benchmark + public long getIntVecBenchmark() { + long sum = 0L; + for (int i = 0; i < rows; i++) { + sum += vecGetData.get(i); + } + return sum; + } + + /** + * Int vecs get benchmark + * + * @return sum value + */ + @Benchmark + public long getsIntVecBenchmark() { + long sum = 0L; + int[] result = vecGetDatas.get(0, rows); + for (int datum : result) { + sum += datum; + } + return sum; + } + + /** + * Arrow vec get benchmark + * + * @return sum value + */ + @Benchmark + public long getArrowVecBenchmark() { + long sum = 0L; + int rowsTemp = this.rows; + for (int i = 0; i < rowsTemp; i++) { + sum += arrowIntVecGet.get(i); + } + return sum; + } + + /** + * Get int vec slice size + * + * @return slice size + */ + @Benchmark + public int sliceIntVecBenchmark() { + IntVec slice = vecGetData.slice(2, rows / 2); + return slice.getSize(); + } + + /** + * Copy position of int vec benchmark + * + * @return position size + */ + @Benchmark + public int copyPositionIntVecBenchmark() { + IntVec copyPosition = vecGetData.copyPositions(positions, 0, positions.length); + return copyPosition.getSize(); + } + + private void closeArrowVec(List vecs) { + for (IntVector vec : vecs) { + vec.close(); + } + } + + private void closeVec(List vecs) { + for (Vec vec : vecs) { + vec.close(); + } + } + + public static void main(String[] args) throws Throwable { + Options options = new OptionsBuilder().verbosity(VerboseMode.NORMAL) + .include(".*" + BenchmarkIntVec.class.getSimpleName() + ".*").build(); + + new Runner(options).run(); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/vector/BenchmarkLongVec.java b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/BenchmarkLongVec.java new file mode 100644 index 0000000..274c55d --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/BenchmarkLongVec.java @@ -0,0 +1,359 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import static java.util.concurrent.TimeUnit.MICROSECONDS; + +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import org.openjdk.jmh.runner.options.VerboseMode; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +/** + * Long vec benchmark + * + * @since 2021-8-10 + */ +@State(Scope.Thread) +@OutputTimeUnit(MICROSECONDS) +@Fork(1) +@Warmup(iterations = 1, batchSize = 1) +@Measurement(iterations = 10, batchSize = 1) +@BenchmarkMode(Mode.AverageTime) +public class BenchmarkLongVec { + private static final int ALLOCATOR_CAPACITY = 1024 * 1024; + private static final int COUNT = 1000; + + static { + // this parameter affects arrow get performance + System.setProperty("arrow.enable_null_check_for_get", "false"); + // this parameter affects arrow set performance + System.setProperty("arrow.enable_unsafe_memory_access", "true"); + } + + // arrow get + RootAllocator allocator2; + BigIntVector arrowLongVecGet; + + @Param({"1024"}) + private int rows = 1024; + + // longVec set/put + private LongVec vecPutData; + private LongVec vecSetData; + private long[] arraySetData; + private long[] randomData; + + // longVec get + private LongVec vecGetData; + private LongVec vecGetDatas; + private long[] results; + private long[] arrayGetData; + private int[] positions; + + // arrow set + private RootAllocator allocator1; + private BigIntVector arrowLongVecSet; + + private final Random random = new Random(0); + + /** + * init + */ + @Setup(Level.Iteration) + public void init() { + // for vec set/put + vecPutData = new LongVec(rows); + vecSetData = new LongVec(rows); + arraySetData = new long[rows]; + randomData = new long[rows]; + initValues(randomData, rows); + + // for vec get + vecGetData = new LongVec(rows); + initValues(vecGetData, rows); + vecGetDatas = new LongVec(rows); + initValues(vecGetDatas, rows); + results = new long[rows]; + + positions = new int[rows / 2]; + for (int i = 0; i < rows / 2; i++) { + positions[i] = i; + } + + // arrow set + allocator1 = new RootAllocator(ALLOCATOR_CAPACITY); + arrowLongVecSet = new BigIntVector("longVector1", allocator1); + arrowLongVecSet.allocateNew(rows); + arrowLongVecSet.setValueCount(rows); + + // arrow get + allocator2 = new RootAllocator(ALLOCATOR_CAPACITY); + arrowLongVecGet = new BigIntVector("longVector2", allocator2); + arrowLongVecGet.allocateNew(rows); + initValues(arrowLongVecGet, rows); + arrowLongVecGet.setValueCount(rows); + + arrayGetData = new long[rows]; + initValues(arrayGetData, rows); + } + + /** + * close + */ + @TearDown(Level.Iteration) + public void tearDown() { + vecPutData.close(); + vecSetData.close(); + vecGetData.close(); + + arrowLongVecSet.close(); + allocator1.close(); + + arrowLongVecGet.close(); + allocator2.close(); + } + + private void initValues(LongVec vec, int rowCount) { + for (int i = 0; i < rowCount; i++) { + vec.set(i, random.nextLong()); + } + } + + private void initValues(long[] array, int rowCount) { + for (int i = 0; i < rowCount; i++) { + array[i] = random.nextLong(); + } + } + + private void initValues(BigIntVector arrayVec, int rowCount) { + for (int i = 0; i < rowCount; i++) { + arrayVec.set(i, random.nextLong()); + } + } + + /** + * Create long vec benchmark + * + * @param blackhole blackhole + */ + @Benchmark + public void createLongVecBenchmark(Blackhole blackhole) { + List vecs = new ArrayList<>(); + for (int i = 0; i < COUNT; i++) { + LongVec vec = new LongVec(rows); + blackhole.consume(vec); + vecs.add(vec); + } + closeVec(vecs); + } + + /** + * create benchmark for arrow vector + * + * @param blackhole blackhole + */ + @Benchmark + public void createArrowVecBenchmark(Blackhole blackhole) { + List vecs = new ArrayList<>(); + RootAllocator rootAllocator = new RootAllocator(Long.MAX_VALUE); + for (int i = 0; i < COUNT; i++) { + BigIntVector arrowVec = new BigIntVector("key", rootAllocator); + arrowVec.allocateNew(rows); + blackhole.consume(arrowVec); + vecs.add(arrowVec); + } + closeArrowVec(vecs); + rootAllocator.close(); + } + + /** + * create long array benchmark + * + * @param benchmarkData benchmark data for test + */ + @Benchmark + public void createLongArrayBenchmark(Blackhole benchmarkData) { + for (int i = 0; i < COUNT; i++) { + benchmarkData.consume(new long[rows]); + } + } + + /** + * put data for long vector + */ + @Benchmark + public void setLongVecBenchmark() { + for (int i = 0; i < rows; i++) { + vecSetData.set(i, randomData[i]); + } + } + + /** + * copy benchmark data + */ + @Benchmark + public void copyArrayBenchmark() { + long[] data = new long[rows]; + System.arraycopy(randomData, 0, data, 0, rows); + } + + /** + * new long vec put benchmark + */ + @Benchmark + public void newPutReleaseLongVecBenchmark() { + LongVec vec = new LongVec(rows); + vec.put(randomData, 0, 0, randomData.length); + vec.close(); + } + + /** + * Long vec put benchmark + */ + @Benchmark + public void putLongVecBenchmark() { + vecPutData.put(randomData, 0, 0, randomData.length); + } + + /** + * Long array set benchmark + */ + @Benchmark + public void setLongArrayBenchmark() { + if (rows >= 0) { + System.arraycopy(randomData, 0, arraySetData, 0, rows); + } + } + + /** + * Arrow vec set benchmark + */ + @Benchmark + public void setArrowVecBenchmark() { + for (int i = 0; i < rows; i++) { + arrowLongVecSet.set(i, randomData[i]); + } + } + + /** + * Long array get benchmark + * + * @return sum value + */ + @Benchmark + public long getLongArrayBenchmark() { + long sum = 0L; + for (int i = 0; i < rows; i++) { + sum += arrayGetData[i]; + } + return sum; + } + + /** + * Long vec get benchmark + * + * @return sum value + */ + @Benchmark + public long getLongVecBenchmark() { + long sum = 0L; + for (int i = 0; i < rows; i++) { + sum += vecGetData.get(i); + } + return sum; + } + + /** + * Long vecs get benchmark + * + * @return sum value + */ + @Benchmark + public long getsLongVecBenchmark() { + long sum = 0L; + long[] result = vecGetDatas.get(0, rows); + for (long datum : result) { + sum += datum; + } + return sum; + } + + /** + * Arrow vec get benchmark + * + * @return sum value + */ + @Benchmark + public long getArrowVecBenchmark() { + long sum = 0L; + int rowsTemp = this.rows; + for (int i = 0; i < rowsTemp; i++) { + sum += arrowLongVecGet.get(i); + } + return sum; + } + + /** + * Get long vec slice size + * + * @return slice size + */ + @Benchmark + public int sliceLongVecBenchmark() { + LongVec slice = vecGetData.slice(2, rows / 2); + return slice.getSize(); + } + + /** + * Copy position of long vec benchmark + * + * @return position size + */ + @Benchmark + public int copyPositionLongVecBenchmark() { + LongVec copyPosition = vecGetData.copyPositions(positions, 0, positions.length); + return copyPosition.getSize(); + } + + private void closeArrowVec(List vecs) { + for (BigIntVector vec : vecs) { + vec.close(); + } + } + + private void closeVec(List vecs) { + for (Vec vec : vecs) { + vec.close(); + } + } + + public static void main(String[] args) throws Throwable { + Options options = new OptionsBuilder().verbosity(VerboseMode.NORMAL) + .include(".*" + BenchmarkLongVec.class.getSimpleName() + ".*").build(); + + new Runner(options).run(); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/vector/BenchmarkShortVec.java b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/BenchmarkShortVec.java new file mode 100644 index 0000000..6667834 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/BenchmarkShortVec.java @@ -0,0 +1,364 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import static java.util.concurrent.TimeUnit.MICROSECONDS; + +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.SmallIntVector; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import org.openjdk.jmh.runner.options.VerboseMode; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +/** + * Short vec benchmark + * + * @since 2022-8-2 + */ +@State(Scope.Thread) +@OutputTimeUnit(MICROSECONDS) +@Fork(1) +@Warmup(iterations = 1, batchSize = 1) +@Measurement(iterations = 10, batchSize = 1) +@BenchmarkMode(Mode.AverageTime) +public class BenchmarkShortVec { + private static final int ALLOCATOR_CAPACITY = 1024 * 1024; + private static final int COUNT = 1000; + + static { + // this parameter affects arrow get performance + System.setProperty("arrow.enable_null_check_for_get", "false"); + // this parameter affects arrow set performance + System.setProperty("arrow.enable_unsafe_memory_access", "true"); + } + + // arrow get + RootAllocator allocator2; + SmallIntVector arrowShortVecGet; + + @Param({"1024"}) + private int rows = 1024; + + // ShortVec set/put + private ShortVec vecPutData; + private ShortVec vecSetData; + private short[] arraySetData; + private short[] randomData; + + // ShortVec get + private ShortVec vecGetData; + private ShortVec vecGetDatas; + private short[] results; + private short[] arrayGetData; + private int[] positions; + + // arrow set + private RootAllocator allocator1; + private SmallIntVector arrowShortVecSet; + + private final Random random = new Random(0); + + /** + * init + */ + @Setup(Level.Iteration) + public void init() { + // for vec set/put + vecPutData = new ShortVec(rows); + vecSetData = new ShortVec(rows); + arraySetData = new short[rows]; + randomData = new short[rows]; + initValues(randomData, rows); + + // for vec get + vecGetData = new ShortVec(rows); + initValues(vecGetData, rows); + vecGetDatas = new ShortVec(rows); + initValues(vecGetDatas, rows); + results = new short[rows]; + + positions = new int[rows / 2]; + for (int i = 0; i < rows / 2; i++) { + positions[i] = i; + } + + // arrow set + allocator1 = new RootAllocator(ALLOCATOR_CAPACITY); + arrowShortVecSet = new SmallIntVector("SmallIntVector1", allocator1); + arrowShortVecSet.allocateNew(rows); + arrowShortVecSet.setValueCount(rows); + + // arrow get + allocator2 = new RootAllocator(ALLOCATOR_CAPACITY); + arrowShortVecGet = new SmallIntVector("SmallIntVector2", allocator2); + arrowShortVecGet.allocateNew(rows); + initValues(arrowShortVecGet, rows); + arrowShortVecGet.setValueCount(rows); + + arrayGetData = new short[rows]; + initValues(arrayGetData, rows); + } + + /** + * close + */ + @TearDown(Level.Iteration) + public void tearDown() { + vecPutData.close(); + vecSetData.close(); + vecGetData.close(); + vecGetDatas.close(); + + arrowShortVecSet.close(); + allocator1.close(); + + arrowShortVecGet.close(); + allocator2.close(); + } + + private void initValues(ShortVec vec, int rowCount) { + for (int i = 0; i < rowCount; i++) { + vec.set(i, (short) random.nextInt(Short.MAX_VALUE)); + } + } + + private void initValues(short[] array, int rowCount) { + for (int i = 0; i < rowCount; i++) { + array[i] = (short) random.nextInt(Short.MAX_VALUE); + } + } + + private void initValues(SmallIntVector arrayVec, int rowCount) { + for (int i = 0; i < rowCount; i++) { + arrayVec.set(i, (short) random.nextInt(Short.MAX_VALUE)); + } + } + + /** + * Create short vec benchmark + * + * @param blackhole blackhole + */ + @Benchmark + public void createShortVecBenchmark(Blackhole blackhole) { + List vecs = new ArrayList<>(); + for (int i = 0; i < COUNT; i++) { + ShortVec vec = new ShortVec(rows); + blackhole.consume(vec); + vecs.add(vec); + } + closeVec(vecs); + } + + /** + * create benchmark for arrow vector + * + * @param blackhole blackhole + */ + @Benchmark + public void createArrowVecBenchmark(Blackhole blackhole) { + List vecs = new ArrayList<>(); + RootAllocator rootAllocator = new RootAllocator(Long.MAX_VALUE); + for (int i = 0; i < COUNT; i++) { + SmallIntVector arrowVec = new SmallIntVector("key", rootAllocator); + arrowVec.allocateNew(rows); + blackhole.consume(arrowVec); + vecs.add(arrowVec); + } + closeArrowVec(vecs); + rootAllocator.close(); + } + + /** + * create short array benchmark + * + * @param benchmarkData benchmark data for test + */ + @Benchmark + public void createShortArrayBenchmark(Blackhole benchmarkData) { + for (int i = 0; i < COUNT; i++) { + benchmarkData.consume(new short[rows]); + } + } + + /** + * put data for short vector + */ + @Benchmark + public void setShortVecBenchmark() { + for (int i = 0; i < rows; i++) { + vecSetData.set(i, randomData[i]); + } + } + + /** + * copy benchmark data + */ + @Benchmark + public void copyArrayBenchmark() { + short[] data = new short[rows]; + System.arraycopy(randomData, 0, data, 0, rows); + } + + /** + * new short vec put benchmark + */ + @Benchmark + public void newPutReleaseShortVecBenchmark() { + ShortVec vec = new ShortVec(rows); + vec.put(randomData, 0, 0, randomData.length); + vec.close(); + } + + /** + * Short vec put benchmark + */ + @Benchmark + public void putShortVecBenchmark() { + vecPutData.put(randomData, 0, 0, randomData.length); + } + + /** + * Short array set benchmark + */ + @Benchmark + public void setShortArrayBenchmark() { + if (rows >= 0) { + System.arraycopy(randomData, 0, arraySetData, 0, rows); + } + } + + /** + * Arrow vec set benchmark + */ + @Benchmark + public void setArrowVecBenchmark() { + for (int i = 0; i < rows; i++) { + arrowShortVecSet.set(i, randomData[i]); + } + } + + /** + * Short array get benchmark + * + * @return sum value + */ + @Benchmark + public long getShortArrayBenchmark() { + long sum = 0L; + for (int i = 0; i < rows; i++) { + sum += arrayGetData[i]; + } + return sum; + } + + /** + * Short vec get benchmark + * + * @return sum value + */ + @Benchmark + public long getShortVecBenchmark() { + long sum = 0L; + for (int i = 0; i < rows; i++) { + sum += vecGetData.get(i); + } + return sum; + } + + /** + * Short vecs get benchmark + * + * @return sum value + */ + @Benchmark + public long getsShortVecBenchmark() { + long sum = 0L; + short[] result = vecGetDatas.get(0, rows); + for (short datum : result) { + sum += datum; + } + return sum; + } + + /** + * Arrow vec get benchmark + * + * @return sum value + */ + @Benchmark + public long getArrowVecBenchmark() { + long sum = 0L; + int rowsTemp = this.rows; + for (int i = 0; i < rowsTemp; i++) { + sum += arrowShortVecGet.get(i); + } + return sum; + } + + /** + * Get Short vec slice size + * + * @return slice size + */ + @Benchmark + public int sliceShortVecBenchmark() { + ShortVec slice = vecGetData.slice(2, rows / 2); + int size = slice.getSize(); + slice.close(); + return size; + } + + /** + * Copy position of int vec benchmark + * + * @return position size + */ + @Benchmark + public int copyPositionShortVecBenchmark() { + ShortVec copyPosition = vecGetData.copyPositions(positions, 0, positions.length); + int size = copyPosition.getSize(); + copyPosition.close(); + return size; + } + + private void closeArrowVec(List vecs) { + for (SmallIntVector vec : vecs) { + vec.close(); + } + } + + private void closeVec(List vecs) { + for (Vec vec : vecs) { + vec.close(); + } + } + + public static void main(String[] args) throws Throwable { + Options options = new OptionsBuilder().verbosity(VerboseMode.NORMAL) + .include(".*" + BenchmarkShortVec.class.getSimpleName() + ".*").build(); + + new Runner(options).run(); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/vector/BenchmarkVarcharVec.java b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/BenchmarkVarcharVec.java new file mode 100644 index 0000000..15b4a16 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/BenchmarkVarcharVec.java @@ -0,0 +1,371 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import static java.util.concurrent.TimeUnit.MICROSECONDS; + +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VarCharVector; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import org.openjdk.jmh.runner.options.VerboseMode; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; + +/** + * Varchar vec benchmark + * + * @since 2021-8-10 + */ +@State(Scope.Thread) +@OutputTimeUnit(MICROSECONDS) +@Fork(1) +@Warmup(iterations = 1, batchSize = 1) +@Measurement(iterations = 10, batchSize = 1) +@BenchmarkMode(Mode.AverageTime) +public class BenchmarkVarcharVec { + private static final int ALLOCATOR_CAPACITY = 1024 * 1024; + private static final int COUNT = 1000; + + static { + // this parameter affects arrow get performance + System.setProperty("arrow.enable_null_check_for_get", "false"); + // this parameter affects arrow set performance + System.setProperty("arrow.enable_unsafe_memory_access", "true"); + } + + @Param({"1024"}) + private int rows = 1024; + + // varcharVec set/put + private VarcharVec vecSetData; + private VarcharVec vecPutData; + private VarcharVecTest putDataSource; + + // heap byteBuffer + private VarcharVecTest vecTestSetData; + + private ByteBuffer[] byteValues; + + // varcharVec get + private VarcharVec vecGetData; + private VarcharVecTest varcharVecTest; + private int[] positions; + + // arrow set + private RootAllocator allocator1; + private VarCharVector arrowVecSet; + + // arrow get + private RootAllocator allocator2; + private VarCharVector arrowVecGet; + + /** + * init + */ + @Setup(Level.Iteration) + public void init() { + // for varchar set/put + vecPutData = new VarcharVec(rows); + putDataSource = new VarcharVecTest(rows * 8, rows); + initValues(putDataSource, rows); + + vecSetData = new VarcharVec(rows); + vecTestSetData = new VarcharVecTest(rows * 8, rows); + + byteValues = new ByteBuffer[rows]; + for (int i = 0; i < rows; i++) { + String str = String.valueOf(i * 1000); + ByteBuffer buffer = ByteBuffer.allocate(str.length()); + buffer.put(str.getBytes(StandardCharsets.UTF_8), 0, str.length()); + byteValues[i] = buffer; + } + + // varcharVec get + vecGetData = new VarcharVec(rows); + initValues(vecGetData, rows); + varcharVecTest = new VarcharVecTest(rows * 8, rows); + initValues(varcharVecTest, rows); + + positions = new int[rows / 2]; + for (int i = 0; i < rows / 2; i++) { + positions[i] = i; + } + + // arrow set + allocator1 = new RootAllocator(ALLOCATOR_CAPACITY); + arrowVecSet = new VarCharVector("varchar1", allocator1); + arrowVecSet.allocateNew(rows * 8); + arrowVecSet.setValueCount(rows); + + // arrow get + allocator2 = new RootAllocator(ALLOCATOR_CAPACITY); + arrowVecGet = new VarCharVector("longVector2", allocator2); + arrowVecGet.allocateNew(rows); + initValues(arrowVecGet, rows); + arrowVecGet.setValueCount(rows); + } + + /** + * close + */ + @TearDown(Level.Iteration) + public void tearDown() { + vecPutData.close(); + vecSetData.close(); + vecGetData.close(); + + arrowVecSet.close(); + allocator1.close(); + + arrowVecGet.close(); + allocator2.close(); + } + + private void initValues(VarcharVec vec, int rowCount) { + for (int i = 0; i < rowCount; i++) { + vec.set(i, String.valueOf(i * 1000).getBytes(StandardCharsets.UTF_8)); + } + } + + private void initValues(VarcharVecTest heapByteBuf, int rowCount) { + for (int i = 0; i < rowCount; i++) { + heapByteBuf.set(i, String.valueOf(i * 1000).getBytes(StandardCharsets.UTF_8)); + } + } + + private void initValues(VarCharVector arrayVec, int rowCount) { + for (int i = 0; i < rowCount; i++) { + arrayVec.set(i, String.valueOf(i * 1000).getBytes(StandardCharsets.UTF_8)); + } + } + + /** + * Create varchar vec benchmark + * + * @param blackhole blackhole + */ + @Benchmark + public void createVarcharVecBenchmark(Blackhole blackhole) { + List vecs = new ArrayList<>(); + for (int i = 0; i < COUNT; i++) { + VarcharVec vec = new VarcharVec(rows); + blackhole.consume(vec); + vecs.add(vec); + } + closeVec(vecs); + } + + /** + * Create arrow vec benchmark + * + * @param blackhole blackhole + */ + @Benchmark + public void createArrowVecBenchmark(Blackhole blackhole) { + List vecs = new ArrayList<>(); + RootAllocator rootAllocator = new RootAllocator(Long.MAX_VALUE); + for (int i = 0; i < COUNT; i++) { + VarCharVector arrowVec = new VarCharVector("key", rootAllocator); + arrowVec.allocateNew(rows, rows); + blackhole.consume(arrowVec); + vecs.add(arrowVec); + } + closeArrowVec(vecs); + rootAllocator.close(); + } + + /** + * Put varchar vec benchmark + * + * @param blackhole blackhole + */ + @Benchmark + public void putVarcharVecBenchmark(Blackhole blackhole) { + vecPutData.put(0, putDataSource.getData(), 0, putDataSource.offsets, 0, rows); + } + + /** + * Set arrow vec benchmark + * + * @param blackhole blackhole + */ + @Benchmark + public void setArrowVecBenchmark(Blackhole blackhole) { + for (int i = 0; i < rows; i++) { + arrowVecSet.set(i, byteValues[i].array()); + } + } + + /** + * Set heap bytebuffer benchmark + * + * @param blackhole blackhole + */ + @Benchmark + public void setHeapBytebufferBenchmark(Blackhole blackhole) { + for (int i = 0; i < rows; i++) { + vecTestSetData.set(i, byteValues[i].array()); + } + } + + /** + * Set varchar vec benchmark + */ + @Benchmark + public void setVarcharVecBenchmark() { + for (int i = 0; i < rows; i++) { + vecSetData.set(i, byteValues[i].array()); + } + } + + /** + * Get heap bytebuffer benchmark + * + * @return sum value + */ + @Benchmark + public long getHeapBytebufferBenchmark() { + long sum = 0L; + for (int i = 0; i < rows; i++) { + sum += varcharVecTest.get(i).length; + } + return sum; + } + + /** + * Get varchar vec benchmark + * + * @return sum value + */ + @Benchmark + public long getVarcharVecBenchmark() { + long sum = 0L; + for (int i = 0; i < rows; i++) { + sum += vecGetData.get(i).length; + } + return sum; + } + + /** + * Get arrow vec benchmark + * + * @return sum value + */ + @Benchmark + public long getArrowVecBenchmark() { + long sum = 0L; + for (int i = 0; i < rows; i++) { + sum += arrowVecGet.get(i).length; + } + return sum; + } + + /** + * Get varchar slice size + * + * @return slice size + */ + @Benchmark + public int sliceVarcharVecBenchmark() { + VarcharVec slice = vecGetData.slice(2, rows / 2); + return slice.getSize(); + } + + /** + * Copy position varchar vec benchmark + * + * @return copy position size + */ + @Benchmark + public int copyPositionVarcharVecBenchmark() { + VarcharVec copyPosition = vecGetData.copyPositions(positions, 0, positions.length); + return copyPosition.getSize(); + } + + static class VarcharVecTest { + int[] offsets; + ByteBuffer byteBuffer; + + public VarcharVecTest(int capacityInBytes, int size) { + offsets = new int[size + 1]; + byteBuffer = ByteBuffer.allocate(capacityInBytes); + } + + /** + * Get data + * + * @param index index + * @return data + */ + public byte[] get(int index) { + int startOffset = offsets[index]; + int dataLen = offsets[index + 1] - offsets[index]; + byteBuffer.position(startOffset); + byte[] data = new byte[dataLen]; + byteBuffer.get(data, 0, dataLen); + return data; + } + + /** + * Set data + * + * @param index index + * @param value data + */ + public void set(int index, byte[] value) { + int startOffset = offsets[index]; + offsets[index + 1] = startOffset + value.length; + byteBuffer.position(startOffset); + byteBuffer.put(value, 0, value.length); + } + + /** + * Get Data + * + * @return data + */ + public byte[] getData() { + return byteBuffer.array(); + } + } + + private void closeArrowVec(List vecs) { + for (VarCharVector vec : vecs) { + vec.close(); + } + } + + private void closeVec(List vecs) { + for (Vec vec : vecs) { + vec.close(); + } + } + + public static void main(String[] args) throws Throwable { + Options options = new OptionsBuilder().verbosity(VerboseMode.NORMAL) + .include(".*" + BenchmarkVarcharVec.class.getSimpleName() + ".*").build(); + + new Runner(options).run(); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/vector/MergeVectorsTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/MergeVectorsTest.java new file mode 100644 index 0000000..75c3c70 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/MergeVectorsTest.java @@ -0,0 +1,146 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import static org.testng.Assert.assertEquals; + +import org.testng.annotations.Test; + +/** + * Merge vector test + * + * @since 2021-7-7 + */ +public class MergeVectorsTest { + /** + * test int vector merge + */ + @Test + public void testIntVectorsMerge() { + IntVec vec1 = new IntVec(10000); + IntVec vec2 = new IntVec(20000); + IntVec vec3 = new IntVec(40000); + IntVec vec = new IntVec(70000); + for (int i = 0; i < vec1.getSize(); i++) { + vec1.set(i, i); + } + for (int i = vec1.getSize(), j = 0; i < vec2.getSize() + vec1.getSize(); i++, j++) { + vec2.set(j, i); + } + for (int i = vec1.getSize() + vec2.getSize(), + j = 0; i < vec3.getSize() + vec2.getSize() + vec1.getSize(); i++, j++) { + vec3.set(j, i); + } + vec.append(vec1, 0, vec1.getSize()); + vec.append(vec2, vec1.getSize(), vec2.getSize()); + vec.append(vec3, vec1.getSize() + vec2.getSize(), vec3.getSize()); + for (int i = 0; i < vec1.getSize() + vec2.getSize() + vec3.getSize(); i++) { + assertEquals(vec.get(i), i); + } + + closeVecs(new Vec[]{vec1, vec2, vec3, vec}); + } + + /** + * test double vector merge + */ + @Test + public void testDoubleVectorsMerge() { + DoubleVec vec1 = new DoubleVec(10000); + DoubleVec vec2 = new DoubleVec(20000); + DoubleVec vec3 = new DoubleVec(40000); + DoubleVec vec = new DoubleVec(70000); + // Creating and appending Vector 1 + for (int i = 0; i < vec1.getSize(); i++) { + vec1.set(i, (double) i); + } + vec.append(vec1, 0, vec1.getSize()); + // Creating and appending Vector 2 + for (int i = vec1.getSize(), j = 0; i < vec2.getSize() + vec1.getSize(); i++, j++) { + vec2.set(j, (double) i); + } + vec.append(vec2, vec1.getSize(), vec2.getSize()); + // Creating and appending Vector 3 + for (int i = vec1.getSize() + vec2.getSize(), + j = 0; i < vec3.getSize() + vec2.getSize() + vec1.getSize(); i++, j++) { + vec3.set(j, (double) i); + } + vec.append(vec3, vec1.getSize() + vec2.getSize(), vec3.getSize()); + + for (int i = 0; i < vec1.getSize() + vec2.getSize() + vec3.getSize(); i++) { + assertEquals(vec.get(i), (double) i); + } + + closeVecs(new Vec[]{vec1, vec2, vec3, vec}); + } + + /** + * test short vector merge + */ + @Test + public void testShortVectorsMerge() { + ShortVec vec1 = new ShortVec(10000); + ShortVec vec2 = new ShortVec(10000); + ShortVec vec3 = new ShortVec(10000); + ShortVec vec = new ShortVec(30000); + for (int i = 0; i < vec1.getSize(); i++) { + vec1.set(i, (short) i); + } + for (int i = vec1.getSize(), j = 0; i < vec2.getSize() + vec1.getSize(); i++, j++) { + vec2.set(j, (short) i); + } + for (int i = vec1.getSize() + vec2.getSize(), + j = 0; i < vec3.getSize() + vec2.getSize() + vec1.getSize(); i++, j++) { + vec3.set(j, (short) i); + } + vec.append(vec1, 0, vec1.getSize()); + vec.append(vec2, vec1.getSize(), vec2.getSize()); + vec.append(vec3, vec1.getSize() + vec2.getSize(), vec3.getSize()); + for (int i = 0; i < vec1.getSize() + vec2.getSize() + vec3.getSize(); i++) { + assertEquals(vec.get(i), i); + } + + closeVecs(new Vec[]{vec1, vec2, vec3, vec}); + } + + /** + * test long vector merge + */ + @Test + public void testLongVectorsMerge() { + LongVec vec1 = new LongVec(10000); + LongVec vec2 = new LongVec(20000); + LongVec vec3 = new LongVec(40000); + LongVec vec = new LongVec(70000); + // Creating and appending Vector 1 + for (int i = 0; i < vec1.getSize(); i++) { + vec1.set(i, (long) i); + } + vec.append(vec1, 0, vec1.getSize()); + // Creating and appending Vector 2 + for (int i = vec1.getSize(), j = 0; i < vec2.getSize() + vec1.getSize(); i++, j++) { + vec2.set(j, (long) i); + } + vec.append(vec2, vec1.getSize(), vec2.getSize()); + // Creating and appending Vector 3 + for (int i = vec1.getSize() + vec2.getSize(), + j = 0; i < vec3.getSize() + vec2.getSize() + vec1.getSize(); i++, j++) { + vec3.set(j, (long) i); + } + vec.append(vec3, vec1.getSize() + vec2.getSize(), vec3.getSize()); + + for (int i = 0; i < vec1.getSize() + vec2.getSize() + vec3.getSize(); i++) { + assertEquals(vec.get(i), (long) i); + } + + closeVecs(new Vec[]{vec1, vec2, vec3, vec}); + } + + private void closeVecs(Vec[] vecs) { + for (Vec vec : vecs) { + vec.close(); + } + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestBooleanVec.java b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestBooleanVec.java new file mode 100644 index 0000000..4a5735d --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestBooleanVec.java @@ -0,0 +1,179 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_BOOLEAN; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +import org.testng.annotations.Test; + +/** + * Test boolean vec + * + * @since 2021-7-2 + */ +public class TestBooleanVec { + /** + * test new vector + */ + @Test + public void testNewVector() { + BooleanVec vector1 = new BooleanVec(256); + assertEquals(vector1.getSize(), 256); + assertEquals(vector1.getRealValueBufCapacityInBytes(), 256); + assertEquals(vector1.getType().getId(), OMNI_BOOLEAN); + vector1.close(); + + BooleanVec vector2 = new BooleanVec(251); + assertEquals(vector2.getSize(), 251); + assertEquals(vector2.getRealValueBufCapacityInBytes(), 251); + assertEquals(vector2.getType().getId(), OMNI_BOOLEAN); + vector2.close(); + } + + /** + * test slice + */ + @Test + public void testSlice() { + BooleanVec originalVec = new BooleanVec(10); + for (int i = 0; i < originalVec.getSize(); i++) { + originalVec.set(i, i % 2 == 0); + } + int offset = 3; + BooleanVec slice1 = originalVec.slice(offset, 4); + assertEquals(slice1.getSize(), 4); + for (int i = 0; i < slice1.getSize(); i++) { + assertEquals(slice1.get(i), originalVec.get(i + offset), "Error item value at: " + i); + } + + BooleanVec slice2 = slice1.slice(1, 2); + + for (int i = 0; i < slice2.getSize(); i++) { + assertEquals(slice2.get(i), originalVec.get(i + offset + 1), "Error item value at: " + i); + } + + originalVec.close(); + slice1.close(); + slice2.close(); + } + + /** + * test set and get value + */ + @Test + public void testSetAndGetValue() { + final int size = 1024; + BooleanVec vec1 = new BooleanVec(size); + for (int i = 0; i < size; i++) { + vec1.set(i, i % 2 == 0); + } + + for (int i = 0; i < size; i++) { + if (i % 2 == 0) { + assertTrue(vec1.get(i)); + } else { + assertFalse(vec1.get(i)); + } + } + vec1.close(); + } + + /** + * test set values + */ + @Test + public void testSetValues() { + boolean[] values = {true, false, false, true, true}; + BooleanVec vector1 = new BooleanVec(values.length); + vector1.put(values, 0, 0, values.length); + for (int i = 0; i < values.length; i++) { + assertEquals(vector1.get(i), values[i]); + } + vector1.close(); + + BooleanVec vector2 = new BooleanVec(values.length); + vector2.put(values, 1, 2, 3); + for (int i = 0; i < 3; i++) { + assertEquals(vector2.get(i + 1), values[i + 2]); + } + vector2.close(); + + byte[] byteValues = {1, 0, 0, 1, 1}; + BooleanVec vector3 = new BooleanVec(byteValues.length); + vector3.put(byteValues, 0, 0, byteValues.length); + for (int i = 0; i < byteValues.length; i++) { + assertEquals(vector3.get(i), byteValues[i] == Vec.NULL); + } + vector3.close(); + BooleanVec vector4 = new BooleanVec(byteValues.length); + vector4.put(byteValues, 1, 2, 3); + for (int i = 0; i < 3; i++) { + assertEquals(vector4.get(i + 1), byteValues[i + 2] == Vec.NULL); + } + vector4.close(); + } + + /** + * test value null + */ + @Test + public void testValueNull() { + BooleanVec vector1 = new BooleanVec(256); + for (int i = 0; i < vector1.getSize(); i++) { + if (i % 5 == 0) { + vector1.setNull(i); + } else { + vector1.set(i, i % 2 == 0); + } + } + for (int i = 0; i < vector1.getSize(); i++) { + if (i % 5 == 0) { + assertTrue(vector1.isNull(i)); + } else { + assertEquals(vector1.get(i), i % 2 == 0); + } + } + + vector1.close(); + } + + /** + * test copy positions + */ + @Test + public void testCopyPositions() { + BooleanVec originalVector = new BooleanVec(4); + for (int i = 0; i < originalVector.getSize(); i++) { + originalVector.set(i, i % 2 == 0); + } + + int[] positions = {1, 3}; + BooleanVec copyPositionVector = originalVector.copyPositions(positions, 0, 2); + assertEquals(copyPositionVector.getRealValueBufCapacityInBytes(), 2); + for (int i = 0; i < copyPositionVector.getSize(); i++) { + assertEquals(copyPositionVector.get(i), originalVector.get(positions[i])); + } + + originalVector.close(); + copyPositionVector.close(); + } + + @Test + public void testGetValues() { + boolean[] values = new boolean[1024]; + for (int i = 0; i < values.length; i++) { + values[i] = i % 3 == 0; + } + BooleanVec originalVec = new BooleanVec(values.length); + originalVec.put(values, 0, 0, values.length); + + boolean[] actual = originalVec.get(0, values.length); + assertEquals(actual, values); + originalVec.close(); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestContainerVec.java b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestContainerVec.java new file mode 100644 index 0000000..c7150dc --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestContainerVec.java @@ -0,0 +1,176 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import static nova.hetu.omniruntime.type.DoubleDataType.DOUBLE; +import static nova.hetu.omniruntime.type.LongDataType.LONG; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +import nova.hetu.omniruntime.type.ContainerDataType; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.Decimal128DataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.type.ShortDataType; +import nova.hetu.omniruntime.type.VarcharDataType; + +import org.testng.annotations.Test; + +/** + * Container vec test: Now(2023.3.4), c++ do not support append\copyPositions\slice method + * + * @since 2021-7-6 + */ +public class TestContainerVec { + @Test(enabled = false) + public void testSlice() { + int rows = 10; + DoubleVec field1 = new DoubleVec(rows); + double[] data1 = new double[]{0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9}; + field1.put(data1, 0, 0, rows); + LongVec field2 = new LongVec(rows); + long[] data2 = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + field2.put(data2, 0, 0, rows); + ContainerVec originalVec = new ContainerVec(2, rows, + new long[]{field1.getNativeVector(), field2.getNativeVector()}, new DataType[]{DOUBLE, LONG}); + + int offset = 1; + ContainerVec sliced = originalVec.slice(offset, 4); + DoubleVec result1 = new DoubleVec(sliced.get(0)); + LongVec result2 = new LongVec(sliced.get(1)); + for (int i = 0; i < 5; i++) { + assertEquals(result1.get(i), data1[offset + i]); + assertEquals(result2.get(i), data2[offset + i]); + } + originalVec.close(); + sliced.close(); + } + + @Test(enabled = false) + public void testCopyPositions() { + int rows = 10; + DoubleVec field1 = new DoubleVec(rows); + double[] data1 = new double[]{0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9}; + field1.put(data1, 0, 0, rows); + LongVec field2 = new LongVec(rows); + long[] data2 = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + field2.put(data2, 0, 0, rows); + ContainerVec originalVec = new ContainerVec(2, rows, + new long[]{field1.getNativeVector(), field2.getNativeVector()}, new DataType[]{DOUBLE, LONG}); + + int[] positions = new int[]{1, 3, 5, 7, 9}; + ContainerVec copyPositionsed = originalVec.copyPositions(positions, 0, 5); + DoubleVec result1 = new DoubleVec(copyPositionsed.get(0)); + LongVec result2 = new LongVec(copyPositionsed.get(1)); + for (int i = 0; i < 5; i++) { + assertEquals(result1.get(i), data1[positions[i]]); + assertEquals(result2.get(i), data2[positions[i]]); + } + originalVec.close(); + copyPositionsed.close(); + } + + @Test + public void testAppend() { + int rows = 5; + DoubleVec field1 = new DoubleVec(rows); + double[] doubles = new double[]{1.1, 2.2, 3.3, 4.4, 5.5}; + field1.put(doubles, 0, 0, rows); + LongVec field2 = new LongVec(rows); + long[] longs = new long[]{1, 2, 3, 4, 5}; + field2.put(longs, 0, 0, rows); + + DoubleVec field11 = new DoubleVec(rows); + double[] doubles1 = new double[]{6.6, 7.7, 8.8, 9.9, 10.1}; + field11.put(doubles1, 0, 0, rows); + LongVec field22 = new LongVec(rows); + long[] longs1 = new long[]{6, 7, 8, 9, 10}; + field22.put(longs1, 0, 0, rows); + ContainerVec originalVec1 = new ContainerVec(2, rows, + new long[]{field11.getNativeVector(), field22.getNativeVector()}, new DataType[]{DOUBLE, LONG}); + + DoubleVec appendedDouble = new DoubleVec(rows * 2); + LongVec appendedLong = new LongVec(rows * 2); + ContainerVec appended = new ContainerVec(2, rows * 2, + new long[]{appendedDouble.getNativeVector(), appendedLong.getNativeVector()}, + new DataType[]{DOUBLE, LONG}); + + ContainerVec originalVec = new ContainerVec(2, rows, + new long[]{field1.getNativeVector(), field2.getNativeVector()}, new DataType[]{DOUBLE, LONG}); + appended.append(originalVec, 0, 5); + appended.append(originalVec1, 5, 5); + + double[] expected1 = new double[]{1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1}; + DoubleVec result1 = new DoubleVec(appended.get(0)); + for (int i = 0; i < result1.getSize(); i++) { + assertEquals(result1.get(i), expected1[i]); + } + + long[] expected2 = new long[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + LongVec result2 = new LongVec(appended.get(1)); + for (int i = 0; i < result2.getSize(); i++) { + assertEquals(result2.get(i), expected2[i]); + } + + originalVec.close(); + originalVec1.close(); + appended.close(); + } + + @Test + public void TestContainerVecSerialize() { + int rows = 3; + ShortVec shortVec = new ShortVec(rows); + IntVec intVec = new IntVec(rows); + LongVec longVec = new LongVec(rows); + Decimal128Vec decimal128Vec = new Decimal128Vec(rows); + VarcharVec varcharVec = new VarcharVec(10); + + long[] subAddr = new long[]{shortVec.getNativeVector(), intVec.getNativeVector(), longVec.getNativeVector()}; + DataType[] subTypes = new DataType[]{ShortDataType.SHORT, IntDataType.INTEGER, LongDataType.LONG}; + ContainerVec subContainerVec = new ContainerVec(3, rows, subAddr, subTypes); + + long[] addr = new long[]{decimal128Vec.getNativeVector(), varcharVec.getNativeVector(), + subContainerVec.getNativeVector()}; + DataType[] dataTypes = new DataType[]{Decimal128DataType.DECIMAL128, VarcharDataType.VARCHAR, + new ContainerDataType(new DataType[]{ShortDataType.SHORT, IntDataType.INTEGER, LongDataType.LONG})}; + + ContainerVec vec = new ContainerVec(dataTypes.length, rows, addr, dataTypes); + ContainerVec vecFromNative = new ContainerVec(vec.getNativeVector()); + DataType[] dataTypesFromNative = vecFromNative.getDataTypes(); + + for (int i = 0; i < dataTypes.length; i++) { + assertEquals(dataTypes[i], dataTypesFromNative[i]); + } + + vec.close(); + } + + @Test + public void testNullFlagWithSet() { + int rows = 10; + IntVec sub1 = new IntVec(rows); + LongVec sub2 = new LongVec(rows); + + long[] subAddrs = new long[]{sub1.slice(0, rows).getNativeVector(), sub2.slice(0, rows).getNativeVector()}; + DataType[] subTypes = new DataType[]{IntDataType.INTEGER, LONG}; + ContainerVec hasNulls = new ContainerVec(2, rows, subAddrs, subTypes); + byte[] nulls = new byte[]{1, 0, 1, 0, 1, 0, 1, 0, 1, 0}; + hasNulls.setNulls(0, nulls, 0, rows); + assertTrue(hasNulls.hasNull()); + hasNulls.close(); + + subAddrs = new long[]{sub1.getNativeVector(), sub2.getNativeVector()}; + ContainerVec hasNull = new ContainerVec(2, rows, subAddrs, subTypes); + for (int i = 0; i < rows; i++) { + if (i % 2 == 0) { + hasNull.setNull(i); + } + } + assertTrue(hasNull.hasNull()); + hasNull.close(); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestDecimal128Vec.java b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestDecimal128Vec.java new file mode 100644 index 0000000..fdc79f6 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestDecimal128Vec.java @@ -0,0 +1,252 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.Decimal128DataType; + +import org.testng.annotations.Test; + +import java.math.BigInteger; + +/** + * test decimal 128-bit vec + * + * @since 2021-6-23 + */ +public class TestDecimal128Vec { + /** + * test new vector + */ + @Test + public void testNewVector() { + Decimal128Vec vec = new Decimal128Vec(256); + assertEquals(vec.getSize(), 256); + assertEquals(vec.getRealValueBufCapacityInBytes(), 4096); + assertEquals(vec.getType().getId(), DataType.DataTypeId.OMNI_DECIMAL128); + assertEquals(((Decimal128DataType) (vec.getType())).getPrecision(), 38); + assertEquals(((Decimal128DataType) (vec.getType())).getScale(), 0); + vec.close(); + } + + /** + * test slice + */ + @Test + public void testSlice() { + final int size = 10; + Decimal128Vec vec1 = new Decimal128Vec(size); + for (int i = 0; i < size; i++) { + long[] value = {1 << (i + 1), 3}; + vec1.set(i, value); + } + Decimal128Vec slice1 = vec1.slice(3, 5); + Decimal128Vec slice2 = vec1.slice(0, vec1.getSize()); + for (int i = 0; i < slice1.getSize(); i++) { + assertEquals(vec1.get(i + 3)[0], slice1.get(i)[0], "Error item value at: " + i); + assertEquals(vec1.get(i + 3)[1], slice1.get(i)[1], "Error item value at: " + i); + } + for (int i = 0; i < slice2.getSize(); i++) { + assertEquals(vec1.get(i)[0], slice2.get(i)[0], "Error item value at: " + i); + assertEquals(vec1.get(i)[1], slice2.get(i)[1], "Error item value at: " + i); + } + vec1.close(); + slice1.close(); + slice2.close(); + } + + /** + * test set and get value + */ + @Test + public void testSetAndGetValue() { + final int size = 1024; + Decimal128Vec vec1 = new Decimal128Vec(size); + long[] values = new long[size * 2]; + for (int i = 0; i < size; i++) { + long[] value = {i, 3}; + vec1.set(i, value); + values[i * 2] = i; + values[i * 2 + 1] = 3; + } + + for (int i = 0; i < size; i++) { + assertEquals(values[i * 2], vec1.get(i)[0]); + assertEquals(values[i * 2 + 1], vec1.get(i)[1]); + } + vec1.close(); + } + + /** + * test put value + */ + @Test + public void testPutValues() { + long[] values = {1, 3, 4, 6, 7, 8}; + Decimal128Vec vec1 = new Decimal128Vec(values.length / 2); + vec1.put(values, 0, 0, values.length); + for (int i = 0; i < values.length / 2; i++) { + assertEquals(vec1.get(i)[0], values[i * 2]); + assertEquals(vec1.get(i)[1], values[i * 2 + 1]); + } + + Decimal128Vec vec2 = new Decimal128Vec(values.length / 2); + vec2.put(values, 1, 2, 4); + for (int i = 0; i < 2; i++) { + assertEquals(vec2.get(i + 1)[0], values[i * 2 + 2]); + assertEquals(vec2.get(i + 1)[1], values[i * 2 + 3]); + } + + vec1.close(); + vec2.close(); + } + + /** + * test value null + */ + @Test + public void testValueNull() { + final int size = 10; + Decimal128Vec vec = new Decimal128Vec(size); + for (int i = 0; i < size; i++) { + long[] value = {i, 3}; + vec.set(i, value); + } + + for (int i = 0; i < size; i++) { + vec.setNull(i); + } + + for (int i = 0; i < size; i++) { + assertTrue(vec.isNull(i)); + } + vec.close(); + } + + /** + * test copy position + */ + @Test + public void testCopyPositions() { + Decimal128Vec originalVector = new Decimal128Vec(4); + for (int i = 0; i < originalVector.getSize(); i++) { + long[] value = {0, i}; + originalVector.set(i, value); + } + + int[] positions = {1, 3}; + Decimal128Vec copyPositionVector = originalVector.copyPositions(positions, 0, 2); + assertEquals(copyPositionVector.getRealValueBufCapacityInBytes(), 32); + for (int i = 0; i < copyPositionVector.getSize(); i++) { + assertEquals(copyPositionVector.get(i), originalVector.get(positions[i])); + } + + originalVector.close(); + copyPositionVector.close(); + } + + /** + * test zero sized allocate + */ + @Test + public void testZeroSizeAllocate() { + Decimal128Vec v1 = new Decimal128Vec(0); + long[] values = new long[0]; + v1.put(values, 0, 0, values.length); + v1.close(); + } + + /** + * test BigInteger to long array and vice versa + */ + @Test + public void testBigIntegerTrans() { + BigInteger bigInteger = new BigInteger("11111111111111111111111111111111111111"); + + long[] longs = Decimal128Vec.putDecimal(bigInteger); + assertEquals(longs[1], 602334540269724685L); + assertEquals(longs[0], -8122175193715281465L); + + BigInteger newBigInteger = Decimal128Vec.getDecimal(longs); + assertEquals(newBigInteger, bigInteger); + } + + /** + * test Decimal128Vec set/get BigInteger + */ + @Test + public void testSetGetBigInteger() { + final int size = 1024; + BigInteger decimal128 = new BigInteger("11111111111111111111111111111111111111"); + Decimal128Vec vec1 = new Decimal128Vec(size); + BigInteger decimal64 = new BigInteger("111111"); + Decimal128Vec vec2 = new Decimal128Vec(size); + + for (int i = 0; i < size; ++i) { + vec1.setBigInteger(i, decimal128); + } + + for (int i = 0; i < size; ++i) { + BigInteger val = vec1.getBigInteger(i); + assertEquals(val, decimal128); + } + + for (int i = 0; i < size; ++i) { + vec2.setBigInteger(i, decimal64); + } + + for (int i = 0; i < size; ++i) { + BigInteger val = vec2.getBigInteger(i); + assertEquals(val, decimal64); + } + vec1.close(); + vec2.close(); + } + + /** + * test Decimal128Vec set/get BigInteger using bytes + */ + @Test + public void testSetGetBigIntegerBytes() { + final int size = 1024; + BigInteger decimal128 = new BigInteger("11111111111111111111111111111111111111"); + byte[] bytes = decimal128.toByteArray(); + boolean isNegative = decimal128.compareTo(new BigInteger("0")) == -1; + Decimal128Vec vec1 = new Decimal128Vec(size); + + for (int i = 0; i < size; ++i) { + vec1.setBigInteger(i, bytes, isNegative); + } + + for (int i = 0; i < size; ++i) { + byte[] val = vec1.getBytes(i); + BigInteger bigInteger = new BigInteger(val); + assertEquals(bigInteger, decimal128); + } + vec1.close(); + } + + /** + * test Decimal128Vec set/get BigInteger + */ + @Test + public void testBigIntegerByteLengthBetweenEightAndSixteen() { + BigInteger decimal1 = new BigInteger("111311100000000000000000000"); + BigInteger decimal2 = new BigInteger("-99999999999999999999999999"); + Decimal128Vec vec = new Decimal128Vec(2); + + vec.setBigInteger(0, decimal1); + vec.setBigInteger(1, decimal2); + + BigInteger val1 = vec.getBigInteger(0); + BigInteger val2 = vec.getBigInteger(1); + assertEquals(val1, decimal1); + assertEquals(val2, decimal2); + vec.close(); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestDictionaryVec.java b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestDictionaryVec.java new file mode 100644 index 0000000..705166a --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestDictionaryVec.java @@ -0,0 +1,259 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotEquals; +import static org.testng.Assert.assertTrue; + +import nova.hetu.omniruntime.util.TestUtils; + +import org.testng.annotations.Test; + +import java.nio.charset.StandardCharsets; + +/** + * test dictionary vec + * + * @since 2021-9-8 + */ +public class TestDictionaryVec { + /** + * test slice + */ + @Test + public void testSlice() { + LongVec originalVec = new LongVec(100); + for (int i = 0; i < originalVec.getSize(); i++) { + originalVec.set(i, i); + } + + int[] ids = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + DictionaryVec dictionaryVec = new DictionaryVec(originalVec, ids); + int offset1 = 3; + DictionaryVec slice1 = dictionaryVec.slice(offset1, 4); + assertEquals(slice1.getSize(), 4); + for (int i = 0; i < slice1.getSize(); i++) { + long value = slice1.getLong(i); + assertEquals(value, originalVec.get(i + offset1), "Error item value from slice1 at: " + i); + } + + int offset2 = 1; + DictionaryVec slice2 = slice1.slice(offset2, 2); + assertEquals(slice2.getSize(), 2); + for (int i = 0; i < slice2.getSize(); i++) { + long value = slice2.getLong(i); + assertEquals(value, slice1.getLong(i + offset2), "Error item value from slice2 at: " + i); + assertEquals(value, originalVec.get(i + offset2 + offset1), "Error item value from slice2 at: " + i); + } + originalVec.close(); + slice2.close(); + slice1.close(); + dictionaryVec.close(); + } + + @Test + public void testGetLong() { + LongVec originalVec = new LongVec(10); + for (int i = 0; i < 10; i++) { + originalVec.set(i, i); + } + + int[] ids = {6, 8, 9}; + DictionaryVec dictionaryVec = new DictionaryVec(originalVec, ids); + assertEquals(dictionaryVec.getLong(0), originalVec.get(6)); + assertEquals(dictionaryVec.getLong(1), originalVec.get(8)); + assertEquals(dictionaryVec.getLong(2), originalVec.get(9)); + + originalVec.close(); + dictionaryVec.close(); + } + + @Test + public void testGetShort() { + ShortVec originalVec = new ShortVec(10); + for (int i = 0; i < 10; i++) { + originalVec.set(i, (short) i); + } + + int[] ids = {6, 8, 9}; + DictionaryVec dictionaryVec = new DictionaryVec(originalVec, ids); + assertEquals(dictionaryVec.getShort(0), originalVec.get(6)); + assertEquals(dictionaryVec.getShort(1), originalVec.get(8)); + assertEquals(dictionaryVec.getShort(2), originalVec.get(9)); + + originalVec.close(); + dictionaryVec.close(); + } + + @Test + public void testGetInt() { + IntVec originalVec = new IntVec(10); + for (int i = 0; i < 10; i++) { + originalVec.set(i, i); + } + + int[] ids = {6, 8, 9}; + DictionaryVec dictionaryVec = new DictionaryVec(originalVec, ids); + assertEquals(dictionaryVec.getInt(0), originalVec.get(6)); + assertEquals(dictionaryVec.getInt(1), originalVec.get(8)); + assertEquals(dictionaryVec.getInt(2), originalVec.get(9)); + + originalVec.close(); + dictionaryVec.close(); + } + + @Test + public void testGetBoolean() { + BooleanVec originalVec = new BooleanVec(10); + for (int i = 0; i < 10; i++) { + originalVec.set(i, i % 2 == 0); + } + + int[] ids = {6, 8, 9}; + DictionaryVec dictionaryVec = new DictionaryVec(originalVec, ids); + assertEquals(dictionaryVec.getBoolean(0), originalVec.get(6)); + assertEquals(dictionaryVec.getBoolean(1), originalVec.get(8)); + assertEquals(dictionaryVec.getBoolean(2), originalVec.get(9)); + + originalVec.close(); + dictionaryVec.close(); + } + + @Test + public void testGetDouble() { + DoubleVec originalVec = new DoubleVec(10); + for (int i = 0; i < 10; i++) { + originalVec.set(i, 2.3d * i); + } + + int[] ids = {6, 8, 9}; + DictionaryVec dictionaryVec = new DictionaryVec(originalVec, ids); + assertEquals(Double.compare(dictionaryVec.getDouble(0), originalVec.get(6)), 0); + assertEquals(Double.compare(dictionaryVec.getDouble(0), originalVec.get(6)), 0); + assertEquals(Double.compare(dictionaryVec.getDouble(1), originalVec.get(8)), 0); + assertEquals(Double.compare(dictionaryVec.getDouble(2), originalVec.get(9)), 0); + + originalVec.close(); + dictionaryVec.close(); + } + + @Test + public void testGetBytes() { + VarcharVec originalVec = new VarcharVec(10); + for (int i = 0; i < 10; i++) { + originalVec.set(i, String.valueOf(i).getBytes(StandardCharsets.UTF_8)); + } + + int[] ids = {6, 8, 9}; + DictionaryVec dictionaryVec = new DictionaryVec(originalVec, ids); + assertEquals(dictionaryVec.getBytes(0), originalVec.get(6)); + assertEquals(dictionaryVec.getBytes(1), originalVec.get(8)); + assertEquals(dictionaryVec.getBytes(2), originalVec.get(9)); + + originalVec.close(); + dictionaryVec.close(); + } + + @Test + public void testGetDecimal128() { + Decimal128Vec originalVec = new Decimal128Vec(10); + for (int i = 0; i < 10; i++) { + long[] value = {i, i * 2}; + originalVec.set(i, value); + } + + int[] ids = {6, 8, 9}; + DictionaryVec dictionaryVec = new DictionaryVec(originalVec, ids); + // means decimal1={6, 12},decimal2={8, 16}, decimal3={9, 18}; + Object[] expected = {6L, 12L, 8L, 16L, 9L, 18L}; + TestUtils.assertDictionaryVecEquals(dictionaryVec, expected); + + originalVec.close(); + dictionaryVec.close(); + } + + /** + * test copy position + */ + @Test + public void testCopyPositions() { + LongVec originalVector = new LongVec(10); + for (int i = 0; i < originalVector.getSize(); i++) { + originalVector.set(i, i); + } + + int[] ids = {2, 3, 4, 5, 6, 8, 9}; + DictionaryVec dictionaryVec = new DictionaryVec(originalVector, ids); + int[] positions = {1, 3, 5, 6}; + DictionaryVec copyPositions = dictionaryVec.copyPositions(positions, 1, 3); + assertEquals(copyPositions.getLong(0), originalVector.get(5)); + assertEquals(copyPositions.getLong(1), originalVector.get(8)); + assertEquals(copyPositions.getLong(2), originalVector.get(9)); + + originalVector.close(); + dictionaryVec.close(); + copyPositions.close(); + + // dictionary data compress + originalVector = new LongVec(2); + originalVector.set(0, 100); + originalVector.set(1, 200); + int[] ids1 = {0, 0, 0, 1, 1, 1}; + dictionaryVec = new DictionaryVec(originalVector, ids1); + int[] positions1 = {1, 2, 3, 5}; + copyPositions = dictionaryVec.copyPositions(positions1, 0, 4); + assertEquals(copyPositions.getLong(0), originalVector.get(0)); + assertEquals(copyPositions.getLong(1), originalVector.get(0)); + assertEquals(copyPositions.getLong(2), originalVector.get(1)); + assertEquals(copyPositions.getLong(3), originalVector.get(1)); + + originalVector.close(); + dictionaryVec.close(); + copyPositions.close(); + } + + @Test + public void testNullFlag() { + LongVec originalVector = new LongVec(10); + for (int i = 0; i < 10; i++) { + if (i % 2 == 0) { + originalVector.setNull(i); + } else { + originalVector.set(i, i); + } + } + + int[] ids = {6, 8, 9}; + DictionaryVec dictionaryVec = new DictionaryVec(originalVector, ids); + originalVector.close(); + assertTrue(dictionaryVec.hasNull()); + + DictionaryVec slice = dictionaryVec.slice(2, 1); + assertFalse(slice.hasNull()); + slice.close(); + + int[] positions = {0, 2}; + DictionaryVec copyPosition = dictionaryVec.copyPositions(positions, 0, 2); + assertTrue(copyPosition.hasNull()); + copyPosition.close(); + + dictionaryVec.close(); + } + + @Test + public void testGetDictionaryOfEmptyStrings() { + VarcharVec originalVec = new VarcharVec(0, 10); + int[] ids = {6, 8, 9}; + DictionaryVec dictionaryVec = new DictionaryVec(originalVec, ids); + + long dictAddr = dictionaryVec.getDataAddress(); + assertNotEquals(dictAddr, 0); + + originalVec.close(); + dictionaryVec.close(); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestDoubleVec.java b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestDoubleVec.java new file mode 100644 index 0000000..6dad094 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestDoubleVec.java @@ -0,0 +1,183 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_DOUBLE; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +import org.testng.annotations.Test; + +import java.util.Arrays; + +/** + * test double vec + * + * @since 2021-7-2 + */ +public class TestDoubleVec { + /** + * test new vector + */ + @Test + public void testNewVector() { + DoubleVec vec = new DoubleVec(256); + assertEquals(vec.getSize(), 256); + assertEquals(vec.getRealValueBufCapacityInBytes(), 2048); + assertEquals(vec.getType().getId(), OMNI_DOUBLE); + vec.close(); + } + + /** + * test slice + */ + @Test + public void testSlice() { + DoubleVec originalVec = new DoubleVec(10); + for (int i = 0; i < originalVec.getSize(); i++) { + originalVec.set(i, (double) i / 3); + } + int offset = 3; + DoubleVec slice1 = originalVec.slice(offset, 4); + assertEquals(slice1.getSize(), 4); + for (int i = 0; i < slice1.getSize(); i++) { + assertEquals(slice1.get(i), originalVec.get(i + offset), "Error item value at: " + i); + } + + DoubleVec slice2 = slice1.slice(1, 3); + + for (int i = 0; i < slice2.getSize(); i++) { + assertEquals(slice2.get(i), originalVec.get(i + offset + 1), "Error item value at: " + i); + } + + originalVec.close(); + slice1.close(); + slice2.close(); + } + + /** + * test set and get value + */ + @Test + public void testSetAndGetValue() { + DoubleVec vec = new DoubleVec(256); + for (int i = 0; i < vec.getSize(); i++) { + vec.set(i, (double) i / 3); + } + + for (int i = 0; i < vec.getSize(); i++) { + assertEquals(vec.get(i), (double) i / 3); + } + vec.close(); + } + + /** + * test put value + */ + @Test + public void testPutValues() { + double[] values = {1.13, 3.33, 4.44, 6.66, 7.81}; + DoubleVec doubleVec1 = new DoubleVec(values.length); + doubleVec1.put(values, 0, 0, values.length); + for (int i = 0; i < values.length; i++) { + assertEquals(doubleVec1.get(i), values[i]); + } + + DoubleVec doubleVec2 = new DoubleVec(values.length); + doubleVec2.put(values, 1, 2, 3); + for (int i = 0; i < 3; i++) { + assertEquals(doubleVec2.get(i + 1), values[i + 2]); + } + + doubleVec1.close(); + doubleVec2.close(); + } + + /** + * test value null + */ + @Test + public void testValueNull() { + DoubleVec doubleVec = new DoubleVec(256); + for (int i = 0; i < doubleVec.getSize(); i++) { + if (i % 5 == 0) { + doubleVec.setNull(i); + } else { + doubleVec.set(i, (double) i / 3); + } + } + for (int i = 0; i < doubleVec.getSize(); i++) { + if (i % 5 == 0) { + assertTrue(doubleVec.isNull(i)); + } else { + assertEquals(doubleVec.get(i), (double) i / 3); + } + } + + doubleVec.close(); + } + + /** + * test copy positions + */ + @Test + public void testCopyPositions() { + DoubleVec originalVector = new DoubleVec(4); + for (int i = 0; i < originalVector.getSize(); i++) { + originalVector.set(i, i); + } + + int[] positions = {1, 3}; + DoubleVec copyPositionVector = originalVector.copyPositions(positions, 0, 2); + assertEquals(copyPositionVector.getRealValueBufCapacityInBytes(), 16); + for (int i = 0; i < copyPositionVector.getSize(); i++) { + assertEquals(copyPositionVector.get(i), originalVector.get(positions[i])); + } + + originalVector.close(); + copyPositionVector.close(); + } + + /** + * test zero sized allocate + */ + @Test + public void testZeroSizeAllocate() { + DoubleVec v1 = new DoubleVec(0); + double[] values = new double[0]; + v1.put(values, 0, 0, values.length); + v1.close(); + } + + @Test + public void testGetValues() { + double[] values = {1.13, 3.33, 4.44, 6.66, 7.81}; + DoubleVec doubleVec1 = new DoubleVec(values.length); + doubleVec1.put(values, 0, 0, values.length); + assertEquals(doubleVec1.get(0, values.length), values); + double[] expected = {3.33, 4.44, 6.66}; + double[] actual = doubleVec1.get(1, 3); + for (int i = 0; i < actual.length; i++) { + assertEquals(actual[i], expected[i]); + } + doubleVec1.close(); + } + + @Test + public void setDoubleMax() { + int len = 1024 * 1024; + double[] values = new double[len]; + Arrays.fill(values, Double.MAX_VALUE); + DoubleVec max = new DoubleVec(len); + max.put(values, 0, 0, values.length); + + for (int i = 0; i < max.getSize(); i++) { + assertEquals(max.get(i), Double.MAX_VALUE); + } + + assertEquals(max.get(0, values.length), values); + max.close(); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestIntVec.java b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestIntVec.java new file mode 100644 index 0000000..f6fbb2c --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestIntVec.java @@ -0,0 +1,277 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_INT; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +import org.testng.annotations.Test; + +import java.util.Arrays; + +/** + * test int vec + * + * @since 2021-7-2 + */ +public class TestIntVec { + /** + * test new vector + */ + @Test + public void testNewVector() { + IntVec vec = new IntVec(256); + assertEquals(vec.getSize(), 256); + assertEquals(vec.getRealValueBufCapacityInBytes(), 1024); + assertEquals(vec.getType().getId(), OMNI_INT); + vec.close(); + } + + /** + * test slice + */ + @Test + public void testSlice() { + IntVec oritinalVec = new IntVec(10); + for (int i = 0; i < oritinalVec.getSize(); i++) { + oritinalVec.set(i, i); + } + int offset = 3; + IntVec slice1 = oritinalVec.slice(offset, 4); + assertEquals(slice1.getSize(), 4); + for (int i = 0; i < slice1.getSize(); i++) { + assertEquals(slice1.get(i), oritinalVec.get(i + offset), "Error item value at: " + i); + } + + IntVec slice2 = slice1.slice(1, 2); + + for (int i = 0; i < slice2.getSize(); i++) { + assertEquals(slice2.get(i), oritinalVec.get(i + offset + 1), "Error item value at: " + i); + } + + oritinalVec.close(); + slice1.close(); + slice2.close(); + } + + /** + * test set and get value + */ + @Test + public void testSetAndGetValue() { + IntVec vec = new IntVec(256); + for (int i = 0; i < vec.getSize(); i++) { + vec.set(i, i * 2); + } + + for (int i = 0; i < vec.getSize(); i++) { + assertEquals(vec.get(i), i * 2); + } + vec.close(); + } + + /** + * test put value + */ + @Test + public void testPutValues() { + int[] values = {1, 3, 4, 6, 7}; + IntVec vec1 = new IntVec(values.length); + vec1.put(values, 0, 0, values.length); + for (int i = 0; i < values.length; i++) { + assertEquals(vec1.get(i), values[i]); + } + + IntVec vec2 = new IntVec(values.length); + vec2.put(values, 1, 2, 3); + for (int i = 0; i < 3; i++) { + assertEquals(vec2.get(i + 1), values[i + 2]); + } + + vec1.close(); + vec2.close(); + } + + /** + * test value null + */ + @Test + public void testValueNull() { + IntVec vec = new IntVec(256); + for (int i = 0; i < vec.getSize(); i++) { + if (i % 5 == 0) { + vec.setNull(i); + } else { + vec.set(i, i); + } + } + for (int i = 0; i < vec.getSize(); i++) { + if (i % 5 == 0) { + assertTrue(vec.isNull(i)); + } else { + assertEquals(vec.get(i), i); + } + } + + vec.close(); + } + + /** + * test copy postion + */ + @Test + public void testCopyPositions() { + IntVec originalVector = new IntVec(4); + for (int i = 0; i < originalVector.getSize(); i++) { + originalVector.set(i, i); + } + + int[] positions = {1, 3}; + IntVec copyPositionVector = originalVector.copyPositions(positions, 0, 2); + assertEquals(copyPositionVector.getRealValueBufCapacityInBytes(), 8); + for (int i = 0; i < copyPositionVector.getSize(); i++) { + assertEquals(copyPositionVector.get(i), originalVector.get(positions[i])); + } + + originalVector.close(); + copyPositionVector.close(); + } + + /** + * test zero size allocate + */ + @Test + public void testZeroSizeAllocate() { + IntVec v1 = new IntVec(0); + int[] values = new int[0]; + v1.put(values, 0, 0, values.length); + v1.close(); + } + + @Test + public void testGetValues() { + int[] values = {1, 3, 4, 6, 7}; + IntVec vec = new IntVec(values.length); + vec.put(values, 0, 0, values.length); + assertEquals(vec.get(0, values.length), values); + int[] expected = {3, 4, 6}; + int[] actual = vec.get(1, 3); + for (int i = 0; i < actual.length; i++) { + assertEquals(actual[i], expected[i]); + } + vec.close(); + } + + @Test + public void setIntMax() { + int len = 1024 * 1024; + int[] values = new int[len]; + Arrays.fill(values, Integer.MAX_VALUE); + IntVec max = new IntVec(len); + max.put(values, 0, 0, values.length); + + for (int i = 0; i < max.getSize(); i++) { + assertEquals(max.get(i), Integer.MAX_VALUE); + } + + assertEquals(max.get(0, values.length), values); + max.close(); + } + + @Test + public void testNullFlagWithSet() { + // no null value + IntVec noNull = new IntVec(10); + assertFalse(noNull.hasNull()); + noNull.close(); + + // has null value + IntVec hasNulls = new IntVec(10); + byte[] nulls = new byte[] {0, 1, 0, 1, 0, 1, 0, 1, 0, 1}; + hasNulls.setNulls(0, nulls, 0, nulls.length); + assertTrue(hasNulls.hasNull()); + hasNulls.close(); + + IntVec hasNull = new IntVec(10); + for (int i = 0; i < hasNull.size; i++) { + if (i % 2 == 0) { + hasNull.setNull(i); + } else { + hasNull.set(i, i); + } + } + assertTrue(hasNull.hasNull()); + hasNull.close(); + } + + @Test + public void testNullFlagWithCopyPosition() { + // has null value + IntVec hasNulls = new IntVec(10); + byte[] nulls = new byte[] {0, 0, 1, 1, 0, 1, 0, 1, 0, 1}; + hasNulls.setNulls(0, nulls, 0, nulls.length); + assertTrue(hasNulls.hasNull()); + + int[] positions = new int[]{0, 1}; + IntVec copyPositionNoNull = hasNulls.copyPositions(positions, 0, 2); + assertFalse(copyPositionNoNull.hasNull()); + copyPositionNoNull.close(); + + positions = new int[]{1, 2, 3, 4}; + IntVec copyPositionHasNull = hasNulls.copyPositions(positions, 0, 4); + assertTrue(copyPositionHasNull.hasNull()); + copyPositionHasNull.close(); + + hasNulls.close(); + } + + @Test + public void testNullFlagWithSlice() { + // has null value + IntVec hasNulls = new IntVec(10); + byte[] nulls = new byte[] {0, 0, 1, 1, 0, 1, 0, 1, 0, 1}; + hasNulls.setNulls(0, nulls, 0, nulls.length); + assertTrue(hasNulls.hasNull()); + + IntVec sliceNoNull = hasNulls.slice(0, 1); + assertFalse(sliceNoNull.hasNull()); + sliceNoNull.close(); + + IntVec sliceHasNull = hasNulls.slice(1, 3); + assertTrue(sliceHasNull.hasNull()); + sliceHasNull.close(); + + hasNulls.close(); + } + + @Test + public void testNullFlagWithAppend() { + int rowCount = 5; + IntVec src1 = new IntVec(rowCount); + + for (int i = 0; i < rowCount; i++) { + src1.set(i, i + 1); + } + + IntVec appended = new IntVec(15); + appended.append(src1, 0, rowCount); + src1.close(); + assertFalse(appended.hasNull()); + + IntVec withNull = new IntVec(rowCount); + byte[] nulls = new byte[] {0, 1, 1, 0, 1}; + withNull.setNulls(0, nulls, 0, 5); + appended.append(withNull, 5, rowCount); + assertTrue(appended.hasNull()); + + appended.append(withNull, 10, rowCount); + assertTrue(appended.hasNull()); + withNull.close(); + + appended.close(); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestLongVec.java b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestLongVec.java new file mode 100644 index 0000000..5deae70 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestLongVec.java @@ -0,0 +1,184 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +import nova.hetu.omniruntime.type.LongDataType; + +import org.testng.annotations.Test; + +import java.util.Arrays; + +/** + * test long vec + * + * @since 2021-7-2 + */ +public class TestLongVec { + /** + * test new vector + */ + @Test + public void testNewVector() { + LongVec vec = new LongVec(256); + assertEquals(vec.getSize(), 256); + assertEquals(vec.getRealValueBufCapacityInBytes(), 2048); + assertEquals(vec.getType(), LongDataType.LONG); + vec.close(); + } + + /** + * test slice + */ + @Test + public void testSlice() { + LongVec originalVec = new LongVec(10); + for (int i = 0; i < originalVec.getSize(); i++) { + originalVec.set(i, i); + } + int offset = 3; + LongVec slice1 = originalVec.slice(offset, 4); + assertEquals(slice1.getSize(), 4); + for (int i = 0; i < slice1.getSize(); i++) { + assertEquals(slice1.get(i), originalVec.get(i + offset), "Error item value at: " + i); + } + + LongVec slice2 = slice1.slice(1, 2); + + for (int i = 0; i < slice2.getSize(); i++) { + assertEquals(slice2.get(i), originalVec.get(i + offset + 1), "Error item value at: " + i); + } + + originalVec.close(); + slice1.close(); + slice2.close(); + } + + /** + * test set and get value + */ + @Test + public void testSetAndGetValue() { + LongVec vec = new LongVec(256); + for (int i = 0; i < vec.getSize(); i++) { + vec.set(i, i * 2); + } + + for (int i = 0; i < vec.getSize(); i++) { + assertEquals(vec.get(i), i * 2); + } + vec.close(); + } + + /** + * test put value + */ + @Test + public void testPutValues() { + long[] values = {1, 3, 4, 6, 7}; + LongVec vec1 = new LongVec(values.length); + vec1.put(values, 0, 0, values.length); + for (int i = 0; i < values.length; i++) { + assertEquals(vec1.get(i), values[i]); + } + + LongVec vec2 = new LongVec(values.length); + vec2.put(values, 1, 2, 3); + for (int i = 0; i < 3; i++) { + assertEquals(vec2.get(i + 1), values[i + 2]); + } + + vec1.close(); + vec2.close(); + } + + /** + * test value null + */ + @Test + public void testValueNull() { + LongVec longVec = new LongVec(256); + for (int i = 0; i < longVec.getSize(); i++) { + if (i % 5 == 0) { + longVec.setNull(i); + } else { + longVec.set(i, i); + } + } + for (int i = 0; i < longVec.getSize(); i++) { + if (i % 5 == 0) { + assertTrue(longVec.isNull(i)); + } else { + assertEquals(longVec.get(i), i); + } + } + + longVec.close(); + } + + /** + * test copy postion + */ + @Test + public void testCopyPositions() { + LongVec originalVector = new LongVec(4); + for (int i = 0; i < originalVector.getSize(); i++) { + originalVector.set(i, i); + } + + int[] positions = {1, 3}; + LongVec copyPositionVector = originalVector.copyPositions(positions, 0, 2); + assertEquals(copyPositionVector.getRealValueBufCapacityInBytes(), 16); + for (int i = 0; i < copyPositionVector.getSize(); i++) { + assertEquals(copyPositionVector.get(i), originalVector.get(positions[i])); + } + + originalVector.close(); + copyPositionVector.close(); + } + + /** + * test zero sized allocate + */ + @Test + public void testZeroSizeAllocate() { + LongVec v1 = new LongVec(0); + long[] values = new long[0]; + v1.put(values, 0, 0, values.length); + v1.close(); + } + + @Test + public void testGetValues() { + long[] values = {1, 3, 4, 6, 7}; + LongVec vec = new LongVec(values.length); + vec.put(values, 0, 0, values.length); + assertEquals(vec.get(0, values.length), values); + long[] expected = {3, 4, 6}; + long[] actual = vec.get(1, 3); + for (int i = 0; i < actual.length; i++) { + assertEquals(actual[i], expected[i]); + } + vec.close(); + } + + @Test + public void setLongMax() { + int len = 1024 * 1024; + long[] values = new long[len]; + Arrays.fill(values, Long.MAX_VALUE); + LongVec max = new LongVec(len); + max.put(values, 0, 0, values.length); + + for (int i = 0; i < max.getSize(); i++) { + assertEquals(max.get(i), Long.MAX_VALUE); + } + + assertEquals(max.get(0, values.length), values); + max.close(); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestOmniRow.java b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestOmniRow.java new file mode 100644 index 0000000..e60818b --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestOmniRow.java @@ -0,0 +1,162 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import static nova.hetu.omniruntime.util.TestUtils.assertVecBatchEquals; + +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.vector.serialize.OmniRowDeserializer; + +import org.testng.annotations.Test; + +import java.lang.reflect.InvocationTargetException; +import java.util.Arrays; + +/** + * test omni row vec + * + * @since 2024-5-16 + */ +class GenericClass { + private Class clazz; + + /** + * set special value + * + * @param clazz the class object of TVec + * + */ + public GenericClass(Class clazz) { + this.clazz = clazz; + } + + /** + * use reflect function to create vector + * + * @param rowCount row count of vector + * @return TVec return vector in heap + */ + public TVec createInstance(int rowCount) { + try { + return clazz.getDeclaredConstructor(int.class).newInstance(rowCount); + } catch (InstantiationException | IllegalAccessException | NoSuchMethodException + | InvocationTargetException e) { + e.printStackTrace(); + return null; + } + } +} + +/** + * lambda express to set value into omniBuf + */ +@FunctionalInterface +interface OmniBufSet { + /** + * set special value + * + * @param buffer OmniBuffer which will be set value + * @param index set special index row 's value + */ + void operate(OmniBuffer buffer, int index); +} + +/** + * Function to test template vector + */ +public class TestOmniRow { + /** + * set special value + * + * @param clazz object of class + * @param type data type of new vector + * @param vecCount vec count of vector batch + * @param rowCount row count of vector batch + * @param setFunc used to set value for every row + */ + private void testRowBatchNewAndParse(Class clazz, int type, int vecCount, int rowCount, + OmniBufSet setFunc) { + Vec[] vecArray = new Vec[vecCount]; + GenericClass vecGenericClass = new GenericClass<>(clazz); + + for (int i = 0; i < vecCount; i++) { + vecArray[i] = vecGenericClass.createInstance(rowCount); + for (int j = 0; j < rowCount; ++j) { + setFunc.operate(vecArray[i].getValuesBuf(), j); + } + } + + VecBatch expectVecBatch = new VecBatch(vecArray, rowCount); + RowBatch rowBatch = new RowBatch(expectVecBatch); + + int[] types = new int[vecCount]; + Arrays.fill(types, type); + + VecBatch resultVb = new VecBatch(vecArray, rowCount); + long[] vecAddresses = new long[vecCount]; + for (int i = 0; i < vecCount; i++) { + vecAddresses[i] = resultVb.getVector(i).nativeVector; + } + OmniRowDeserializer deserializer = new OmniRowDeserializer(types, vecAddresses); + + deserializer.parseAll(rowBatch.getNativeRowBatch()); + + assertVecBatchEquals(resultVb, expectVecBatch); + deserializer.close(); + rowBatch.close(); + expectVecBatch.close(); + resultVb.close(); + } + + /** + * test row batch, only value is positive + */ + @Test + public void testRowBatch() { + testRowBatchNewAndParse(BooleanVec.class, DataType.DataTypeId.OMNI_BOOLEAN.toValue(), 10, 1024, + (omniBuf, i) -> omniBuf.setByte(i, (i % 2 == 0) ? (byte) 1 : (byte) 0)); + + testRowBatchNewAndParse(ShortVec.class, DataType.DataTypeId.OMNI_SHORT.toValue(), 10, 1024, + (omniBuf, i) -> omniBuf.setShort(i, (short) (i + 1))); + + testRowBatchNewAndParse(IntVec.class, DataType.DataTypeId.OMNI_INT.toValue(), 10, 1024, + (omniBuf, i) -> omniBuf.setInt(i, i)); + + testRowBatchNewAndParse(LongVec.class, DataType.DataTypeId.OMNI_LONG.toValue(), 10, 1024, + (omniBuf, i) -> omniBuf.setLong(i, i)); + + testRowBatchNewAndParse(DoubleVec.class, DataType.DataTypeId.OMNI_DOUBLE.toValue(), 10, 1024, + (omniBuf, i) -> omniBuf.setDouble(i, i)); + + testRowBatchNewAndParse(Decimal128Vec.class, DataType.DataTypeId.OMNI_DECIMAL128.toValue(), 10, 1024, + (omniBuf, i) -> { + omniBuf.setLong((2 * i), i); + omniBuf.setLong((2 * i) + 1, i); + }); + + testRowBatchNewAndParse(LongVec.class, DataType.DataTypeId.OMNI_DECIMAL64.toValue(), 10, 1024, (omniBuf, i) -> { + omniBuf.setLong((2 * i), i); + omniBuf.setLong((2 * i) + 1, i); + }); + } + + /** + * test row batch, all value is negative + */ + @Test + public void testNegativeRowBatch() { + // negative short + testRowBatchNewAndParse(ShortVec.class, DataType.DataTypeId.OMNI_SHORT.toValue(), 10, 1024, + (omniBuf, i) -> omniBuf.setShort(i, (short) (-i))); + + // negative int + testRowBatchNewAndParse(IntVec.class, DataType.DataTypeId.OMNI_INT.toValue(), 10, 1024, + (omniBuf, i) -> omniBuf.setInt(i, -i)); + + // negative long to test negative compress + testRowBatchNewAndParse(LongVec.class, DataType.DataTypeId.OMNI_LONG.toValue(), 10, 1024, + (omniBuf, i) -> omniBuf.setLong(i, -i)); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestShortVec.java b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestShortVec.java new file mode 100644 index 0000000..0541700 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestShortVec.java @@ -0,0 +1,277 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_SHORT; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +import org.testng.annotations.Test; + +import java.util.Arrays; + +/** + * test short vec + * + * @since 2022-8-2 + */ +public class TestShortVec { + /** + * test new vector + */ + @Test + public void testNewVector() { + ShortVec vec = new ShortVec(256); + assertEquals(vec.getSize(), 256); + assertEquals(vec.getRealValueBufCapacityInBytes(), 512); + assertEquals(vec.getType().getId(), OMNI_SHORT); + vec.close(); + } + + /** + * test slice + */ + @Test + public void testSlice() { + ShortVec originalVec = new ShortVec(10); + for (int i = 0; i < originalVec.getSize(); i++) { + originalVec.set(i, (short) i); + } + int offset = 3; + ShortVec slice1 = originalVec.slice(offset, 4); + assertEquals(slice1.getSize(), 4); + for (int i = 0; i < slice1.getSize(); i++) { + assertEquals(slice1.get(i), originalVec.get(i + offset), "Error item value at: " + i); + } + + ShortVec slice2 = slice1.slice(1, 2); + + for (int i = 0; i < slice2.getSize(); i++) { + assertEquals(slice2.get(i), originalVec.get(i + offset + 1), "Error item value at: " + i); + } + + originalVec.close(); + slice1.close(); + slice2.close(); + } + + /** + * test set and get value + */ + @Test + public void testSetAndGetValue() { + ShortVec vec = new ShortVec(256); + for (int i = 0; i < vec.getSize(); i++) { + vec.set(i, (short) (i * 2)); + } + + for (int i = 0; i < vec.getSize(); i++) { + assertEquals(vec.get(i), i * 2); + } + vec.close(); + } + + /** + * test put value + */ + @Test + public void testPutValues() { + short[] values = {1, 3, 4, 6, 7}; + ShortVec vec1 = new ShortVec(values.length); + vec1.put(values, 0, 0, values.length); + for (int i = 0; i < values.length; i++) { + assertEquals(vec1.get(i), values[i]); + } + + ShortVec vec2 = new ShortVec(values.length); + vec2.put(values, 1, 2, 3); + for (int i = 0; i < 3; i++) { + assertEquals(vec2.get(i + 1), values[i + 2]); + } + + vec1.close(); + vec2.close(); + } + + /** + * test value null + */ + @Test + public void testValueNull() { + ShortVec vec = new ShortVec(256); + for (int i = 0; i < vec.getSize(); i++) { + if (i % 5 == 0) { + vec.setNull(i); + } else { + vec.set(i, (short) i); + } + } + for (int i = 0; i < vec.getSize(); i++) { + if (i % 5 == 0) { + assertTrue(vec.isNull(i)); + } else { + assertEquals(vec.get(i), i); + } + } + + vec.close(); + } + + /** + * test copy postion + */ + @Test + public void testCopyPositions() { + ShortVec originalVector = new ShortVec(4); + for (int i = 0; i < originalVector.getSize(); i++) { + originalVector.set(i, (short) i); + } + + int[] positions = {1, 3}; + ShortVec copyPositionVector = originalVector.copyPositions(positions, 0, 2); + assertEquals(copyPositionVector.getRealValueBufCapacityInBytes(), 4); + for (int i = 0; i < copyPositionVector.getSize(); i++) { + assertEquals(copyPositionVector.get(i), originalVector.get(positions[i])); + } + + originalVector.close(); + copyPositionVector.close(); + } + + /** + * test zero size allocate + */ + @Test + public void testZeroSizeAllocate() { + ShortVec v1 = new ShortVec(0); + short[] values = new short[0]; + v1.put(values, 0, 0, values.length); + v1.close(); + } + + @Test + public void testGetValues() { + short[] values = {1, 3, 4, 6, 7}; + ShortVec vec = new ShortVec(values.length); + vec.put(values, 0, 0, values.length); + assertEquals(vec.get(0, values.length), values); + short[] expected = {3, 4, 6}; + short[] actual = vec.get(1, 3); + for (int i = 0; i < actual.length; i++) { + assertEquals(actual[i], expected[i]); + } + vec.close(); + } + + @Test + public void setShortMax() { + int len = 1024 * 1024; + short[] values = new short[len]; + Arrays.fill(values, Short.MAX_VALUE); + ShortVec max = new ShortVec(len); + max.put(values, 0, 0, values.length); + + for (int i = 0; i < max.getSize(); i++) { + assertEquals(max.get(i), Short.MAX_VALUE); + } + + assertEquals(max.get(0, values.length), values); + max.close(); + } + + @Test + public void testNullFlagWithSet() { + // no null value + ShortVec noNull = new ShortVec(10); + assertFalse(noNull.hasNull()); + noNull.close(); + + // has null value + ShortVec hasNulls = new ShortVec(10); + byte[] nulls = new byte[]{0, 1, 0, 1, 0, 1, 0, 1, 0, 1}; + hasNulls.setNulls(0, nulls, 0, nulls.length); + assertTrue(hasNulls.hasNull()); + hasNulls.close(); + + ShortVec hasNull = new ShortVec(10); + for (int i = 0; i < hasNull.size; i++) { + if (i % 2 == 0) { + hasNull.setNull(i); + } else { + hasNull.set(i, (short) i); + } + } + assertTrue(hasNull.hasNull()); + hasNull.close(); + } + + @Test + public void testNullFlagWithCopyPosition() { + // has null value + ShortVec hasNulls = new ShortVec(10); + byte[] nulls = new byte[]{0, 0, 1, 1, 0, 1, 0, 1, 0, 1}; + hasNulls.setNulls(0, nulls, 0, nulls.length); + assertTrue(hasNulls.hasNull()); + + int[] positions = new int[]{0, 1}; + ShortVec copyPositionNoNull = hasNulls.copyPositions(positions, 0, 2); + assertFalse(copyPositionNoNull.hasNull()); + copyPositionNoNull.close(); + + positions = new int[]{1, 2, 3, 4}; + ShortVec copyPositionHasNull = hasNulls.copyPositions(positions, 0, 4); + assertTrue(copyPositionHasNull.hasNull()); + copyPositionHasNull.close(); + + hasNulls.close(); + } + + @Test + public void testNullFlagWithSlice() { + // has null value + ShortVec hasNulls = new ShortVec(10); + byte[] nulls = new byte[]{0, 0, 1, 1, 0, 1, 0, 1, 0, 1}; + hasNulls.setNulls(0, nulls, 0, nulls.length); + assertTrue(hasNulls.hasNull()); + + ShortVec sliceNoNull = hasNulls.slice(0, 1); + assertFalse(sliceNoNull.hasNull()); + sliceNoNull.close(); + + ShortVec sliceHasNull = hasNulls.slice(1, 4); + assertTrue(sliceHasNull.hasNull()); + sliceHasNull.close(); + + hasNulls.close(); + } + + @Test + public void testNullFlagWithAppend() { + int rowCount = 5; + ShortVec src1 = new ShortVec(rowCount); + + for (int i = 0; i < rowCount; i++) { + src1.set(i, (short) (i + 1)); + } + + ShortVec appended = new ShortVec(15); + appended.append(src1, 0, rowCount); + src1.close(); + assertFalse(appended.hasNull()); + + ShortVec withNull = new ShortVec(rowCount); + byte[] nulls = new byte[]{0, 1, 1, 0, 1}; + withNull.setNulls(0, nulls, 0, 5); + appended.append(withNull, 5, rowCount); + assertTrue(appended.hasNull()); + + appended.append(withNull, 10, rowCount); + assertTrue(appended.hasNull()); + withNull.close(); + + appended.close(); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestVarcharVec.java b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestVarcharVec.java new file mode 100644 index 0000000..f06a7bb --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestVarcharVec.java @@ -0,0 +1,464 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; + +/** + * test varchar vec + * + * @since 2021-7-2 + */ +public class TestVarcharVec { + private static final int INIT_CAPACITY_IN_BYTES = 4 * 1024; // 4K + + /** + * test new vector + */ + @Test + public void testNewVector() { + VarcharVec vec = new VarcharVec(256); + assertEquals(vec.getSize(), 256); + assertEquals(vec.getRealValueBufCapacityInBytes(), 0); + assertEquals(vec.getType().getId(), OMNI_VARCHAR); + vec.close(); + } + + /** + * test slice + */ + @Test + public void testSlice() { + int size = 10; + VarcharVec originalVec = new VarcharVec(size); + String tmpStr = "testvarchar"; + for (int i = 0; i < size; i++) { + String str = tmpStr.substring(0, i) + i; + originalVec.set(i, str.getBytes(StandardCharsets.UTF_8)); + } + assertEquals(originalVec.getRealValueBufCapacityInBytes(), 55); + + int offset = 3; + VarcharVec sliceVec1 = originalVec.slice(offset, 4); + assertEquals(sliceVec1.getSize(), 4); + assertEquals(sliceVec1.getRealValueBufCapacityInBytes(), 22); + + for (int i = 0; i < sliceVec1.getSize(); i++) { + byte[] actualValue = sliceVec1.get(i); + byte[] expectedValue = originalVec.get(i + offset); + assertEquals(actualValue, expectedValue); + } + + VarcharVec sliceVec2 = sliceVec1.slice(1, 2); + assertEquals(sliceVec2.getSize(), 2); + + for (int i = 0; i < sliceVec2.getSize(); i++) { + byte[] actualValue = sliceVec2.get(i); + byte[] expectedValue = originalVec.get(i + offset + 1); + assertEquals(actualValue, expectedValue); + } + sliceVec2.close(); + sliceVec1.close(); + originalVec.close(); + } + + /** + * test set and get value + */ + @Test + public void testSetAndGetValue() { + int size = 4; + VarcharVec varcharVec = new VarcharVec(size); + String tmpStr = "test"; + for (int i = 0; i < 4; i++) { + String str = tmpStr.substring(0, i) + i; + varcharVec.set(i, str.getBytes(StandardCharsets.UTF_8)); + } + + for (int i = 0; i < 4; i++) { + String str = tmpStr.substring(0, i) + i; + byte[] actualValue = varcharVec.get(i); + assertEquals(actualValue, str.getBytes(StandardCharsets.UTF_8)); + } + + varcharVec.close(); + } + + @Test + public void testPutValues() { + int size = 100; + int[] offsets = new int[size * 2 + 1]; + StringBuilder data = new StringBuilder(); + for (int i = 0; i < size; i++) { + String str = "test" + i; + offsets[i + 1] = str.length() + offsets[i]; + data.append(str); + } + + for (int i = 0; i < size; i++) { + String str = i + "put"; + offsets[size + i + 1] = str.length() + offsets[size + i]; + data.append(str); + } + + VarcharVec values = new VarcharVec(size * 2); + values.put(0, data.toString().getBytes(StandardCharsets.UTF_8), 0, offsets, 0, size); + values.put(size, data.toString().getBytes(StandardCharsets.UTF_8), 0, offsets, size, size); + ByteBuffer buffer = ByteBuffer.wrap(data.toString().getBytes(StandardCharsets.UTF_8)); + for (int i = 0; i < size * 2; i++) { + assertEquals(new String(values.get(i)), + new String(getDataFromBuffer(buffer, offsets[i], offsets[i + 1] - offsets[i]))); + } + values.close(); + } + + private byte[] getDataFromBuffer(ByteBuffer buffer, int offsetInBytes, int length) { + byte[] data = new byte[length]; + buffer.position(offsetInBytes); + buffer.get(data, 0, length); + return data; + } + + /** + * test value null + */ + @Test + public void testValueNull() { + VarcharVec varcharVec = new VarcharVec(256); + for (int i = 0; i < varcharVec.getSize(); i++) { + if (i % 5 == 0) { + varcharVec.setNull(i); + } else { + varcharVec.set(i, ("test" + i).getBytes(StandardCharsets.UTF_8)); + } + } + for (int i = 0; i < varcharVec.getSize(); i++) { + if (i % 5 == 0) { + assertTrue(varcharVec.isNull(i)); + } else { + assertEquals("test" + i, new String(varcharVec.get(i), StandardCharsets.UTF_8)); + } + } + + varcharVec.close(); + } + + @Test + public void testBatchSetValueNull() { + int size = 256; + boolean[] isNulls = new boolean[size]; + for (int i = 0; i < size; i++) { + isNulls[i] = i % 2 == 0; + } + VarcharVec varcharVec = new VarcharVec(size); + varcharVec.setNulls(0, isNulls, 0, isNulls.length); + assertTrue(varcharVec.hasNull()); + assertEquals(varcharVec.getValuesNulls(0, size), isNulls); + int offset = 3; + boolean[] acutal = varcharVec.getValuesNulls(offset, size / 2); + for (int i = 0; i < size / 2; i++) { + assertEquals(acutal[i], isNulls[i + offset]); + } + varcharVec.close(); + } + + /** + * test copy position + */ + @Test + public void testCopyPositions() { + VarcharVec originalVector = new VarcharVec(4); + String tmpStr = "test"; + for (int i = 0; i < 4; i++) { + String str = tmpStr.substring(0, i) + i; + originalVector.set(i, str.getBytes(StandardCharsets.UTF_8)); + } + + int[] positions = {1, 3}; + VarcharVec copyPostionVector = originalVector.copyPositions(positions, 0, 2); + + for (int i = 0; i < copyPostionVector.getSize(); i++) { + byte[] expectedValue = originalVector.get(positions[i]); + byte[] actualValue = copyPostionVector.get(i); + assertEquals(actualValue, expectedValue); + } + + originalVector.close(); + copyPostionVector.close(); + } + + @Test + public void testGetValues() { + int size = 10; + StringBuilder getData = new StringBuilder(); + VarcharVec getVec = new VarcharVec(size); + for (int i = 0; i < size; i++) { + String str = "gets" + i; + getVec.set(i, str.getBytes(StandardCharsets.UTF_8)); + getData.append(str); + } + + byte[] actual = getVec.get(0, size / 2); + ByteBuffer buffer = ByteBuffer.wrap(getData.toString().getBytes(StandardCharsets.UTF_8)); + int total = 0; + for (int i = 0; i < size / 2; i++) { + total += getVec.getDataLength(i); + } + byte[] expected = getDataFromBuffer(buffer, 0, total); + assertEquals(getString(actual), getString(expected)); + + int getLen = 5; + int offset = 2; + byte[] acutal1 = getVec.get(offset, getLen); + ByteBuffer buffer1 = ByteBuffer.wrap(acutal1); + int[] offsets = new int[getLen + 1]; + offsets[0] = 0; + + for (int i = 1; i < offsets.length; i++) { + offsets[i] = getVec.getDataLength(offset + i - 1) + offsets[i - 1]; + } + + for (int i = 0; i < getLen; i++) { + assertEquals(getString(getDataFromBuffer(buffer1, offsets[i], offsets[i + 1] - offsets[i])), + getString(getVec.get(i + offset))); + } + getVec.close(); + } + + @Test + public void testEmptyString() { + String[] data = new String[]{"a", "ef", "", "ef", "", ""}; + String[] expected = new String[]{"a", "ef", "", "ef", "", ""}; + int size = 6; + VarcharVec varcharVec = new VarcharVec(size); + for (int i = 0; i < size; i++) { + varcharVec.set(i, data[i].getBytes(StandardCharsets.UTF_8)); + } + + String[] result = new String[size]; + for (int i = 0; i < size; i++) { + result[i] = getString(varcharVec.get(i)); + } + + Assert.assertEquals(result, expected); + + VarcharVec vec2 = new VarcharVec(size); + int[] offsets = new int[]{0, 1, 3, 3, 5, 5, 5}; + StringBuilder sb = new StringBuilder(); + for (String str : data) { + sb.append(str); + } + vec2.put(0, sb.toString().getBytes(StandardCharsets.UTF_8), 0, offsets, 0, size); + + String[] result1 = new String[size]; + for (int i = 0; i < size; i++) { + result1[i] = getString(vec2.get(i)); + } + + Assert.assertEquals(result1, expected); + + // slice + VarcharVec sliceEmpty = varcharVec.slice(2, 3); + String emptyString = ""; + Assert.assertEquals(getString(sliceEmpty.get(0)), emptyString); + + // copyPosition + int[] positions = new int[]{2, 4, 5}; + VarcharVec copyPosition = varcharVec.copyPositions(positions, 0, 3); + for (int i = 0; i < copyPosition.size; i++) { + Assert.assertEquals(getString(copyPosition.get(i)), emptyString); + } + + varcharVec.close(); + vec2.close(); + sliceEmpty.close(); + copyPosition.close(); + } + + @Test + public void testSetExpandCapacity() { + int rowCount = 4; + VarcharVec varcharVec = new VarcharVec(rowCount); + String baseStr = "test"; + for (int i = 0; i < rowCount; i++) { + String str = baseStr.substring(0, i); + str += i; + varcharVec.set(i, str.getBytes(StandardCharsets.UTF_8)); + } + Assert.assertEquals(varcharVec.getCapacityInBytes(), INIT_CAPACITY_IN_BYTES); + + for (int i = 0; i < rowCount; i++) { + String str = baseStr.substring(0, i); + str += i; + Assert.assertEquals(new String(varcharVec.get(i)), str); + } + + // no capacity specified when created, init capacity is 32K + rowCount = 8000; + VarcharVec vector1 = new VarcharVec(rowCount); + for (int i = 0; i < rowCount; i++) { + String str = baseStr + i; + vector1.set(i, str.getBytes(StandardCharsets.UTF_8)); + } + + // init capacity is 32K, expansion to 64k at a time + int expectedExpandedCapacity1 = 65536; + Assert.assertEquals(vector1.getCapacityInBytes(), expectedExpandedCapacity1); + for (int i = 0; i < rowCount; i++) { + Assert.assertEquals(getString(vector1.get(i)), baseStr + i); + } + + VarcharVec initZeroCapacityVector = new VarcharVec(1); + initZeroCapacityVector.set(0, "".getBytes(StandardCharsets.UTF_8)); + Assert.assertEquals(initZeroCapacityVector.getCapacityInBytes(), INIT_CAPACITY_IN_BYTES); + initZeroCapacityVector.set(0, baseStr.getBytes(StandardCharsets.UTF_8)); + Assert.assertEquals(initZeroCapacityVector.getCapacityInBytes(), INIT_CAPACITY_IN_BYTES); + Assert.assertEquals(new String(initZeroCapacityVector.get(0)), baseStr); + initZeroCapacityVector.close(); + + varcharVec.close(); + vector1.close(); + } + + @Test + public void testAppendExpandCapacity() { + int rowCount = 5; + VarcharVec src1 = new VarcharVec(rowCount); + VarcharVec src2 = new VarcharVec(rowCount); + + for (int i = 0; i < rowCount; i++) { + src1.set(i, String.valueOf(i + 1).getBytes(StandardCharsets.UTF_8)); + src2.set(i, String.valueOf(i + 6).getBytes(StandardCharsets.UTF_8)); + } + + VarcharVec appended = new VarcharVec(10); + appended.append(src1, 0, rowCount); + appended.append(src2, 5, rowCount); + + int expectedExpandCapacity = 20; + Assert.assertEquals(appended.getCapacityInBytes(), INIT_CAPACITY_IN_BYTES); + + for (int i = 0; i < 10; i++) { + Assert.assertEquals(getString(appended.get(i)), String.valueOf(i + 1)); + } + + src1.close(); + src2.close(); + appended.close(); + } + + @Test + public void testNullFlagWithSet() { + // no null value + VarcharVec noNull = new VarcharVec(10); + assertFalse(noNull.hasNull()); + noNull.close(); + + // has null value + VarcharVec hasNulls = new VarcharVec(10); + byte[] nulls = new byte[]{0, 1, 0, 1, 0, 1, 0, 1, 0, 1}; + hasNulls.setNulls(0, nulls, 0, nulls.length); + assertTrue(hasNulls.hasNull()); + hasNulls.close(); + + VarcharVec hasNull = new VarcharVec(10); + for (int i = 0; i < hasNull.size; i++) { + if (i % 2 == 0) { + hasNull.setNull(i); + } else { + hasNull.set(i, String.valueOf(i).getBytes(StandardCharsets.UTF_8)); + } + } + assertTrue(hasNull.hasNull()); + hasNull.close(); + } + + @Test + public void testNullFlagWithCopyPosition() { + // has null value + VarcharVec hasNulls = new VarcharVec(10); + byte[] nulls = new byte[]{0, 0, 1, 1, 0, 1, 0, 1, 0, 1}; + hasNulls.setNulls(0, nulls, 0, nulls.length); + for (int i = 0; i < 10; i++) { + if (nulls[i] == 0) { + hasNulls.set(i, "".getBytes(StandardCharsets.UTF_8)); + } + } + assertTrue(hasNulls.hasNull()); + + int[] positions = new int[]{0, 1}; + VarcharVec copyPositionNoNull = hasNulls.copyPositions(positions, 0, 2); + assertFalse(copyPositionNoNull.hasNull()); + copyPositionNoNull.close(); + + positions = new int[]{1, 2, 3, 4}; + VarcharVec copyPositionHasNull = hasNulls.copyPositions(positions, 0, 4); + assertTrue(copyPositionHasNull.hasNull()); + copyPositionHasNull.close(); + + hasNulls.close(); + } + + @Test + public void testNullFlagWithSlice() { + // has null value + VarcharVec hasNulls = new VarcharVec(10); + byte[] nulls = new byte[]{0, 0, 1, 1, 0, 1, 0, 1, 0, 1}; + hasNulls.setNulls(0, nulls, 0, nulls.length); + assertTrue(hasNulls.hasNull()); + + VarcharVec sliceNoNull = hasNulls.slice(0, 1); + assertFalse(sliceNoNull.hasNull()); + sliceNoNull.close(); + + VarcharVec sliceHasNull = hasNulls.slice(1, 4); + assertTrue(sliceHasNull.hasNull()); + sliceHasNull.close(); + + hasNulls.close(); + } + + @Test + public void testNullFlagWithAppend() { + int rowCount = 5; + VarcharVec src = new VarcharVec(rowCount); + + for (int i = 0; i < rowCount; i++) { + src.set(i, String.valueOf(i + 1).getBytes(StandardCharsets.UTF_8)); + } + + VarcharVec appended = new VarcharVec(15); + appended.append(src, 0, rowCount); + src.close(); + assertFalse(appended.hasNull()); + + VarcharVec withNull = new VarcharVec(rowCount); + byte[] nulls = new byte[]{0, 1, 1, 0, 1}; + withNull.setNulls(0, nulls, 0, 5); + int[] offsets = new int[]{0, 2, 2, 2, 3, 3}; + withNull.put(0, "abe".getBytes(StandardCharsets.UTF_8), 0, offsets, 0, rowCount); + appended.append(withNull, 5, rowCount); + assertTrue(appended.hasNull()); + + appended.append(withNull, 10, rowCount); + assertTrue(appended.hasNull()); + withNull.close(); + + appended.close(); + } + + private String getString(byte[] strInBytes) { + return new String(strInBytes, StandardCharsets.UTF_8); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestVecBatch.java b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestVecBatch.java new file mode 100644 index 0000000..9eaf95c --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/TestVecBatch.java @@ -0,0 +1,56 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import static org.testng.Assert.assertEquals; + +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.List; + +/** + * test vec batch + * + * @since 2021-6-23 + */ +public class TestVecBatch { + /** + * test new vec batch + */ + @Test + public void testNewVecBatch() { + int vecCount = 10; + int rowCount = 1024; + Vec[] vecArray = new Vec[vecCount]; + for (int i = 0; i < vecCount; i++) { + vecArray[i] = new LongVec(rowCount); + } + VecBatch vecBatch = new VecBatch(vecArray, rowCount); + vecBatch.releaseAllVectors(); + vecBatch.close(); + } + + @Test + public void testNewVecBatchWithEmptyVectors() { + // for load libomni_runtime.so + LongVec vec = new LongVec(1); + vec.close(); + List emptyVecs = new ArrayList<>(); + int rowCount = 100; + VecBatch vecBatch = new VecBatch(emptyVecs, rowCount); + assertEquals(vecBatch.getRowCount(), rowCount); + assertEquals(vecBatch.getVectorCount(), 0); + vecBatch.releaseAllVectors(); + vecBatch.close(); + + // rowcount and vectorcount is 0 + VecBatch vecBatch1 = new VecBatch(emptyVecs, 0); + assertEquals(vecBatch1.getRowCount(), 0); + assertEquals(vecBatch1.getVectorCount(), 0); + vecBatch1.releaseAllVectors(); + vecBatch1.close(); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/vector/VecUtil.java b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/VecUtil.java new file mode 100644 index 0000000..e8f62f1 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/VecUtil.java @@ -0,0 +1,28 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.vector; + +import nova.hetu.omniruntime.type.DataType; + +/** + * Vec util + * + * @since 2021-12-17 + */ +public class VecUtil { + private VecUtil() { + + } + + /** + * Set data type + * + * @param vec vec + * @param dataType data type + */ + public static void setDataType(Vec vec, DataType dataType) { + vec.setDataType(dataType); + } +} diff --git a/bindings/java/src/test/java/nova/hetu/omniruntime/vector/serialize/VecBatchSerializerTest.java b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/serialize/VecBatchSerializerTest.java new file mode 100644 index 0000000..af9ca96 --- /dev/null +++ b/bindings/java/src/test/java/nova/hetu/omniruntime/vector/serialize/VecBatchSerializerTest.java @@ -0,0 +1,640 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + */ + +package nova.hetu.omniruntime.vector.serialize; + +import static nova.hetu.omniruntime.type.CharDataType.CHAR; +import static nova.hetu.omniruntime.type.Date32DataType.DATE32; +import static nova.hetu.omniruntime.type.Date64DataType.DATE64; +import static nova.hetu.omniruntime.type.Decimal64DataType.DECIMAL64; +import static nova.hetu.omniruntime.type.InvalidDataType.INVALID; +import static nova.hetu.omniruntime.util.TestUtils.assertVecBatchEquals; +import static nova.hetu.omniruntime.util.TestUtils.freeVecBatch; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +import nova.hetu.omniruntime.type.ContainerDataType; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.Decimal128DataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.type.ShortDataType; +import nova.hetu.omniruntime.type.VarcharDataType; +import nova.hetu.omniruntime.utils.OmniRuntimeException; +import nova.hetu.omniruntime.vector.BooleanVec; +import nova.hetu.omniruntime.vector.ContainerVec; +import nova.hetu.omniruntime.vector.Decimal128Vec; +import nova.hetu.omniruntime.vector.DictionaryVec; +import nova.hetu.omniruntime.vector.DoubleVec; +import nova.hetu.omniruntime.vector.IntVec; +import nova.hetu.omniruntime.vector.LongVec; +import nova.hetu.omniruntime.vector.ShortVec; +import nova.hetu.omniruntime.vector.VarcharVec; +import nova.hetu.omniruntime.vector.Vec; +import nova.hetu.omniruntime.vector.VecBatch; +import nova.hetu.omniruntime.vector.VecUtil; + +import org.testng.annotations.Test; + +import java.nio.charset.StandardCharsets; + +/** + * Vec batch serializer test + * + * @since 2021-9-14 + */ +public class VecBatchSerializerTest { + private static final int ROW_COUNT = 1024; + + @Test + public void testSerializeCommonTypes() { + // prepare vector batch + LongVec longVec = new LongVec(ROW_COUNT); + IntVec intVec = new IntVec(ROW_COUNT); + VarcharVec varCharVec = new VarcharVec(ROW_COUNT); + Decimal128Vec decimal128Vec = new Decimal128Vec(ROW_COUNT); + ShortVec shortVec = new ShortVec(ROW_COUNT); + for (int i = 0; i < ROW_COUNT; i++) { + shortVec.set(i, (short) i); + longVec.set(i, i); + intVec.set(i, i); + varCharVec.set(i, ("test" + i).getBytes(StandardCharsets.UTF_8)); + decimal128Vec.set(i, new long[]{i, i + 1}); + } + Vec[] vecs = {longVec, intVec, varCharVec, decimal128Vec, shortVec}; + VecBatch vecBatch = new VecBatch(vecs); + + // serialize + VecBatchSerializer serializer = VecBatchSerializerFactory.create(); + byte[] str = serializer.serialize(vecBatch); + + // deserialize + VecBatch checkVecBatch = serializer.deserialize(str); + + // check result + LongVec checkLongVec = (LongVec) checkVecBatch.getVectors()[0]; + IntVec checkIntVec = (IntVec) checkVecBatch.getVectors()[1]; + VarcharVec checkVarCharVec = (VarcharVec) checkVecBatch.getVectors()[2]; + Decimal128Vec checkDecimal128Vec = (Decimal128Vec) checkVecBatch.getVectors()[3]; + ShortVec checkShortVec = (ShortVec) checkVecBatch.getVectors()[4]; + for (int i = 0; i < ROW_COUNT; i++) { + assertEquals(i, checkLongVec.get(i)); + assertEquals(i, checkIntVec.get(i)); + assertEquals("test" + i, new String(checkVarCharVec.get(i))); + assertEquals(i, checkDecimal128Vec.get(i)[0]); + assertEquals(i + 1, checkDecimal128Vec.get(i)[1]); + assertEquals(i, checkShortVec.get(i)); + } + freeVecBatch(vecBatch); + freeVecBatch(checkVecBatch); + } + + @Test + public void testSerializeDirectoryVecContainsLongVec() { + // prepare vector batch + LongVec dictionary = new LongVec(ROW_COUNT); + for (int i = 0; i < ROW_COUNT; i++) { + dictionary.set(i, i); + } + DictionaryVec dictionaryVec = new DictionaryVec(dictionary, new int[]{1, 2, 1000}); + dictionary.close(); + VecBatch vecBatch = new VecBatch(new Vec[]{dictionaryVec}); + + // serialize + VecBatchSerializer serializer = VecBatchSerializerFactory.create(); + byte[] str = serializer.serialize(vecBatch); + + // deserialize + VecBatch checkVecBatch = serializer.deserialize(str); + + // check result + LongVec checkLongVec = (LongVec) checkVecBatch.getVectors()[0]; + assertEquals(3, checkLongVec.getSize()); + assertEquals(1, checkLongVec.get(0)); + assertEquals(2, checkLongVec.get(1)); + assertEquals(1000, checkLongVec.get(2)); + + freeVecBatch(vecBatch); + freeVecBatch(checkVecBatch); + } + + @Test + public void testSerializeDirectoryVecContainsVarcharVec() { + // prepare vector batch + VarcharVec dictionary = new VarcharVec(ROW_COUNT); + for (int i = 0; i < ROW_COUNT; i++) { + dictionary.set(i, ("test" + i).getBytes(StandardCharsets.UTF_8)); + } + DictionaryVec dictionaryVec = new DictionaryVec(dictionary, new int[]{1, 2, 1000}); + dictionary.close(); + VecBatch vecBatch = new VecBatch(new Vec[]{dictionaryVec}); + + // serialize + VecBatchSerializer serializer = VecBatchSerializerFactory.create(); + byte[] str = serializer.serialize(vecBatch); + + // deserialize + VecBatch checkVecBatch = serializer.deserialize(str); + + // check result + VarcharVec checkLongVec = (VarcharVec) checkVecBatch.getVectors()[0]; + assertEquals(3, checkLongVec.getSize()); + assertEquals("test1", new String(checkLongVec.get(0))); + assertEquals("test2", new String(checkLongVec.get(1))); + assertEquals("test1000", new String(checkLongVec.get(2))); + + freeVecBatch(vecBatch); + freeVecBatch(checkVecBatch); + } + + // dictionary vector don't support nested dictionary vector + @Test(enabled = false) + public void testSerializeNestedDirectoryVec() { + // prepare vector batch + VarcharVec dictionary = new VarcharVec(ROW_COUNT); + for (int i = 0; i < ROW_COUNT; i++) { + dictionary.set(i, ("test" + i).getBytes(StandardCharsets.UTF_8)); + } + DictionaryVec dictionaryVec = new DictionaryVec(dictionary, new int[]{1, 2, 3, 4, 5, 6, 7, 1000}); + dictionary.close(); + DictionaryVec nestedDictionaryVec = new DictionaryVec(dictionaryVec, new int[]{1, 2, 7}); + dictionaryVec.close(); + VecBatch vecBatch = new VecBatch(new Vec[]{nestedDictionaryVec}); + + // serialize + VecBatchSerializer serializer = VecBatchSerializerFactory.create(); + byte[] str = serializer.serialize(vecBatch); + + // deserialize + VecBatch checkVecBatch = serializer.deserialize(str); + + // check result + VarcharVec checkLongVec = (VarcharVec) checkVecBatch.getVectors()[0]; + assertEquals(3, checkLongVec.getSize()); + assertEquals("test2", new String(checkLongVec.get(0))); + assertEquals("test3", new String(checkLongVec.get(1))); + assertEquals("test1000", new String(checkLongVec.get(2))); + + freeVecBatch(vecBatch); + freeVecBatch(checkVecBatch); + } + + // container vector don't support copyPositions + @Test + public void testSerializeContainerVec() { + // prepare vector batch + LongVec longVec = new LongVec(ROW_COUNT); + IntVec intVec = new IntVec(ROW_COUNT); + VarcharVec varCharVec = new VarcharVec(ROW_COUNT); + Decimal128Vec decimal128Vec = new Decimal128Vec(ROW_COUNT); + ShortVec shortVec = new ShortVec(ROW_COUNT); + for (int i = 0; i < ROW_COUNT; i++) { + longVec.set(i, i); + intVec.set(i, i); + varCharVec.set(i, ("test" + i).getBytes(StandardCharsets.UTF_8)); + decimal128Vec.set(i, new long[]{i, i + 1}); + shortVec.set(i, (short) i); + } + long[] vecAddresses = new long[]{longVec.getNativeVector(), intVec.getNativeVector(), + varCharVec.getNativeVector(), decimal128Vec.getNativeVector(), shortVec.getNativeVector()}; + ContainerVec containerVec = new ContainerVec(vecAddresses.length, ROW_COUNT, vecAddresses, + new DataType[]{new LongDataType(), new IntDataType(), new VarcharDataType(20), + new Decimal128DataType(10, 1), new ShortDataType()}); + VecBatch vecBatch = new VecBatch(new Vec[]{containerVec}); + + // serialize + VecBatchSerializer serializer = VecBatchSerializerFactory.create(); + byte[] str = serializer.serialize(vecBatch); + + // deserialize + VecBatch checkVecBatch = serializer.deserialize(str); + + // check result + ContainerVec checkContainerVec = (ContainerVec) checkVecBatch.getVectors()[0]; + assertEquals(1024, checkContainerVec.getSize()); + LongVec checkLongVec = new LongVec(checkContainerVec.getVector(0)); + IntVec checkIntVec = new IntVec(checkContainerVec.getVector(1)); + VarcharVec checkVarCharVec = new VarcharVec(checkContainerVec.getVector(2)); + Decimal128Vec checkDecimal128Vec = new Decimal128Vec(checkContainerVec.getVector(3)); + ShortVec checkShortVec = new ShortVec(checkContainerVec.getVector(4)); + for (int i = 0; i < ROW_COUNT; i++) { + assertEquals(i, checkLongVec.get(i)); + assertEquals(i, checkIntVec.get(i)); + assertEquals("test" + i, new String(checkVarCharVec.get(i))); + assertEquals(i, checkDecimal128Vec.get(i)[0]); + assertEquals(i + 1, checkDecimal128Vec.get(i)[1]); + assertEquals(i, checkShortVec.get(i)); + } + + freeVecBatch(vecBatch); + freeVecBatch(checkVecBatch); + } + + // container vector don't support slice + @Test + public void testSerializeNestedContainerVec() { + // prepare vector batch + LongVec longVec = new LongVec(ROW_COUNT); + IntVec intVec = new IntVec(ROW_COUNT); + VarcharVec varCharVec = new VarcharVec(ROW_COUNT); + Decimal128Vec decimal128Vec = new Decimal128Vec(ROW_COUNT); + ShortVec shortVec = new ShortVec(ROW_COUNT); + for (int i = 0; i < ROW_COUNT; i++) { + longVec.set(i, i); + intVec.set(i, i); + varCharVec.set(i, ("test" + i).getBytes(StandardCharsets.UTF_8)); + decimal128Vec.set(i, new long[]{i, i + 1}); + shortVec.set(i, (short) i); + } + long[] vecAddresses = new long[]{longVec.getNativeVector(), intVec.getNativeVector(), + varCharVec.getNativeVector(), decimal128Vec.getNativeVector(), shortVec.getNativeVector()}; + ContainerVec containerVec = new ContainerVec(vecAddresses.length, ROW_COUNT, vecAddresses, + new DataType[]{new LongDataType(), new IntDataType(), new VarcharDataType(20), + new Decimal128DataType(10, 1), new ShortDataType()}); + ContainerVec nestedContainerVec = new ContainerVec(1, ROW_COUNT, new long[]{containerVec.getNativeVector()}, + new DataType[]{new ContainerDataType()}); + VecBatch vecBatch = new VecBatch(new Vec[]{nestedContainerVec}); + + // serialize + VecBatchSerializer serializer = VecBatchSerializerFactory.create(); + byte[] str = serializer.serialize(vecBatch); + + // deserialize + VecBatch checkVecBatch = serializer.deserialize(str); + + // check result + ContainerVec nestCheckContainerVec = (ContainerVec) checkVecBatch.getVectors()[0]; + assertEquals(1024, nestCheckContainerVec.getSize()); + ContainerVec checkContainerVec = new ContainerVec(nestCheckContainerVec.getVector(0)); + LongVec checkLongVec = new LongVec(checkContainerVec.getVector(0)); + IntVec checkIntVec = new IntVec(checkContainerVec.getVector(1)); + VarcharVec checkVarCharVec = new VarcharVec(checkContainerVec.getVector(2)); + Decimal128Vec checkDecimal128Vec = new Decimal128Vec(checkContainerVec.getVector(3)); + ShortVec checkShortVec = new ShortVec(checkContainerVec.getVector(4)); + for (int i = 0; i < ROW_COUNT; i++) { + assertEquals(i, checkLongVec.get(i)); + assertEquals(i, checkIntVec.get(i)); + assertEquals("test" + i, new String(checkVarCharVec.get(i))); + assertEquals(i, checkDecimal128Vec.get(i)[0]); + assertEquals(i + 1, checkDecimal128Vec.get(i)[1]); + assertEquals(i, checkShortVec.get(i)); + } + + freeVecBatch(vecBatch); + freeVecBatch(checkVecBatch); + } + + @Test + public void testSerializeSameVectorMultipleTimes() { + int size = 10; + LongVec col1 = new LongVec(size); + for (int i = 0; i < size; i++) { + col1.set(i, i); + } + + Vec[] vecs = new Vec[1]; + for (int count = 0; count < 2; count++) { + vecs[0] = col1; + VecBatch vecBatch = new VecBatch(vecs, size); + // serialize + VecBatchSerializer serializer = VecBatchSerializerFactory.create(); + byte[] str = serializer.serialize(vecBatch); + // deserialize + VecBatch resultVecBatch = serializer.deserialize(str); + Object[][] expectedDatas = {{0L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + vecBatch.close(); + freeVecBatch(resultVecBatch); + } + col1.close(); + } + + @Test + public void testSerializeCharVec() { + // prepare vector batch + VarcharVec dictionary = new VarcharVec(ROW_COUNT); + for (int i = 0; i < ROW_COUNT; i++) { + dictionary.set(i, ("test" + i).getBytes(StandardCharsets.UTF_8)); + } + DictionaryVec dictionaryVec = new DictionaryVec(dictionary, new int[]{1, 2, 1000}); + dictionary.close(); + VecBatch vecBatch = new VecBatch(new Vec[]{dictionaryVec}); + + // serialize + VecBatchSerializer serializer = VecBatchSerializerFactory.create(); + byte[] str = serializer.serialize(vecBatch); + + // deserialize + VecBatch checkVecBatch = serializer.deserialize(str); + + // check result + VarcharVec checkResultVec = (VarcharVec) checkVecBatch.getVectors()[0]; + assertEquals(3, checkResultVec.getSize()); + assertEquals("test1", new String(checkResultVec.get(0))); + assertEquals("test2", new String(checkResultVec.get(1))); + assertEquals("test1000", new String(checkResultVec.get(2))); + + freeVecBatch(vecBatch); + freeVecBatch(checkVecBatch); + } + + @Test + public void testSerializeSlicedVarCharVec() { + // prepare vector batch + VarcharVec varcharVec = new VarcharVec(ROW_COUNT); + for (int i = 0; i < ROW_COUNT; i++) { + varcharVec.set(i, ("test" + i).getBytes(StandardCharsets.UTF_8)); + } + + int size = 10; + VarcharVec slicedVec = varcharVec.slice(10, size); + + VecBatch vecBatch = new VecBatch(new Vec[]{slicedVec}); + + // serialize + VecBatchSerializer serializer = VecBatchSerializerFactory.create(); + byte[] str = serializer.serialize(vecBatch); + + // deserialize + VecBatch checkVecBatch = serializer.deserialize(str); + + // check result + VarcharVec checkResultVec = (VarcharVec) checkVecBatch.getVectors()[0]; + assertEquals(size, checkResultVec.getSize()); + for (int i = 0; i < size; i++) { + int tmp = i + 10; + assertEquals("test" + tmp, new String(checkResultVec.get(i))); + } + + varcharVec.close(); + freeVecBatch(vecBatch); + freeVecBatch(checkVecBatch); + } + + @Test + public void testSerializeSlicedVarCharVecWithNull() { + // prepare vector batch + VarcharVec varcharVec = new VarcharVec(ROW_COUNT); + for (int i = 0; i < ROW_COUNT; i++) { + if (i % 2 == 0) { + varcharVec.set(i, ("test" + i).getBytes(StandardCharsets.UTF_8)); + } else { + varcharVec.setNull(i); + } + } + + int positionOffset = 5; + int size = 10; + VarcharVec slicedVec = varcharVec.slice(positionOffset, size); + + VecBatch vecBatch = new VecBatch(new Vec[]{slicedVec}); + + // serialize + VecBatchSerializer serializer = VecBatchSerializerFactory.create(); + byte[] str = serializer.serialize(vecBatch); + + // deserialize + VecBatch checkVecBatch = serializer.deserialize(str); + + // check result + VarcharVec checkResultVec = (VarcharVec) checkVecBatch.getVectors()[0]; + assertEquals(size, checkResultVec.getSize()); + for (int i = 0; i < size; i++) { + if (i % 2 == positionOffset % 2) { + int tmp = i + positionOffset; + assertEquals("test" + tmp, new String(checkResultVec.get(i))); + } else { + assertTrue(checkResultVec.isNull(i)); + } + } + + varcharVec.close(); + freeVecBatch(vecBatch); + freeVecBatch(checkVecBatch); + } + + @Test + public void testSerializeSlicedVarCharVecWithFullNull() { + // prepare vector batch + VarcharVec varcharVec = new VarcharVec(ROW_COUNT); + for (int i = 0; i < ROW_COUNT; i++) { + varcharVec.setNull(i); + } + + int positionOffset = 5; + int size = 10; + VarcharVec slicedVec = varcharVec.slice(positionOffset, size); + + VecBatch vecBatch = new VecBatch(new Vec[]{slicedVec}); + + // serialize + VecBatchSerializer serializer = VecBatchSerializerFactory.create(); + byte[] str = serializer.serialize(vecBatch); + + // deserialize + VecBatch checkVecBatch = serializer.deserialize(str); + + // check result + VarcharVec checkResultVec = (VarcharVec) checkVecBatch.getVectors()[0]; + assertEquals(size, checkResultVec.getSize()); + for (int i = 0; i < size; i++) { + assertTrue(checkResultVec.isNull(i)); + } + + varcharVec.close(); + freeVecBatch(vecBatch); + freeVecBatch(checkVecBatch); + } + + @Test + public void testSerializeSlicedVarCharVecWithEmptyValue() { + // prepare vector batch + VarcharVec varcharVec = new VarcharVec(ROW_COUNT); + for (int i = 0; i < ROW_COUNT; i++) { + varcharVec.set(i, (" ").getBytes(StandardCharsets.UTF_8)); + } + + int positionOffset = 5; + int size = 10; + VarcharVec slicedVec = varcharVec.slice(positionOffset, size); + + VecBatch vecBatch = new VecBatch(new Vec[]{slicedVec}); + + // serialize + VecBatchSerializer serializer = VecBatchSerializerFactory.create(); + byte[] str = serializer.serialize(vecBatch); + + // deserialize + VecBatch checkVecBatch = serializer.deserialize(str); + + // check result + VarcharVec checkResultVec = (VarcharVec) checkVecBatch.getVectors()[0]; + assertEquals(size, checkResultVec.getSize()); + for (int i = 0; i < size; i++) { + assertEquals(" ", new String(checkResultVec.get(i))); + } + + varcharVec.close(); + freeVecBatch(vecBatch); + freeVecBatch(checkVecBatch); + } + + @Test + public void testSerializeVectorSizeReset() { + int size = 1000; + LongVec col1 = new LongVec(size); + for (int i = 0; i < size; i++) { + col1.set(i, i); + } + col1.setSize(5); + Vec[] vecs = new Vec[]{col1}; + VecBatch vecBatch = new VecBatch(vecs, size); + // serialize + VecBatchSerializer serializer = VecBatchSerializerFactory.create(); + byte[] str = serializer.serialize(vecBatch); + // deserialize + VecBatch resultVecBatch = serializer.deserialize(str); + Object[][] expectedDatas = {{0L, 1L, 2L, 3L, 4L}}; + assertVecBatchEquals(resultVecBatch, expectedDatas); + + freeVecBatch(vecBatch); + freeVecBatch(resultVecBatch); + } + + @Test + public void testSerializeVarcharVecWithNull() { + // prepare vector batch + int row = 10; + VarcharVec vec = new VarcharVec(row); + for (int i = 0; i < row; i++) { + if (i % 2 == 0) { + vec.set(i, ("test" + i).getBytes(StandardCharsets.UTF_8)); + } else { + vec.setNull(i); + } + } + VecBatch vecBatch = new VecBatch(new Vec[]{vec}); + + // serialize + VecBatchSerializer serializer = VecBatchSerializerFactory.create(); + byte[] str = serializer.serialize(vecBatch); + + // deserialize + VecBatch checkVecBatch = serializer.deserialize(str); + + // check result + VarcharVec checkResultVec = (VarcharVec) checkVecBatch.getVector(0); + assertEquals(row, checkResultVec.getSize()); + for (int i = 0; i < row; i++) { + if (i % 2 == 0) { + assertEquals("test" + i, new String(checkResultVec.get(i))); + } else { + assertTrue(checkResultVec.isNull(i)); + } + } + + freeVecBatch(vecBatch); + freeVecBatch(checkVecBatch); + } + + @Test + public void testSerializeWithSetDataType() { + int row = 5; + IntVec data32 = new IntVec(row); + VecUtil.setDataType(data32, DATE32); + data32.put(new int[]{1, 2, 3, 4, 5}, 0, 0, row); + LongVec data64 = new LongVec(row); + VecUtil.setDataType(data64, DATE64); + data64.put(new long[]{1, 2, 3, 4, 5}, 0, 0, row); + LongVec decimal64 = new LongVec(row); + VecUtil.setDataType(decimal64, DECIMAL64); + decimal64.put(new long[]{1, 2, 3, 4, 5}, 0, 0, row); + VarcharVec charVec = new VarcharVec(row); + VecUtil.setDataType(charVec, CHAR); + charVec.put(0, "12345".getBytes(StandardCharsets.UTF_8), 0, new int[]{0, 1, 2, 3, 4, 5}, 0, row); + + DoubleVec doubleVec = new DoubleVec(row); + doubleVec.put(new double[]{1.1, 2.2, 3.3, 4.4, 5.5}, 0, 0, row); + BooleanVec booleanVec = new BooleanVec(row); + booleanVec.put(new boolean[]{true, false, true, false, true}, 0, 0, row); + + VecBatch vecBatch = new VecBatch(new Vec[]{data32, data64, decimal64, charVec, doubleVec, booleanVec}); + + // serialize + VecBatchSerializer serializer = VecBatchSerializerFactory.create(); + byte[] serialized = serializer.serialize(vecBatch); + + // deserialize + VecBatch checkVecBatch = serializer.deserialize(serialized); + + // check result + Object[][] expectedDatas = {{1, 2, 3, 4, 5}, {1L, 2L, 3L, 4L, 5L}, {1L, 2L, 3L, 4L, 5L}, + {"1", "2", "3", "4", "5"}, {1.1D, 2.2D, 3.3D, 4.4D, 5.5D}, {true, false, true, false, true}}; + assertVecBatchEquals(checkVecBatch, expectedDatas); + + freeVecBatch(vecBatch); + freeVecBatch(checkVecBatch); + } + + @Test(expectedExceptions = IllegalStateException.class, expectedExceptionsMessageRegExp = "Unexpected data type: " + + "OMNI_INVALID") + public void testSerializeInvalidType() { + int row = 5; + IntVec invalidType = new IntVec(row); + VecUtil.setDataType(invalidType, INVALID); + VecBatch vecBatch = new VecBatch(new Vec[]{invalidType}); + VecBatchSerializer serializer = VecBatchSerializerFactory.create(); + try { + serializer.serialize(vecBatch); + } finally { + freeVecBatch(vecBatch); + } + } + + @Test(expectedExceptions = OmniRuntimeException.class, expectedExceptionsMessageRegExp = "deserialize failed.null") + public void deserializeInvalid() { + VecBatchSerializer serializer = VecBatchSerializerFactory.create(); + serializer.deserialize("invalid".getBytes(StandardCharsets.UTF_8)); + } + + @Test + public void testSerializeDecimal128Vec() { + int row = 8; + IntVec intVec = new IntVec(row); + VarcharVec varcharVec = new VarcharVec(row); + Decimal128Vec decimal128Vec = new Decimal128Vec(row); + for (int i = 0; i < row; i++) { + intVec.set(i, i); + varcharVec.set(i, ("test" + i).getBytes(StandardCharsets.UTF_8)); + decimal128Vec.set(i, new long[]{i, i + 1}); + } + Vec[] vecArray = new Vec[]{intVec, varcharVec, decimal128Vec}; + + long[] nativeVectors = new long[vecArray.length]; + long[] nativeVectorValueBufAddresses = new long[vecArray.length]; + long[] nativeVectorNullBufAddresses = new long[vecArray.length]; + long[] nativeVectorOffsetBufAddresses = new long[]{0, varcharVec.getOffsetsBuf().getAddress(), 0}; + int[] encodings = new int[vecArray.length]; + int[] dataTypeIds = new int[vecArray.length]; + for (int i = 0; i < vecArray.length; i++) { + nativeVectors[i] = vecArray[i].getNativeVector(); + nativeVectorValueBufAddresses[i] = vecArray[i].getValuesBuf().getAddress(); + nativeVectorNullBufAddresses[i] = vecArray[i].getValueNullsBuf().getAddress(); + encodings[i] = vecArray[i].getEncoding().ordinal(); + dataTypeIds[i] = vecArray[i].getType().getId().ordinal(); + } + VecBatch vecBatch = new VecBatch(vecArray, row); + VecBatch vecBatchFromNative = new VecBatch(vecBatch.getNativeVectorBatch(), nativeVectors, + nativeVectorValueBufAddresses, nativeVectorNullBufAddresses, nativeVectorOffsetBufAddresses, encodings, + dataTypeIds, row); + VecBatchSerializer serializer = VecBatchSerializerFactory.create(); + byte[] vecBatchSerialized = serializer.serialize(vecBatchFromNative); + VecBatch vecBatchDeserialized = serializer.deserialize(vecBatchSerialized); + assertVecBatchEquals(vecBatchDeserialized, vecBatchFromNative); + assertVecBatchEquals(vecBatch, vecBatchFromNative); + + freeVecBatch(vecBatch); + freeVecBatch(vecBatchDeserialized); + } +} diff --git a/build.sh b/build.sh new file mode 100644 index 0000000..8aace8c --- /dev/null +++ b/build.sh @@ -0,0 +1,137 @@ +#!/bin/bash +# build file for OmniOperatorJit +# Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + +set -e + +source $(cd $(dirname ${BASH_SOURCE[0]}) && pwd)/env_check.sh + +TARGZ_NAME=boostkit-omniop-operator-1.9.0-aarch64 +ZIP_NAME=BoostKit-omniop_1.9.0 + +# if either help or --help is provided, the usage should be printed prior to exit +if [ "$1" = 'help' ] || [ "$1" = '--help' ]; then + print_usage + exit 0 +fi + +# check if required env vars are set +check_java_home +check_omni_home + +### main build begins here ### +echo "Start building modules using $1" +echo "-- Enter" $(dirname $(readlink -f $0)) + +# save working directory +CWD=$(pwd) +# check for $1 param +case "$1" in + package) + setup_dependencies package + + echo "-- Package without test" + cd ${CWD} && build release:java --exclude-test + + cd $CWD/bindings/java && mvn clean install -Domni.home=$OMNI_HOME -DskipTests + cd $CWD/core/src/udf/java && mvn clean install -DskipTests + + cd $CWD + # clean environment + [ -d "$TARGZ_NAME" ] && rm -rf $TARGZ_NAME + [ -f "$TARGZ_NAME.tar.gz" ] && rm -rf $TARGZ_NAME.tar.gz + [ -f "$ZIP_NAME.zip" ] && rm -rf $ZIP_NAME.zip + + cp -r $OMNI_HOME/lib $TARGZ_NAME + cp $CWD/bindings/java/target/*.jar $TARGZ_NAME + cp $CWD/core/src/udf/java/target/*.jar $TARGZ_NAME + tar --owner root --group root -zcvf $TARGZ_NAME.tar.gz $TARGZ_NAME + zip $ZIP_NAME.zip $TARGZ_NAME.tar.gz + ;; + svepackage) + setup_dependencies package + + echo "-- Package without test, with sve" + cd ${CWD} && build release:java --exclude-test --enable-sve + + cd $CWD/bindings/java && mvn clean install -Domni.home=$OMNI_HOME -DskipTests + cd $CWD/core/src/udf/java && mvn clean install -DskipTests + + cd $CWD + # clean environment + [ -d "$TARGZ_NAME" ] && rm -rf $TARGZ_NAME + [ -f "$TARGZ_NAME.tar.gz" ] && rm -rf $TARGZ_NAME.tar.gz + [ -f "$ZIP_NAME.zip" ] && rm -rf $ZIP_NAME.zip + + cp -r $OMNI_HOME/lib $TARGZ_NAME + cp $CWD/bindings/java/target/*.jar $TARGZ_NAME + cp $CWD/core/src/udf/java/target/*.jar $TARGZ_NAME + tar --owner root --group root -zcvf $TARGZ_NAME.tar.gz $TARGZ_NAME + zip $ZIP_NAME.zip $TARGZ_NAME.tar.gz + ;; + release) + setup_dependencies release + + echo "-- Only build" + cd ${CWD} && build release:java --exclude-test + + cd $CWD/bindings/java && mvn clean install -Domni.home=$OMNI_HOME -DskipTests + cd $CWD/core/src/udf/java && mvn clean install -DskipTests + ;; + test) + setup_dependencies + + echo "-- Enable build and test" + cd ${CWD} && build release:java + $CWD/build/core/test/omtest --gtest_output=xml:test_detail.xml + + cd $CWD/bindings/java && mvn clean install -Domni.home=$OMNI_HOME + cd $CWD/core/src/udf/java && mvn clean install + ;; + coverage-java) + echo "-- Enable coverage for java" + cd ${CWD} && build release:java + + cd $CWD/bindings/java && mvn clean install devtestcov:atest -Domni.home=$OMNI_HOME -Dactive.devtest=true -Dmaven.test.failure.ignore=true -Djacoco-agent.destfile=target/jacoco.exec -Dmaven.wagon.http.ssl.insecure=true -Dmaven.wagon.http.ssl.allowall=true + cd $CWD/core/src/udf/java && mvn clean install + ;; + coverage-c++) + echo "-- Enable coverage for c++" + cd ${CWD} && build coverage:java + $CWD/build/core/test/omtest --gtest_output=xml:${CWD}/core/build/test_detail.xml + + lcov --d $CWD/build --c --output-file test.info --rc lcov_branch_coverage=1 + lcov --remove test.info '*/opt/buildtools/include/*' '*/usr/include/*' '*/usr/lib/*' '*/usr/lib64/*' '*/usr/local/include/*' '*/usr/local/lib/*' '*/usr/local/lib64/*' '*/test/*' -o final.info --rc lcov_branch_coverage=1 + genhtml final.info -o ${CWD}/core/build/test_coverage --branch-coverage --rc lcov_branch_coverage=1 + + ;; + coverage) + setup_dependencies package + + echo "-- Package asan without test" + cd ${CWD} && build coverage:java --exclude-test + + cd $CWD/bindings/java && mvn clean install -Domni.home=$OMNI_HOME -DskipTests + cd $CWD/core/src/udf/java && mvn clean install -DskipTests + + cd $CWD + # clean environment + [ -d "$TARGZ_NAME" ] && rm -rf $TARGZ_NAME + [ -f "$TARGZ_NAME.tar.gz" ] && rm -rf $TARGZ_NAME.tar.gz + [ -f "$ZIP_NAME.zip" ] && rm -rf $ZIP_NAME.zip + + cp -r $OMNI_HOME/lib $TARGZ_NAME + cp $CWD/bindings/java/target/*.jar $TARGZ_NAME + cp $CWD/core/src/udf/java/target/*.jar $TARGZ_NAME + tar --owner root --group root -zcvf $TARGZ_NAME.tar.gz $TARGZ_NAME + zip $ZIP_NAME.zip $TARGZ_NAME.tar.gz + ;; + *) + echo "-- Enable default options" + cd ${CWD} && build release:java + $CWD/build/core/test/omtest --gtest_output=xml:${CWD}/core/build/test_detail.xml + + cd $CWD/bindings/java && mvn clean -Domni.home=$OMNI_HOME install + cd $CWD/core/src/udf/java && mvn clean install + ;; +esac diff --git a/build_scripts/build.sh b/build_scripts/build.sh new file mode 100644 index 0000000..0f19b39 --- /dev/null +++ b/build_scripts/build.sh @@ -0,0 +1,102 @@ +#!/bin/bash +# build file for OmniOperatorJit +# Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved. + +set -e + +source $(cd $(dirname ${BASH_SOURCE[0]}) && pwd)/env_check.sh + +# if either help or --help is provided, the usage should be printed prior to exit +if [ "$1" = 'help' ] || [ "$1" = '--help' ]; then + print_help + exit 0 +fi + +# run the checks to ensure the prerequisites are ready +check_set_prerequisites + +# init variables +BINDING_TARGET_EXPR='boostkit-omniop-\1-binding-1.9.0-aarch64' +CWD=$(pwd) +OPTIONS="" +TARGETS="--target all" + +if [ $# = 0 ]; then + # if no params are passed, default to Release build + echo "-- Enable Release" + OPTIONS="-DCMAKE_BUILD_TYPE=Release" +else + # $1 has build type and targets to built + target_build_type=$1 + build_type=${target_build_type%%:*} + build_targets=$(test $(echo $target_build_type | grep ':' | wc -l) != 0 && echo "${target_build_type#*:}" || echo "") + + if [ ! -z "$build_targets" ]; then + TARGETS+=" $(echo ":$build_targets" | sed "s@:\([^\:]\+\)@ --target ${BINDING_TARGET_EXPR} @g" )" + fi + + if [ "$build_type" = 'debug' ]; then + echo "-- Enable Debug" + OPTIONS+=" -DCMAKE_BUILD_TYPE=Debug -DDEBUG=ON" + elif [ "$build_type" = 'trace' ]; then + echo "-- Enable Trace" + OPTIONS+=" -DCMAKE_BUILD_TYPE=Debug -DTRACE=ON" + elif [ "$build_type" = 'coverage' ]; then + echo "-- Enable Coverage" + OPTIONS+=" -DCMAKE_BUILD_TYPE=Debug -DCOVERAGE=ON" + elif [ "$build_type" = 'release' ]; then + echo "-- Enable Release" + OPTIONS+=" -DCMAKE_BUILD_TYPE=Release" + else + exit_with_message_and_print_help "ERROR: Invalid type: $build_type" + fi + + # $2 has build options + if [ "$2" = "all" ]; then + echo "-- Enable All Module Debug, Include: OPERATOR,VECTOR,LLVM" + OPTIONS+=" -DDEBUG_OPERATOR=ON -DDEBUG_VECTOR=ON -DDEBUG_LLVM=ON" + else + for i in ${*:2} ; do + if [ "$i" == 'op' ] || [ "$i" == '--enable-operator-debug' ]; then + echo "-- Enable Operator Debug" + OPTIONS+=" -DDEBUG_OPERATOR=ON" + elif [ "$i" == 'vec' ] || [ "$i" == '--enable-vector-debug' ]; then + echo "-- Enable Vector Debug" + OPTIONS+=" -DDEBUG_VECTOR=ON" + elif [ "$i" == 'llvm' ] || [ "$i" == '--enable-llvm-debug' ]; then + echo "-- Enable LLVM Debug" + OPTIONS+=" -DDEBUG_LLVM=ON" + elif [ "$i" == '--disable-cpuchecker' ]; then + echo "-- Disable CPU checker" + OPTIONS+=" -DDISABLE_CPU_CHECKER=ON" + elif [ "$i" == '--enable-dt' ]; then + [ "$build_type" != 'coverage' ] && exit_with_message "-- Please use coverage with --enable-dt" + echo "-- Enable DT checker" + OPTIONS+=" -DENABLE_DT=ON -DCOVERAGE=ON" + elif [ "$i" == '--enable-benchmark' ]; then + echo "-- Enable benchmark" + OPTIONS+=" -DENABLE_BENCHMARK=ON" + elif [ "$i" == '--enable-compile-time-report' ]; then + echo " --Enable Compile Time Report" + OPTIONS+=" -DENABLE_COMPILE_TIME_REPORT=ON" + elif [ "$i" == '--exclude-test' ]; then + echo "-- Exclude Test Source" + OPTIONS+=" -DEXCLUDE_TEST=ON" + elif [ "$i" == '--enable-sve' ]; then + echo "-- ENABLE SVE" + OPTIONS+=" -DENABLE_SVE=ON" + else + exit_with_message_and_print_help "ERROR: Invalid option: $i" + fi + done + fi +fi + +print_gcc_lib + +# need to delete the CMakeCache.txt to refresh the options +rm -rf $CWD/build/CMakeCache.txt && cmake -S $(cd $(dirname ${BASH_SOURCE[0]})/.. && pwd) -B $CWD/build $OPTIONS +# use all available cpu cores to speed up build process +cmake --build $CWD/build --clean-first $TARGETS -j $(test -z "${OMNI_COMPILER_THREAD_COUNT}" && echo $(nproc) || echo ${OMNI_COMPILER_THREAD_COUNT}) +# install requires root privilege +cmake --install $CWD/build diff --git a/build_scripts/env_check.sh b/build_scripts/env_check.sh new file mode 100644 index 0000000..0e2e01e --- /dev/null +++ b/build_scripts/env_check.sh @@ -0,0 +1,97 @@ +#!/bin/bash +# check env for OmniOperatorJit +# Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved. + +set -e + +print_gcc_lib() { + gcc -print-search-dirs | sed '/^lib/b 1;d;:1;s,/[^/.][^/]*/\.\./,/,;t 1;s,:[^=]*=,:;,;s,;,; ,g' | tr \; \\012 +} + +# print help manual +print_help() { + echo " + Usage: + build.sh [type:binding-target] [options] + + Binding Targets: + java = Java binding library + + Types: + debug = Enable Debug + release = Enable Release + trace = Enable Trace + coverage = Enable Coverage + + Options: + all = Enable All Module Debug, Include: OPERATOR,VECTOR,LLVM + op, --enable-operator-debug = Enable Operator Debug + vec, --enable-vector-debug = Enable Vector Debug + llvm, --enable-llvm-debug = Enable LLVM Debug + --disable-cpuchecker = Disable CPU checker + --enable-dt = Enable DT checker + --exclude-test = Exclude Test Source + " +} + +# if JAVA_HOME is not set, +# prompt to set to detected location before exiting +check_java_home() { + if [ -z "$JAVA_HOME" ]; then + local java=$(which java|xargs readlink -f) + + echo "ERROR: JAVA_HOME is not set!" + echo "If it's ${java%/bin/java}," + echo "you can set it as follows:" + echo "export JAVA_HOME=${java%/bin/java}" + echo "" + echo "Please set JAVA_HOME and try again" + + exit 1 + fi +} + +# if OMNI_HOME is not set, +# prompt to set to suggested location before exiting +check_omni_home() { + if [ -z "$OMNI_HOME" ]; then + echo "ERROR: OMNI_HOME is not set!" + echo "You can set it as follows:" + echo "for system level configuration (requiring root privilege)," + echo "export OMNI_HOME=/opt" + echo "or for user level configuration," + echo "export OMNI_HOME=$HOME/opt" + echo "" + echo "Please set OMNI_HOME and try again" + + exit 1 + fi +} + +check_set_prerequisites() { + check_java_home + check_omni_home + + echo "OMNI_HOME = $OMNI_HOME" + LIB_HOME=$OMNI_HOME/lib + [ ! -d "$LIB_HOME" ] && mkdir -p $LIB_HOME + + echo "LIB_HOME = $LIB_HOME, LD_LIBRARY_PATH = $LD_LIBRARY_PATH" + + rm -rf $LIB_HOME/libboostkit*.so $OMNI_HOME/*-binding +} + +exit_with_message_and_print_help() +{ + echo "" + echo "$1" + print_help + exit 1 +} + +exit_with_message() +{ + echo "" + echo "$1" + exit 1 +} \ No newline at end of file diff --git a/core/CMakeLists.txt b/core/CMakeLists.txt new file mode 100644 index 0000000..41a4532 --- /dev/null +++ b/core/CMakeLists.txt @@ -0,0 +1,15 @@ +set(SOURCE_ROOT ${CMAKE_CURRENT_LIST_DIR}) + +add_subdirectory(src) + +# unit test +if(NOT EXCLUDE_TEST) + enable_testing() + add_subdirectory(test) +endif() + +# configure file +configure_file( + "${CMAKE_CURRENT_LIST_DIR}/config.h.in" + "${CMAKE_CURRENT_LIST_DIR}/config.h" +) diff --git a/core/config.h.in b/core/config.h.in new file mode 100644 index 0000000..4698b12 --- /dev/null +++ b/core/config.h.in @@ -0,0 +1,8 @@ +#cmakedefine DEBUG +#cmakedefine TRACE +#cmakedefine DEBUG_OPERATOR +#cmakedefine DEBUG_VECTOR +#cmakedefine DEBUG_LLVM +#cmakedefine COVERAGE +#cmakedefine DISABLE_CPU_CHECKER +#cmakedefine ENABLE_BENCHMARK diff --git a/core/secDTFuzz/Dockerfile_build b/core/secDTFuzz/Dockerfile_build new file mode 100644 index 0000000..42fc9cb --- /dev/null +++ b/core/secDTFuzz/Dockerfile_build @@ -0,0 +1,2 @@ +FROM swr.cn-east-204-dev.myhuaweicloud.com/compute-fuzz-secdtfuzz/boostkit_openeuler_arm_20220424_new:v_omni_operator_10.0 +ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${SRC_ROOT}/tools/DTFrame/dist/opencv/lib:/out/dist/opencv/lib \ No newline at end of file diff --git a/core/secDTFuzz/Dockerfile_run b/core/secDTFuzz/Dockerfile_run new file mode 100644 index 0000000..8c2d32e --- /dev/null +++ b/core/secDTFuzz/Dockerfile_run @@ -0,0 +1 @@ +FROM swr.cn-east-204-dev.myhuaweicloud.com/compute-fuzz-secdtfuzz/boostkit_openeuler_arm_20220424_new:v_omni_operator_10.0 \ No newline at end of file diff --git a/core/secDTFuzz/SecDTFuzz.yaml b/core/secDTFuzz/SecDTFuzz.yaml new file mode 100644 index 0000000..b417611 --- /dev/null +++ b/core/secDTFuzz/SecDTFuzz.yaml @@ -0,0 +1,15 @@ +apiVersion: OmniOperatorJIT +kind: dtfuzz +fuzzEngine: secodefuzz +name: OmniOperatorJIT +language: C++ +parameters: + fuzz_args: -fuzz.C count -fuzz.T time -fuzz.rp reportPath -fuzz.cf corpusPath + coverage_args: -fuzz.C 1000 -fuzz.T 10 -fuzz.rp reportPath -fuzz.cf corpusPath + reproduce_args: -fuzz.reproduce 1 -fuzz.reporducePath testcaseName -fuzz.rp reportPath -fuzz.cf corpusPath +spec: +- name: suite-01.case1 + bin: /out/dist + cmd: LD_PRELOAD=/opt/buildtools/gcc-7.3.0/lib64/libasan.so /out/dist/bin/dt_engine --config /src/OmniOperatorJIT/core/test/dt/testtree/dtframe.cfg -f.select='suite-01.GlobalFuzz($FuzzData, $loopCount, $filterExpr, $opCnt, $chooseFunc)' ${args} + fun: case1 + version: 0.0.10 \ No newline at end of file diff --git a/core/secDTFuzz/build.sh b/core/secDTFuzz/build.sh new file mode 100644 index 0000000..3f62a5d --- /dev/null +++ b/core/secDTFuzz/build.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# build file for OmniOperatorJit +# Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + +set -e +set -x +echo "build.sh+" + +SCRIPT_DIR=$(dirname "$(readlink -f "$0")") +SRC_ROOT="$(cd $SCRIPT_DIR/../..;pwd)" + +echo ${SRC_ROOT} +mkdir ${SRC_ROOT}/tools +cp -rf /SecDTFuzz/DTFrame ${SRC_ROOT}/tools/ + +cd ${SRC_ROOT}/tools/DTFrame +cd build +cmake -DCMAKE_BUILD_TYPE=Debug -DENABLE_PRODUCT=product .. && make -j 2 && make install +cp -r ${SRC_ROOT}/tools/DTFrame/dist ${SRC_ROOT}/core/test/dt/testtree + +cd ${SRC_ROOT} +mkdir omni-operator +cp -r /usr1/huawei_secure_c/lib omni-operator +cp -r /usr1/huawei_secure_c/include omni-operator/lib + +dos2unix core/build/build.sh +bash core/build/build.sh coverage --enable-dt + +cp -r ${SRC_ROOT}/core/test/dt/testtree/dist/ /out +cp -r ${SRC_ROOT}/core/test/dt/testtree/cases/ /out +echo "build.sh-" \ No newline at end of file diff --git a/core/src/CMakeLists.txt b/core/src/CMakeLists.txt new file mode 100644 index 0000000..91c735f --- /dev/null +++ b/core/src/CMakeLists.txt @@ -0,0 +1,21 @@ +if (DEFINED ENABLE_DT) + if (${ENABLE_DT} STREQUAL "ON") + message(STATUS "ENABLE_DT ON: enable -fsanitize=address -fsanitize=undefined -fsanitize-coverage=trace-pc") + add_compile_options(-fprofile-arcs -ftest-coverage -fdump-rtl-expand) + endif () +endif () + +add_subdirectory(simd) +add_subdirectory(memory) +add_subdirectory(operator) +add_subdirectory(plannode) +add_subdirectory(util) +add_subdirectory(vector) +add_subdirectory(codegen) +add_subdirectory(expression) +add_subdirectory(type) +add_subdirectory(cpu_checker) +add_subdirectory(udf/cplusplus) +add_subdirectory(compute) + +include_directories(/usr/lib/llvm-15/include) diff --git a/core/src/README.MD b/core/src/README.MD new file mode 100644 index 0000000..8a6c9e9 --- /dev/null +++ b/core/src/README.MD @@ -0,0 +1,110 @@ +# joy + +Joy is project hoping to make it easier to create high performance data processing logic. + +* Joy leverage LLVM to create the code dynamically. +* Joy doesn't define any intermediate IR, since +1. SQL can be the IR +2. the number of operations in SQL is not that many, we can create optimised operators for all of them +* Joy provides a API, hopefully simplified, , without requiring LLVM knowledge, to create and optimise the needed operators + +#### Introduction +Bring happiness to data processing. :) + +API overview: + +table api +groupby api + +code gen api + +#### Architecture + + + +#### Usage + +###### requirements: +1. provide high performance `atom operators`, which can be combined into `task` +2. SqlJit compiler `fusion` capability to performan `task` level optimization such as Weld, + optimizations: + * dynamic vector size: the compiler should take into account the CPU capabilities and decide on for example `vector size`, to leverage SIMD and at the same reduce CPU cache miss + * type compaction, requires statistics, for example long -> int -> short, phone -> long +3. ensure cacheline alignment + +Use TPC-H Q1 as an example to create the group by aggregator functionality + + +the purpose of this project is to create a sql processing engine using miri + +https://github.com/rust-lang/miri +Rust miri provides a mid-level IR which can be used to interpret and run rust code + +The idea is to compile sql into rust closure and run using miri + +Weld: maintains it's own language +Joy: use closure instead + +1. crate a new SQL JIT compiler leveraging llvm-sys +2. take over optimizer from Weld? + +Weld: + +1. use closure syntax with its own parser -- lots of code to maintain +2. require more time to parse and visit the AST, which could impact the JIT performance +3. We can potentially borrow the optimizer passes and implement more + + +##### Data Types + +The type system should be as transparent as possible, ideally we should be able to +use the native data types such as `i32`, `i64`, `f32`, `u8` directly. + +Column::create() + +##### Code Gen +1. Direct LLVM code gen without parser to reduce code gen latency +2. direct optimization of the code to reduce the time needed for optimization pass +3. expose high level data processing codegen API for the community to create optimized data processing logic (vs Weld expose IR) + +MCJit?? --> ORCJit (On request compiler) + +##### Code Gen Simplified + +The Joy project target to + +1. provide a codegen framework requires NO knowledge of LLVM +2. zero overhead codegen: framework should not bring any overhead to the generated code. + + + +1. A framework which provides a `uniop` (single input) and a `binop` (2 input) +can we provide a trait for each of the operator type? +how is the trait plug into the codegen? +what's the benefit of using codegen for join? +2. built-in Vectorize input support +3. pluggable logic such as groupby, join + +###### Debug using Visual Studio Code +1. Install the RUST and LLDB plugin for the vscode +2. Config the debug launch.json and input attach under `configurations` in launch.json. The `LLDB attach` is automatically displayed. +3. In the debug panel, click `Launch` in debug window and select the process to debug. +4. You can add breakpoints in the vscode and debug the code. + + +##### gen() +* The gen function provides boilerplate code which loops over each row. +* allows generate code while looping over each row +* allows composition of generated code processing each row + +the context of the generated code: + 1. has access to all of the columns + 2. knows what columns needed is needed + 3. all columns are access via column index + 3. which column to store the output + +##### C++ Building +* Build with llvm options(-S -O3 -emit-llvm -fno-discard-value-names) +> ./build.sh release +* Build without llvm option +> ./build.sh debug diff --git a/core/src/codegen/CMakeLists.txt b/core/src/codegen/CMakeLists.txt new file mode 100644 index 0000000..7f6e551 --- /dev/null +++ b/core/src/codegen/CMakeLists.txt @@ -0,0 +1,14 @@ +file(GLOB_RECURSE CODE_GEN_LIST ${CMAKE_CURRENT_LIST_DIR}/*.cpp) +set(CODEGEN_TARGET ${OMNI_CODEGEN_SO}) +add_library(${CODEGEN_TARGET} SHARED ${CODE_GEN_LIST}) + +# from command 'llvm-config-15 --cxxflags --ldflags --libs' +target_include_directories(${CODEGEN_TARGET} PUBLIC /usr/lib/llvm-15/include) +target_compile_options(${CODEGEN_TARGET} PRIVATE -D_GNU_SOURCE -D__STDC_CONSTANT_MACROS -D__STDC_FORMAT_MACROS -D__STDC_LIMIT_MACROS) +target_link_libraries(${CODEGEN_TARGET} PRIVATE LLVM-15 securec ${OMNI_VECTOR_SO} util expression udf) +install(TARGETS ${CODEGEN_TARGET} DESTINATION ${CMAKE_INSTALL_PREFIX}) + +file(GLOB CODEGEN_HEAD_FILES ${SOURCE_ROOT}/src/codegen/*.h) +file(GLOB CODEGEN_FUNCTIONS_HEAD_FILES ${SOURCE_ROOT}/src/codegen/functions/*.h) +install(FILES ${CODEGEN_HEAD_FILES} DESTINATION ${CMAKE_INSTALL_PREFIX}/include/codegen) +install(FILES ${CODEGEN_FUNCTIONS_HEAD_FILES} DESTINATION ${CMAKE_INSTALL_PREFIX}/include/codegen/functions) \ No newline at end of file diff --git a/core/src/codegen/batch_codegen_context.h b/core/src/codegen/batch_codegen_context.h new file mode 100644 index 0000000..303516d --- /dev/null +++ b/core/src/codegen/batch_codegen_context.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: batch codegen context + */ + +#ifndef OMNI_RUNTIME_BATCH_CODEGEN_CONTEXT_H +#define OMNI_RUNTIME_BATCH_CODEGEN_CONTEXT_H + +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Value.h" + +namespace omniruntime::codegen { +class BatchCodegenContext { +public: + explicit BatchCodegenContext() + : data(nullptr), + nullBitmap(nullptr), + offsets(nullptr), + rowCnt(nullptr), + rowIdxArray(nullptr), + executionContext(nullptr), + dictionaryVectors(nullptr) + {} + + explicit BatchCodegenContext(llvm::Value *data, llvm::Value *nullBitmap, llvm::Value *offsets, llvm::Value *rowIdx, + llvm::Value *rowIdxArray, llvm::Value *executionContext, llvm::Value *dictionaryVectors) + : data(data), + nullBitmap(nullBitmap), + offsets(offsets), + rowCnt(rowIdx), + rowIdxArray(rowIdxArray), + executionContext(executionContext), + dictionaryVectors(dictionaryVectors) + {} + + ~BatchCodegenContext() = default; + + friend class BatchExpressionCodeGen; + + friend class CodegenBase; + +private: + llvm::Value *data; + llvm::Value *nullBitmap; + llvm::Value *offsets; + llvm::Value *rowCnt; + llvm::Value *rowIdxArray; + llvm::Value *executionContext; + llvm::Value *dictionaryVectors; +}; +} +#endif // OMNI_RUNTIME_BATCH_CODEGEN_CONTEXT_H diff --git a/core/src/codegen/batch_expression_codegen.cpp b/core/src/codegen/batch_expression_codegen.cpp new file mode 100644 index 0000000..fb988ae --- /dev/null +++ b/core/src/codegen/batch_expression_codegen.cpp @@ -0,0 +1,1730 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: batch expression codegen + */ +#include "batch_expression_codegen.h" + +namespace omniruntime::codegen { +namespace { +const int BATCH_EXPRFUNC_ROWCNT_INDEX = 3; +const int BATCH_EXPRFUNC_OUT_LENGTH_ARG_INDEX = 5; +const int BATCH_EXPRFUNC_OUT_NULL_INDEX = 8; +const int BATCH_EXPRFUNC_OUT_DATA_INDEX = 9; +} + +BatchExpressionCodeGen::BatchExpressionCodeGen(std::string name, const Expr &cpExpr, op::OverflowConfig *overflowConfig) + : CodegenBase(name, cpExpr, overflowConfig) +{} + +bool BatchExpressionCodeGen::InitializeBatchCodegenContext(iterator_range args) +{ + this->batchCodegenContext = std::make_unique(); + for (auto &arg : args) { + auto argName = arg.getName().str(); + if (argName == "data") { + batchCodegenContext->data = &arg; + } else if (argName == "nullBitmap") { + batchCodegenContext->nullBitmap = &arg; + } else if (argName == "offsets") { + batchCodegenContext->offsets = &arg; + } else if (argName == "rowCnt") { + batchCodegenContext->rowCnt = &arg; + } else if (argName == "rowIdxArray") { + batchCodegenContext->rowIdxArray = &arg; + } else if (argName == "outputLength" || argName == "outputNull" || argName == "outputData") { + continue; + } else if (argName == "executionContext") { + batchCodegenContext->executionContext = &arg; + } else if (argName == "dictionaryVectors") { + batchCodegenContext->dictionaryVectors = &arg; + } else { + LogWarn("Invalid argument %s", argName.c_str()); + return false; + } + } + + return true; +} + +llvm::Function *BatchExpressionCodeGen::CreateBatchFunction() +{ + std::vector args { + llvmTypes->I64PtrType(), // data + llvmTypes->I64PtrType(), // bitmap + llvmTypes->I64PtrType(), // offsets + llvmTypes->I32Type(), // rowCnt + llvmTypes->I32PtrType(), // rowIdxArray + llvmTypes->I32PtrType(), // outputLength + llvmTypes->I64Type(), // executionCon + llvmTypes->I64PtrType(), // dictionaryVe + llvmTypes->I1PtrType(), // outputNull + llvmTypes->ToBatchDataPointerType(expr->GetReturnTypeId()) // outputData + }; + + FunctionType *prototype = FunctionType::get(llvmTypes->I32Type(), args, false); + func = llvm::Function::Create(prototype, llvm::Function::ExternalLinkage, funcName, modulePtr); + + std::string argNames[] = { + "data", "nullBitmap", "offsets", "rowCnt", "rowIdxArray", + "outputLength", "executionContext", "dictionaryVectors", "outputNull", "outputData" + }; + int32_t idx = 0; + for (auto &arg : func->args()) { + arg.setName(argNames[idx]); + idx++; + } + + BasicBlock *body = BasicBlock::Create(*context, "CREATED_BATCH_FUNC_BODY", func); + builder->SetInsertPoint(body); + + if (!InitializeBatchCodegenContext(func->args())) { + return nullptr; + } + + auto result = VisitExpr(*expr); + if (result->data == nullptr) { + return nullptr; + } + + // copy length + if (result->length != nullptr) { + CallExternFunction("batch_copy", { OMNI_INT }, OMNI_INT, + { func->getArg(BATCH_EXPRFUNC_OUT_LENGTH_ARG_INDEX), result->length, + func->getArg(BATCH_EXPRFUNC_ROWCNT_INDEX) }, + nullptr, "copy_length"); + } + // copy data + CallExternFunction("batch_copy", { expr->GetReturnTypeId() }, expr->GetReturnTypeId(), + { func->getArg(BATCH_EXPRFUNC_OUT_DATA_INDEX), result->data, func->getArg(BATCH_EXPRFUNC_ROWCNT_INDEX) }, + nullptr, "copy_data"); + + // copy null + CallExternFunction("batch_copy", { OMNI_BOOLEAN }, OMNI_BOOLEAN, + { func->getArg(BATCH_EXPRFUNC_OUT_NULL_INDEX), result->isNull, func->getArg(BATCH_EXPRFUNC_ROWCNT_INDEX) }, + nullptr, "copy_null"); + + // Return rowCnt + builder->CreateRet(func->getArg(BATCH_EXPRFUNC_ROWCNT_INDEX)); + verifyFunction(*func); + return func; +} + +CodeGenValuePtr BatchExpressionCodeGen::VisitExpr(const Expr &e) +{ + e.Accept(*this); + return this->value; +} + +void BatchExpressionCodeGen::Visit(const LiteralExpr &lExpr) +{ + this->value.reset(BatchLiteralExprConstantHelper(lExpr)); +} + +CodeGenValue *BatchExpressionCodeGen::BatchLiteralExprConstantHelper(const LiteralExpr &lExpr) +{ + bool isNullLiteral = lExpr.isNull; + Value *isNull = llvmTypes->CreateConstantBool(isNullLiteral); + AllocaInst *nullArrayPtr = + builder->CreateAlloca(llvmTypes->I1Type(), this->batchCodegenContext->rowCnt, "IS_NULL_PTR"); + AllocaInst *literalArrayPtr = GetResultArray(lExpr.GetReturnTypeId(), this->batchCodegenContext->rowCnt); + Value *literalValue = nullptr; + Value *length = nullptr; + Value *lengthArrayPtr = nullptr; + Value *precisionVal = nullptr; + Value *scaleVal = nullptr; + switch (lExpr.GetReturnTypeId()) { + case OMNI_INT: + case OMNI_DATE32: { + literalValue = llvmTypes->CreateConstantInt(lExpr.intVal); + break; + } + case OMNI_TIMESTAMP: + case OMNI_LONG: { + literalValue = llvmTypes->CreateConstantLong(lExpr.longVal); + break; + } + case OMNI_DOUBLE: { + literalValue = llvmTypes->CreateConstantDouble(lExpr.doubleVal); + break; + } + case OMNI_CHAR: + case OMNI_VARCHAR: { + literalValue = this->CreateConstantString(*(lExpr.stringVal)); + lengthArrayPtr = + builder->CreateAlloca(llvmTypes->I32Type(), this->batchCodegenContext->rowCnt, "LENGTH_PTR"); + length = llvmTypes->CreateConstantInt(lExpr.stringVal->length()); + break; + } + case OMNI_BOOLEAN: { + literalValue = llvmTypes->CreateConstantBool(lExpr.boolVal); + break; + } + case OMNI_DECIMAL64: { + precisionVal = llvmTypes->CreateConstantInt( + static_cast(lExpr.GetReturnType().get())->GetPrecision()); + scaleVal = + llvmTypes->CreateConstantInt(static_cast(lExpr.GetReturnType().get())->GetScale()); + literalValue = llvmTypes->CreateConstantLong(lExpr.longVal); + break; + } + case OMNI_DECIMAL128: { + std::string dec128String = isNullLiteral ? "0" : *lExpr.stringVal; + __uint128_t dec128 = Decimal128Utils::StrToUint128_t(dec128String.c_str()); + dec128String = Decimal128Utils::Uint128_tToStr(dec128); + precisionVal = llvmTypes->CreateConstantInt( + dynamic_cast(lExpr.GetReturnType().get())->GetPrecision()); + scaleVal = llvmTypes->CreateConstantInt( + dynamic_cast(lExpr.GetReturnType().get())->GetScale()); + literalValue = llvm::ConstantInt::get(llvm::Type::getInt128Ty(*context), dec128String, 10); + break; + } + default: { + LogWarn("Unsupported data type in LITERAL Expr %d", lExpr.GetReturnTypeId()); + return new CodeGenValue(nullptr, nullptr); + } + } + + std::vector funcArgs; + if (TypeUtil::IsStringType(lExpr.GetReturnTypeId())) { + funcArgs = { this->batchCodegenContext->executionContext, + literalArrayPtr, + nullArrayPtr, + lengthArrayPtr, + literalValue, + isNull, + length, + this->batchCodegenContext->rowCnt }; + CallExternFunction("batch_fill_literal", { OMNI_VARCHAR }, OMNI_VARCHAR, funcArgs, nullptr, + "fill_literal_array"); + return new CodeGenValue(literalArrayPtr, nullArrayPtr, lengthArrayPtr); + } else if (TypeUtil::IsDecimalType(lExpr.GetReturnTypeId())) { + funcArgs = { literalArrayPtr, nullArrayPtr, literalValue, isNull, this->batchCodegenContext->rowCnt }; + CallExternFunction("batch_fill_literal", { lExpr.GetReturnTypeId() }, lExpr.GetReturnTypeId(), funcArgs, + nullptr, "fill_literal_array"); + return new DecimalValue(literalArrayPtr, nullArrayPtr, precisionVal, scaleVal); + } else { + funcArgs = { literalArrayPtr, nullArrayPtr, literalValue, isNull, this->batchCodegenContext->rowCnt }; + CallExternFunction("batch_fill_literal", { lExpr.GetReturnTypeId() }, lExpr.GetReturnTypeId(), funcArgs, + nullptr, "fill_literal_array"); + return new CodeGenValue(literalArrayPtr, nullArrayPtr); + } +} + +void BatchExpressionCodeGen::Visit(const FieldExpr &fExpr) +{ + Value *rowCnt = this->batchCodegenContext->rowCnt; + Value *vecBatch = this->batchCodegenContext->data; + Value *bitmap = this->batchCodegenContext->nullBitmap; + Value *offsets = this->batchCodegenContext->offsets; + Value *dictionaryVectors = this->batchCodegenContext->dictionaryVectors; + Value *rowIdxArray = this->batchCodegenContext->rowIdxArray; + + Value *colIdx = llvmTypes->CreateConstantInt(fExpr.colVal); + Value *gep = builder->CreateGEP(llvmTypes->I64Type(), vecBatch, colIdx); + auto dictionaryVectorGEP = builder->CreateGEP(llvmTypes->I64Type(), dictionaryVectors, colIdx); + Value *dictionaryVectorPtr = builder->CreateLoad(llvmTypes->I64Type(), dictionaryVectorGEP); + auto condition = builder->CreateIsNotNull(dictionaryVectorPtr); + + BasicBlock *trueBlock = BasicBlock::Create(*context, "DICTIONARY_NOT_NULL", func); + BasicBlock *falseBlock = BasicBlock::Create(*context, "DICTIONARY_IS_NULL"); + BasicBlock *mergeBlock = BasicBlock::Create(*context, "field_data"); + builder->CreateCondBr(condition, trueBlock, falseBlock); + + builder->SetInsertPoint(trueBlock); + AllocaInst *dicLengthArray = builder->CreateAlloca(llvmTypes->I32Type(), rowCnt, "dic_varchar_length"); + auto dicArrayPtr = this->GetDictionaryVectorValue(*(fExpr.GetReturnType()), rowIdxArray, rowCnt, + dictionaryVectorPtr, dicLengthArray); + builder->CreateBr(mergeBlock); + trueBlock = builder->GetInsertBlock(); + + func->getBasicBlockList().push_back(falseBlock); + builder->SetInsertPoint(falseBlock); + Value *elementAddr = builder->CreateLoad(llvmTypes->I64Type(), gep); + AllocaInst *lengthArray = nullptr; + Value *dataArrayPtr = nullptr; + if (TypeUtil::IsStringType(fExpr.GetReturnTypeId())) { + lengthArray = builder->CreateAlloca(llvmTypes->I32Type(), rowCnt, "varchar_length"); + auto offsetsGEP = builder->CreateGEP(llvmTypes->I64Type(), offsets, colIdx); + Value *offsetPtr = builder->CreateLoad(llvmTypes->I64Type(), offsetsGEP); + offsetPtr = builder->CreateIntToPtr(offsetPtr, llvmTypes->I32PtrType()); + dataArrayPtr = + this->GetVectorValue(*(fExpr.GetReturnType()), rowIdxArray, rowCnt, elementAddr, offsetPtr, lengthArray); + } else { + dataArrayPtr = + this->GetVectorValue(*(fExpr.GetReturnType()), rowIdxArray, rowCnt, elementAddr, nullptr, nullptr); + } + builder->CreateBr(mergeBlock); + falseBlock = builder->GetInsertBlock(); + + func->getBasicBlockList().push_back(mergeBlock); + builder->SetInsertPoint(mergeBlock); + int32_t numReservedValues = 2; + Type *phiType = llvmTypes->ToBatchDataPointerType(fExpr.GetReturnTypeId()); + PHINode *phiValue = builder->CreatePHI(phiType, numReservedValues, "data"); + phiValue->addIncoming(dicArrayPtr, trueBlock); + phiValue->addIncoming(dataArrayPtr, falseBlock); + + PHINode *phiLength = nullptr; + if (TypeUtil::IsStringType(fExpr.GetReturnTypeId())) { + phiLength = builder->CreatePHI(llvmTypes->I32PtrType(), numReservedValues, "length"); + phiLength->addIncoming(dicLengthArray, trueBlock); + phiLength->addIncoming(lengthArray, falseBlock); + } + + // Get isNull value + auto bitmapGEP = builder->CreateGEP(llvmTypes->I64Type(), bitmap, colIdx); + Value *nullBitsPtr = builder->CreateLoad(llvmTypes->I64Type(), bitmapGEP); + nullBitsPtr = builder->CreateIntToPtr(nullBitsPtr, llvmTypes->I32PtrType()); + auto dstNullArray = GetResultArray(OMNI_BOOLEAN, rowCnt); + CallExternFunction("batch_BitsToNullArray", { OMNI_BOOLEAN }, OMNI_BOOLEAN, + { dstNullArray, nullBitsPtr, rowIdxArray, rowCnt }, nullptr, "copy_null"); + + if (TypeUtil::IsDecimalType(fExpr.GetReturnTypeId())) { + Value *precision = + llvmTypes->CreateConstantInt(static_cast(fExpr.GetReturnType().get())->GetPrecision()); + Value *scale = + llvmTypes->CreateConstantInt(static_cast(fExpr.GetReturnType().get())->GetScale()); + this->value = std::make_shared(phiValue, dstNullArray, precision, scale); + } else if (TypeUtil::IsStringType(fExpr.GetReturnTypeId())) { + this->value = std::make_shared(phiValue, dstNullArray, phiLength); + } else { + this->value = std::make_shared(phiValue, dstNullArray); + } +} + +Value *BatchExpressionCodeGen::GetDictionaryVectorValue(const DataType &dataType, Value *rowIdxArray, + llvm::Value *rowCnt, llvm::Value *dictionaryVectorPtr, AllocaInst *lengthArrayPtr) +{ + std::vector paramTypes = { OMNI_LONG }; + DataTypeId retTypeId = dataType.GetId(); + AllocaInst *dataArrayPtr = GetResultArray(retTypeId, rowCnt); + std::vector funcArgs; + if (TypeUtil::IsStringType(retTypeId)) { + funcArgs = { batchCodegenContext->executionContext, + dictionaryVectorPtr, + rowIdxArray, + rowCnt, + dataArrayPtr, + lengthArrayPtr }; + } else { + funcArgs = { dictionaryVectorPtr, rowIdxArray, rowCnt, dataArrayPtr }; + } + + CallExternFunction("batch_GetDic", { OMNI_LONG }, retTypeId, funcArgs, nullptr, "get_dictionary_value"); + return dataArrayPtr; +} + +Value *BatchExpressionCodeGen::GetVectorValue(const DataType &dataType, Value *rowIdxArray, llvm::Value *rowCnt, + llvm::Value *dataVectorPtr, Value *offsetArrayPtr, llvm::Value *lengthArrayPtr) +{ + std::vector paramTypes = { OMNI_LONG }; + DataTypeId retTypeId = dataType.GetId(); + AllocaInst *dataArrayPtr = GetResultArray(retTypeId, rowCnt); + std::vector funcArgs; + if (TypeUtil::IsStringType(retTypeId)) { + funcArgs = { batchCodegenContext->executionContext, + offsetArrayPtr, + dataVectorPtr, + rowIdxArray, + rowCnt, + dataArrayPtr, + lengthArrayPtr }; + } else { + funcArgs = { dataVectorPtr, rowIdxArray, rowCnt, dataArrayPtr }; + } + + CallExternFunction("batch_GetData", { OMNI_LONG }, retTypeId, funcArgs, nullptr, "get_vector_value"); + return dataArrayPtr; +} + +void BatchExpressionCodeGen::Visit(const UnaryExpr &uExpr) +{ + auto val = VisitExpr(*(uExpr.exp)); + if (!val->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + + switch (uExpr.op) { + case omniruntime::expressions::Operator::NOT: { + std::vector funcArgs { val->data, this->batchCodegenContext->rowCnt }; + CallExternFunction("batch_not", { uExpr.exp->GetReturnTypeId() }, uExpr.GetReturnTypeId(), funcArgs, + nullptr, "logical_not"); + this->value = std::make_shared(val->data, val->isNull); + break; + } + default: { + this->value = CreateInvalidCodeGenValue(); + break; + } + } +} + +void BatchExpressionCodeGen::Visit(const BinaryExpr &binaryExpr) +{ + auto *bExpr = const_cast(&binaryExpr); + + CodeGenValuePtr left = VisitExpr(*(bExpr->left)); + if (!left->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + Value *leftValue = left->data; + Value *leftLen = left->length; + Value *leftNull = left->isNull; + + CodeGenValuePtr right = VisitExpr(*(bExpr->right)); + if (!right->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + Value *rightValue = right->data; + Value *rightLen = right->length; + Value *rightNull = right->isNull; + + if (bExpr->op == omniruntime::expressions::Operator::AND) { + std::vector boolParams { OMNI_BOOLEAN, OMNI_BOOLEAN }; + std::vector andFuncParams { leftValue, leftNull, rightValue, rightNull, + this->batchCodegenContext->rowCnt }; + CallExternFunction("batch_and_expr", boolParams, OMNI_BOOLEAN, andFuncParams, nullptr, "and_expr"); + this->value = std::make_shared(leftValue, leftNull); + return; + } + + if (bExpr->op == omniruntime::expressions::Operator::OR) { + std::vector boolParams { OMNI_BOOLEAN, OMNI_BOOLEAN }; + std::vector orFuncParams { leftValue, leftNull, rightValue, rightNull, + this->batchCodegenContext->rowCnt }; + CallExternFunction("batch_or_expr", boolParams, OMNI_BOOLEAN, orFuncParams, nullptr, "or_expr"); + this->value = std::make_shared(leftValue, leftNull); + return; + } + + if (bExpr->left->GetReturnTypeId() == OMNI_INT || bExpr->left->GetReturnTypeId() == OMNI_DATE32 || + bExpr->left->GetReturnTypeId() == OMNI_LONG || bExpr->left->GetReturnTypeId() == OMNI_TIMESTAMP) { + this->BatchBinaryExprIntLongHelper(bExpr, leftValue, rightValue, leftNull, rightNull); + return; + } else if (bExpr->left->GetReturnTypeId() == OMNI_DOUBLE) { + this->BatchBinaryExprDoubleHelper(bExpr, leftValue, rightValue, leftNull, rightNull); + return; + } else if (TypeUtil::IsStringType(bExpr->left->GetReturnTypeId())) { + this->BatchBinaryExprStringHelper(bExpr, leftValue, leftLen, rightValue, rightLen, leftNull, rightNull); + return; + } else if (TypeUtil::IsDecimalType(bExpr->left->GetReturnTypeId())) { + this->BatchBinaryExprDecimalHelper(bExpr, static_cast(*left.get()), + static_cast(*right.get()), leftNull, rightNull); + return; + } + + LogWarn("Unsupported data type for BINARY expr %d", bExpr->left->GetReturnTypeId()); + this->value = CreateInvalidCodeGenValue(); +} + +void BatchExpressionCodeGen::Visit(const BetweenExpr &btExpr) +{ + auto bExpr = const_cast(&btExpr); + + auto val = VisitExpr(*(bExpr->value)); + if (!val->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + + auto lowerVal = VisitExpr(*(bExpr->lowerBound)); + if (!lowerVal->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + + auto upperVal = VisitExpr(*(bExpr->upperBound)); + if (!upperVal->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + + AllocaInst *cmpLeft = builder->CreateAlloca(llvmTypes->I1Type(), this->batchCodegenContext->rowCnt, "cmpLeft"); + AllocaInst *cmpRight = builder->CreateAlloca(llvmTypes->I1Type(), this->batchCodegenContext->rowCnt, "cmpRight"); + std::pair cmpPair = std::make_pair(&cmpLeft, &cmpRight); + + BatchVisitBetweenExprHelper(*bExpr, val, lowerVal, upperVal, cmpPair); +} + +void BatchExpressionCodeGen::Visit(const IsNullExpr &isNullExpr) +{ + Expr *valueExpr = isNullExpr.value; + auto isNullValue = VisitExpr(*valueExpr); + if (!isNullValue->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + + std::vector funcArgs { isNullValue->isNull, llvmTypes->CreateConstantBool(true), + this->batchCodegenContext->rowCnt }; + CallExternFunction("batch_equal", { OMNI_BOOLEAN, OMNI_BOOLEAN }, OMNI_BOOLEAN, funcArgs, nullptr, "is_null"); + + AllocaInst *nullArrayPtr = + builder->CreateAlloca(llvmTypes->I1Type(), this->batchCodegenContext->rowCnt, "IS_NULL_PTR"); + funcArgs = { nullArrayPtr, llvmTypes->CreateConstantBool(false), this->batchCodegenContext->rowCnt }; + CallExternFunction("batch_fill_null", { OMNI_BOOLEAN }, OMNI_BOOLEAN, funcArgs, nullptr, "batch_fill_null"); + + this->value = std::make_shared(isNullValue->isNull, nullArrayPtr); +} + +llvm::AllocaInst *BatchExpressionCodeGen::GetResultArray(omniruntime::type::DataTypeId dataTypeId, Value *rowCnt) +{ + AllocaInst *resultArray = nullptr; + switch (dataTypeId) { + case OMNI_INT: + case OMNI_DATE32: { + resultArray = builder->CreateAlloca(llvmTypes->I32Type(), rowCnt, "DATA_PTR"); + break; + } + case OMNI_TIMESTAMP: + case OMNI_DECIMAL64: + case OMNI_LONG: { + resultArray = builder->CreateAlloca(llvmTypes->I64Type(), rowCnt, "DATA_PTR"); + break; + } + case OMNI_DOUBLE: { + resultArray = builder->CreateAlloca(llvmTypes->DoubleType(), rowCnt, "DATA_PTR"); + break; + } + case OMNI_CHAR: + case OMNI_VARCHAR: { + resultArray = builder->CreateAlloca(llvmTypes->I8PtrType(), rowCnt, "DATA_PTR"); + break; + } + case OMNI_BOOLEAN: { + resultArray = builder->CreateAlloca(llvmTypes->I1Type(), rowCnt, "DATA_PTR"); + break; + } + case OMNI_DECIMAL128: { + resultArray = builder->CreateAlloca(llvmTypes->I128Type(), rowCnt, "DATA_PTR"); + break; + } + default: { + LogWarn("Unsupported type when creating array %d", dataTypeId); + break; + } + } + if (resultArray == nullptr) { + LogWarn("Failed to create result array"); + } + return resultArray; +} + +static std::string ChangeFuncNameToNull(const FuncExpr &fExpr) +{ + auto typeSize = static_cast(fExpr.arguments.size() + 1); + auto originalFuncName = fExpr.function->GetId(); + auto originalFuncChars = originalFuncName.c_str(); + int32_t separatorIdx = 0; + auto pos = static_cast(originalFuncName.length() - 1); + for (; pos >= 0; pos--) { + if (originalFuncChars[pos] == '_') { + separatorIdx++; + if (separatorIdx == typeSize) { + break; + } + } + } + return originalFuncName.insert(pos, "_null"); +} + +void BatchExpressionCodeGen::FuncExprOverflowNullHelper(const FuncExpr &fExpr) +{ + Value *falseValue = llvmTypes->CreateConstantBool(false); + AllocaInst *isAnyNull = + builder->CreateAlloca(llvmTypes->I1Type(), this->batchCodegenContext->rowCnt, "IS_NULL_PTR"); + std::vector funcArgs { isAnyNull, falseValue, this->batchCodegenContext->rowCnt }; + auto ret = + CallExternFunction("batch_fill_null", { OMNI_BOOLEAN }, OMNI_BOOLEAN, funcArgs, nullptr, "fill_null_array"); + AllocaInst *overflowNull = + builder->CreateAlloca(llvmTypes->I1Type(), this->batchCodegenContext->rowCnt, "OVERFLOW_NULL_PTR"); + funcArgs = { overflowNull, falseValue, this->batchCodegenContext->rowCnt }; + CallExternFunction("batch_fill_null", { OMNI_BOOLEAN }, OMNI_BOOLEAN, funcArgs, nullptr, + "fill_overflow_null_array"); + + DataTypeId funcRetType = fExpr.GetReturnTypeId(); + bool isInvalidExpr = false; + + auto argVals = GetDataAndOverflowNullArgs(fExpr, isAnyNull, isInvalidExpr, overflowNull); + if (isInvalidExpr) { + this->value = CreateInvalidCodeGenValue(); + return; + } + Value *isNullArray = PushAndGetNullFlagArray(fExpr, argVals, isAnyNull, false); + AllocaInst *resultArray = GetResultArray(funcRetType, this->batchCodegenContext->rowCnt); + argVals.push_back(resultArray); + AllocaInst *outputLenPtr = nullptr; + + if (TypeUtil::IsStringType(funcRetType)) { + outputLenPtr = builder->CreateAlloca(llvmTypes->I32Type(), this->batchCodegenContext->rowCnt, "output_len"); + auto defaultLength = + llvmTypes->CreateConstantInt(static_cast(fExpr.GetReturnType().get())->GetWidth()); + funcArgs = { outputLenPtr, defaultLength, this->batchCodegenContext->rowCnt }; + CallExternFunction("batch_fill_length_literal", { OMNI_INT }, OMNI_INT, funcArgs, nullptr, + "fill_literal_array"); + argVals.push_back(outputLenPtr); + if (FuncExpr::IsCastStrStr(fExpr)) { + argVals.push_back( + llvmTypes->CreateConstantInt(static_cast(fExpr.GetReturnType().get())->GetWidth())); + } + } else if (TypeUtil::IsDecimalType(funcRetType)) { + argVals.push_back( + llvmTypes->CreateConstantInt(static_cast(fExpr.GetReturnType().get())->GetPrecision())); + argVals.push_back( + llvmTypes->CreateConstantInt(static_cast(fExpr.GetReturnType().get())->GetScale())); + } + argVals.push_back(this->batchCodegenContext->rowCnt); + + auto f = modulePtr->getFunction("batch_" + ChangeFuncNameToNull(fExpr)); + if (f) { + ret = CreateCall(f, argVals, fExpr.function->GetId()); + InlineFunctionInfo inlineFunctionInfo; + llvm::InlineFunction(*((CallInst *)ret), inlineFunctionInfo); + } else { + LogWarn("Unable to generate function : %s", fExpr.funcName.c_str()); + this->value = CreateInvalidCodeGenValue(); + return; + } + + CallExternFunction("batch_or", { OMNI_BOOLEAN, OMNI_BOOLEAN }, OMNI_BOOLEAN, + { isNullArray, overflowNull, this->batchCodegenContext->rowCnt }, nullptr); + + if (TypeUtil::IsDecimalType(funcRetType)) { + this->value = std::make_shared(resultArray, isNullArray, + llvmTypes->CreateConstantInt(static_cast(fExpr.GetReturnType().get())->GetPrecision()), + llvmTypes->CreateConstantInt(static_cast(fExpr.GetReturnType().get())->GetScale())); + } else { + this->value = std::make_shared(resultArray, isNullArray, outputLenPtr); + } +} + +std::vector BatchExpressionCodeGen::GetDataAndOverflowNullArgs( + const omniruntime::expressions::FuncExpr &fExpr, AllocaInst *isAnyNull, bool &isInvalidExpr, + AllocaInst *overflowNull) +{ + std::vector argVals; + argVals.push_back(overflowNull); + CodeGenValuePtr resultPtr; + int numArgs = fExpr.arguments.size(); + std::vector nullFuncParams; + + auto signature = fExpr.function->GetSignatures()[0]; + if (FunctionRegistry::IsNullExecutionContextSet(&signature)) { + argVals.push_back(this->batchCodegenContext->executionContext); + } + for (int i = 0; i < numArgs; i++) { + Expr *argN = fExpr.arguments[i]; + resultPtr = VisitExpr(*argN); + if (!resultPtr->IsValidValue()) { + isInvalidExpr = true; + return argVals; + } + argVals.push_back(resultPtr->data); + + nullFuncParams = { isAnyNull, resultPtr->isNull, this->batchCodegenContext->rowCnt }; + CallExternFunction("batch_or", { OMNI_BOOLEAN, OMNI_BOOLEAN }, OMNI_BOOLEAN, nullFuncParams, nullptr, + "either_null"); + + if ((TypeUtil::IsStringType(argN->GetReturnTypeId()))) { + if (argN->GetReturnTypeId() == OMNI_CHAR) { + argVals.push_back( + llvmTypes->CreateConstantInt(static_cast(argN->GetReturnType().get())->GetWidth())); + } + if (FuncExpr::IsCastStrStr(fExpr)) { + argVals.push_back(llvmTypes->CreateConstantInt( + static_cast(argN->GetReturnType().get())->GetWidth())); + } + argVals.push_back(this->value->length); + } + if (TypeUtil::IsDecimalType(argN->GetReturnTypeId())) { + argVals.push_back(llvmTypes->CreateConstantInt( + static_cast(argN->GetReturnType().get())->GetPrecision())); + argVals.push_back( + llvmTypes->CreateConstantInt(static_cast(argN->GetReturnType().get())->GetScale())); + } + if (fExpr.function->GetNullableResultType() == INPUT_DATA_AND_NULL_AND_RETURN_NULL) { + argVals.push_back(this->value->isNull); + } + } + return argVals; +} + +template +std::vector BatchExpressionCodeGen::GetDefaultFunctionArgValues( + const FuncExpr &fExpr, + AllocaInst *isAnyNull, + bool &isInvalidExpr) +{ + std::vector argVals; + CodeGenValuePtr resultPtr; + int numArgs = fExpr.arguments.size(); + std::vector nullFuncParams; + + if (fExpr.function->IsExecutionContextSet()) { + argVals.push_back(this->batchCodegenContext->executionContext); + } + for (int i = 0; i < numArgs; i++) { + Expr *argN = fExpr.arguments[i]; + resultPtr = VisitExpr(*argN); + if (!resultPtr->IsValidValue()) { + isInvalidExpr = true; + return argVals; + } + argVals.push_back(resultPtr->data); + if constexpr (isNeedVerifyResult) { + nullFuncParams = { isAnyNull, resultPtr->isNull, this->batchCodegenContext->rowCnt }; + CallExternFunction("batch_or", { OMNI_BOOLEAN, OMNI_BOOLEAN }, OMNI_BOOLEAN, nullFuncParams, nullptr, + "either_null"); + } + if ((TypeUtil::IsStringType(argN->GetReturnTypeId()))) { + if (argN->GetReturnTypeId() == OMNI_CHAR) { + argVals.push_back( + llvmTypes->CreateConstantInt( + static_cast(argN->GetReturnType().get())->GetWidth())); + } + if (FuncExpr::IsCastStrStr(fExpr)) { + argVals.push_back(llvmTypes->CreateConstantInt( + static_cast(argN->GetReturnType().get())->GetWidth())); + } + argVals.push_back(this->value->length); + } + if (TypeUtil::IsDecimalType(argN->GetReturnTypeId())) { + argVals.push_back(llvmTypes->CreateConstantInt( + static_cast(argN->GetReturnType().get())->GetPrecision())); + argVals.push_back( + llvmTypes->CreateConstantInt( + static_cast(argN->GetReturnType().get())->GetScale())); + } + if constexpr (isNeedVerifyVal) { + argVals.push_back(this->value->isNull); + } + } + return argVals; +} + +inline std::vector BatchExpressionCodeGen::GetDataArgs(const FuncExpr &fExpr, AllocaInst *isAnyNull, + bool &isInvalidExpr) +{ + return GetDefaultFunctionArgValues(fExpr, isAnyNull, isInvalidExpr); +} + +inline std::vector BatchExpressionCodeGen::GetDataAndNullArgs(const FuncExpr &fExpr, + AllocaInst *isAnyNull, bool &isInvalidExpr) +{ + return GetDefaultFunctionArgValues(fExpr, isAnyNull, isInvalidExpr); +} + +inline std::vector BatchExpressionCodeGen::GetDataAndNullArgsAndReturnNull(const FuncExpr &fExpr, + AllocaInst *isAnyNull, bool &isInvalidExpr) +{ + return GetDefaultFunctionArgValues(fExpr, isAnyNull, isInvalidExpr); +} + +std::vector BatchExpressionCodeGen::GetFunctionArgValues(const omniruntime::expressions::FuncExpr &fExpr, + AllocaInst *isAnyNull, bool &isInvalidExpr) +{ + switch (fExpr.function->GetNullableResultType()) { + case INPUT_DATA: + return GetDataArgs(fExpr, isAnyNull, isInvalidExpr); + case INPUT_DATA_AND_NULL: + return GetDataAndNullArgs(fExpr, isAnyNull, isInvalidExpr); + case INPUT_DATA_AND_NULL_AND_RETURN_NULL: + return GetDataAndNullArgsAndReturnNull(fExpr, isAnyNull, isInvalidExpr); + default: + return GetDataArgs(fExpr, isAnyNull, isInvalidExpr); + } +} + +Value *BatchExpressionCodeGen::ArenaAlloc(Value *sizeInBytes) +{ + return CallExternFunction("ArenaAllocatorMalloc", { OMNI_LONG, OMNI_INT }, OMNI_CHAR, + { batchCodegenContext->executionContext, sizeInBytes }, nullptr); +} + +Value *BatchExpressionCodeGen::GetTypeSize(DataTypeId dataTypeId) +{ + int32_t typeSize = 0; + switch (dataTypeId) { + case OMNI_INT: + case OMNI_DATE32: + typeSize = sizeof(int32_t); + break; + case OMNI_TIMESTAMP: + case OMNI_LONG: + case OMNI_DECIMAL64: + typeSize = sizeof(int64_t); + break; + case OMNI_DOUBLE: + typeSize = sizeof(double); + break; + case OMNI_BOOLEAN: + typeSize = sizeof(bool); + break; + case OMNI_SHORT: + typeSize = sizeof(int16_t); + break; + case OMNI_DECIMAL128: + typeSize = 2 * sizeof(int64_t); + break; + case OMNI_CHAR: + case OMNI_VARCHAR: + typeSize = sizeof(int64_t); // for pointer + break; + default: + LogWarn("Unsupported data type in UDF funcExpr %d", dataTypeId); + return nullptr; + } + return llvmTypes->CreateConstantInt(typeSize); +} + +std::vector BatchExpressionCodeGen::GetHiveUdfArgValues(const FuncExpr &fExpr, bool &isInvalidExpr) +{ + auto argSize = static_cast(fExpr.arguments.size()); + auto size = llvmTypes->CreateConstantInt(argSize); + auto valueAddrArray = builder->CreateAlloca(llvmTypes->I64Type(), size); + auto nullAddrArray = builder->CreateAlloca(llvmTypes->I64Type(), size); + auto lengthAddrArray = builder->CreateAlloca(llvmTypes->I64Type(), size); + + std::vector argVals; + for (int32_t i = 0; i < argSize; i++) { + auto argExpr = fExpr.arguments[i]; + auto argExprResult = VisitExpr(*argExpr); + if (!argExprResult->IsValidValue()) { + isInvalidExpr = true; + return argVals; + } + + auto valuePtr = builder->CreateGEP(llvmTypes->I64Type(), valueAddrArray, llvmTypes->CreateConstantInt(i)); + auto nullPtr = builder->CreateGEP(llvmTypes->I64Type(), nullAddrArray, llvmTypes->CreateConstantInt(i)); + auto lengthPtr = builder->CreateGEP(llvmTypes->I64Type(), lengthAddrArray, llvmTypes->CreateConstantInt(i)); + builder->CreateStore(argExprResult->data, valuePtr); + builder->CreateStore(argExprResult->isNull, nullPtr); + auto length = TypeUtil::IsStringType(argExpr->GetReturnTypeId()) ? argExprResult->length : + llvmTypes->CreateConstantLong(0); + builder->CreateStore(length, lengthPtr); + } + + argVals.push_back(valueAddrArray); + argVals.push_back(nullAddrArray); + argVals.push_back(lengthAddrArray); + return argVals; +} + +Value *BatchExpressionCodeGen::CreateHiveUdfArgTypes(const FuncExpr &fExpr) +{ + auto elementSize = static_cast(fExpr.arguments.size()); + auto alloca = builder->CreateAlloca(llvmTypes->I32Type(), llvmTypes->CreateConstantInt(elementSize)); + for (int32_t i = 0; i < elementSize; i++) { + auto ptr = builder->CreateGEP(llvmTypes->I32Type(), alloca, llvmTypes->CreateConstantInt(i)); + builder->CreateStore(llvmTypes->CreateConstantInt(fExpr.arguments[i]->GetReturnTypeId()), ptr); + } + return alloca; +} + +void BatchExpressionCodeGen::CallHiveUdfFunction(const FuncExpr &fExpr) +{ + auto returnTypeId = fExpr.GetReturnTypeId(); + std::vector argVals; + argVals.emplace_back(batchCodegenContext->executionContext); + argVals.emplace_back(CreateConstantString(fExpr.funcName)); // for udf class name + argVals.emplace_back(CreateHiveUdfArgTypes(fExpr)); // for inputTypes + argVals.emplace_back(llvmTypes->CreateConstantInt(returnTypeId)); // for return type + argVals.emplace_back(llvmTypes->CreateConstantInt(fExpr.arguments.size())); // for vec count + argVals.emplace_back(batchCodegenContext->rowCnt); // for row count + + bool isInvalidExpr = false; + auto inputArgs = GetHiveUdfArgValues(fExpr, isInvalidExpr); + if (isInvalidExpr) { + this->value = CreateInvalidCodeGenValue(); + return; + } + argVals.insert(argVals.end(), inputArgs.begin(), + inputArgs.end()); // for inputValues, inputNulls, inputLengths + + // for output value, output null, output length + auto returnTypeSize = GetTypeSize(returnTypeId); + if (returnTypeSize == nullptr) { + this->value = CreateInvalidCodeGenValue(); + return; + } + auto arraySize = batchCodegenContext->rowCnt; + auto outputValuePtr = ArenaAlloc(builder->CreateMul(returnTypeSize, arraySize)); + auto outputNullPtr = ArenaAlloc(arraySize); + auto outputLengthPtr = TypeUtil::IsStringType(returnTypeId) ? + ArenaAlloc(builder->CreateMul(GetTypeSize(OMNI_INT), arraySize)) : + llvmTypes->CreateConstantLong(0); + argVals.emplace_back(outputValuePtr); + argVals.emplace_back(outputNullPtr); + argVals.emplace_back(outputLengthPtr); + + auto signature = FunctionSignature("EvaluateHiveUdfBatch", std::vector {}, OMNI_INT); + auto function = FunctionRegistry::LookupFunction(&signature); + auto f = modulePtr->getFunction(function->GetId()); + if (f) { + auto ret = CreateCall(f, argVals, "call_evaluate_hive_udf"); + InlineFunctionInfo inlineFunctionInfo; + llvm::InlineFunction(*((CallInst *)ret), inlineFunctionInfo); + this->value = std::make_shared(outputValuePtr, outputNullPtr, + TypeUtil::IsStringType(returnTypeId) ? outputLengthPtr : nullptr); + } else { + LogWarn("Unable to generate udf function : %s", fExpr.funcName.c_str()); + this->value = CreateInvalidCodeGenValue(); + } +} + +void BatchExpressionCodeGen::Visit(const FuncExpr &fExpr) +{ + if (fExpr.functionType == HIVE_UDF) { + CallHiveUdfFunction(fExpr); + return; + } + + if (this->overflowConfig != nullptr && + this->overflowConfig->GetOverflowConfigId() == omniruntime::op::OVERFLOW_CONFIG_NULL) { + auto signature = fExpr.function->GetSignatures()[0]; + if (FunctionRegistry::LookupNullFunction(&signature)) { + FuncExprOverflowNullHelper(fExpr); + return; + } + } + + Value *falseValue = llvmTypes->CreateConstantBool(false); + AllocaInst *isAnyNull = + builder->CreateAlloca(llvmTypes->I1Type(), this->batchCodegenContext->rowCnt, "IS_NULL_PTR"); + std::vector funcArgs { isAnyNull, falseValue, this->batchCodegenContext->rowCnt }; + CallExternFunction("batch_fill_null", { OMNI_BOOLEAN }, OMNI_BOOLEAN, funcArgs, nullptr, "fill_null_array"); + + DataTypeId funcRetType = fExpr.GetReturnTypeId(); + bool isInvalidExpr = false; + + auto argVals = GetFunctionArgValues(fExpr, isAnyNull, isInvalidExpr); + if (isInvalidExpr) { + this->value = CreateInvalidCodeGenValue(); + return; + } + + AllocaInst *resultArray = GetResultArray(funcRetType, this->batchCodegenContext->rowCnt); + Value *isNullArray = PushAndGetNullFlagArray(fExpr, argVals, isAnyNull, true); + argVals.push_back(resultArray); + AllocaInst *outputLenPtr = nullptr; + + if (TypeUtil::IsStringType(funcRetType)) { + outputLenPtr = builder->CreateAlloca(llvmTypes->I32Type(), this->batchCodegenContext->rowCnt, "output_len"); + auto defaultLength = + llvmTypes->CreateConstantInt(static_cast(fExpr.GetReturnType().get())->GetWidth()); + funcArgs = { outputLenPtr, defaultLength, this->batchCodegenContext->rowCnt }; + CallExternFunction("batch_fill_length_literal", { OMNI_INT }, OMNI_INT, funcArgs, nullptr, + "fill_literal_array"); + argVals.push_back(outputLenPtr); + if (FuncExpr::IsCastStrStr(fExpr)) { + argVals.push_back( + llvmTypes->CreateConstantInt(static_cast(fExpr.GetReturnType().get())->GetWidth())); + } + } else if (TypeUtil::IsDecimalType(funcRetType)) { + argVals.push_back( + llvmTypes->CreateConstantInt(static_cast(fExpr.GetReturnType().get())->GetPrecision())); + argVals.push_back( + llvmTypes->CreateConstantInt(static_cast(fExpr.GetReturnType().get())->GetScale())); + } + argVals.push_back(this->batchCodegenContext->rowCnt); + + auto f = modulePtr->getFunction("batch_" + fExpr.function->GetId()); + if (f) { + auto ret = CreateCall(f, argVals, fExpr.function->GetId()); + InlineFunctionInfo inlineFunctionInfo; + llvm::InlineFunction(*((CallInst *)ret), inlineFunctionInfo); + } else { + LogWarn("Unable to generate function : %s", fExpr.funcName.c_str()); + this->value = CreateInvalidCodeGenValue(); + return; + } + + if (TypeUtil::IsDecimalType(funcRetType)) { + this->value = std::make_shared(resultArray, isNullArray, + llvmTypes->CreateConstantInt(static_cast(fExpr.GetReturnType().get())->GetPrecision()), + llvmTypes->CreateConstantInt(static_cast(fExpr.GetReturnType().get())->GetScale())); + } else { + this->value = std::make_shared(resultArray, isNullArray, outputLenPtr); + } +} + +void BatchExpressionCodeGen::Visit(const IfExpr &ifExpr) +{ + Expr *cond = ifExpr.condition; + Expr *ifTrue = ifExpr.trueExpr; + Expr *ifFalse = ifExpr.falseExpr; + + auto baseType = ifExpr.GetReturnTypeId(); + + CodeGenValuePtr evCond = VisitExpr(*cond); + if (!evCond->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + + auto evTrue = VisitExpr(*ifTrue); + if (!evTrue->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + Value *evTrueValue = evTrue->data; + Value *evTrueLength = evTrue->length; + Value *evTrueNull = evTrue->isNull; + + auto evFalse = VisitExpr(*ifFalse); + if (!evFalse->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + Value *evFalseValue = evFalse->data; + Value *evFalseLength = evFalse->length; + Value *evFalseNull = evFalse->isNull; + + switch (baseType) { + case OMNI_INT: + case OMNI_DATE32: + case OMNI_LONG: + case OMNI_TIMESTAMP: + case OMNI_DOUBLE: + case OMNI_BOOLEAN: { + CallExternFunction("batch_if", { baseType }, baseType, + { evCond->data, evCond->isNull, evTrueValue, evTrueNull, evFalseValue, evFalseNull, + this->batchCodegenContext->rowCnt }, + nullptr); + this->value = std::make_shared(evTrueValue, evTrueNull); + return; + } + case OMNI_DECIMAL64: + case OMNI_DECIMAL128: { + auto &left = static_cast(*evTrue); + auto &right = static_cast(*evFalse); + + std::vector argValsCmp { evCond->data, + evCond->isNull, + left.data, + left.isNull, + right.data, + right.isNull, + this->batchCodegenContext->rowCnt }; + CallExternFunction("batch_if", { baseType }, baseType, argValsCmp, nullptr); + this->value = std::make_shared(left.data, left.isNull, + const_cast(left.GetPrecision()), const_cast(left.GetScale())); + return; + } + case OMNI_CHAR: + case OMNI_VARCHAR: { + CallExternFunction("batch_if", { baseType }, baseType, + { evCond->data, evCond->isNull, evTrueValue, evTrueNull, evTrueLength, evFalseValue, evFalseNull, + evFalseLength, this->batchCodegenContext->rowCnt }, + nullptr); + this->value = std::make_shared(evTrueValue, evTrueNull, evTrueLength); + return; + } + default: { + LogWarn("Unsupported data type in IF expr %d", baseType); + this->value = CreateInvalidCodeGenValue(); + return; + } + } +} + +void BatchExpressionCodeGen::Visit(const CoalesceExpr &cExpr) +{ + Expr *value1Expr = cExpr.value1; + Expr *value2Expr = cExpr.value2; + CodeGenValuePtr value1 = VisitExpr(*value1Expr); + if (!value1->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + auto value2 = VisitExpr(*value2Expr); + if (!value2->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + + if (cExpr.GetReturnTypeId() == OMNI_BOOLEAN || cExpr.GetReturnTypeId() == OMNI_INT || + cExpr.GetReturnTypeId() == OMNI_LONG || cExpr.GetReturnTypeId() == OMNI_DOUBLE || + cExpr.GetReturnTypeId() == OMNI_DATE32 || cExpr.GetReturnTypeId() == OMNI_TIMESTAMP) { + CallExternFunction("batch_coalesce", { cExpr.GetReturnTypeId(), cExpr.GetReturnTypeId() }, + cExpr.GetReturnTypeId(), + { value1->data, value1->isNull, value2->data, value2->isNull, this->batchCodegenContext->rowCnt }, nullptr); + this->value = std::make_shared(value1->data, value1->isNull); + } else if (TypeUtil::IsStringType(cExpr.GetReturnTypeId())) { + CallExternFunction("batch_coalesce", { OMNI_VARCHAR, OMNI_VARCHAR }, OMNI_VARCHAR, + { value1->data, value1->isNull, value1->length, value2->data, value2->isNull, value2->length, + this->batchCodegenContext->rowCnt }, + nullptr); + this->value = std::make_shared(value1->data, value1->isNull, value1->length); + } else if (TypeUtil::IsDecimalType(cExpr.GetReturnTypeId())) { + auto value1Precision = (Value *)static_cast(*value1.get()).GetPrecision(); + auto value1Scale = (Value *)static_cast(*value1.get()).GetScale(); + + CallExternFunction("batch_coalesce", { cExpr.GetReturnTypeId(), cExpr.GetReturnTypeId() }, + cExpr.GetReturnTypeId(), + { value1->data, value1->isNull, value2->data, value2->isNull, this->batchCodegenContext->rowCnt }, nullptr); + this->value = std::make_shared(value1->data, value1->isNull, value1Precision, value1Scale); + } else { + LogWarn("Unsupported data type in COALESCE expr %d", cExpr.GetReturnTypeId()); + this->value = CreateInvalidCodeGenValue(); + return; + } +} + +void BatchExpressionCodeGen::Visit(const InExpr &inExpr) +{ + auto iExpr = const_cast(&inExpr); + Expr *toCompare = iExpr->arguments[0]; + auto baseType = iExpr->arguments[0]->GetReturnTypeId(); + int32_t size = iExpr->arguments.size(); + + auto valueToCompare = VisitExpr(*toCompare); + if (!valueToCompare->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + + Value *inArray = GetResultArray(OMNI_BOOLEAN, this->batchCodegenContext->rowCnt); + Value *isNull = GetResultArray(OMNI_BOOLEAN, this->batchCodegenContext->rowCnt); + + std::vector cmps(size - 1); + for (int i = 1; i < size; ++i) { + Expr *cmp = iExpr->arguments[i]; + cmps[i - 1] = VisitExpr(*cmp); + if (!cmps[i - 1]->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + } + + auto cmpValues = GetResultArray(OMNI_LONG, llvmTypes->CreateConstantInt(size - 1)); + auto cmpBools = GetResultArray(OMNI_LONG, llvmTypes->CreateConstantInt(size - 1)); + for (int i = 0; i < size - 1; ++i) { + Value *gep = + builder->CreateGEP(llvmTypes->I64Type(), cmpValues, llvmTypes->CreateConstantInt(i), "cmp_value_address"); + builder->CreateStore(cmps[i]->data, gep); + + gep = builder->CreateGEP(llvmTypes->I64Type(), cmpBools, llvmTypes->CreateConstantInt(i), "cmp_null_address"); + builder->CreateStore(cmps[i]->isNull, gep); + } + + std::vector args; + switch (baseType) { + case OMNI_BOOLEAN: + case OMNI_INT: + case OMNI_DATE32: + case OMNI_LONG: + case OMNI_TIMESTAMP: + case OMNI_DOUBLE: + case OMNI_DECIMAL64: + case OMNI_DECIMAL128: { + args = { llvmTypes->CreateConstantInt(size - 1), + cmpValues, + cmpBools, + valueToCompare->data, + valueToCompare->isNull, + inArray, + isNull, + this->batchCodegenContext->rowCnt }; + CallExternFunction("batch_in", { baseType }, OMNI_BOOLEAN, args, nullptr); + break; + } + case OMNI_CHAR: + case OMNI_VARCHAR: { + auto cmpLengths = GetResultArray(OMNI_LONG, llvmTypes->CreateConstantInt(size - 1)); + for (int i = 0; i < size - 1; ++i) { + Value *gep = builder->CreateGEP(llvmTypes->I64Type(), cmpLengths, llvmTypes->CreateConstantInt(i), + "cmp_length_address"); + builder->CreateStore(cmps[i]->length, gep); + } + args = { llvmTypes->CreateConstantInt(size - 1), + cmpValues, + cmpBools, + cmpLengths, + valueToCompare->data, + valueToCompare->isNull, + valueToCompare->length, + inArray, + isNull, + this->batchCodegenContext->rowCnt }; + CallExternFunction("batch_in", { baseType }, OMNI_BOOLEAN, args, nullptr); + break; + } + default: { + LogWarn("Unsupported data type in IN expr %d", baseType); + this->value = CreateInvalidCodeGenValue(); + return; + } + } + + this->value = std::make_shared(inArray, isNull); +} + +void BatchExpressionCodeGen::Visit(const SwitchExpr &switchExpr) +{ + auto switchDataType = switchExpr.GetReturnTypeId(); + Expr *elseExpr = switchExpr.falseExpr; + std::vector> whenClause = switchExpr.whenClause; + const int size = whenClause.size(); + + AllocaInst *finalValue = GetResultArray(switchDataType, this->batchCodegenContext->rowCnt); + AllocaInst *finalNull = GetResultArray(OMNI_BOOLEAN, this->batchCodegenContext->rowCnt); + + std::vector conditions(size); + std::vector results(size); + for (int i = 0; i < size; ++i) { + Expr *cond = whenClause[i].first; + Expr *resExpr = whenClause[i].second; + conditions[i] = VisitExpr(*cond); + if (!conditions[i]->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + results[i] = VisitExpr(*resExpr); + if (!results[i]->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + } + auto whenClauses = GetResultArray(OMNI_LONG, llvmTypes->CreateConstantInt(size)); + auto whenBools = GetResultArray(OMNI_LONG, llvmTypes->CreateConstantInt(size)); + auto resultValues = GetResultArray(OMNI_LONG, llvmTypes->CreateConstantInt(size)); + auto resultNulls = GetResultArray(OMNI_LONG, llvmTypes->CreateConstantInt(size)); + + for (int i = 0; i < size; ++i) { + Value *gep = builder->CreateGEP(llvmTypes->I64Type(), whenClauses, llvmTypes->CreateConstantInt(i), + "when_value_address"); + builder->CreateStore(conditions[i]->data, gep); + + gep = builder->CreateGEP(llvmTypes->I64Type(), whenBools, llvmTypes->CreateConstantInt(i), "when_null_address"); + builder->CreateStore(conditions[i]->isNull, gep); + + gep = builder->CreateGEP(llvmTypes->I64Type(), resultValues, llvmTypes->CreateConstantInt(i), + "result_value_address"); + builder->CreateStore(results[i]->data, gep); + + gep = builder->CreateGEP(llvmTypes->I64Type(), resultNulls, llvmTypes->CreateConstantInt(i), + "result_null_address"); + builder->CreateStore(results[i]->isNull, gep); + } + + auto evFalse = VisitExpr(*elseExpr); + if (!evFalse->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + + std::vector args; + switch (switchDataType) { + case OMNI_INT: + case OMNI_BOOLEAN: + case OMNI_DATE32: + case OMNI_LONG: + case OMNI_TIMESTAMP: + case OMNI_DOUBLE: { + args = { llvmTypes->CreateConstantInt(size), + whenClauses, + whenBools, + resultValues, + resultNulls, + evFalse->data, + evFalse->isNull, + finalValue, + finalNull, + this->batchCodegenContext->rowCnt }; + CallExternFunction("batch_switch", { switchDataType }, switchDataType, args, nullptr); + this->value = std::make_shared(finalValue, finalNull); + return; + } + case OMNI_DECIMAL64: + case OMNI_DECIMAL128: { + auto returnDecimalValue = BuildDecimalValue(nullptr, *switchExpr.GetReturnType(), nullptr); + args = { llvmTypes->CreateConstantInt(size), + whenClauses, + whenBools, + resultValues, + resultNulls, + evFalse->data, + evFalse->isNull, + finalValue, + finalNull, + this->batchCodegenContext->rowCnt }; + CallExternFunction("batch_switch", { switchDataType }, switchDataType, args, nullptr); + this->value = std::make_shared(finalValue, finalNull, + const_cast(returnDecimalValue->GetPrecision()), + const_cast(returnDecimalValue->GetScale())); + return; + } + case OMNI_CHAR: + case OMNI_VARCHAR: { + auto resultLengths = GetResultArray(OMNI_LONG, llvmTypes->CreateConstantInt(size)); + for (int i = 0; i < size; ++i) { + Value *gep = builder->CreateGEP(llvmTypes->I64Type(), resultLengths, llvmTypes->CreateConstantInt(i), + "result_length_address"); + builder->CreateStore(results[i]->length, gep); + } + AllocaInst *finalLength = GetResultArray(OMNI_INT, this->batchCodegenContext->rowCnt); + args = { llvmTypes->CreateConstantInt(size), + whenClauses, + whenBools, + resultValues, + resultNulls, + resultLengths, + evFalse->data, + evFalse->isNull, + evFalse->length, + finalValue, + finalNull, + finalLength, + this->batchCodegenContext->rowCnt }; + CallExternFunction("batch_switch", { switchDataType }, switchDataType, args, nullptr); + this->value = std::make_shared(finalValue, finalNull, finalLength); + return; + } + default: { + LogWarn("Unsupported data type in SWITCH expr %d", switchDataType); + this->value = CreateInvalidCodeGenValue(); + return; + } + } +} + +void BatchExpressionCodeGen::BatchBinaryExprIntLongHelper(const omniruntime::expressions::BinaryExpr *binaryExpr, + llvm::Value *left, llvm::Value *right, llvm::Value *leftIsNull, llvm::Value *rightIsNull) +{ + DataTypeId returnTypeId = binaryExpr->GetReturnTypeId(); + std::vector typeParams; + if (binaryExpr->left->GetReturnTypeId() == OMNI_INT || binaryExpr->left->GetReturnTypeId() == OMNI_DATE32) { + typeParams = { OMNI_INT, OMNI_INT }; + } else { + typeParams = { OMNI_LONG, OMNI_LONG }; + } + std::vector boolParams { OMNI_BOOLEAN, OMNI_BOOLEAN }; + AllocaInst *logicalArrayPtr = nullptr; + std::vector logicalFuncParams; + if (returnTypeId == OMNI_BOOLEAN) { + logicalArrayPtr = builder->CreateAlloca(llvmTypes->I1Type(), this->batchCodegenContext->rowCnt, "LOGICAL_PTR"); + logicalFuncParams = { left, right, logicalArrayPtr, this->batchCodegenContext->rowCnt }; + } + std::vector arithFuncParams { left, right, this->batchCodegenContext->rowCnt }; + + std::vector nullFuncParams { leftIsNull, rightIsNull, this->batchCodegenContext->rowCnt }; + CallExternFunction("batch_or", boolParams, OMNI_BOOLEAN, nullFuncParams, nullptr, "either_null"); + + switch (binaryExpr->op) { + case omniruntime::expressions::Operator::LT: + CallExternFunction("batch_lessThan", typeParams, OMNI_BOOLEAN, logicalFuncParams, nullptr, "relational_lt"); + this->value = std::make_shared(logicalArrayPtr, leftIsNull); + return; + case omniruntime::expressions::Operator::GT: + CallExternFunction("batch_greaterThan", typeParams, OMNI_BOOLEAN, logicalFuncParams, nullptr, + "relational_gt"); + this->value = std::make_shared(logicalArrayPtr, leftIsNull); + return; + case omniruntime::expressions::Operator::LTE: + CallExternFunction("batch_lessThanEqual", typeParams, OMNI_BOOLEAN, logicalFuncParams, nullptr, + "relational_le"); + this->value = std::make_shared(logicalArrayPtr, leftIsNull); + return; + case omniruntime::expressions::Operator::GTE: + CallExternFunction("batch_greaterThanEqual", typeParams, OMNI_BOOLEAN, logicalFuncParams, nullptr, + "relational_ge"); + this->value = std::make_shared(logicalArrayPtr, leftIsNull); + return; + case omniruntime::expressions::Operator::EQ: + CallExternFunction("batch_equal", typeParams, OMNI_BOOLEAN, logicalFuncParams, nullptr, "relational_eq"); + this->value = std::make_shared(logicalArrayPtr, leftIsNull); + return; + case omniruntime::expressions::Operator::NEQ: + CallExternFunction("batch_notEqual", typeParams, OMNI_BOOLEAN, logicalFuncParams, nullptr, + "relational_neq"); + this->value = std::make_shared(logicalArrayPtr, leftIsNull); + return; + case omniruntime::expressions::Operator::ADD: + CallExternFunction("batch_add", typeParams, returnTypeId, arithFuncParams, nullptr, "arithmetic_add"); + this->value = std::make_shared(left, leftIsNull); + return; + case omniruntime::expressions::Operator::SUB: + CallExternFunction("batch_subtract", typeParams, returnTypeId, arithFuncParams, nullptr, "arithmetic_sub"); + this->value = std::make_shared(left, leftIsNull); + return; + case omniruntime::expressions::Operator::MUL: + CallExternFunction("batch_multiply", typeParams, returnTypeId, arithFuncParams, nullptr, "arithmetic_mul"); + this->value = std::make_shared(left, leftIsNull); + return; + case omniruntime::expressions::Operator::DIV: + arithFuncParams.push_back(leftIsNull); + CallExternFunction("batch_divide", typeParams, returnTypeId, arithFuncParams, + batchCodegenContext->executionContext, "arithmetic_div"); + this->value = std::make_shared(left, leftIsNull); + return; + case omniruntime::expressions::Operator::MOD: + arithFuncParams.push_back(leftIsNull); + CallExternFunction("batch_modulus", typeParams, returnTypeId, arithFuncParams, + batchCodegenContext->executionContext, "arithmetic_mod"); + this->value = std::make_shared(left, leftIsNull); + return; + default: { + LogError("Unsupported int or long binary operator %u", static_cast(binaryExpr->op)); + this->value = CreateInvalidCodeGenValue(); + return; + } + } +} + +void BatchExpressionCodeGen::BatchBinaryExprDoubleHelper(const omniruntime::expressions::BinaryExpr *binaryExpr, + llvm::Value *left, llvm::Value *right, llvm::Value *leftIsNull, llvm::Value *rightIsNull) +{ + std::vector doubleParams { OMNI_DOUBLE, OMNI_DOUBLE }; + std::vector boolParams { OMNI_BOOLEAN, OMNI_BOOLEAN }; + DataTypeId returnTypeId = binaryExpr->GetReturnTypeId(); + AllocaInst *logicalArrayPtr = nullptr; + std::vector logicalFuncParams; + if (returnTypeId == OMNI_BOOLEAN) { + logicalArrayPtr = builder->CreateAlloca(llvmTypes->I1Type(), this->batchCodegenContext->rowCnt, "LOGICAL_PTR"); + logicalFuncParams = { left, right, logicalArrayPtr, this->batchCodegenContext->rowCnt }; + } + std::vector arithFuncParams { left, right, this->batchCodegenContext->rowCnt }; + + std::vector nullFuncParams { leftIsNull, rightIsNull, this->batchCodegenContext->rowCnt }; + CallExternFunction("batch_or", boolParams, OMNI_BOOLEAN, nullFuncParams, nullptr, "either_null"); + + switch (binaryExpr->op) { + case omniruntime::expressions::Operator::LT: + CallExternFunction("batch_lessThan", doubleParams, OMNI_BOOLEAN, logicalFuncParams, nullptr, + "relational_lt"); + this->value = std::make_shared(logicalArrayPtr, leftIsNull); + return; + case omniruntime::expressions::Operator::GT: + CallExternFunction("batch_greaterThan", doubleParams, OMNI_BOOLEAN, logicalFuncParams, nullptr, + "relational_gt"); + this->value = std::make_shared(logicalArrayPtr, leftIsNull); + return; + case omniruntime::expressions::Operator::LTE: + CallExternFunction("batch_lessThanEqual", doubleParams, OMNI_BOOLEAN, logicalFuncParams, nullptr, + "relational_le"); + this->value = std::make_shared(logicalArrayPtr, leftIsNull); + return; + case omniruntime::expressions::Operator::GTE: + CallExternFunction("batch_greaterThanEqual", doubleParams, OMNI_BOOLEAN, logicalFuncParams, nullptr, + "relational_ge"); + this->value = std::make_shared(logicalArrayPtr, leftIsNull); + return; + case omniruntime::expressions::Operator::EQ: + CallExternFunction("batch_equal", doubleParams, OMNI_BOOLEAN, logicalFuncParams, nullptr, "relational_eq"); + this->value = std::make_shared(logicalArrayPtr, leftIsNull); + return; + case omniruntime::expressions::Operator::NEQ: + CallExternFunction("batch_notEqual", doubleParams, OMNI_BOOLEAN, logicalFuncParams, nullptr, + "relational_neq"); + this->value = std::make_shared(logicalArrayPtr, leftIsNull); + return; + case omniruntime::expressions::Operator::ADD: + CallExternFunction("batch_add", doubleParams, OMNI_DOUBLE, arithFuncParams, nullptr, "arithmetic_add"); + this->value = std::make_shared(left, leftIsNull); + return; + case omniruntime::expressions::Operator::SUB: + CallExternFunction("batch_subtract", doubleParams, OMNI_DOUBLE, arithFuncParams, nullptr, "arithmetic_sub"); + this->value = std::make_shared(left, leftIsNull); + return; + case omniruntime::expressions::Operator::MUL: + CallExternFunction("batch_multiply", doubleParams, OMNI_DOUBLE, arithFuncParams, nullptr, "arithmetic_mul"); + this->value = std::make_shared(left, leftIsNull); + return; + case omniruntime::expressions::Operator::DIV: + CallExternFunction("batch_divide", doubleParams, OMNI_DOUBLE, arithFuncParams, nullptr, "arithmetic_div"); + this->value = std::make_shared(left, leftIsNull); + return; + case omniruntime::expressions::Operator::MOD: + CallExternFunction("batch_modulus", doubleParams, OMNI_DOUBLE, arithFuncParams, nullptr, "arithmetic_mod"); + this->value = std::make_shared(left, leftIsNull); + return; + default: { + LogError("Unsupported double binary operator %u", static_cast(binaryExpr->op)); + this->value = CreateInvalidCodeGenValue(); + return; + } + } +} + +void BatchExpressionCodeGen::BatchBinaryExprDecimalHelper(const omniruntime::expressions::BinaryExpr *binaryExpr, + DecimalValue &left, DecimalValue &right, llvm::Value *leftIsNull, llvm::Value *rightIsNull) +{ + std::vector decimalParams { binaryExpr->left->GetReturnTypeId(), binaryExpr->right->GetReturnTypeId() }; + std::vector boolParams { OMNI_BOOLEAN, OMNI_BOOLEAN }; + DataTypeId returnTypeId = binaryExpr->GetReturnTypeId(); + std::shared_ptr returnDecimalValue = nullptr; + AllocaInst *logicalArrayPtr = nullptr; + AllocaInst *arithArrayPtr = nullptr; + std::vector logicalFuncParams; + std::vector arithFuncParams; + + if (returnTypeId == OMNI_BOOLEAN) { + logicalArrayPtr = builder->CreateAlloca(llvmTypes->I1Type(), this->batchCodegenContext->rowCnt, "LOGICAL_PTR"); + logicalFuncParams = { + left.data, const_cast(left.GetPrecision()), const_cast(left.GetScale()), + right.data, const_cast(right.GetPrecision()), const_cast(right.GetScale()), + logicalArrayPtr, this->batchCodegenContext->rowCnt + }; + } else if (TypeUtil::IsDecimalType(returnTypeId)) { + returnDecimalValue = BuildDecimalValue(nullptr, *binaryExpr->GetReturnType(), nullptr); + arithFuncParams = { left.data, + const_cast(left.GetPrecision()), + const_cast(left.GetScale()), + right.data, + const_cast(right.GetPrecision()), + const_cast(right.GetScale()), + const_cast(returnDecimalValue->GetPrecision()), + const_cast(returnDecimalValue->GetScale()), + this->batchCodegenContext->rowCnt }; + if (decimalParams == std::vector { OMNI_DECIMAL128, OMNI_DECIMAL128 } && + returnTypeId == OMNI_DECIMAL64) { + arithArrayPtr = builder->CreateAlloca(llvmTypes->I64Type(), this->batchCodegenContext->rowCnt, "ARITH_PTR"); + arithFuncParams.insert(std::begin(arithFuncParams) + 6, arithArrayPtr); + } else if (decimalParams == std::vector { OMNI_DECIMAL64, OMNI_DECIMAL64 } && + returnTypeId == OMNI_DECIMAL128) { + arithArrayPtr = + builder->CreateAlloca(llvmTypes->I128Type(), this->batchCodegenContext->rowCnt, "ARITH_PTR"); + arithFuncParams.insert(std::begin(arithFuncParams) + 6, arithArrayPtr); + } + } + + std::vector nullFuncParams { leftIsNull, rightIsNull, this->batchCodegenContext->rowCnt }; + CallExternFunction("batch_or", boolParams, OMNI_BOOLEAN, nullFuncParams, nullptr, "either_null"); + // when val->isNull = true, val->data is a random number, may cause false exception. + // so push leftIsNull into args. + // for functions throwing exception, if leftIsNull == true, do nothing. + if (!overflowConfig || overflowConfig->GetOverflowConfigId() != omniruntime::op::OVERFLOW_CONFIG_NULL) { + arithFuncParams.insert(arithFuncParams.begin(), leftIsNull); + } + + Value *falseValue = llvmTypes->CreateConstantBool(false); + AllocaInst *overflowNull = + builder->CreateAlloca(Type::getInt1Ty(*context), this->batchCodegenContext->rowCnt, "OVERFLOW_NULL_PTR"); + std::vector funcArgs { overflowNull, falseValue, this->batchCodegenContext->rowCnt }; + std::vector paramsVec = { OMNI_BOOLEAN }; + CallExternFunction("batch_fill_null", paramsVec, OMNI_BOOLEAN, funcArgs, nullptr, "fill_null_array"); + switch (binaryExpr->op) { + case omniruntime::expressions::Operator::LT: + CallExternFunction("batch_lessThan", decimalParams, returnTypeId, logicalFuncParams, nullptr, + "relational_lt"); + break; + case omniruntime::expressions::Operator::GT: + CallExternFunction("batch_greaterThan", decimalParams, returnTypeId, logicalFuncParams, nullptr, + "relational_gt"); + break; + case omniruntime::expressions::Operator::LTE: + CallExternFunction("batch_lessThanEqual", decimalParams, returnTypeId, logicalFuncParams, nullptr, + "relational_le"); + break; + case omniruntime::expressions::Operator::GTE: + CallExternFunction("batch_greaterThanEqual", decimalParams, returnTypeId, logicalFuncParams, nullptr, + "relational_ge"); + break; + case omniruntime::expressions::Operator::EQ: + CallExternFunction("batch_equal", decimalParams, returnTypeId, logicalFuncParams, nullptr, "relational_eq"); + break; + case omniruntime::expressions::Operator::NEQ: + CallExternFunction("batch_notEqual", decimalParams, returnTypeId, logicalFuncParams, nullptr, + "relational_neq"); + break; + case omniruntime::expressions::Operator::ADD: + CallExternFunction("batch_add", decimalParams, returnTypeId, arithFuncParams, + batchCodegenContext->executionContext, "arithmetic_add", this->overflowConfig, overflowNull); + break; + case omniruntime::expressions::Operator::SUB: + CallExternFunction("batch_subtract", decimalParams, returnTypeId, arithFuncParams, + batchCodegenContext->executionContext, "arithmetic_sub", this->overflowConfig, overflowNull); + break; + case omniruntime::expressions::Operator::MUL: + CallExternFunction("batch_multiply", decimalParams, returnTypeId, arithFuncParams, + batchCodegenContext->executionContext, "arithmetic_mul", this->overflowConfig, overflowNull); + break; + case omniruntime::expressions::Operator::DIV: + CallExternFunction("batch_divide", decimalParams, returnTypeId, arithFuncParams, + batchCodegenContext->executionContext, "arithmetic_div", this->overflowConfig, overflowNull); + break; + case omniruntime::expressions::Operator::MOD: + CallExternFunction("batch_modulus", decimalParams, returnTypeId, arithFuncParams, + batchCodegenContext->executionContext, "arithmetic_mod", this->overflowConfig, overflowNull); + break; + default: { + LogError("Unsupported decimal binary operator %u", static_cast(binaryExpr->op)); + this->value = CreateInvalidCodeGenValue(); + return; + } + } + + if (TypeUtil::IsDecimalType(returnTypeId)) { + std::vector isAnyNullParams { leftIsNull, overflowNull, this->batchCodegenContext->rowCnt }; + CallExternFunction("batch_or", boolParams, OMNI_BOOLEAN, isAnyNullParams, nullptr, "either_null"); + llvm::Value *dataValue = nullptr; + if (returnTypeId == omniruntime::type::OMNI_DECIMAL128) { + if (decimalParams[0] == OMNI_DECIMAL128) { + dataValue = left.data; + } else if (decimalParams[1] == OMNI_DECIMAL128) { + dataValue = right.data; + } else { + dataValue = arithArrayPtr; + } + this->value = std::make_shared(dataValue, leftIsNull, + const_cast(returnDecimalValue->GetPrecision()), + const_cast(returnDecimalValue->GetScale())); + } else if (returnTypeId == omniruntime::type::OMNI_DECIMAL64) { + if (decimalParams[0] == OMNI_DECIMAL64) { + dataValue = left.data; + } else if (decimalParams[1] == OMNI_DECIMAL64) { + dataValue = right.data; + } else { + dataValue = arithArrayPtr; + } + this->value = std::make_shared(dataValue, leftIsNull, + const_cast(returnDecimalValue->GetPrecision()), + const_cast(returnDecimalValue->GetScale())); + } + } else { + this->value = std::make_shared(logicalArrayPtr, leftIsNull); + } +} + +void BatchExpressionCodeGen::BatchBinaryExprStringHelper(const omniruntime::expressions::BinaryExpr *binaryExpr, + llvm::Value *left, llvm::Value *leftLen, llvm::Value *right, llvm::Value *rightLen, llvm::Value *leftIsNull, + llvm::Value *rightIsNull) +{ + std::vector strParams { OMNI_VARCHAR, OMNI_VARCHAR }; + std::vector boolParams { OMNI_BOOLEAN, OMNI_BOOLEAN }; + auto logicalArrayPtr = builder->CreateAlloca(llvmTypes->I1Type(), this->batchCodegenContext->rowCnt, "LOGICAL_PTR"); + std::vector logicalFuncParams { left, leftLen, right, + rightLen, logicalArrayPtr, this->batchCodegenContext->rowCnt }; + + std::vector nullFuncParams { leftIsNull, rightIsNull, this->batchCodegenContext->rowCnt }; + CallExternFunction("batch_or", boolParams, OMNI_BOOLEAN, nullFuncParams, nullptr, "either_null"); + + switch (binaryExpr->op) { + case omniruntime::expressions::Operator::LT: + CallExternFunction("batch_lessThan", strParams, OMNI_BOOLEAN, logicalFuncParams, nullptr, "relational_lt"); + this->value = std::make_shared(logicalArrayPtr, leftIsNull); + return; + case omniruntime::expressions::Operator::GT: + CallExternFunction("batch_greaterThan", strParams, OMNI_BOOLEAN, logicalFuncParams, nullptr, + "relational_gt"); + this->value = std::make_shared(logicalArrayPtr, leftIsNull); + return; + case omniruntime::expressions::Operator::LTE: + CallExternFunction("batch_lessThanEqual", strParams, OMNI_BOOLEAN, logicalFuncParams, nullptr, + "relational_le"); + this->value = std::make_shared(logicalArrayPtr, leftIsNull); + return; + case omniruntime::expressions::Operator::GTE: + CallExternFunction("batch_greaterThanEqual", strParams, OMNI_BOOLEAN, logicalFuncParams, nullptr, + "relational_ge"); + this->value = std::make_shared(logicalArrayPtr, leftIsNull); + return; + case omniruntime::expressions::Operator::EQ: + CallExternFunction("batch_equal", strParams, OMNI_BOOLEAN, logicalFuncParams, nullptr, "relational_eq"); + this->value = std::make_shared(logicalArrayPtr, leftIsNull); + return; + case omniruntime::expressions::Operator::NEQ: + CallExternFunction("batch_notEqual", strParams, OMNI_BOOLEAN, logicalFuncParams, nullptr, "relational_neq"); + this->value = std::make_shared(logicalArrayPtr, leftIsNull); + return; + default: { + LogError("Unsupported string binary operator %u", static_cast(binaryExpr->op)); + this->value = CreateInvalidCodeGenValue(); + return; + } + } +} + +void BatchExpressionCodeGen::BatchVisitBetweenExprHelper(BetweenExpr &bExpr, const std::shared_ptr &val, + const std::shared_ptr &lowerVal, const std::shared_ptr &upperVal, + std::pair cmpPair) +{ + auto cmpLeft = cmpPair.first; + auto cmpRight = cmpPair.second; + std::vector params; + if (TypeUtil::IsStringType(bExpr.value->GetReturnTypeId())) { + params = { OMNI_VARCHAR, OMNI_VARCHAR }; + } else if (bExpr.value->GetReturnTypeId() == type::OMNI_DATE32) { + params = { OMNI_INT, OMNI_INT }; + } else if (bExpr.value->GetReturnTypeId() == type::OMNI_TIMESTAMP) { + params = { OMNI_LONG, OMNI_LONG }; + } else { + params = { bExpr.value->GetReturnTypeId(), bExpr.value->GetReturnTypeId() }; + } + std::vector logicalFuncParams1; + std::vector logicalFuncParams2; + bool isSupportedType = false; + + if (bExpr.value->GetReturnTypeId() == OMNI_INT || bExpr.value->GetReturnTypeId() == OMNI_LONG || + bExpr.value->GetReturnTypeId() == OMNI_DATE32 || bExpr.value->GetReturnTypeId() == OMNI_DOUBLE || + bExpr.value->GetReturnTypeId() == type::OMNI_TIMESTAMP) { + logicalFuncParams1 = { lowerVal->data, val->data, *cmpLeft, this->batchCodegenContext->rowCnt }; + logicalFuncParams2 = { val->data, upperVal->data, *cmpRight, this->batchCodegenContext->rowCnt }; + isSupportedType = true; + } else if (TypeUtil::IsStringType(bExpr.value->GetReturnTypeId())) { + logicalFuncParams1 = { lowerVal->data, lowerVal->length, val->data, + val->length, *cmpLeft, this->batchCodegenContext->rowCnt }; + logicalFuncParams2 = { val->data, val->length, upperVal->data, + upperVal->length, *cmpRight, this->batchCodegenContext->rowCnt }; + isSupportedType = true; + } else if (TypeUtil::IsDecimalType(bExpr.value->GetReturnTypeId())) { + logicalFuncParams1 = { lowerVal->data, + const_cast(static_cast(*lowerVal).GetPrecision()), + const_cast(static_cast(*lowerVal).GetScale()), + val->data, + const_cast(static_cast(*val).GetPrecision()), + const_cast(static_cast(*val).GetScale()), + *cmpLeft, + this->batchCodegenContext->rowCnt }; + logicalFuncParams2 = { val->data, + const_cast(static_cast(*val).GetPrecision()), + const_cast(static_cast(*val).GetScale()), + upperVal->data, + const_cast(static_cast(*upperVal).GetPrecision()), + const_cast(static_cast(*upperVal).GetScale()), + *cmpRight, + this->batchCodegenContext->rowCnt }; + isSupportedType = true; + } + if (isSupportedType) { + CallExternFunction("batch_lessThanEqual", params, OMNI_BOOLEAN, logicalFuncParams1, nullptr, "relational_le"); + CallExternFunction("batch_lessThanEqual", params, OMNI_BOOLEAN, logicalFuncParams2, nullptr, "relational_le"); + + std::vector boolParams { OMNI_BOOLEAN, OMNI_BOOLEAN }; + std::vector nullFuncParams { lowerVal->isNull, val->isNull, this->batchCodegenContext->rowCnt }; + CallExternFunction("batch_or", boolParams, OMNI_BOOLEAN, nullFuncParams, nullptr, "either_null"); + nullFuncParams = { lowerVal->isNull, upperVal->isNull, this->batchCodegenContext->rowCnt }; + CallExternFunction("batch_or", boolParams, OMNI_BOOLEAN, nullFuncParams, nullptr, "either_null"); + + std::vector betweenFuncParams = { *cmpLeft, *cmpRight, this->batchCodegenContext->rowCnt }; + CallExternFunction("batch_and", boolParams, OMNI_BOOLEAN, betweenFuncParams, nullptr, "between_pass"); + this->value = std::make_shared(*cmpLeft, lowerVal->isNull); + return; + } + + LogError("Unsupported data type for BETWEEN expr %d", bExpr.value->GetReturnTypeId()); + this->value = CreateInvalidCodeGenValue(); +} + +Value *BatchExpressionCodeGen::PushAndGetNullFlagArray(const FuncExpr &fExpr, std::vector &argVals, + Value *nullFlagArray, bool needAdd) +{ + if (fExpr.function->GetNullableResultType() == INPUT_DATA_AND_NULL_AND_RETURN_NULL) { + AllocaInst *isNullArrPtr = + builder->CreateAlloca(llvmTypes->I1Type(), this->batchCodegenContext->rowCnt, "IS_RET_NULL_PTR"); + CallExternFunction("batch_fill_null", { OMNI_BOOLEAN }, OMNI_BOOLEAN, + { isNullArrPtr, llvmTypes->CreateConstantBool(false), this->batchCodegenContext->rowCnt }, nullptr, + "fill_ret_null_array"); + argVals.push_back(isNullArrPtr); + return isNullArrPtr; + } + if (needAdd) { + argVals.push_back(nullFlagArray); + } + return nullFlagArray; +} +} diff --git a/core/src/codegen/batch_expression_codegen.h b/core/src/codegen/batch_expression_codegen.h new file mode 100644 index 0000000..79f043c --- /dev/null +++ b/core/src/codegen/batch_expression_codegen.h @@ -0,0 +1,163 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: batch expression codegen + */ + +#ifndef OMNI_RUNTIME_BATCH_EXPRESSION_CODEGEN_H +#define OMNI_RUNTIME_BATCH_EXPRESSION_CODEGEN_H + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "codegen_value.h" +#include "batch_codegen_context.h" +#include "expression/expressions.h" +#include "expression/parser/parser.h" +#include "expression/expr_printer.h" +#include "util/debug.h" +#include "llvm_types.h" +#include "llvm_engine.h" +#include "operator/config/operator_config.h" +#include "type/data_type.h" +#include "vector/vector_batch.h" +#include "codegen_base.h" + +namespace omniruntime::codegen { +using namespace llvm; +using namespace orc; +using namespace omniruntime; +using namespace omniruntime::expressions; +using namespace omniruntime::type; +using CodeGenValuePtr = std::shared_ptr; + +class BatchExpressionCodeGen : public ExprVisitor, public CodegenBase { +public: + BatchExpressionCodeGen(std::string name, const omniruntime::expressions::Expr &cpExpr, + op::OverflowConfig *ofConfig); + + ~BatchExpressionCodeGen() override + { + if (rt) { + eoe(rt->remove()); + } + } + + virtual intptr_t GetFunction() = 0; + + void Visit(const LiteralExpr &e) override; + + void Visit(const FieldExpr &e) override; + + void Visit(const UnaryExpr &e) override; + + void Visit(const BinaryExpr &e) override; + + void Visit(const InExpr &e) override; + + void Visit(const BetweenExpr &e) override; + + void Visit(const IfExpr &e) override; + + void Visit(const CoalesceExpr &e) override; + + void Visit(const IsNullExpr &e) override; + + void Visit(const FuncExpr &e) override; + + void Visit(const SwitchExpr &e) override; + + CodeGenValuePtr VisitExpr(const Expr &e); + + std::vector GetFunctionArgValues(const FuncExpr &fExpr, llvm::AllocaInst *isAnyNull, + bool &isInvalidExpr); + +protected: + AllocaInst *GetResultArray(DataTypeId dataTypeId, Value *rowCnt); + + virtual llvm::Function *CreateBatchFunction(); + +private: + bool InitializeBatchCodegenContext(llvm::iterator_range args); + + Value *GetDictionaryVectorValue(const DataType &dataType, llvm::Value *rowIdxArray, Value *rowCnt, + Value *dictionaryVectorPtr, AllocaInst *lengthArrayPtr); + + Value *GetVectorValue(const DataType &dataType, Value *rowIdxArray, Value *rowCnt, Value *dataVectorPtr, + Value *offsetArray, Value *lengthArrayPtr); + + CodeGenValue *BatchLiteralExprConstantHelper(const LiteralExpr &lExpr); + + void BatchBinaryExprIntLongHelper(const BinaryExpr *binaryExpr, Value *left, Value *right, Value *leftIsNull, + Value *rightIsNull); + + void BatchBinaryExprDoubleHelper(const BinaryExpr *binaryExpr, Value *left, Value *right, Value *leftIsNull, + Value *rightIsNull); + + void BatchBinaryExprStringHelper(const BinaryExpr *binaryExpr, Value *left, Value *leftLen, Value *right, + Value *rightLen, Value *leftIsNull, Value *rightIsNull); + + void BatchBinaryExprDecimalHelper(const BinaryExpr *binaryExpr, DecimalValue &left, DecimalValue &right, + Value *leftIsNull, Value *rightIsNull); + + void BatchVisitBetweenExprHelper(BetweenExpr &bExpr, const std::shared_ptr &val, + const std::shared_ptr &lowerVal, const std::shared_ptr &upperVal, + std::pair cmpPair); + + template + std::vector GetDefaultFunctionArgValues(const FuncExpr &fExpr, AllocaInst *isAnyNull, + bool &isInvalidExpr); + + std::vector GetDataArgs(const FuncExpr &fExpr, AllocaInst *isAnyNull, bool &isInvalidExpr); + + std::vector GetDataAndNullArgs(const FuncExpr &fExpr, AllocaInst *isAnyNull, bool &isInvalidExpr); + + std::vector GetDataAndNullArgsAndReturnNull(const FuncExpr &fExpr, AllocaInst *isAnyNull, + bool &isInvalidExpr); + + std::vector GetDataAndOverflowNullArgs(const FuncExpr &fExpr, AllocaInst *isAnyNull, + bool &isInvalidExpr, AllocaInst *overflowNull); + + void FuncExprOverflowNullHelper(const FuncExpr &e); + + Value *ArenaAlloc(Value *sizeInBytes); + + Value *GetTypeSize(DataTypeId dataTypeId); + + std::vector GetHiveUdfArgValues(const FuncExpr &fExpr, bool &isInvalidExpr); + + llvm::Value *CreateHiveUdfArgTypes(const FuncExpr &fExpr); + + void CallHiveUdfFunction(const FuncExpr &fExpr); + + Value *PushAndGetNullFlagArray(const FuncExpr &fExpr, std::vector &argVals, Value *nullFlagArray, + bool needAdd); +}; +} +#endif // OMNI_RUNTIME_BATCH_EXPRESSION_CODEGEN_H diff --git a/core/src/codegen/batch_filter_codegen.cpp b/core/src/codegen/batch_filter_codegen.cpp new file mode 100644 index 0000000..e9f164e --- /dev/null +++ b/core/src/codegen/batch_filter_codegen.cpp @@ -0,0 +1,87 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: batch filter expression codegen + */ + +#include "batch_filter_codegen.h" + +#include + +namespace omniruntime::codegen { +using namespace llvm; +using namespace orc; +using namespace omniruntime::expressions; + +namespace { +const int INPUT_INDEX = 0; +const int ARGUMENT_ONE = 1; +const int ARGUMENT_TWO = 2; +const int ARGUMENT_THREE = 3; +const int OFFSETS_INDEX = 4; +const int EXECUTION_CONTEXT_IDX = 5; +const int DICTIONARY_VECTORS_IDX = 6; +} + +intptr_t BatchFilterCodeGen::GetFunction() +{ + llvm::Function *func = CreateBatchFunction(); + if (func == nullptr) { + return 0; + } + return CreateBatchWrapper(*func); +} + +intptr_t BatchFilterCodeGen::CreateBatchWrapper(llvm::Function &filter) +{ + llvm::Function *filterFunc = &filter; + std::vector args; + args.push_back(llvmTypes->I64PtrType()); // inputData + args.push_back(llvmTypes->I32Type()); // rowCnt + args.push_back(llvmTypes->I32PtrType()); // selectedRows + args.push_back(llvmTypes->I64PtrType()); // inputBitmap + args.push_back(llvmTypes->I64PtrType()); // inputOffsets + args.push_back(llvmTypes->I64Type()); // execution_context + args.push_back(llvmTypes->I64PtrType()); // dictionary vectors + + FunctionType *funcSignature = FunctionType::get(llvmTypes->I32Type(), args, false); + llvm::Function *funcDecl = + llvm::Function::Create(funcSignature, llvm::Function::ExternalLinkage, "WRAPPER_FUNC", modulePtr); + BasicBlock *filterMain = BasicBlock::Create(*context, "FILTER_MAIN", funcDecl); + // set arg names + Argument *data = funcDecl->getArg(INPUT_INDEX); + data->setName("ARGS_ARRAY"); + Argument *numRows = funcDecl->getArg(ARGUMENT_ONE); + numRows->setName("NUM_ROWS"); + Argument *selectedRows = funcDecl->getArg(ARGUMENT_TWO); + selectedRows->setName("RESULTS"); + Argument *bitmap = funcDecl->getArg(ARGUMENT_THREE); + bitmap->setName("BITMAP"); + Argument *offsets = funcDecl->getArg(OFFSETS_INDEX); + offsets->setName("OFFSETS"); + Argument *executionContext = funcDecl->getArg(EXECUTION_CONTEXT_IDX); + executionContext->setName("EXECUTION_CONTEXT_ADDRESS"); + Argument *dictionaryVectors = funcDecl->getArg(DICTIONARY_VECTORS_IDX); + dictionaryVectors->setName("DICTIONARY_VECTORS"); + + builder->SetInsertPoint(filterMain); + AllocaInst *lengthAllocaInst = builder->CreateAlloca(llvmTypes->I32Type(), numRows, "LENGTH_PTR"); + AllocaInst *isNullPtr = builder->CreateAlloca(llvmTypes->I1Type(), numRows, "IS_NULL_PTR"); + AllocaInst *rowIdxArray = builder->CreateAlloca(llvmTypes->I32Type(), numRows, "ROW_INDEX_ARRAY"); + std::vector funcArgs { rowIdxArray, numRows }; + CallExternFunction("fill_rowIdx", { OMNI_INT, OMNI_INT }, OMNI_INT, funcArgs, nullptr, "fill_rowIdx"); + // in the form of {0, 1, 1, ...}. 1 indicates passing the filter, 0 otherwise. + auto filterResArray = builder->CreateAlloca(llvmTypes->I1Type(), numRows, "FILTER_RES_PTR"); + + std::vector filterFuncArgs { data, bitmap, offsets, numRows, + rowIdxArray, lengthAllocaInst, executionContext, dictionaryVectors, + isNullPtr, filterResArray }; + builder->CreateCall(filterFunc, filterFuncArgs, "INNER_FUNC"); + + std::vector paramTypes = { OMNI_BOOLEAN, OMNI_BOOLEAN, OMNI_INT, OMNI_INT }; + funcArgs = { filterResArray, isNullPtr, selectedRows, numRows }; + auto res = CallExternFunction("batch_and_not", paramTypes, OMNI_INT, funcArgs, nullptr, "fill_filter_result"); + builder->CreateRet(res); + OptimizeFunctionsAndModule(); + return Compile(); +} +} \ No newline at end of file diff --git a/core/src/codegen/batch_filter_codegen.h b/core/src/codegen/batch_filter_codegen.h new file mode 100644 index 0000000..7f3cfbc --- /dev/null +++ b/core/src/codegen/batch_filter_codegen.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: batch filter expression codegen + */ + +#ifndef OMNI_RUNTIME_BATCH_FILTER_CODEGEN_H +#define OMNI_RUNTIME_BATCH_FILTER_CODEGEN_H + +#include +#include "batch_expression_codegen.h" + +namespace omniruntime { +namespace codegen { +class BatchFilterCodeGen : public BatchExpressionCodeGen { +public: + BatchFilterCodeGen(std::string name, const omniruntime::expressions::Expr &expression, + omniruntime::op::OverflowConfig *overflowConfig) + : BatchExpressionCodeGen(std::move(name), expression, overflowConfig) + {} + + ~BatchFilterCodeGen() override = default; + + intptr_t GetFunction() override; + +private: + intptr_t CreateBatchWrapper(llvm::Function &filter); +}; +} +} + +#endif // OMNI_RUNTIME_BATCH_FILTER_CODEGEN_H diff --git a/core/src/codegen/batch_func_registry_datetime.cpp b/core/src/codegen/batch_func_registry_datetime.cpp new file mode 100644 index 0000000..4e15825 --- /dev/null +++ b/core/src/codegen/batch_func_registry_datetime.cpp @@ -0,0 +1,26 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: Batch Date Time Function Registry + */ +#include "batch_func_registry_datetime.h" +#include "batch_functions/batch_datetime_functions.h" + +namespace omniruntime::codegen { +using namespace omniruntime::type; +using namespace omniruntime::codegen::function; + +std::vector BatchDateTimeFunctionRegistry::GetFunctions() +{ + static std::vector batchDateTimeFunctions = { + Function(reinterpret_cast(BatchUnixTimestampFromStr), "batch_unix_timestamp", {}, + { OMNI_VARCHAR, OMNI_VARCHAR, OMNI_VARCHAR, OMNI_VARCHAR }, OMNI_LONG, INPUT_DATA_AND_NULL_AND_RETURN_NULL), + Function(reinterpret_cast(BatchUnixTimestampFromDate), "batch_unix_timestamp", {}, + { OMNI_DATE32, OMNI_VARCHAR, OMNI_VARCHAR, OMNI_VARCHAR }, OMNI_LONG, INPUT_DATA), + Function(reinterpret_cast(BatchFromUnixTime), "batch_from_unixtime", {}, + { OMNI_LONG, OMNI_VARCHAR, OMNI_VARCHAR }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchFromUnixTimeRetNull), "batch_from_unixtime_null", {}, + { OMNI_LONG, OMNI_VARCHAR, OMNI_VARCHAR }, OMNI_VARCHAR, INPUT_DATA_AND_OVERFLOW_NULL, true) }; + + return batchDateTimeFunctions; +} +} \ No newline at end of file diff --git a/core/src/codegen/batch_func_registry_datetime.h b/core/src/codegen/batch_func_registry_datetime.h new file mode 100644 index 0000000..62b9153 --- /dev/null +++ b/core/src/codegen/batch_func_registry_datetime.h @@ -0,0 +1,17 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: Batch Date Time Function Registry + */ + +#ifndef OMNI_RUNTIME_BATCH_FUNC_REGISTRY_DATETIME_H +#define OMNI_RUNTIME_BATCH_FUNC_REGISTRY_DATETIME_H +#include "function.h" +#include "func_registry_base.h" + +namespace omniruntime::codegen { +class BatchDateTimeFunctionRegistry : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; +} +#endif // OMNI_RUNTIME_BATCH_FUNC_REGISTRY_DATETIME_H diff --git a/core/src/codegen/batch_func_registry_decimal.cpp b/core/src/codegen/batch_func_registry_decimal.cpp new file mode 100644 index 0000000..42a2553 --- /dev/null +++ b/core/src/codegen/batch_func_registry_decimal.cpp @@ -0,0 +1,435 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: Batch Decimal Function Registry + */ +#include "batch_func_registry_decimal.h" +#include "batch_functions/batch_decimal_arithmetic_functions.h" +#include "batch_functions/batch_decimal_cast_functions.h" + +namespace omniruntime::codegen { +using namespace omniruntime::type; +using namespace function; + +namespace { +const std::string ABS_FN_STR = "batch_abs"; +const std::string ROUND_FN_STR = "batch_round"; +const std::string ROUND_NULL_FN_STR = "batch_round_null"; +const std::string CAST_FN_STR = "batch_CAST"; +const std::string CAST_NULL_FN_STR = "batch_CAST_null"; +const std::string ADD_NULL_FN_STR = "batch_add_null"; +const std::string SUBTRACT_NULL_FN_STR = "batch_subtract_null"; +const std::string MULTIPLY_NULL_FN_STR = "batch_multiply_null"; +const std::string DIVIDE_NULL_FN_STR = "batch_divide_null"; +const std::string MODULUS_NULL_FN_STR = "batch_modulus_null"; +const std::string ADD_FN_STR = "batch_add"; +const std::string SUBTRACT_FN_STR = "batch_subtract"; +const std::string MULTIPLY_FN_STR = "batch_multiply"; +const std::string DIVIDE_FN_STR = "batch_divide"; +const std::string MODULUS_FN_STR = "batch_modulus"; +const std::string LESS_THAN_FN_STR = "batch_lessThan"; +const std::string LESS_THAN_EQUAL_FN_STR = "batch_lessThanEqual"; +const std::string GREATER_THAN_FN_STR = "batch_greaterThan"; +const std::string GREATER_THAN_EQUAL_FN_STR = "batch_greaterThanEqual"; +const std::string EQUAL_FN_STR = "batch_equal"; +const std::string NOT_EQUAL_FN_STR = "batch_notEqual"; +const std::string MAKE_DECIMAL_FN_STR = "batch_MakeDecimal"; +const std::string MAKE_DECIMAL_NULL_FN_STR = "batch_MakeDecimal_null"; +const std::string BATCH_DECIMAL128_COMPARE_STR = "batch_Decimal128Compare"; +const std::string BATCH_DECIMAL64_COMPARE_STR = "batch_Decimal64Compare"; +const std::string BATCH_UNSCALED_VALUE_STR = "batch_UnscaledValue"; +const std::string GREATEST_DECIMAL_FN_STR = "batch_Greatest"; +const std::string GREATEST_DECIMAL_NULL_FN_STR = "batch_Greatest_null"; +} + +std::vector BatchDecimalFunctionRegistry::GetFunctions() +{ + std::vector paramTypes128 = { OMNI_DECIMAL128, OMNI_DECIMAL128 }; + std::vector paramTypes64 = { OMNI_DECIMAL64, OMNI_DECIMAL64 }; + std::vector paramTypes64Op128 = { OMNI_DECIMAL64, OMNI_DECIMAL128 }; + std::vector paramTypes128Op64 = { OMNI_DECIMAL128, OMNI_DECIMAL64 }; + + static std::vector batchDecimalFunctions = { + Function(reinterpret_cast(BatchDecimal128Compare), BATCH_DECIMAL128_COMPARE_STR, {}, paramTypes128, + OMNI_INT), + Function(reinterpret_cast(BatchAbsDecimal128), ABS_FN_STR, {}, { OMNI_DECIMAL128 }, OMNI_DECIMAL128, + INPUT_DATA), + Function(reinterpret_cast(BatchDecimal64Compare), BATCH_DECIMAL64_COMPARE_STR, {}, paramTypes64, + OMNI_INT), + Function(reinterpret_cast(BatchAbsDecimal64), ABS_FN_STR, {}, { OMNI_DECIMAL64 }, OMNI_DECIMAL64, + INPUT_DATA), + + Function(reinterpret_cast(BatchRoundDecimal128), ROUND_FN_STR, {}, { OMNI_DECIMAL128, OMNI_INT }, + OMNI_DECIMAL64, INPUT_DATA, true), + Function(reinterpret_cast(BatchRoundDecimal64), ROUND_FN_STR, {}, { OMNI_DECIMAL64, OMNI_INT }, + OMNI_DECIMAL64, INPUT_DATA, true), + Function(reinterpret_cast(BatchRoundDecimal128WithoutRound), ROUND_FN_STR, {}, { OMNI_DECIMAL128 }, + OMNI_DECIMAL64, INPUT_DATA, true), + Function(reinterpret_cast(BatchRoundDecimal64WithoutRound), ROUND_FN_STR, {}, { OMNI_DECIMAL64 }, + OMNI_DECIMAL64, INPUT_DATA, true), + + // decimal64 compare + Function(reinterpret_cast(BatchLessThanDecimal64), LESS_THAN_FN_STR, {}, paramTypes64, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(BatchLessThanEqualDecimal64), LESS_THAN_EQUAL_FN_STR, {}, paramTypes64, + OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(BatchGreaterThanDecimal64), GREATER_THAN_FN_STR, {}, paramTypes64, + OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(BatchGreaterThanEqualDecimal64), GREATER_THAN_EQUAL_FN_STR, {}, paramTypes64, + OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(BatchEqualDecimal64), EQUAL_FN_STR, {}, paramTypes64, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(BatchNotEqualDecimal64), NOT_EQUAL_FN_STR, {}, paramTypes64, OMNI_BOOLEAN, + INPUT_DATA), + + // decimal128 compare + Function(reinterpret_cast(BatchLessThanDecimal128), LESS_THAN_FN_STR, {}, paramTypes128, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(BatchLessThanEqualDecimal128), LESS_THAN_EQUAL_FN_STR, {}, paramTypes128, + OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(BatchGreaterThanDecimal128), GREATER_THAN_FN_STR, {}, paramTypes128, + OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(BatchGreaterThanEqualDecimal128), GREATER_THAN_EQUAL_FN_STR, {}, + paramTypes128, OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(BatchEqualDecimal128), EQUAL_FN_STR, {}, paramTypes128, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(BatchNotEqualDecimal128), NOT_EQUAL_FN_STR, {}, paramTypes128, OMNI_BOOLEAN, + INPUT_DATA), + + // Decimal Cast Function + Function(reinterpret_cast(BatchCastDecimal64To64), CAST_FN_STR, {}, { OMNI_DECIMAL64 }, OMNI_DECIMAL64, + INPUT_DATA, true), + Function(reinterpret_cast(BatchCastDecimal128To128), CAST_FN_STR, {}, { OMNI_DECIMAL128 }, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchCastDecimal64To128), CAST_FN_STR, {}, { OMNI_DECIMAL64 }, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchCastDecimal128To64), CAST_FN_STR, {}, { OMNI_DECIMAL128 }, + OMNI_DECIMAL64, INPUT_DATA, true), + + Function(reinterpret_cast(BatchCastIntToDecimal64), CAST_FN_STR, {}, { OMNI_INT }, OMNI_DECIMAL64, + INPUT_DATA, true), + Function(reinterpret_cast(BatchCastLongToDecimal64), CAST_FN_STR, {}, { OMNI_LONG }, OMNI_DECIMAL64, + INPUT_DATA, true), + Function(reinterpret_cast(BatchCastDoubleToDecimal64), CAST_FN_STR, {}, { OMNI_DOUBLE }, OMNI_DECIMAL64, + INPUT_DATA, true), + Function(reinterpret_cast(BatchCastIntToDecimal128), CAST_FN_STR, {}, { OMNI_INT }, OMNI_DECIMAL128, + INPUT_DATA, true), + Function(reinterpret_cast(BatchCastLongToDecimal128), CAST_FN_STR, {}, { OMNI_LONG }, OMNI_DECIMAL128, + INPUT_DATA, true), + Function(reinterpret_cast(BatchCastDoubleToDecimal128), CAST_FN_STR, {}, { OMNI_DOUBLE }, + OMNI_DECIMAL128, INPUT_DATA, true), + + Function(reinterpret_cast(BatchCastDecimal128ToLong), CAST_FN_STR, {}, { OMNI_DECIMAL128 }, OMNI_LONG, + INPUT_DATA, true), + Function(reinterpret_cast(BatchCastDecimal128ToInt), CAST_FN_STR, {}, { OMNI_DECIMAL128 }, OMNI_INT, + INPUT_DATA, true), + + // Decimal Cast Function Return Null + Function(reinterpret_cast(BatchRoundDecimal128RetNull), ROUND_NULL_FN_STR, {}, + { OMNI_DECIMAL128, OMNI_INT }, OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchRoundDecimal64RetNull), ROUND_NULL_FN_STR, {}, + { OMNI_DECIMAL64, OMNI_INT }, OMNI_DECIMAL64, INPUT_DATA_AND_OVERFLOW_NULL), + + Function(reinterpret_cast(BatchCastDecimal64To64RetNull), CAST_NULL_FN_STR, {}, { OMNI_DECIMAL64 }, + OMNI_DECIMAL64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchCastDecimal128To128RetNull), CAST_NULL_FN_STR, {}, { OMNI_DECIMAL128 }, + OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchCastDecimal64To128RetNull), CAST_NULL_FN_STR, {}, { OMNI_DECIMAL64 }, + OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchCastDecimal128To64RetNull), CAST_NULL_FN_STR, {}, { OMNI_DECIMAL128 }, + OMNI_DECIMAL64, INPUT_DATA_AND_OVERFLOW_NULL), + + Function(reinterpret_cast(BatchCastDecimal64ToDoubleRetNull), CAST_NULL_FN_STR, {}, { OMNI_DECIMAL64 }, + OMNI_DOUBLE, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchCastDecimal64ToLongRetNull), CAST_NULL_FN_STR, {}, { OMNI_DECIMAL64 }, + OMNI_LONG, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchCastDecimal128ToDoubleRetNull), CAST_NULL_FN_STR, {}, + { OMNI_DECIMAL128 }, OMNI_DOUBLE, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchCastDecimal64ToIntRetNull), CAST_NULL_FN_STR, {}, { OMNI_DECIMAL64 }, + OMNI_INT, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchCastDecimal128ToIntRetNull), CAST_NULL_FN_STR, {}, { OMNI_DECIMAL128 }, + OMNI_INT, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchCastDecimal128ToLongRetNull), CAST_NULL_FN_STR, {}, { OMNI_DECIMAL128 }, + OMNI_LONG, INPUT_DATA_AND_OVERFLOW_NULL), + + Function(reinterpret_cast(BatchCastIntToDecimal64RetNull), CAST_NULL_FN_STR, {}, { OMNI_INT }, + OMNI_DECIMAL64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchCastLongToDecimal64RetNull), CAST_NULL_FN_STR, {}, { OMNI_LONG }, + OMNI_DECIMAL64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchCastDoubleToDecimal64RetNull), CAST_NULL_FN_STR, {}, { OMNI_DOUBLE }, + OMNI_DECIMAL64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchCastIntToDecimal128RetNull), CAST_NULL_FN_STR, {}, { OMNI_INT }, + OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchCastLongToDecimal128RetNull), CAST_NULL_FN_STR, {}, { OMNI_LONG }, + OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchCastDoubleToDecimal128RetNull), CAST_NULL_FN_STR, {}, { OMNI_DOUBLE }, + OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + + // UnscaledValue + Function(reinterpret_cast(BatchUnscaledValue64), BATCH_UNSCALED_VALUE_STR, {}, { OMNI_DECIMAL64 }, + OMNI_LONG, INPUT_DATA), + // MakeDecimal + Function(reinterpret_cast(BatchMakeDecimal64), MAKE_DECIMAL_FN_STR, {}, { OMNI_LONG }, OMNI_DECIMAL64, + INPUT_DATA, true), + Function(reinterpret_cast(BatchMakeDecimal64RetNull), MAKE_DECIMAL_NULL_FN_STR, {}, { OMNI_LONG }, + OMNI_DECIMAL64, INPUT_DATA_AND_OVERFLOW_NULL), + + // Return Null + Function(reinterpret_cast(BatchAddDec64Dec64Dec64RetNull), ADD_NULL_FN_STR, {}, paramTypes64, + OMNI_DECIMAL64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchAddDec64Dec64Dec128RetNull), ADD_NULL_FN_STR, {}, paramTypes64, + OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchAddDec128Dec128Dec128RetNull), ADD_NULL_FN_STR, {}, paramTypes128, + OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchAddDec64Dec128Dec128RetNull), ADD_NULL_FN_STR, {}, paramTypes64Op128, + OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchAddDec128Dec64Dec128RetNull), ADD_NULL_FN_STR, {}, paramTypes128Op64, + OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + + Function(reinterpret_cast(BatchSubDec64Dec64Dec64RetNull), SUBTRACT_NULL_FN_STR, {}, paramTypes64, + OMNI_DECIMAL64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchSubDec64Dec64Dec128RetNull), SUBTRACT_NULL_FN_STR, {}, paramTypes64, + OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchSubDec128Dec128Dec128RetNull), SUBTRACT_NULL_FN_STR, {}, paramTypes128, + OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchSubDec64Dec128Dec128RetNull), SUBTRACT_NULL_FN_STR, {}, + paramTypes64Op128, OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchSubDec128Dec64Dec128RetNull), SUBTRACT_NULL_FN_STR, {}, + paramTypes128Op64, OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + + Function(reinterpret_cast(BatchMulDec64Dec64Dec64RetNull), MULTIPLY_NULL_FN_STR, {}, paramTypes64, + OMNI_DECIMAL64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchMulDec64Dec64Dec128RetNull), MULTIPLY_NULL_FN_STR, {}, paramTypes64, + OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchMulDec128Dec128Dec128RetNull), MULTIPLY_NULL_FN_STR, {}, paramTypes128, + OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchMulDec64Dec128Dec128RetNull), MULTIPLY_NULL_FN_STR, {}, + paramTypes64Op128, OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchMulDec128Dec64Dec128RetNull), MULTIPLY_NULL_FN_STR, {}, + paramTypes128Op64, OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + + Function(reinterpret_cast(BatchDivDec64Dec64Dec64RetNull), DIVIDE_NULL_FN_STR, {}, paramTypes64, + OMNI_DECIMAL64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchDivDec64Dec128Dec64RetNull), DIVIDE_NULL_FN_STR, {}, paramTypes64Op128, + OMNI_DECIMAL64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchDivDec128Dec64Dec64RetNull), DIVIDE_NULL_FN_STR, {}, paramTypes128Op64, + OMNI_DECIMAL64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchDivDec64Dec64Dec128RetNull), DIVIDE_NULL_FN_STR, {}, paramTypes64, + OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchDivDec128Dec128Dec128RetNull), DIVIDE_NULL_FN_STR, {}, paramTypes128, + OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchDivDec64Dec128Dec128RetNull), DIVIDE_NULL_FN_STR, {}, paramTypes64Op128, + OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchDivDec128Dec64Dec128RetNull), DIVIDE_NULL_FN_STR, {}, paramTypes128Op64, + OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + + Function(reinterpret_cast(BatchModDec64Dec64Dec64RetNull), MODULUS_NULL_FN_STR, {}, paramTypes64, + OMNI_DECIMAL64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchModDec64Dec128Dec64RetNull), MODULUS_NULL_FN_STR, {}, paramTypes64Op128, + OMNI_DECIMAL64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchModDec128Dec64Dec64RetNull), MODULUS_NULL_FN_STR, {}, paramTypes128Op64, + OMNI_DECIMAL64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchModDec128Dec64Dec128RetNull), MODULUS_NULL_FN_STR, {}, paramTypes128Op64, + OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchModDec128Dec128Dec128RetNull), MODULUS_NULL_FN_STR, {}, paramTypes128, + OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchModDec128Dec128Dec64RetNull), MODULUS_NULL_FN_STR, {}, paramTypes128, + OMNI_DECIMAL64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchModDec64Dec128Dec128RetNull), MODULUS_NULL_FN_STR, {}, paramTypes64Op128, + OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + // BatchDecimalGreatest + Function(reinterpret_cast(BatchGreatestDecimal64), GREATEST_DECIMAL_FN_STR, {}, paramTypes64, + OMNI_DECIMAL64, INPUT_DATA_AND_NULL_AND_RETURN_NULL, true), + Function(reinterpret_cast(BatchGreatestDecimal128), GREATEST_DECIMAL_FN_STR, {}, paramTypes128, + OMNI_DECIMAL128, INPUT_DATA_AND_NULL_AND_RETURN_NULL, true), + Function(reinterpret_cast(BatchGreatestDecimal64RetNull), GREATEST_DECIMAL_NULL_FN_STR, {}, + paramTypes64, OMNI_DECIMAL64, INPUT_DATA_AND_NULL_AND_RETURN_NULL), + Function(reinterpret_cast(BatchGreatestDecimal128RetNull), GREATEST_DECIMAL_NULL_FN_STR, {}, + paramTypes128, OMNI_DECIMAL128, INPUT_DATA_AND_NULL_AND_RETURN_NULL), + }; + + return batchDecimalFunctions; +} + +std::vector BatchDecimalFunctionRegistryDown::GetFunctions() +{ + static std::vector batchDecimalFunctions = { + Function(reinterpret_cast(BatchCastDecimal64ToLongDown), CAST_FN_STR, {}, { OMNI_DECIMAL64 }, OMNI_LONG, + INPUT_DATA), + Function(reinterpret_cast(BatchCastDecimal64ToIntDown), CAST_FN_STR, {}, { OMNI_DECIMAL64 }, OMNI_INT, + INPUT_DATA, true), + Function(reinterpret_cast(BatchCastDecimal64ToDoubleDown), CAST_FN_STR, {}, { OMNI_DECIMAL64 }, + OMNI_DOUBLE, INPUT_DATA), + Function(reinterpret_cast(BatchCastDecimal128ToDoubleDown), CAST_FN_STR, {}, { OMNI_DECIMAL128 }, + OMNI_DOUBLE, INPUT_DATA), + }; + + return batchDecimalFunctions; +} + +std::vector BatchDecimalFunctionRegistryHalfUp::GetFunctions() +{ + static std::vector batchDecimalFunctions = { + Function(reinterpret_cast(BatchCastDecimal64ToLongHalfUp), CAST_FN_STR, {}, { OMNI_DECIMAL64 }, + OMNI_LONG, INPUT_DATA), + Function(reinterpret_cast(BatchCastDecimal64ToIntHalfUp), CAST_FN_STR, {}, { OMNI_DECIMAL64 }, OMNI_INT, + INPUT_DATA, true), + Function(reinterpret_cast(BatchCastDecimal64ToDoubleHalfUp), CAST_FN_STR, {}, { OMNI_DECIMAL64 }, + OMNI_DOUBLE, INPUT_DATA), + Function(reinterpret_cast(BatchCastDecimal128ToDoubleHalfUp), CAST_FN_STR, {}, { OMNI_DECIMAL128 }, + OMNI_DOUBLE, INPUT_DATA), + }; + + return batchDecimalFunctions; +} + +std::vector BatchDecimalFunctionRegistryReScale::GetFunctions() +{ + std::vector paramTypes128 = { OMNI_DECIMAL128, OMNI_DECIMAL128 }; + std::vector paramTypes64 = { OMNI_DECIMAL64, OMNI_DECIMAL64 }; + std::vector paramTypes64Op128 = { OMNI_DECIMAL64, OMNI_DECIMAL128 }; + std::vector paramTypes128Op64 = { OMNI_DECIMAL128, OMNI_DECIMAL64 }; + + static std::vector batchDecimalFunctions = { + // decimal arith function + Function(reinterpret_cast(BatchAddDec64Dec64Dec64ReScale), ADD_FN_STR, {}, paramTypes64, OMNI_DECIMAL64, + INPUT_DATA, true), + Function(reinterpret_cast(BatchAddDec64Dec64Dec128ReScale), ADD_FN_STR, {}, paramTypes64, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchAddDec128Dec128Dec128ReScale), ADD_FN_STR, {}, paramTypes128, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchAddDec64Dec128Dec128ReScale), ADD_FN_STR, {}, paramTypes64Op128, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchAddDec128Dec64Dec128ReScale), ADD_FN_STR, {}, paramTypes128Op64, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubDec64Dec64Dec64ReScale), SUBTRACT_FN_STR, {}, paramTypes64, + OMNI_DECIMAL64, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubDec64Dec64Dec128ReScale), SUBTRACT_FN_STR, {}, paramTypes64, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubDec128Dec128Dec128ReScale), SUBTRACT_FN_STR, {}, paramTypes128, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubDec64Dec128Dec128ReScale), SUBTRACT_FN_STR, {}, paramTypes64Op128, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubDec128Dec64Dec128ReScale), SUBTRACT_FN_STR, {}, paramTypes128Op64, + OMNI_DECIMAL128, INPUT_DATA, true), + + Function(reinterpret_cast(BatchMulDec64Dec64Dec64ReScale), MULTIPLY_FN_STR, {}, paramTypes64, + OMNI_DECIMAL64, INPUT_DATA, true), + Function(reinterpret_cast(BatchMulDec64Dec64Dec128ReScale), MULTIPLY_FN_STR, {}, paramTypes64, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchMulDec128Dec128Dec128ReScale), MULTIPLY_FN_STR, {}, paramTypes128, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchMulDec64Dec128Dec128ReScale), MULTIPLY_FN_STR, {}, paramTypes64Op128, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchMulDec128Dec64Dec128ReScale), MULTIPLY_FN_STR, {}, paramTypes128Op64, + OMNI_DECIMAL128, INPUT_DATA, true), + + Function(reinterpret_cast(BatchDivDec64Dec64Dec64ReScale), DIVIDE_FN_STR, {}, paramTypes64, + OMNI_DECIMAL64, INPUT_DATA, true), + Function(reinterpret_cast(BatchDivDec64Dec128Dec64ReScale), DIVIDE_FN_STR, {}, paramTypes64Op128, + OMNI_DECIMAL64, INPUT_DATA, true), + Function(reinterpret_cast(BatchDivDec128Dec64Dec64ReScale), DIVIDE_FN_STR, {}, paramTypes128Op64, + OMNI_DECIMAL64, INPUT_DATA, true), + Function(reinterpret_cast(BatchDivDec64Dec64Dec128ReScale), DIVIDE_FN_STR, {}, paramTypes64, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchDivDec128Dec128Dec128ReScale), DIVIDE_FN_STR, {}, paramTypes128, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchDivDec64Dec128Dec128ReScale), DIVIDE_FN_STR, {}, paramTypes64Op128, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchDivDec128Dec64Dec128ReScale), DIVIDE_FN_STR, {}, paramTypes128Op64, + OMNI_DECIMAL128, INPUT_DATA, true), + + Function(reinterpret_cast(BatchModDec64Dec64Dec64ReScale), MODULUS_FN_STR, {}, paramTypes64, + OMNI_DECIMAL64, INPUT_DATA, true), + Function(reinterpret_cast(BatchModDec64Dec128Dec64ReScale), MODULUS_FN_STR, {}, paramTypes64Op128, + OMNI_DECIMAL64, INPUT_DATA, true), + Function(reinterpret_cast(BatchModDec128Dec64Dec64ReScale), MODULUS_FN_STR, {}, paramTypes128Op64, + OMNI_DECIMAL64, INPUT_DATA, true), + Function(reinterpret_cast(BatchModDec128Dec64Dec128ReScale), MODULUS_FN_STR, {}, paramTypes128Op64, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchModDec128Dec128Dec128ReScale), MODULUS_FN_STR, {}, paramTypes128, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchModDec128Dec128Dec64ReScale), MODULUS_FN_STR, {}, paramTypes128, + OMNI_DECIMAL64, INPUT_DATA, true), + Function(reinterpret_cast(BatchModDec64Dec128Dec128ReScale), MODULUS_FN_STR, {}, paramTypes64Op128, + OMNI_DECIMAL128, INPUT_DATA, true), + }; + + return batchDecimalFunctions; +} + +std::vector BatchDecimalFunctionRegistryNotReScale::GetFunctions() +{ + std::vector paramTypes128 = { OMNI_DECIMAL128, OMNI_DECIMAL128 }; + std::vector paramTypes64 = { OMNI_DECIMAL64, OMNI_DECIMAL64 }; + std::vector paramTypes64Op128 = { OMNI_DECIMAL64, OMNI_DECIMAL128 }; + std::vector paramTypes128Op64 = { OMNI_DECIMAL128, OMNI_DECIMAL64 }; + + static std::vector batchDecimalFunctions = { + // decimal arith function + Function(reinterpret_cast(BatchAddDec64Dec64Dec64NotReScale), ADD_FN_STR, {}, paramTypes64, + OMNI_DECIMAL64, INPUT_DATA, true), + Function(reinterpret_cast(BatchAddDec64Dec64Dec128NotReScale), ADD_FN_STR, {}, paramTypes64, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchAddDec128Dec128Dec128NotReScale), ADD_FN_STR, {}, paramTypes128, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchAddDec64Dec128Dec128NotReScale), ADD_FN_STR, {}, paramTypes64Op128, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchAddDec128Dec64Dec128NotReScale), ADD_FN_STR, {}, paramTypes128Op64, + OMNI_DECIMAL128, INPUT_DATA, true), + + Function(reinterpret_cast(BatchSubDec64Dec64Dec64NotReScale), SUBTRACT_FN_STR, {}, paramTypes64, + OMNI_DECIMAL64, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubDec64Dec64Dec128NotReScale), SUBTRACT_FN_STR, {}, paramTypes64, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubDec128Dec128Dec128NotReScale), SUBTRACT_FN_STR, {}, paramTypes128, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubDec64Dec128Dec128NotReScale), SUBTRACT_FN_STR, {}, paramTypes64Op128, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubDec128Dec64Dec128NotReScale), SUBTRACT_FN_STR, {}, paramTypes128Op64, + OMNI_DECIMAL128, INPUT_DATA, true), + + Function(reinterpret_cast(BatchMulDec64Dec64Dec64NotReScale), MULTIPLY_FN_STR, {}, paramTypes64, + OMNI_DECIMAL64, INPUT_DATA, true), + Function(reinterpret_cast(BatchMulDec64Dec64Dec128NotReScale), MULTIPLY_FN_STR, {}, paramTypes64, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchMulDec128Dec128Dec128NotReScale), MULTIPLY_FN_STR, {}, paramTypes128, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchMulDec64Dec128Dec128NotReScale), MULTIPLY_FN_STR, {}, paramTypes64Op128, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchMulDec128Dec64Dec128NotReScale), MULTIPLY_FN_STR, {}, paramTypes128Op64, + OMNI_DECIMAL128, INPUT_DATA, true), + + Function(reinterpret_cast(BatchDivDec64Dec64Dec64NotReScale), DIVIDE_FN_STR, {}, paramTypes64, + OMNI_DECIMAL64, INPUT_DATA, true), + Function(reinterpret_cast(BatchDivDec64Dec128Dec64NotReScale), DIVIDE_FN_STR, {}, paramTypes64Op128, + OMNI_DECIMAL64, INPUT_DATA, true), + Function(reinterpret_cast(BatchDivDec128Dec64Dec64NotReScale), DIVIDE_FN_STR, {}, paramTypes128Op64, + OMNI_DECIMAL64, INPUT_DATA, true), + Function(reinterpret_cast(BatchDivDec64Dec64Dec128NotReScale), DIVIDE_FN_STR, {}, paramTypes64, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchDivDec128Dec128Dec128NotReScale), DIVIDE_FN_STR, {}, paramTypes128, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchDivDec64Dec128Dec128NotReScale), DIVIDE_FN_STR, {}, paramTypes64Op128, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchDivDec128Dec64Dec128NotReScale), DIVIDE_FN_STR, {}, paramTypes128Op64, + OMNI_DECIMAL128, INPUT_DATA, true), + + Function(reinterpret_cast(BatchModDec64Dec64Dec64NotReScale), MODULUS_FN_STR, {}, paramTypes64, + OMNI_DECIMAL64, INPUT_DATA, true), + Function(reinterpret_cast(BatchModDec64Dec128Dec64NotReScale), MODULUS_FN_STR, {}, paramTypes64Op128, + OMNI_DECIMAL64, INPUT_DATA, true), + Function(reinterpret_cast(BatchModDec128Dec64Dec64NotReScale), MODULUS_FN_STR, {}, paramTypes128Op64, + OMNI_DECIMAL64, INPUT_DATA, true), + Function(reinterpret_cast(BatchModDec128Dec64Dec128NotReScale), MODULUS_FN_STR, {}, paramTypes128Op64, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchModDec128Dec128Dec128NotReScale), MODULUS_FN_STR, {}, paramTypes128, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchModDec128Dec128Dec64NotReScale), MODULUS_FN_STR, {}, paramTypes128, + OMNI_DECIMAL64, INPUT_DATA, true), + Function(reinterpret_cast(BatchModDec64Dec128Dec128NotReScale), MODULUS_FN_STR, {}, paramTypes64Op128, + OMNI_DECIMAL128, INPUT_DATA, true), + }; + + return batchDecimalFunctions; +} +} diff --git a/core/src/codegen/batch_func_registry_decimal.h b/core/src/codegen/batch_func_registry_decimal.h new file mode 100644 index 0000000..61881e3 --- /dev/null +++ b/core/src/codegen/batch_func_registry_decimal.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: Batch Decimal Function Registry + */ + +#ifndef OMNI_RUNTIME_BATCH_FUNC_REGISTRY_DECIMAL_H +#define OMNI_RUNTIME_BATCH_FUNC_REGISTRY_DECIMAL_H +#include "function.h" +#include "func_registry_base.h" + +namespace omniruntime::codegen { +class BatchDecimalFunctionRegistry : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; + +class BatchDecimalFunctionRegistryDown : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; + +class BatchDecimalFunctionRegistryHalfUp : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; + +class BatchDecimalFunctionRegistryReScale : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; + +class BatchDecimalFunctionRegistryNotReScale : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; +} + +#endif // OMNI_RUNTIME_BATCH_FUNC_REGISTRY_DECIMAL_H diff --git a/core/src/codegen/batch_func_registry_dictionary.cpp b/core/src/codegen/batch_func_registry_dictionary.cpp new file mode 100644 index 0000000..40f8654 --- /dev/null +++ b/core/src/codegen/batch_func_registry_dictionary.cpp @@ -0,0 +1,48 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: Batch Dictionary Function Registry + */ + +#include "batch_func_registry_dictionary.h" +#include "batch_functions/batch_dictionaryfunctions.h" + +namespace omniruntime::codegen { +using namespace omniruntime::type; +using namespace codegen::function; + +std::vector BatchDictionaryFunctionRegistry::GetFunctions() +{ + std::vector paramTypes = { OMNI_LONG }; + std::vector batchDictionaryFnRegistry = { + Function(reinterpret_cast(BatchGetIntFromDictionaryVector), "batch_GetDic", {}, paramTypes, OMNI_INT), + Function(reinterpret_cast(BatchGetIntFromDictionaryVector), "batch_GetDic", {}, paramTypes, + OMNI_DATE32), + Function(reinterpret_cast(BatchGetLongFromDictionaryVector), "batch_GetDic", {}, paramTypes, OMNI_LONG), + Function(reinterpret_cast(BatchGetLongFromDictionaryVector), "batch_GetDic", {}, paramTypes, + OMNI_DECIMAL64), + Function(reinterpret_cast(BatchGetLongFromDictionaryVector), "batch_GetDic", {}, paramTypes, + OMNI_TIMESTAMP), + Function(reinterpret_cast(BatchGetDoubleFromDictionaryVector), "batch_GetDic", {}, paramTypes, + OMNI_DOUBLE), + Function(reinterpret_cast(BatchGetBooleanFromDictionaryVector), "batch_GetDic", {}, paramTypes, + OMNI_BOOLEAN), + Function(reinterpret_cast(BatchGetVarcharFromDictionaryVector), "batch_GetDic", {}, paramTypes, + OMNI_VARCHAR), + Function(reinterpret_cast(BatchGetVarcharFromDictionaryVector), "batch_GetDic", {}, paramTypes, + OMNI_CHAR), + Function(reinterpret_cast(BatchGetDecimalFromDictionaryVector), "batch_GetDic", {}, paramTypes, + OMNI_DECIMAL128), + Function(reinterpret_cast(BatchGetIntFromVector), "batch_GetData", {}, paramTypes, OMNI_INT), + Function(reinterpret_cast(BatchGetIntFromVector), "batch_GetData", {}, paramTypes, OMNI_DATE32), + Function(reinterpret_cast(BatchGetLongFromVector), "batch_GetData", {}, paramTypes, OMNI_LONG), + Function(reinterpret_cast(BatchGetLongFromVector), "batch_GetData", {}, paramTypes, OMNI_DECIMAL64), + Function(reinterpret_cast(BatchGetLongFromVector), "batch_GetData", {}, paramTypes, OMNI_TIMESTAMP), + Function(reinterpret_cast(BatchGetDoubleFromVector), "batch_GetData", {}, paramTypes, OMNI_DOUBLE), + Function(reinterpret_cast(BatchGetBooleanFromVector), "batch_GetData", {}, paramTypes, OMNI_BOOLEAN), + Function(reinterpret_cast(BatchGetVarcharFromVector), "batch_GetData", {}, paramTypes, OMNI_VARCHAR), + Function(reinterpret_cast(BatchGetVarcharFromVector), "batch_GetData", {}, paramTypes, OMNI_CHAR), + Function(reinterpret_cast(BatchGetDecimalFromVector), "batch_GetData", {}, paramTypes, OMNI_DECIMAL128) + }; + return batchDictionaryFnRegistry; +} +} \ No newline at end of file diff --git a/core/src/codegen/batch_func_registry_dictionary.h b/core/src/codegen/batch_func_registry_dictionary.h new file mode 100644 index 0000000..f51d092 --- /dev/null +++ b/core/src/codegen/batch_func_registry_dictionary.h @@ -0,0 +1,17 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: Batch Dictionary Function Registry + */ + +#ifndef OMNI_RUNTIME_BATCH_FUNC_REGISTRY_DICTIONARY_H +#define OMNI_RUNTIME_BATCH_FUNC_REGISTRY_DICTIONARY_H +#include "function.h" +#include "func_registry_base.h" + +namespace omniruntime::codegen { +class BatchDictionaryFunctionRegistry : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; +} +#endif // OMNI_RUNTIME_BATCH_FUNC_REGISTRY_DICTIONARY_H diff --git a/core/src/codegen/batch_func_registry_hash.cpp b/core/src/codegen/batch_func_registry_hash.cpp new file mode 100644 index 0000000..cc2d9c6 --- /dev/null +++ b/core/src/codegen/batch_func_registry_hash.cpp @@ -0,0 +1,39 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: Batch Hash Function Registry + */ +#include "batch_func_registry_hash.h" +#include "batch_functions/batch_murmur3_hash.h" + +namespace omniruntime::codegen { +using namespace omniruntime::type; +using namespace codegen::function; + +std::vector BatchHashFunctionRegistry::GetFunctions() +{ + DataTypeId retType = OMNI_INT; + std::string batchMm3fnStr = "batch_mm3hash"; + std::vector batchHashFunctions = { Function(reinterpret_cast(BatchCombineHash), + "batch_combine_hash", {}, { OMNI_LONG, OMNI_LONG }, OMNI_LONG, INPUT_DATA_AND_NULL), + Function(reinterpret_cast(BatchMm3Int32), batchMm3fnStr, {}, { OMNI_INT, OMNI_INT }, retType, + INPUT_DATA_AND_NULL), + Function(reinterpret_cast(BatchMm3Int32), batchMm3fnStr, {}, { OMNI_DATE32, OMNI_INT }, retType, + INPUT_DATA_AND_NULL), + Function(reinterpret_cast(BatchMm3Int64), batchMm3fnStr, {}, { OMNI_LONG, OMNI_INT }, retType, + INPUT_DATA_AND_NULL), + Function(reinterpret_cast(BatchMm3Int64), batchMm3fnStr, {}, { OMNI_TIMESTAMP, OMNI_INT }, retType, + INPUT_DATA_AND_NULL), + Function(reinterpret_cast(BatchMm3Double), batchMm3fnStr, {}, { OMNI_DOUBLE, OMNI_INT }, retType, + INPUT_DATA_AND_NULL), + Function(reinterpret_cast(BatchMm3String), batchMm3fnStr, {}, { OMNI_VARCHAR, OMNI_INT }, retType, + INPUT_DATA_AND_NULL), + Function(reinterpret_cast(BatchMm3Decimal64), batchMm3fnStr, {}, { OMNI_DECIMAL64, OMNI_INT }, retType, + INPUT_DATA_AND_NULL), + Function(reinterpret_cast(BatchMm3Decimal128), batchMm3fnStr, {}, { OMNI_DECIMAL128, OMNI_INT }, + retType, INPUT_DATA_AND_NULL), + Function(reinterpret_cast(BatchMm3Boolean), batchMm3fnStr, {}, { OMNI_BOOLEAN, OMNI_INT }, retType, + INPUT_DATA_AND_NULL) }; + + return batchHashFunctions; +} +} diff --git a/core/src/codegen/batch_func_registry_hash.h b/core/src/codegen/batch_func_registry_hash.h new file mode 100644 index 0000000..d6cc04d --- /dev/null +++ b/core/src/codegen/batch_func_registry_hash.h @@ -0,0 +1,17 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: Batch Hash Function Registry + */ +#ifndef OMNI_RUNTIME_BATCH_FUNC_REGISTRY_HASH_H +#define OMNI_RUNTIME_BATCH_FUNC_REGISTRY_HASH_H + +#include "function.h" +#include "func_registry_base.h" + +namespace omniruntime::codegen { +class BatchHashFunctionRegistry : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; +} +#endif // OMNI_RUNTIME_BATCH_FUNC_REGISTRY_HASH_H diff --git a/core/src/codegen/batch_func_registry_math.cpp b/core/src/codegen/batch_func_registry_math.cpp new file mode 100644 index 0000000..c4da811 --- /dev/null +++ b/core/src/codegen/batch_func_registry_math.cpp @@ -0,0 +1,156 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: batch math functions registry + */ +#include "batch_func_registry_math.h" +#include "batch_functions/batch_mathfunctions.h" + +namespace omniruntime::codegen { +using namespace omniruntime::type; +using namespace codegen::function; + +namespace { +const std::string ABS_FN_STR = "batch_abs"; +const std::string CAST_FN_STR = "batch_CAST"; +const std::string ROUND_FN_STR = "batch_round"; +const std::string ADD_FN_STR = "batch_add"; +const std::string SUBTRACT_FN_STR = "batch_subtract"; +const std::string MULTIPLY_FN_STR = "batch_multiply"; +const std::string DIVIDE_FN_STR = "batch_divide"; +const std::string MODULUS_FN_STR = "batch_modulus"; +const std::string LESS_THAN_FN_STR = "batch_lessThan"; +const std::string LESS_THAN_EQUAL_FN_STR = "batch_lessThanEqual"; +const std::string GREATER_THAN_FN_STR = "batch_greaterThan"; +const std::string GREATER_THAN_EQUAL_FN_STR = "batch_greaterThanEqual"; +const std::string EQUAL_FN_STR = "batch_equal"; +const std::string NOT_EQUAL_FN_STR = "batch_notEqual"; +const std::string PMOD_FN_STR = "batch_pmod"; +const std::string NORMALIZE_ZERO_FN_STR = "batch_NormalizeNaNAndZero"; +const std::string GREATEST_NUM_FN_STR = "batch_Greatest"; +const std::string POWER_FN_STR = "batch_power"; +} + +std::vector BatchMathFunctionRegistry::GetFunctions() +{ + const std::vector doubleParams = { OMNI_DOUBLE, OMNI_DOUBLE }; + const std::vector longParams = { OMNI_LONG, OMNI_LONG }; + const std::vector intParams = { OMNI_INT, OMNI_INT }; + const std::vector boolParams = { OMNI_BOOLEAN, OMNI_BOOLEAN }; + + std::vector batchMathFunctions = { + Function(reinterpret_cast(BatchAbs), ABS_FN_STR, {}, { OMNI_INT }, OMNI_INT, INPUT_DATA), + Function(reinterpret_cast(BatchAbs), ABS_FN_STR, {}, { OMNI_LONG }, OMNI_LONG, INPUT_DATA), + Function(reinterpret_cast(BatchAbs), ABS_FN_STR, {}, { OMNI_DOUBLE }, OMNI_DOUBLE, INPUT_DATA), + + Function(reinterpret_cast(BatchCastInt32ToDouble), CAST_FN_STR, {}, { OMNI_INT }, OMNI_DOUBLE, + INPUT_DATA), + Function(reinterpret_cast(BatchCastInt64ToDouble), CAST_FN_STR, {}, { OMNI_LONG }, OMNI_DOUBLE, + INPUT_DATA), + Function(reinterpret_cast(BatchCastInt32ToInt64), CAST_FN_STR, {}, { OMNI_INT }, OMNI_LONG, INPUT_DATA), + Function(reinterpret_cast(BatchCastInt64ToInt32), CAST_FN_STR, {}, { OMNI_LONG }, OMNI_INT, INPUT_DATA), + + Function(reinterpret_cast(BatchAddDouble), ADD_FN_STR, {}, doubleParams, OMNI_DOUBLE, INPUT_DATA), + Function(reinterpret_cast(BatchSubtractDouble), SUBTRACT_FN_STR, {}, doubleParams, OMNI_DOUBLE, + INPUT_DATA), + Function(reinterpret_cast(BatchMultiplyDouble), MULTIPLY_FN_STR, {}, doubleParams, OMNI_DOUBLE, + INPUT_DATA), + Function(reinterpret_cast(BatchDivideDouble), DIVIDE_FN_STR, {}, doubleParams, OMNI_DOUBLE, INPUT_DATA), + Function(reinterpret_cast(BatchModulusDouble), MODULUS_FN_STR, {}, doubleParams, OMNI_DOUBLE, + INPUT_DATA), + Function(reinterpret_cast(BatchLessThanDouble), LESS_THAN_FN_STR, {}, doubleParams, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(BatchLessThanEqualDouble), LESS_THAN_EQUAL_FN_STR, {}, doubleParams, + OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(BatchGreaterThanDouble), GREATER_THAN_FN_STR, {}, doubleParams, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(BatchGreaterThanEqualDouble), GREATER_THAN_EQUAL_FN_STR, {}, doubleParams, + OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(BatchEqualDouble), EQUAL_FN_STR, {}, doubleParams, OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(BatchNotEqualDouble), NOT_EQUAL_FN_STR, {}, doubleParams, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(BatchNormalizeNaNAndZero), NORMALIZE_ZERO_FN_STR, {}, { OMNI_DOUBLE }, + OMNI_DOUBLE, INPUT_DATA), + Function(reinterpret_cast(BatchPowerDouble), POWER_FN_STR, {}, doubleParams, OMNI_DOUBLE, INPUT_DATA), + + Function(reinterpret_cast(BatchAddInt64), ADD_FN_STR, {}, longParams, OMNI_LONG, INPUT_DATA), + Function(reinterpret_cast(BatchSubtractInt64), SUBTRACT_FN_STR, {}, longParams, OMNI_LONG, INPUT_DATA), + Function(reinterpret_cast(BatchMultiplyInt64), MULTIPLY_FN_STR, {}, longParams, OMNI_LONG, INPUT_DATA), + Function(reinterpret_cast(BatchDivideInt64), DIVIDE_FN_STR, {}, longParams, OMNI_LONG, INPUT_DATA, + true), + Function(reinterpret_cast(BatchModulusInt64), MODULUS_FN_STR, {}, longParams, OMNI_LONG, INPUT_DATA, + true), + Function(reinterpret_cast(BatchLessThanInt64), LESS_THAN_FN_STR, {}, longParams, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(BatchLessThanEqualInt64), LESS_THAN_EQUAL_FN_STR, {}, longParams, + OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(BatchGreaterThanInt64), GREATER_THAN_FN_STR, {}, longParams, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(BatchGreaterThanEqualInt64), GREATER_THAN_EQUAL_FN_STR, {}, longParams, + OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(BatchEqualInt64), EQUAL_FN_STR, {}, longParams, OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(BatchNotEqualInt64), NOT_EQUAL_FN_STR, {}, longParams, OMNI_BOOLEAN, + INPUT_DATA), + + Function(reinterpret_cast(BatchAddInt32), ADD_FN_STR, {}, intParams, OMNI_INT, INPUT_DATA), + Function(reinterpret_cast(BatchSubtractInt32), SUBTRACT_FN_STR, {}, intParams, OMNI_INT, INPUT_DATA), + Function(reinterpret_cast(BatchMultiplyInt32), MULTIPLY_FN_STR, {}, intParams, OMNI_INT, INPUT_DATA), + Function(reinterpret_cast(BatchDivideInt32), DIVIDE_FN_STR, {}, intParams, OMNI_INT, INPUT_DATA, true), + Function(reinterpret_cast(BatchModulusInt32), MODULUS_FN_STR, {}, intParams, OMNI_INT, INPUT_DATA, + true), + Function(reinterpret_cast(BatchLessThanInt32), LESS_THAN_FN_STR, {}, intParams, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(BatchLessThanEqualInt32), LESS_THAN_EQUAL_FN_STR, {}, intParams, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(BatchGreaterThanInt32), GREATER_THAN_FN_STR, {}, intParams, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(BatchGreaterThanEqualInt32), GREATER_THAN_EQUAL_FN_STR, {}, intParams, + OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(BatchEqualInt32), EQUAL_FN_STR, {}, intParams, OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(BatchNotEqualInt32), NOT_EQUAL_FN_STR, {}, intParams, OMNI_BOOLEAN, + INPUT_DATA), + + Function(reinterpret_cast(BatchEqualBool), EQUAL_FN_STR, {}, boolParams, OMNI_BOOLEAN, INPUT_DATA), + + Function(reinterpret_cast(BatchPmod), PMOD_FN_STR, {}, intParams, OMNI_INT, INPUT_DATA), + Function(reinterpret_cast(BatchRound), ROUND_FN_STR, {}, intParams, OMNI_INT, INPUT_DATA), + Function(reinterpret_cast(BatchRoundLong), ROUND_FN_STR, {}, { OMNI_LONG, OMNI_INT }, OMNI_LONG, + INPUT_DATA), + Function(reinterpret_cast(BatchRound), ROUND_FN_STR, {}, { OMNI_DOUBLE, OMNI_INT }, OMNI_DOUBLE, + INPUT_DATA), + Function(reinterpret_cast(BatchGreatest), GREATEST_NUM_FN_STR, {}, { OMNI_INT, OMNI_INT }, + OMNI_INT, INPUT_DATA_AND_NULL_AND_RETURN_NULL), + Function(reinterpret_cast(BatchGreatest), GREATEST_NUM_FN_STR, {}, { OMNI_LONG, OMNI_LONG }, + OMNI_LONG, INPUT_DATA_AND_NULL_AND_RETURN_NULL), + Function(reinterpret_cast(BatchGreatest), GREATEST_NUM_FN_STR, {}, { OMNI_BOOLEAN, OMNI_BOOLEAN }, + OMNI_BOOLEAN, INPUT_DATA_AND_NULL_AND_RETURN_NULL), + Function(reinterpret_cast(BatchGreatest), GREATEST_NUM_FN_STR, {}, { OMNI_DOUBLE, OMNI_DOUBLE }, + OMNI_DOUBLE, INPUT_DATA_AND_NULL_AND_RETURN_NULL) + }; + + return batchMathFunctions; +} + +std::vector BatchMathFunctionRegistryHalfUp::GetFunctions() +{ + std::vector batchMathFunctions = { + Function(reinterpret_cast(BatchCastDoubleToInt64HalfUp), CAST_FN_STR, {}, { OMNI_DOUBLE }, OMNI_LONG, + INPUT_DATA), + Function(reinterpret_cast(BatchCastDoubleToInt32HalfUp), CAST_FN_STR, {}, { OMNI_DOUBLE }, OMNI_INT, + INPUT_DATA), + }; + + return batchMathFunctions; +} + +std::vector BatchMathFunctionRegistryDown::GetFunctions() +{ + std::vector batchMathFunctions = { + Function(reinterpret_cast(BatchCastDoubleToInt64Down), CAST_FN_STR, {}, { OMNI_DOUBLE }, OMNI_LONG, + INPUT_DATA), + Function(reinterpret_cast(BatchCastDoubleToInt32Down), CAST_FN_STR, {}, { OMNI_DOUBLE }, OMNI_INT, + INPUT_DATA), + }; + + return batchMathFunctions; +} +} \ No newline at end of file diff --git a/core/src/codegen/batch_func_registry_math.h b/core/src/codegen/batch_func_registry_math.h new file mode 100644 index 0000000..a94f37b --- /dev/null +++ b/core/src/codegen/batch_func_registry_math.h @@ -0,0 +1,28 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: batch math functions registry + */ +#ifndef OMNI_RUNTIME_BATCH_FUNC_REGISTRY_MATH_H +#define OMNI_RUNTIME_BATCH_FUNC_REGISTRY_MATH_H + +#include "function.h" +#include "func_registry_base.h" + +namespace omniruntime::codegen { +class BatchMathFunctionRegistry : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; + +class BatchMathFunctionRegistryHalfUp : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; + +class BatchMathFunctionRegistryDown : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; +} + +#endif // OMNI_RUNTIME_BATCH_FUNC_REGISTRY_MATH_H diff --git a/core/src/codegen/batch_func_registry_string.cpp b/core/src/codegen/batch_func_registry_string.cpp new file mode 100644 index 0000000..0451951 --- /dev/null +++ b/core/src/codegen/batch_func_registry_string.cpp @@ -0,0 +1,352 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. + * Description: Batch String Function Registry + */ +#include "batch_func_registry_string.h" +#include "batch_functions/batch_stringfunctions.h" + +namespace omniruntime::codegen { +using namespace omniruntime::type; +using namespace codegen::function; + +namespace { +const std::string LESS_THAN_FN_STR = "batch_lessThan"; +const std::string LESS_THAN_EQUAL_FN_STR = "batch_lessThanEqual"; +const std::string GREATER_THAN_FN_STR = "batch_greaterThan"; +const std::string GREATER_THAN_EQUAL_FN_STR = "batch_greaterThanEqual"; +const std::string EQUAL_FN_STR = "batch_equal"; +const std::string NOT_EQUAL_FN_STR = "batch_notEqual"; +const std::string SUBSTR_FN_STR = "batch_substr"; +const std::string CONCAT_FN_STR = "batch_concat"; +const std::string LIKE_FN_STR = "batch_LIKE"; +const std::string CAST_FN_STR = "batch_CAST"; +const std::string UPPER_FN_STR = "batch_upper"; +const std::string LOWER_FN_STR = "batch_lower"; +const std::string COMPARE_FN_STR = "batch_compare"; +const std::string LENGTH_FN_STR = "batch_length"; +const std::string REPLACE_FN_STR = "batch_replace"; +const std::string CONCAT_FN_STR_RETNULL = "batch_concat_null"; +const std::string CAST_FN_STR_RETNULL = "batch_CAST_null"; +const std::string INSTR_FN_STR = "batch_instr"; +const std::string STARTS_WITH_FN_STR = "batch_StartsWith"; +const std::string ENDS_WITH_FN_STR = "batch_EndsWith"; +const std::string MD5_STR = "batch_Md5"; +const std::string EMPTY2NULL_STR = "batch_empty2null"; +const std::string CONTAINS_FN_STR = "batch_Contains"; +const std::string GREATEST_STR_FN_STR = "batch_Greatest"; +const std::string BATCH_STATIC_INVOKE_VARCHARTYPE_CHECK_FN_STR = "batch_StaticInvokeVarcharTypeWriteSideCheck"; +const std::string BATCH_STATIC_INVOKE_CHAR_READ_PADDING_FN_STR = "batch_StaticInvokeCharReadPadding"; +} + +std::vector BatchStringFunctionRegistry::GetFunctions() +{ + std::vector batchStringFnRegistry = { Function(reinterpret_cast(BatchLessThanStr), + LESS_THAN_FN_STR, {}, { OMNI_VARCHAR, OMNI_VARCHAR }, OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(BatchLessThanEqualStr), LESS_THAN_EQUAL_FN_STR, {}, + { OMNI_VARCHAR, OMNI_VARCHAR }, OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(BatchGreaterThanStr), GREATER_THAN_FN_STR, {}, { OMNI_VARCHAR, OMNI_VARCHAR }, + OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(BatchGreaterThanEqualStr), GREATER_THAN_EQUAL_FN_STR, {}, + { OMNI_VARCHAR, OMNI_VARCHAR }, OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(BatchEqualStr), EQUAL_FN_STR, {}, { OMNI_VARCHAR, OMNI_VARCHAR }, + OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(BatchNotEqualStr), NOT_EQUAL_FN_STR, {}, { OMNI_VARCHAR, OMNI_VARCHAR }, + OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(BatchStrCompare), COMPARE_FN_STR, {}, { OMNI_VARCHAR, OMNI_VARCHAR }, + OMNI_INT), + + // concat functions + Function(reinterpret_cast(BatchConcatStrStr), CONCAT_FN_STR, {}, { OMNI_VARCHAR, OMNI_VARCHAR }, + OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchConcatCharChar), CONCAT_FN_STR, {}, { OMNI_CHAR, OMNI_CHAR }, OMNI_CHAR, + INPUT_DATA, true), + Function(reinterpret_cast(BatchConcatCharStr), CONCAT_FN_STR, {}, { OMNI_CHAR, OMNI_VARCHAR }, + OMNI_CHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchConcatStrChar), CONCAT_FN_STR, {}, { OMNI_VARCHAR, OMNI_CHAR }, + OMNI_CHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchConcatStrStrRetNull), CONCAT_FN_STR_RETNULL, {}, + { OMNI_VARCHAR, OMNI_VARCHAR }, OMNI_CHAR, INPUT_DATA_AND_OVERFLOW_NULL, true), + Function(reinterpret_cast(BatchConcatCharCharRetNull), CONCAT_FN_STR_RETNULL, {}, + { OMNI_CHAR, OMNI_CHAR }, OMNI_CHAR, INPUT_DATA_AND_OVERFLOW_NULL, true), + Function(reinterpret_cast(BatchConcatCharStrRetNull), CONCAT_FN_STR_RETNULL, {}, + { OMNI_CHAR, OMNI_VARCHAR }, OMNI_CHAR, INPUT_DATA_AND_OVERFLOW_NULL, true), + Function(reinterpret_cast(BatchConcatStrCharRetNull), CONCAT_FN_STR_RETNULL, {}, + { OMNI_VARCHAR, OMNI_CHAR }, OMNI_CHAR, INPUT_DATA_AND_OVERFLOW_NULL, true), + + Function(reinterpret_cast(BatchLikeStr), LIKE_FN_STR, {}, { OMNI_VARCHAR, OMNI_VARCHAR }, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(BatchLikeChar), LIKE_FN_STR, {}, { OMNI_CHAR, OMNI_VARCHAR }, OMNI_BOOLEAN, + INPUT_DATA), + + Function(reinterpret_cast(BatchToUpperStr), UPPER_FN_STR, {}, { OMNI_VARCHAR }, OMNI_VARCHAR, + INPUT_DATA, true), + Function(reinterpret_cast(BatchToUpperChar), UPPER_FN_STR, {}, { OMNI_CHAR }, OMNI_CHAR, INPUT_DATA, + true), + Function(reinterpret_cast(BatchToLowerStr), LOWER_FN_STR, {}, { OMNI_VARCHAR }, OMNI_VARCHAR, + INPUT_DATA, true), + Function(reinterpret_cast(BatchToLowerChar), LOWER_FN_STR, {}, { OMNI_CHAR }, OMNI_CHAR, INPUT_DATA, + true), + + // length functions + Function(reinterpret_cast(BatchLengthChar), LENGTH_FN_STR, {}, { OMNI_CHAR }, OMNI_LONG, INPUT_DATA), + Function(reinterpret_cast(BatchLengthStr), LENGTH_FN_STR, {}, { OMNI_VARCHAR }, OMNI_LONG, INPUT_DATA), + Function(reinterpret_cast(BatchLengthCharReturnInt32), LENGTH_FN_STR, {}, { OMNI_CHAR }, OMNI_INT, + INPUT_DATA), + Function(reinterpret_cast(BatchLengthStrReturnInt32), LENGTH_FN_STR, {}, { OMNI_VARCHAR }, OMNI_INT, + INPUT_DATA), + + // cast to string + Function(reinterpret_cast(BatchCastIntToString), CAST_FN_STR, {}, { OMNI_INT }, OMNI_VARCHAR, + INPUT_DATA, true), + Function(reinterpret_cast(BatchCastLongToString), CAST_FN_STR, {}, { OMNI_LONG }, OMNI_VARCHAR, + INPUT_DATA, true), + Function(reinterpret_cast(BatchCastDoubleToString), CAST_FN_STR, {}, { OMNI_DOUBLE }, OMNI_VARCHAR, + INPUT_DATA, true), + Function(reinterpret_cast(BatchCastDecimal64ToString), CAST_FN_STR, {}, { OMNI_DECIMAL64 }, + OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchCastDecimal128ToString), CAST_FN_STR, {}, { OMNI_DECIMAL128 }, + OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchCastIntToStringRetNull), CAST_FN_STR_RETNULL, {}, { OMNI_INT }, + OMNI_VARCHAR, INPUT_DATA_AND_OVERFLOW_NULL, true), + Function(reinterpret_cast(BatchCastLongToStringRetNull), CAST_FN_STR_RETNULL, {}, { OMNI_LONG }, + OMNI_VARCHAR, INPUT_DATA_AND_OVERFLOW_NULL, true), + Function(reinterpret_cast(BatchCastDoubleToStringRetNull), CAST_FN_STR_RETNULL, {}, { OMNI_DOUBLE }, + OMNI_VARCHAR, INPUT_DATA_AND_OVERFLOW_NULL, true), + Function(reinterpret_cast(BatchCastDecimal64ToStringRetNull), CAST_FN_STR_RETNULL, {}, + { OMNI_DECIMAL64 }, OMNI_VARCHAR, INPUT_DATA_AND_OVERFLOW_NULL, true), + Function(reinterpret_cast(BatchCastDecimal128ToStringRetNull), CAST_FN_STR_RETNULL, {}, + { OMNI_DECIMAL128 }, OMNI_VARCHAR, INPUT_DATA_AND_OVERFLOW_NULL, true), + + // cast string to + Function(reinterpret_cast(BatchCastStrWithDiffWidths), CAST_FN_STR, {}, { OMNI_VARCHAR }, OMNI_VARCHAR, + INPUT_DATA, true), + Function(reinterpret_cast(BatchCastStringToIntRetNull), CAST_FN_STR_RETNULL, {}, { OMNI_VARCHAR }, + OMNI_INT, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchCastStringToLongRetNull), CAST_FN_STR_RETNULL, {}, { OMNI_VARCHAR }, + OMNI_LONG, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchCastStringToDoubleRetNull), CAST_FN_STR_RETNULL, {}, { OMNI_VARCHAR }, + OMNI_DOUBLE, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchCastStringToInt), CAST_FN_STR, {}, { OMNI_VARCHAR }, OMNI_INT, + INPUT_DATA, true), + Function(reinterpret_cast(BatchCastStringToLong), CAST_FN_STR, {}, { OMNI_VARCHAR }, OMNI_LONG, + INPUT_DATA, true), + Function(reinterpret_cast(BatchCastStringToDouble), CAST_FN_STR, {}, { OMNI_VARCHAR }, OMNI_DOUBLE, + INPUT_DATA, true), + Function(reinterpret_cast(BatchCastStrWithDiffWidthsRetNull), CAST_FN_STR_RETNULL, {}, { OMNI_VARCHAR }, + OMNI_VARCHAR, INPUT_DATA_AND_OVERFLOW_NULL, true), + + Function(reinterpret_cast(BatchInStr), INSTR_FN_STR, {}, { OMNI_VARCHAR, OMNI_VARCHAR }, OMNI_INT, + INPUT_DATA), + Function(reinterpret_cast(BatchStartsWithStr), STARTS_WITH_FN_STR, {}, { OMNI_VARCHAR, OMNI_VARCHAR }, + OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(BatchEndsWithStr), ENDS_WITH_FN_STR, {}, { OMNI_VARCHAR, OMNI_VARCHAR }, + OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(BatchMd5Str), MD5_STR, {}, { OMNI_VARCHAR }, OMNI_VARCHAR, + INPUT_DATA, true), + Function(reinterpret_cast(BatchContainsStr), CONTAINS_FN_STR, {}, { OMNI_VARCHAR, OMNI_VARCHAR }, + OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(BatchGreatestStr), GREATEST_STR_FN_STR, {}, + { OMNI_VARCHAR, OMNI_VARCHAR }, OMNI_VARCHAR, INPUT_DATA_AND_NULL_AND_RETURN_NULL), + Function(reinterpret_cast(BatchEmptyToNull), EMPTY2NULL_STR, {}, { OMNI_VARCHAR }, OMNI_VARCHAR, + INPUT_DATA, false), + Function(reinterpret_cast(BatchStaticInvokeVarcharTypeWriteSideCheck), + BATCH_STATIC_INVOKE_VARCHARTYPE_CHECK_FN_STR, {}, { OMNI_VARCHAR, OMNI_INT }, + OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchStaticInvokeCharReadPadding), + BATCH_STATIC_INVOKE_CHAR_READ_PADDING_FN_STR, {}, { OMNI_VARCHAR, OMNI_INT }, + OMNI_VARCHAR, INPUT_DATA, true)}; + + return batchStringFnRegistry; +} + +std::vector BatchStringFunctionRegistryNotAllowReducePrecison::GetFunctions() +{ + std::vector batchStringFnRegistry = { + Function(reinterpret_cast(BatchCastStringToDateNotAllowReducePrecison), CAST_FN_STR, {}, + { OMNI_VARCHAR }, OMNI_DATE32, INPUT_DATA, true), + Function(reinterpret_cast(BatchCastStringToDateRetNullNotAllowReducePrecison), CAST_FN_STR_RETNULL, {}, + { OMNI_VARCHAR }, OMNI_DATE32, INPUT_DATA_AND_OVERFLOW_NULL), + }; + + return batchStringFnRegistry; +} + +std::vector BatchStringFunctionRegistryAllowReducePrecison::GetFunctions() +{ + std::vector batchStringFnRegistry = { + Function(reinterpret_cast(BatchCastStringToDateAllowReducePrecison), CAST_FN_STR, {}, { OMNI_VARCHAR }, + OMNI_DATE32, INPUT_DATA, true), + Function(reinterpret_cast(BatchCastStringToDateRetNullAllowReducePrecison), CAST_FN_STR_RETNULL, {}, + { OMNI_VARCHAR }, OMNI_DATE32, INPUT_DATA_AND_OVERFLOW_NULL), + }; + + return batchStringFnRegistry; +} + +std::vector BatchStringFunctionRegistryNotReplace::GetFunctions() +{ + std::vector batchStringFnRegistry = { + // replace functions + Function(reinterpret_cast(BatchReplaceStrStrStrWithRepNotReplace), REPLACE_FN_STR, {}, + { OMNI_VARCHAR, OMNI_VARCHAR, OMNI_VARCHAR }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchReplaceStrStrWithoutRepNotReplace), REPLACE_FN_STR, {}, + { OMNI_VARCHAR, OMNI_VARCHAR }, OMNI_VARCHAR, INPUT_DATA, true), + }; + + return batchStringFnRegistry; +} + +std::vector BatchStringFunctionRegistryReplace::GetFunctions() +{ + std::vector batchStringFnRegistry = { + // replace functions + Function(reinterpret_cast(BatchReplaceStrStrStrWithRepReplace), REPLACE_FN_STR, {}, + { OMNI_VARCHAR, OMNI_VARCHAR, OMNI_VARCHAR }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchReplaceStrStrWithoutRepReplace), REPLACE_FN_STR, {}, + { OMNI_VARCHAR, OMNI_VARCHAR }, OMNI_VARCHAR, INPUT_DATA, true), + }; + + return batchStringFnRegistry; +} + +std::vector BatchStringFunctionRegistrySupportNegativeAndZeroIndex::GetFunctions() +{ + std::vector batchStringFnRegistry = { + // substr functions + Function(reinterpret_cast(BatchSubstrVarchar), SUBSTR_FN_STR, {}, + { OMNI_VARCHAR, OMNI_INT, OMNI_INT }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubstrChar), SUBSTR_FN_STR, {}, + { OMNI_CHAR, OMNI_INT, OMNI_INT }, OMNI_CHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubstrVarchar), SUBSTR_FN_STR, {}, + { OMNI_VARCHAR, OMNI_LONG, OMNI_LONG }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubstrChar), SUBSTR_FN_STR, {}, + { OMNI_CHAR, OMNI_LONG, OMNI_LONG }, OMNI_CHAR, INPUT_DATA, true), + + // substr with start index functions + Function(reinterpret_cast(BatchSubstrVarcharWithStart), SUBSTR_FN_STR, {}, + { OMNI_VARCHAR, OMNI_INT }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubstrCharWithStart), SUBSTR_FN_STR, {}, + { OMNI_CHAR, OMNI_INT }, OMNI_CHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubstrVarcharWithStart), SUBSTR_FN_STR, {}, + { OMNI_VARCHAR, OMNI_LONG }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubstrCharWithStart), SUBSTR_FN_STR, {}, + { OMNI_CHAR, OMNI_LONG }, OMNI_CHAR, INPUT_DATA, true), + }; + + return batchStringFnRegistry; +} + +std::vector BatchStringFunctionRegistrySupportNotNegativeAndZeroIndex::GetFunctions() +{ + std::vector batchStringFnRegistry = { + // substr functions + Function(reinterpret_cast(BatchSubstrVarchar), SUBSTR_FN_STR, {}, + { OMNI_VARCHAR, OMNI_INT, OMNI_INT }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubstrChar), SUBSTR_FN_STR, {}, + { OMNI_CHAR, OMNI_INT, OMNI_INT }, OMNI_CHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubstrVarchar), SUBSTR_FN_STR, {}, + { OMNI_VARCHAR, OMNI_LONG, OMNI_LONG }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubstrChar), SUBSTR_FN_STR, {}, + { OMNI_CHAR, OMNI_LONG, OMNI_LONG }, OMNI_CHAR, INPUT_DATA, true), + + // substr with start index functions + Function(reinterpret_cast(BatchSubstrVarcharWithStart), SUBSTR_FN_STR, {}, + { OMNI_VARCHAR, OMNI_INT }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubstrCharWithStart), SUBSTR_FN_STR, {}, + { OMNI_CHAR, OMNI_INT }, OMNI_CHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubstrVarcharWithStart), SUBSTR_FN_STR, {}, + { OMNI_VARCHAR, OMNI_LONG }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubstrCharWithStart), SUBSTR_FN_STR, {}, + { OMNI_CHAR, OMNI_LONG }, OMNI_CHAR, INPUT_DATA, true), + }; + + return batchStringFnRegistry; +} + +std::vector BatchStringFunctionRegistrySupportNegativeAndNotZeroIndex::GetFunctions() +{ + std::vector batchStringFnRegistry = { + // substr functions + Function(reinterpret_cast(BatchSubstrVarchar), SUBSTR_FN_STR, {}, + { OMNI_VARCHAR, OMNI_INT, OMNI_INT }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubstrChar), SUBSTR_FN_STR, {}, + { OMNI_CHAR, OMNI_INT, OMNI_INT }, OMNI_CHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubstrVarchar), SUBSTR_FN_STR, {}, + { OMNI_VARCHAR, OMNI_LONG, OMNI_LONG }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubstrChar), SUBSTR_FN_STR, {}, + { OMNI_CHAR, OMNI_LONG, OMNI_LONG }, OMNI_CHAR, INPUT_DATA, true), + + // substr with start index functions + Function(reinterpret_cast(BatchSubstrVarcharWithStart), SUBSTR_FN_STR, {}, + { OMNI_VARCHAR, OMNI_INT }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubstrCharWithStart), SUBSTR_FN_STR, {}, + { OMNI_CHAR, OMNI_INT }, OMNI_CHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubstrVarcharWithStart), SUBSTR_FN_STR, {}, + { OMNI_VARCHAR, OMNI_LONG }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubstrCharWithStart), SUBSTR_FN_STR, {}, + { OMNI_CHAR, OMNI_LONG }, OMNI_CHAR, INPUT_DATA, true), + }; + + return batchStringFnRegistry; +} + +std::vector BatchStringFunctionRegistrySupportNotNegativeAndNotZeroIndex::GetFunctions() +{ + std::vector batchStringFnRegistry = { + // substr functions + Function(reinterpret_cast(BatchSubstrVarchar), SUBSTR_FN_STR, {}, + { OMNI_VARCHAR, OMNI_INT, OMNI_INT }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubstrChar), SUBSTR_FN_STR, {}, + { OMNI_CHAR, OMNI_INT, OMNI_INT }, OMNI_CHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubstrVarchar), SUBSTR_FN_STR, {}, + { OMNI_VARCHAR, OMNI_LONG, OMNI_LONG }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubstrChar), SUBSTR_FN_STR, {}, + { OMNI_CHAR, OMNI_LONG, OMNI_LONG }, OMNI_CHAR, INPUT_DATA, true), + + // substr with start index functions + Function(reinterpret_cast(BatchSubstrVarcharWithStart), SUBSTR_FN_STR, {}, + { OMNI_VARCHAR, OMNI_INT }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubstrCharWithStart), SUBSTR_FN_STR, {}, + { OMNI_CHAR, OMNI_INT }, OMNI_CHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubstrVarcharWithStart), SUBSTR_FN_STR, {}, + { OMNI_VARCHAR, OMNI_LONG }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(BatchSubstrCharWithStart), SUBSTR_FN_STR, {}, + { OMNI_CHAR, OMNI_LONG }, OMNI_CHAR, INPUT_DATA, true), + }; + + return batchStringFnRegistry; +} + +std::vector BatchStringToDecimalFunctionRegistryAllowRoundUp::GetFunctions() +{ + std::vector batchStringFnRegistry = { + Function(reinterpret_cast(BatchCastStringToDecimal64RoundUp), CAST_FN_STR, {}, {OMNI_VARCHAR}, + OMNI_DECIMAL64, INPUT_DATA, true), + Function(reinterpret_cast(BatchCastStringToDecimal128RoundUp), CAST_FN_STR, {}, {OMNI_VARCHAR}, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchCastStringToDecimal64RoundUpRetNull), CAST_FN_STR_RETNULL, {}, + {OMNI_VARCHAR}, + OMNI_DECIMAL64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchCastStringToDecimal128RoundUpRetNull), CAST_FN_STR_RETNULL, {}, + {OMNI_VARCHAR}, OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL) + }; + + return batchStringFnRegistry; +} + +std::vector BatchStringToDecimalFunctionRegistry::GetFunctions() +{ + std::vector batchStringFnRegistry = { + Function(reinterpret_cast(BatchCastStringToDecimal64), CAST_FN_STR, {}, {OMNI_VARCHAR}, + OMNI_DECIMAL64, INPUT_DATA, true), + Function(reinterpret_cast(BatchCastStringToDecimal128), CAST_FN_STR, {}, {OMNI_VARCHAR}, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(BatchCastStringToDecimal64RetNull), CAST_FN_STR_RETNULL, {}, {OMNI_VARCHAR}, + OMNI_DECIMAL64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(BatchCastStringToDecimal128RetNull), CAST_FN_STR_RETNULL, {}, + {OMNI_VARCHAR}, OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL) + }; + + return batchStringFnRegistry; +} +} \ No newline at end of file diff --git a/core/src/codegen/batch_func_registry_string.h b/core/src/codegen/batch_func_registry_string.h new file mode 100644 index 0000000..061c091 --- /dev/null +++ b/core/src/codegen/batch_func_registry_string.h @@ -0,0 +1,68 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. + * Description: Batch String Function Registry + */ + +#ifndef OMNI_RUNTIME_BATCH_FUNC_REGISTRY_STRING_H +#define OMNI_RUNTIME_BATCH_FUNC_REGISTRY_STRING_H +#include "function.h" +#include "func_registry_base.h" +#include "util/type_util.h" + +namespace omniruntime::codegen { +class BatchStringFunctionRegistry : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; + +class BatchStringFunctionRegistryNotAllowReducePrecison : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; + +class BatchStringFunctionRegistryAllowReducePrecison : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; + +class BatchStringFunctionRegistryNotReplace : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; + +class BatchStringFunctionRegistryReplace : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; + +class BatchStringFunctionRegistrySupportNegativeAndZeroIndex : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; + +class BatchStringFunctionRegistrySupportNotNegativeAndZeroIndex : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; + +class BatchStringFunctionRegistrySupportNegativeAndNotZeroIndex : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; + +class BatchStringFunctionRegistrySupportNotNegativeAndNotZeroIndex : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; + +class BatchStringToDecimalFunctionRegistryAllowRoundUp : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; + +class BatchStringToDecimalFunctionRegistry : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; +} +#endif // OMNI_RUNTIME_BATCH_FUNC_REGISTRY_STRING_H diff --git a/core/src/codegen/batch_func_registry_util.cpp b/core/src/codegen/batch_func_registry_util.cpp new file mode 100644 index 0000000..95013bd --- /dev/null +++ b/core/src/codegen/batch_func_registry_util.cpp @@ -0,0 +1,133 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: Batch util Function Registry + */ +#include "batch_func_registry_util.h" +#include "batch_functions/batch_utilfunctions.h" + +namespace omniruntime::codegen { +using namespace omniruntime::type; +using namespace codegen::function; + + +std::vector BatchUtilFunctionRegistry::GetFunctions() +{ + std::vector utilFnRegistry = { + Function(reinterpret_cast(FillRowIndexArray), "fill_rowIdx", {}, { OMNI_INT, OMNI_INT }, OMNI_INT, + INPUT_DATA), + Function(reinterpret_cast(CopyNull), "batch_copy_null", {}, { OMNI_BOOLEAN }, OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(FillNull), "batch_fill_null", {}, { OMNI_BOOLEAN }, OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(FillBool), "batch_fill_literal", {}, { OMNI_BOOLEAN }, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(FillInt32), "batch_fill_literal", {}, { OMNI_INT }, OMNI_INT, INPUT_DATA), + Function(reinterpret_cast(FillInt64), "batch_fill_literal", {}, { OMNI_LONG }, OMNI_LONG, INPUT_DATA), + Function(reinterpret_cast(FillInt64), "batch_fill_literal", {}, { OMNI_TIMESTAMP }, OMNI_TIMESTAMP, + INPUT_DATA), + Function(reinterpret_cast(FillDecimal128), "batch_fill_literal", {}, { OMNI_DECIMAL128 }, + OMNI_DECIMAL128, INPUT_DATA), + Function(reinterpret_cast(FillInt64), "batch_fill_literal", {}, { OMNI_DECIMAL64 }, OMNI_DECIMAL64, + INPUT_DATA), + Function(reinterpret_cast(FillInt32), "batch_fill_literal", {}, { OMNI_DATE32 }, OMNI_DATE32, + INPUT_DATA), + Function(reinterpret_cast(FillDouble), "batch_fill_literal", {}, { OMNI_DOUBLE }, OMNI_DOUBLE, + INPUT_DATA), + Function(reinterpret_cast(FillString), "batch_fill_literal", {}, { OMNI_VARCHAR }, OMNI_VARCHAR, + INPUT_DATA), + Function(reinterpret_cast(FillLength), "batch_fill_length", {}, { OMNI_INT }, OMNI_INT, INPUT_DATA), + Function(reinterpret_cast(FillLengthInFuncExpr), "batch_fill_length_literal", {}, { OMNI_INT }, + OMNI_INT, INPUT_DATA), + Function(reinterpret_cast(CopyInt32), "batch_copy", {}, { OMNI_INT }, OMNI_INT, INPUT_DATA), + Function(reinterpret_cast(CopyInt64), "batch_copy", {}, { OMNI_LONG }, OMNI_LONG, INPUT_DATA), + Function(reinterpret_cast(CopyInt64), "batch_copy", {}, { OMNI_TIMESTAMP }, OMNI_TIMESTAMP, INPUT_DATA), + Function(reinterpret_cast(CopyDecimal128), "batch_copy", {}, { OMNI_DECIMAL128 }, OMNI_DECIMAL128, + INPUT_DATA), + Function(reinterpret_cast(CopyInt32), "batch_copy", {}, { OMNI_DATE32 }, OMNI_DATE32, INPUT_DATA), + Function(reinterpret_cast(CopyInt64), "batch_copy", {}, { OMNI_DECIMAL64 }, OMNI_DECIMAL64, INPUT_DATA), + Function(reinterpret_cast(CopyDouble), "batch_copy", {}, { OMNI_DOUBLE }, OMNI_DOUBLE, INPUT_DATA), + Function(reinterpret_cast(CopyBoolean), "batch_copy", {}, { OMNI_BOOLEAN }, OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(CopyString), "batch_copy", {}, { OMNI_VARCHAR }, OMNI_VARCHAR, INPUT_DATA), + Function(reinterpret_cast(CopyString), "batch_copy", {}, { OMNI_CHAR }, OMNI_CHAR, INPUT_DATA), + Function(reinterpret_cast(CreateNot), "batch_not", {}, { OMNI_BOOLEAN }, OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(CreateAnd), "batch_and", {}, { OMNI_BOOLEAN, OMNI_BOOLEAN }, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(CreateOr), "batch_or", {}, { OMNI_BOOLEAN, OMNI_BOOLEAN }, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(CreateAndNot), "batch_and_not", {}, + { OMNI_BOOLEAN, OMNI_BOOLEAN, OMNI_INT, OMNI_INT }, OMNI_INT, INPUT_DATA), + Function(reinterpret_cast(CreateAndNotBool), "batch_and_not", {}, { OMNI_BOOLEAN, OMNI_BOOLEAN }, + OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(CreateOrExpr), "batch_or_expr", {}, { OMNI_BOOLEAN, OMNI_BOOLEAN }, + OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(CreateAndExpr), "batch_and_expr", {}, { OMNI_BOOLEAN, OMNI_BOOLEAN }, + OMNI_BOOLEAN, INPUT_DATA), + + Function(reinterpret_cast(Coalesce), "batch_coalesce", {}, { OMNI_BOOLEAN, OMNI_BOOLEAN }, + OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(Coalesce), "batch_coalesce", {}, { OMNI_INT, OMNI_INT }, OMNI_INT, + INPUT_DATA), + Function(reinterpret_cast(Coalesce), "batch_coalesce", {}, { OMNI_LONG, OMNI_LONG }, OMNI_LONG, + INPUT_DATA), + Function(reinterpret_cast(Coalesce), "batch_coalesce", {}, { OMNI_TIMESTAMP, OMNI_TIMESTAMP }, + OMNI_TIMESTAMP, INPUT_DATA), + Function(reinterpret_cast(Coalesce), "batch_coalesce", {}, { OMNI_DOUBLE, OMNI_DOUBLE }, + OMNI_DOUBLE, INPUT_DATA), + Function(reinterpret_cast(CoalesceString), "batch_coalesce", {}, { OMNI_VARCHAR, OMNI_VARCHAR }, + OMNI_VARCHAR, INPUT_DATA), + Function(reinterpret_cast(Coalesce), "batch_coalesce", {}, { OMNI_DECIMAL64, OMNI_DECIMAL64 }, + OMNI_DECIMAL64, INPUT_DATA), + Function(reinterpret_cast(Coalesce), "batch_coalesce", {}, { OMNI_DATE32, OMNI_DATE32 }, + OMNI_DECIMAL64, INPUT_DATA), + Function(reinterpret_cast(Coalesce), "batch_coalesce", {}, + { OMNI_DECIMAL128, OMNI_DECIMAL128 }, OMNI_DECIMAL128, INPUT_DATA), + + Function(reinterpret_cast(IfExpr), "batch_if", {}, { OMNI_BOOLEAN }, OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(IfExpr), "batch_if", {}, { OMNI_INT }, OMNI_INT, INPUT_DATA), + Function(reinterpret_cast(IfExpr), "batch_if", {}, { OMNI_LONG }, OMNI_LONG, INPUT_DATA), + Function(reinterpret_cast(IfExpr), "batch_if", {}, { OMNI_TIMESTAMP }, OMNI_TIMESTAMP, + INPUT_DATA), + Function(reinterpret_cast(IfExpr), "batch_if", {}, { OMNI_DOUBLE }, OMNI_DOUBLE, INPUT_DATA), + Function(reinterpret_cast(IfExprString), "batch_if", {}, { OMNI_VARCHAR }, OMNI_VARCHAR, INPUT_DATA), + Function(reinterpret_cast(IfExprString), "batch_if", {}, { OMNI_CHAR }, OMNI_CHAR, INPUT_DATA), + Function(reinterpret_cast(IfExpr), "batch_if", {}, { OMNI_DECIMAL64 }, OMNI_DECIMAL64, + INPUT_DATA), + Function(reinterpret_cast(IfExpr), "batch_if", {}, { OMNI_DATE32 }, OMNI_DATE32, INPUT_DATA), + Function(reinterpret_cast(IfExpr), "batch_if", {}, { OMNI_DECIMAL128 }, OMNI_DECIMAL128, + INPUT_DATA), + + Function(reinterpret_cast(SwitchExpr), "batch_switch", {}, { OMNI_BOOLEAN }, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(SwitchExpr), "batch_switch", {}, { OMNI_INT }, OMNI_INT, INPUT_DATA), + Function(reinterpret_cast(SwitchExpr), "batch_switch", {}, { OMNI_LONG }, OMNI_LONG, + INPUT_DATA), + Function(reinterpret_cast(SwitchExpr), "batch_switch", {}, { OMNI_TIMESTAMP }, OMNI_TIMESTAMP, + INPUT_DATA), + Function(reinterpret_cast(SwitchExpr), "batch_switch", {}, { OMNI_DOUBLE }, OMNI_DOUBLE, + INPUT_DATA), + Function(reinterpret_cast(SwitchExprString), "batch_switch", {}, { OMNI_CHAR }, OMNI_CHAR, INPUT_DATA), + Function(reinterpret_cast(SwitchExprString), "batch_switch", {}, { OMNI_VARCHAR }, OMNI_VARCHAR, + INPUT_DATA), + Function(reinterpret_cast(SwitchExpr), "batch_switch", {}, { OMNI_DECIMAL64 }, OMNI_DECIMAL64, + INPUT_DATA), + Function(reinterpret_cast(SwitchExpr), "batch_switch", {}, { OMNI_DATE32 }, OMNI_DATE32, + INPUT_DATA), + Function(reinterpret_cast(SwitchExpr), "batch_switch", {}, { OMNI_DECIMAL128 }, + OMNI_DECIMAL128, INPUT_DATA), + + Function(reinterpret_cast(InExpr), "batch_in", {}, { OMNI_BOOLEAN }, OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(InExpr), "batch_in", {}, { OMNI_INT }, OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(InExpr), "batch_in", {}, { OMNI_DATE32 }, OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(InExpr), "batch_in", {}, { OMNI_LONG }, OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(InExpr), "batch_in", {}, { OMNI_TIMESTAMP }, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(InExpr), "batch_in", {}, { OMNI_DOUBLE }, OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(InExprString), "batch_in", {}, { OMNI_CHAR }, OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(InExprString), "batch_in", {}, { OMNI_VARCHAR }, OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(InExpr), "batch_in", {}, { OMNI_DECIMAL64 }, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(InExpr), "batch_in", {}, { OMNI_DECIMAL128 }, OMNI_BOOLEAN, + INPUT_DATA), + }; + + return utilFnRegistry; +} +} diff --git a/core/src/codegen/batch_func_registry_util.h b/core/src/codegen/batch_func_registry_util.h new file mode 100644 index 0000000..ca32d00 --- /dev/null +++ b/core/src/codegen/batch_func_registry_util.h @@ -0,0 +1,17 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: Batch util Function Registry + */ +#ifndef OMNI_RUNTIME_BATCH_FUNC_REGISTRY_UTIL_H +#define OMNI_RUNTIME_BATCH_FUNC_REGISTRY_UTIL_H +#include "function.h" +#include "func_registry_base.h" +#include "util/type_util.h" + +namespace omniruntime::codegen { +class BatchUtilFunctionRegistry : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; +} +#endif // OMNI_RUNTIME_BATCH_FUNC_REGISTRY_UTIL_H diff --git a/core/src/codegen/batch_func_registry_varchar_vector.cpp b/core/src/codegen/batch_func_registry_varchar_vector.cpp new file mode 100644 index 0000000..00dbc7d --- /dev/null +++ b/core/src/codegen/batch_func_registry_varchar_vector.cpp @@ -0,0 +1,24 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: Batch varchar vector Function Registry + */ + +#include "batch_func_registry_varchar_vector.h" +#include "batch_functions/batch_varcharVectorfunctions.h" + +namespace omniruntime::codegen { +using namespace omniruntime::type; +using namespace codegen::function; + +std::vector BatchVarcharVectorFunctionRegistry::GetFunctions() +{ + std::vector paramTypes = { OMNI_LONG, OMNI_VARCHAR, OMNI_INT, OMNI_INT }; + std::vector batchVarcharVectorFnRegistry = { Function(reinterpret_cast(BatchWrapVarcharVector), + "batch_WrapVarcharVector", {}, paramTypes, OMNI_INT), + Function(reinterpret_cast(BatchNullArrayToBits), "batch_NullArrayToBits", {}, { OMNI_BOOLEAN }, + OMNI_BOOLEAN), + Function(reinterpret_cast(BatchBitsToNullArray), "batch_BitsToNullArray", {}, { OMNI_BOOLEAN }, + OMNI_BOOLEAN) }; + return batchVarcharVectorFnRegistry; +} +} diff --git a/core/src/codegen/batch_func_registry_varchar_vector.h b/core/src/codegen/batch_func_registry_varchar_vector.h new file mode 100644 index 0000000..009d2e0 --- /dev/null +++ b/core/src/codegen/batch_func_registry_varchar_vector.h @@ -0,0 +1,17 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: Batch varchar vector Function Registry + */ + +#ifndef OMNI_RUNTIME_BATCH_FUNC_REGISTRY_VARCHAR_VECTOR_H +#define OMNI_RUNTIME_BATCH_FUNC_REGISTRY_VARCHAR_VECTOR_H +#include "function.h" +#include "func_registry_base.h" + +namespace omniruntime::codegen { +class BatchVarcharVectorFunctionRegistry : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; +} +#endif // OMNI_RUNTIME_BATCH_FUNC_REGISTRY_VARCHAR_VECTOR_H diff --git a/core/src/codegen/batch_functions/batch_datetime_functions.cpp b/core/src/codegen/batch_functions/batch_datetime_functions.cpp new file mode 100644 index 0000000..57b9d8c --- /dev/null +++ b/core/src/codegen/batch_functions/batch_datetime_functions.cpp @@ -0,0 +1,96 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + * Description: batch date time functions implementation + */ +#include "batch_datetime_functions.h" +#include +#include "codegen/context_helper.h" +#include "type/date32.h" +#include "codegen/time_util.h" + +namespace omniruntime::codegen::function { +extern "C" DLLEXPORT void BatchUnixTimestampFromStr(const char **timeStrs, int32_t *timeLens, bool *isNullTimeStr, + const char **fmtStrs, int32_t *fmtLens, bool *isNullFmtStr, const char **tzStrs, int32_t *tzLens, + bool *isNullTzStr, const char **policyStrs, int32_t *policyLens, bool *isNullPolStr, + bool *retIsNull, int64_t *output, int32_t rowCnt) +{ + std::string tzStr(tzStrs[0], tzLens[0]); + setenv("TZ", TimeZoneUtil::GetTZ(tzStr.c_str()), 1); + tzset(); + for (int32_t i = 0; i < rowCnt; i++) { + if (isNullTimeStr[i] || isNullFmtStr[i] || fmtLens[i] == 0 || timeLens[i] == 0) { + retIsNull[i] = true; + output[i] = 0; + continue; + } + if (!TimeUtil::IsTimeValid(timeStrs[i], timeLens[i], fmtStrs[i], fmtLens[i], policyStrs[i])) { + retIsNull[i] = true; + output[i] = 0; + continue; + } + struct tm timeInfo = { 0 }; + std::string timeStr(timeStrs[i], timeLens[i]); + std::string fmtStr(fmtStrs[i], fmtLens[i]); + strptime(timeStr.c_str(), fmtStr.c_str(), &timeInfo); + time_t timeStamp = mktime(&timeInfo); + if (TimeZoneUtil::JudgeDSTByUnixTimestampFromStr(tzStrs[i], tzLens[i], &timeInfo, + timeStrs[i], timeLens[i], fmtStrs[i], fmtLens[i])) { + timeStamp -= type::SECOND_OF_HOUR; + } + output[i] = timeStamp; + } +} + +extern "C" DLLEXPORT void BatchUnixTimestampFromDate(int32_t *dates, const char **fmtStrs, int32_t *fmtLens, + const char **tzStrs, int32_t *tzLens, const char **policyStrs, int32_t *policyLens, + bool *isAnyNull, int64_t *output, int32_t rowCnt) +{ + std::string tzStr(tzStrs[0], tzLens[0]); + setenv("TZ", TimeZoneUtil::GetTZ(tzStr.c_str()), 1); + tzset(); + for (int32_t i = 0; i < rowCnt; i++) { + if (isAnyNull[i]) { + output[i] = 0; + continue; + } + time_t desiredTime = type::SECOND_OF_DAY * dates[i]; + struct tm ltm; + localtime_r(&desiredTime, <m); + time_t result = desiredTime - ltm.tm_gmtoff; + result += TimeZoneUtil::AdjustDSTByUnixTimestampFromDate(tzStrs[i], tzLens[i], <m, desiredTime) * 3600; + output[i] = static_cast(result); + } +} + +extern "C" DLLEXPORT void BatchFromUnixTime(bool *outputNull, int64_t contextPtr, int64_t *timestamps, + const char **fmtStrs, int32_t *fmtLens, const char **tzStrs, int32_t *tzLens, + char **output, int32_t *outLens, int32_t rowCnt) +{ + std::string tzStr(tzStrs[0], tzLens[0]); + setenv("TZ", TimeZoneUtil::GetTZ(tzStr.c_str()), 1); + tzset(); + for (int32_t i = 0; i < rowCnt; i++) { + time_t timeStampVal = timestamps[i]; + struct tm ltm; + localtime_r(&timeStampVal, <m); + if (!TimeZoneUtil::JudgeDSTByFromUnixTime(tzStrs[i], tzLens[i], <m)) { + timeStampVal -= type::SECOND_OF_HOUR; + localtime_r(&timeStampVal, <m); + } + int32_t resultLen = fmtLens[i] + 3; + auto result = ArenaAllocatorMalloc(contextPtr, resultLen); + std::string fmtStr(fmtStrs[i], fmtLens[i]); + auto ret = strftime(result, resultLen, fmtStr.c_str(), <m); + outputNull[i] = (ret == 0); + output[i] = result; + outLens[i] = ret; + } +} + +extern "C" DLLEXPORT void BatchFromUnixTimeRetNull(bool *outputNull, int64_t contextPtr, int64_t *timestamps, + const char **fmtStrs, int32_t *fmtLens, const char **tzStrs, int32_t *tzLen, + char **output, int32_t *outLens, int32_t rowCnt) +{ + BatchFromUnixTime(outputNull, contextPtr, timestamps, fmtStrs, fmtLens, tzStrs, tzLen, output, outLens, rowCnt); +} +} \ No newline at end of file diff --git a/core/src/codegen/batch_functions/batch_datetime_functions.h b/core/src/codegen/batch_functions/batch_datetime_functions.h new file mode 100644 index 0000000..6de1af8 --- /dev/null +++ b/core/src/codegen/batch_functions/batch_datetime_functions.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: batch date time functions implementation + */ + +#ifndef OMNI_RUNTIME_BATCH_DATETIME_FUNCTIONS_H +#define OMNI_RUNTIME_BATCH_DATETIME_FUNCTIONS_H +#include + +// All extern functions go here temporarily +#ifdef _WIN32 +#define DLLEXPORT __declspec(dllexport) +#else +#define DLLEXPORT +#endif + +namespace omniruntime::codegen::function { +extern "C" DLLEXPORT void BatchUnixTimestampFromStr(const char **timeStrs, int32_t *timeLens, bool *isNullTimeStr, + const char **fmtStrs, int32_t *fmtLens, bool *isNullFmtStr, const char **tzStrs, int32_t *tzLens, bool *isNullTzStr, + const char **policyStrs, int32_t *policyLens, bool *isNullPolStr, bool *retIsNull, int64_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchUnixTimestampFromDate(int32_t *dates, const char **fmtStrs, int32_t *fmtLens, + const char **tzStrs, int32_t *tzLens, const char **policyStrs, int32_t *policyLens, + bool *isAnyNull, int64_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchFromUnixTime(bool *outputNull, int64_t contextPtr, int64_t *timestamps, + const char **fmtStrs, int32_t *fmtLens, const char **tzStrs, int32_t *tzLens, + char **output, int32_t *outLens, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchFromUnixTimeRetNull(bool *outputNull, int64_t contextPtr, int64_t *timestamps, + const char **fmtStrs, int32_t *fmtLens, const char **tzStrs, int32_t *tzLens, + char **output, int32_t *outLens, int32_t rowCnt); +} +#endif // OMNI_RUNTIME_BATCH_DATETIME_FUNCTIONS_H diff --git a/core/src/codegen/batch_functions/batch_decimal_arithmetic_functions.cpp b/core/src/codegen/batch_functions/batch_decimal_arithmetic_functions.cpp new file mode 100644 index 0000000..4a10e3f --- /dev/null +++ b/core/src/codegen/batch_functions/batch_decimal_arithmetic_functions.cpp @@ -0,0 +1,1842 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + * Description: batch decimal functions implementation + */ +#include "batch_decimal_arithmetic_functions.h" +#include +#include +#include "codegen/context_helper.h" +#include "type/decimal_operations.h" + +using namespace omniruntime::type; + +namespace omniruntime::codegen::function { +const std::string DECIMAL_OVERFLOW { "Decimal overflow" }; +const std::string DIVIDE_ZERO { "Division by zero" }; + +extern "C" DLLEXPORT void BatchDecimal128Compare(Decimal128 *x, int32_t xPrecision, int32_t xScale, Decimal128 *y, + int32_t yPrecision, int32_t yScale, int32_t *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + output[i] = Decimal128Wrapper(x[i]).SetScale(xScale).Compare(Decimal128Wrapper(y[i]).SetScale(yScale)); + } +} + +extern "C" DLLEXPORT void BatchLessThanDecimal128(Decimal128 *left, int32_t xPrecision, int32_t xScale, + Decimal128 *right, int32_t yPrecision, int32_t yScale, bool *output, int32_t rowCnt) +{ + auto tmp = new int32_t[rowCnt]; + BatchDecimal128Compare(left, xPrecision, xScale, right, yPrecision, yScale, tmp, rowCnt); + for (int i = 0; i < rowCnt; i++) { + output[i] = (tmp[i] < 0); + } + delete[] tmp; +} + +extern "C" DLLEXPORT void BatchLessThanEqualDecimal128(Decimal128 *left, int32_t xPrecision, int32_t xScale, + Decimal128 *right, int32_t yPrecision, int32_t yScale, bool *output, int32_t rowCnt) +{ + auto tmp = new int32_t[rowCnt]; + BatchDecimal128Compare(left, xPrecision, xScale, right, yPrecision, yScale, tmp, rowCnt); + for (int i = 0; i < rowCnt; i++) { + output[i] = (tmp[i] <= 0); + } + delete[] tmp; +} + +extern "C" DLLEXPORT void BatchGreaterThanDecimal128(Decimal128 *left, int32_t xPrecision, int32_t xScale, + Decimal128 *right, int32_t yPrecision, int32_t yScale, bool *output, int32_t rowCnt) +{ + auto tmp = new int32_t[rowCnt]; + BatchDecimal128Compare(left, xPrecision, xScale, right, yPrecision, yScale, tmp, rowCnt); + for (int i = 0; i < rowCnt; i++) { + output[i] = (tmp[i] > 0); + } + delete[] tmp; +} + +extern "C" DLLEXPORT void BatchGreaterThanEqualDecimal128(Decimal128 *left, int32_t xPrecision, int32_t xScale, + Decimal128 *right, int32_t yPrecision, int32_t yScale, bool *output, int32_t rowCnt) +{ + auto tmp = new int32_t[rowCnt]; + BatchDecimal128Compare(left, xPrecision, xScale, right, yPrecision, yScale, tmp, rowCnt); + for (int i = 0; i < rowCnt; i++) { + output[i] = (tmp[i] >= 0); + } + delete[] tmp; +} + +extern "C" DLLEXPORT void BatchEqualDecimal128(Decimal128 *left, int32_t xPrecision, int32_t xScale, Decimal128 *right, + int32_t yPrecision, int32_t yScale, bool *output, int32_t rowCnt) +{ + auto tmp = new int32_t[rowCnt]; + BatchDecimal128Compare(left, xPrecision, xScale, right, yPrecision, yScale, tmp, rowCnt); + for (int i = 0; i < rowCnt; i++) { + output[i] = (tmp[i] == 0); + } + delete[] tmp; +} + +extern "C" DLLEXPORT void BatchNotEqualDecimal128(Decimal128 *left, int32_t xPrecision, int32_t xScale, + Decimal128 *right, int32_t yPrecision, int32_t yScale, bool *output, int32_t rowCnt) +{ + auto tmp = new int32_t[rowCnt]; + BatchDecimal128Compare(left, xPrecision, xScale, right, yPrecision, yScale, tmp, rowCnt); + for (int i = 0; i < rowCnt; i++) { + output[i] = (tmp[i] != 0); + } + delete[] tmp; +} + +extern "C" DLLEXPORT void BatchAbsDecimal128(Decimal128 *x, int32_t xPrecision, int32_t xScale, bool *isNull, + Decimal128 *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + output[i] = Decimal128Wrapper(x[i]).Abs().ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchRoundDecimal128(int64_t contextPtr, Decimal128 *x, int32_t xPrecision, int32_t xScale, + int32_t *round, bool *isNull, Decimal128 *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + if (isNull[i]) { + continue; + } + Decimal128Wrapper input(x[i]); + input.SetScale(xScale); + DecimalOperations::Round(input, outScale, round[i]); + CHECK_OVERFLOW_CONTINUE(input, outPrecision); + output[i] = input.ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchRoundDecimal64(int64_t contextPtr, int64_t *x, int32_t xPrecision, int32_t xScale, + int32_t *round, bool *isNull, int64_t *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + if (isNull[i]) { + continue; + } + Decimal64 input(x[i]); + input.SetScale(xScale); + DecimalOperations::Round(input, outScale, round[i]); + CHECK_OVERFLOW_CONTINUE(input, outPrecision); + output[i] = input.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchRoundDecimal128WithoutRound(int64_t contextPtr, Decimal128 *x, int32_t xPrecision, + int32_t xScale, bool *isNull, Decimal128 *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + if (isNull[i]) { + continue; + } + Decimal128Wrapper input(x[i]); + input.SetScale(xScale); + DecimalOperations::Round(input, outScale, 0); + CHECK_OVERFLOW_CONTINUE(input, outPrecision); + output[i] = input.ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchRoundDecimal64WithoutRound(int64_t contextPtr, int64_t *x, int32_t xPrecision, + int32_t xScale, bool *isNull, int64_t *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + if (isNull[i]) { + continue; + } + Decimal64 input(x[i]); + input.SetScale(xScale); + DecimalOperations::Round(input, outScale, 0); + CHECK_OVERFLOW_CONTINUE(input, outPrecision); + output[i] = input.GetValue(); + } +} + +// decimal64 arith functions +extern "C" DLLEXPORT void BatchDecimal64Compare(int64_t *x, int32_t xPrecision, int32_t xScale, int64_t *y, + int32_t yPrecision, int32_t yScale, int32_t *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + output[i] = Decimal64(x[i]).SetScale(xScale).Compare(Decimal64(y[i]).SetScale(yScale)); + } +} + +extern "C" DLLEXPORT void BatchAbsDecimal64(int64_t *x, int32_t xPrecision, int32_t xScale, bool *isNull, + int64_t *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + output[i] = std::abs(x[i]); + } +} + +extern "C" DLLEXPORT void BatchLessThanDecimal64(int64_t *left, int32_t xPrecision, int32_t xScale, int64_t *right, + int32_t yPrecision, int32_t yScale, bool *output, int32_t rowCnt) +{ + auto tmp = new int32_t[rowCnt]; + BatchDecimal64Compare(left, xPrecision, xScale, right, yPrecision, yScale, tmp, rowCnt); + for (int i = 0; i < rowCnt; i++) { + output[i] = (tmp[i] < 0); + } + delete[] tmp; +} + +extern "C" DLLEXPORT void BatchLessThanEqualDecimal64(int64_t *left, int32_t xPrecision, int32_t xScale, int64_t *right, + int32_t yPrecision, int32_t yScale, bool *output, int32_t rowCnt) +{ + auto tmp = new int32_t[rowCnt]; + BatchDecimal64Compare(left, xPrecision, xScale, right, yPrecision, yScale, tmp, rowCnt); + for (int i = 0; i < rowCnt; i++) { + output[i] = (tmp[i] <= 0); + } + delete[] tmp; +} + +extern "C" DLLEXPORT void BatchGreaterThanDecimal64(int64_t *left, int32_t xPrecision, int32_t xScale, int64_t *right, + int32_t yPrecision, int32_t yScale, bool *output, int32_t rowCnt) +{ + auto tmp = new int32_t[rowCnt]; + BatchDecimal64Compare(left, xPrecision, xScale, right, yPrecision, yScale, tmp, rowCnt); + for (int i = 0; i < rowCnt; i++) { + output[i] = (tmp[i] > 0); + } + delete[] tmp; +} + +extern "C" DLLEXPORT void BatchGreaterThanEqualDecimal64(int64_t *left, int32_t xPrecision, int32_t xScale, + int64_t *right, int32_t yPrecision, int32_t yScale, bool *output, int32_t rowCnt) +{ + auto tmp = new int32_t[rowCnt]; + BatchDecimal64Compare(left, xPrecision, xScale, right, yPrecision, yScale, tmp, rowCnt); + for (int i = 0; i < rowCnt; i++) { + output[i] = (tmp[i] >= 0); + } + delete[] tmp; +} + +extern "C" DLLEXPORT void BatchEqualDecimal64(int64_t *left, int32_t xPrecision, int32_t xScale, int64_t *right, + int32_t yPrecision, int32_t yScale, bool *output, int32_t rowCnt) +{ + auto tmp = new int32_t[rowCnt]; + BatchDecimal64Compare(left, xPrecision, xScale, right, yPrecision, yScale, tmp, rowCnt); + for (int i = 0; i < rowCnt; i++) { + output[i] = (tmp[i] == 0); + } + delete[] tmp; +} + +extern "C" DLLEXPORT void BatchNotEqualDecimal64(int64_t *left, int32_t xPrecision, int32_t xScale, int64_t *right, + int32_t yPrecision, int32_t yScale, bool *output, int32_t rowCnt) +{ + auto tmp = new int32_t[rowCnt]; + BatchDecimal64Compare(left, xPrecision, xScale, right, yPrecision, yScale, tmp, rowCnt); + for (int i = 0; i < rowCnt; i++) { + output[i] = (tmp[i] != 0); + } + delete[] tmp; +} + +extern "C" DLLEXPORT void BatchUnscaledValue64(int64_t *x, int32_t precision, int32_t scale, bool *isAnyNull, + int64_t *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + output[i] = x[i]; + } +} + +extern "C" DLLEXPORT void BatchMakeDecimal64(int64_t contextPtr, int64_t *x, bool *isAnyNull, int64_t *output, + int32_t precision, int32_t scale, int32_t rowCnt) +{ + std::ostringstream errorMessage; + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + output[i] = 1; + continue; + } + if (DecimalOperations::IsUnscaledLongOverflow(x[i], precision, scale) && !HasError(contextPtr)) { + errorMessage << "Unscaled value " << x << " out of Decimal(" << precision << ", " << scale << ") range"; + SetError(contextPtr, errorMessage.str()); + output[i] = 1; + continue; + } + output[i] = x[i]; + } +} + +extern "C" DLLEXPORT void BatchMakeDecimal64RetNull(bool *isNull, int64_t *x, int64_t *output, int32_t precision, + int32_t scale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (DecimalOperations::IsUnscaledLongOverflow(x[i], precision, scale)) { + isNull[i] = true; + output[i] = 1; + continue; + } + output[i] = x[i]; + } +} + +// Decimal Add Operator ReScale +extern "C" DLLEXPORT void BatchAddDec64Dec64Dec64ReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal64 result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalAdd(Decimal64(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal64(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + x[i] = result.GetValue(); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchAddDec64Dec64Dec128ReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, Decimal128 *output, + int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalAdd(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + output[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchAddDec128Dec128Dec128ReScale(int64_t contextPtr, bool *isNull, Decimal128 *x, + int32_t xPrecision, int32_t xScale, Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalAdd(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + x[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchAddDec64Dec128Dec128ReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalAdd(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + y[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchAddDec128Dec64Dec128ReScale(int64_t contextPtr, bool *isNull, Decimal128 *y, + int32_t yPrecision, int32_t yScale, int64_t *x, int32_t xPrecision, int32_t xScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalAdd(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + y[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + } +} + +// Decimal SubOperator ReScale +extern "C" DLLEXPORT void BatchSubDec64Dec64Dec64ReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal64 result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalSubtract(Decimal64(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal64(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + x[i] = result.GetValue(); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchSubDec64Dec64Dec128ReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, Decimal128 *output, + int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalSubtract(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + output[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchSubDec128Dec128Dec128ReScale(int64_t contextPtr, bool *isNull, Decimal128 *x, + int32_t xPrecision, int32_t xScale, Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalSubtract(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + x[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchSubDec64Dec128Dec128ReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalSubtract(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + y[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchSubDec128Dec64Dec128ReScale(int64_t contextPtr, bool *isNull, Decimal128 *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalSubtract(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + x[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + } +} + +// Decimal MulOperator ReScale +extern "C" DLLEXPORT void BatchMulDec64Dec64Dec64ReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal64 result; + for (int i = 0; i < rowCnt; ++i) { + if (isNull[i]) { + continue; + } + DecimalOperations::InternalDecimalMultiply(Decimal64(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal64(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + x[i] = result.GetValue(); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchMulDec64Dec64Dec128ReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, Decimal128 *output, + int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + if (isNull[i]) { + continue; + } + DecimalOperations::InternalDecimalMultiply(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + output[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchMulDec128Dec128Dec128ReScale(int64_t contextPtr, bool *isNull, Decimal128 *x, + int32_t xPrecision, int32_t xScale, Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + if (isNull[i]) { + continue; + } + DecimalOperations::InternalDecimalMultiply(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + x[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchMulDec64Dec128Dec128ReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + if (isNull[i]) { + continue; + } + DecimalOperations::InternalDecimalMultiply(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + y[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchMulDec128Dec64Dec128ReScale(int64_t contextPtr, bool *isNull, Decimal128 *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + if (isNull[i]) { + continue; + } + DecimalOperations::InternalDecimalMultiply(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + x[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + } +} + +// Decimal DivOperation ReScale +extern "C" DLLEXPORT void BatchDivDec64Dec64Dec64ReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal64 result; + for (int i = 0; i < rowCnt; ++i) { + CHECK_DIVIDE_BY_ZERO_CONTINUE(y[i]); + if (isNull[i]) { + x[i] = 1; + continue; + } + DecimalOperations::InternalDecimalDivide(Decimal64(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal64(y[i]).SetScale(yScale), yScale, yPrecision, result, outScale); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + x[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchDivDec64Dec128Dec64ReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal64 result; + for (int i = 0; i < rowCnt; ++i) { + CHECK_DIVIDE_BY_ZERO_CONTINUE(y[i]); + if (isNull[i]) { + x[i] = 1; + continue; + } + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result, outScale); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + x[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchDivDec128Dec64Dec64ReScale(int64_t contextPtr, bool *isNull, Decimal128 *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal64 result; + for (int i = 0; i < rowCnt; ++i) { + CHECK_DIVIDE_BY_ZERO_CONTINUE(y[i]); + if (isNull[i]) { + x[i] = 1; + continue; + } + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result, outScale); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + y[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchDivDec64Dec64Dec128ReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, Decimal128 *output, + int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + CHECK_DIVIDE_BY_ZERO_CONTINUE(y[i]); + if (isNull[i]) { + x[i] = 1; + continue; + } + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result, outScale); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + output[i] = result.ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchDivDec128Dec128Dec128ReScale(int64_t contextPtr, bool *isNull, Decimal128 *x, + int32_t xPrecision, int32_t xScale, Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + CHECK_DIVIDE_BY_ZERO_CONTINUE(y[i]); + if (isNull[i]) { + x[i] = 1; + continue; + } + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result, outScale); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + x[i] = result.ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchDivDec64Dec128Dec128ReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + CHECK_DIVIDE_BY_ZERO_CONTINUE(y[i]); + if (isNull[i]) { + x[i] = 1; + continue; + } + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result, outScale); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + y[i] = result.ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchDivDec128Dec64Dec128ReScale(int64_t contextPtr, bool *isNull, Decimal128 *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + CHECK_DIVIDE_BY_ZERO_CONTINUE(y[i]); + if (isNull[i]) { + x[i] = 1; + continue; + } + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result, outScale); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + x[i] = result.ToDecimal128(); + } +} + +// Decimal Mod Operation ReScale +extern "C" DLLEXPORT void BatchModDec64Dec64Dec64ReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal64 result; + for (int i = 0; i < rowCnt; ++i) { + CHECK_DIVIDE_BY_ZERO_CONTINUE(y[i]); + if (isNull[i]) { + x[i] = 1; + continue; + } + DecimalOperations::InternalDecimalMod(Decimal64(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal64(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + x[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchModDec64Dec128Dec64ReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal64 result; + for (int i = 0; i < rowCnt; ++i) { + CHECK_DIVIDE_BY_ZERO_CONTINUE(y[i]); + if (isNull[i]) { + x[i] = 1; + continue; + } + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + x[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchModDec128Dec64Dec64ReScale(int64_t contextPtr, bool *isNull, Decimal128 *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal64 result; + for (int i = 0; i < rowCnt; ++i) { + CHECK_DIVIDE_BY_ZERO_CONTINUE(y[i]); + if (isNull[i]) { + x[i] = 1; + continue; + } + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + y[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchModDec128Dec64Dec128ReScale(int64_t contextPtr, bool *isNull, Decimal128 *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + CHECK_DIVIDE_BY_ZERO_CONTINUE(y[i]); + if (isNull[i]) { + x[i] = 1; + continue; + } + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + x[i] = result.ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchModDec128Dec128Dec128ReScale(int64_t contextPtr, bool *isNull, Decimal128 *x, + int32_t xPrecision, int32_t xScale, Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + CHECK_DIVIDE_BY_ZERO_CONTINUE(y[i]); + if (isNull[i]) { + x[i] = 1; + continue; + } + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + x[i] = result.ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchModDec128Dec128Dec64ReScale(int64_t contextPtr, bool *isNull, Decimal128 *x, + int32_t xPrecision, int32_t xScale, Decimal128 *y, int32_t yPrecision, int32_t yScale, int64_t *output, + int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + Decimal64 result; + for (int i = 0; i < rowCnt; ++i) { + CHECK_DIVIDE_BY_ZERO_CONTINUE(y[i]); + if (isNull[i]) { + x[i] = 1; + continue; + } + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + output[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchModDec64Dec128Dec128ReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + CHECK_DIVIDE_BY_ZERO_CONTINUE(y[i]); + if (isNull[i]) { + x[i] = 1; + continue; + } + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + y[i] = result.ToDecimal128(); + } +} + +// Decimal Add Operator NotReScale +extern "C" DLLEXPORT void BatchAddDec64Dec64Dec64NotReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal64 result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalAdd(Decimal64(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal64(y[i]).SetScale(yScale), yScale, yPrecision, result); + x[i] = result.GetValue(); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchAddDec64Dec64Dec128NotReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, Decimal128 *output, + int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalAdd(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + output[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchAddDec128Dec128Dec128NotReScale(int64_t contextPtr, bool *isNull, Decimal128 *x, + int32_t xPrecision, int32_t xScale, Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalAdd(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + x[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchAddDec64Dec128Dec128NotReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalAdd(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + y[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchAddDec128Dec64Dec128NotReScale(int64_t contextPtr, bool *isNull, Decimal128 *y, + int32_t yPrecision, int32_t yScale, int64_t *x, int32_t xPrecision, int32_t xScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalAdd(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + y[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + } +} + +// Decimal SubOperator NotReScale +extern "C" DLLEXPORT void BatchSubDec64Dec64Dec64NotReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal64 result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalSubtract(Decimal64(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal64(y[i]).SetScale(yScale), yScale, yPrecision, result); + x[i] = result.GetValue(); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchSubDec64Dec64Dec128NotReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, Decimal128 *output, + int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalSubtract(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + output[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchSubDec128Dec128Dec128NotReScale(int64_t contextPtr, bool *isNull, Decimal128 *x, + int32_t xPrecision, int32_t xScale, Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalSubtract(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + x[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchSubDec64Dec128Dec128NotReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalSubtract(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + y[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchSubDec128Dec64Dec128NotReScale(int64_t contextPtr, bool *isNull, Decimal128 *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalSubtract(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + x[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + } +} + +// Decimal MulOperator NotReScale +extern "C" DLLEXPORT void BatchMulDec64Dec64Dec64NotReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal64 result; + for (int i = 0; i < rowCnt; ++i) { + if (isNull[i]) { + continue; + } + DecimalOperations::InternalDecimalMultiply(Decimal64(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal64(y[i]).SetScale(yScale), yScale, yPrecision, result); + x[i] = result.GetValue(); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchMulDec64Dec64Dec128NotReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, Decimal128 *output, + int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + if (isNull[i]) { + continue; + } + DecimalOperations::InternalDecimalMultiply(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + output[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchMulDec128Dec128Dec128NotReScale(int64_t contextPtr, bool *isNull, Decimal128 *x, + int32_t xPrecision, int32_t xScale, Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + if (isNull[i]) { + continue; + } + DecimalOperations::InternalDecimalMultiply(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + x[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchMulDec64Dec128Dec128NotReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + if (isNull[i]) { + continue; + } + DecimalOperations::InternalDecimalMultiply(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + y[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchMulDec128Dec64Dec128NotReScale(int64_t contextPtr, bool *isNull, Decimal128 *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + if (isNull[i]) { + continue; + } + DecimalOperations::InternalDecimalMultiply(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + x[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + } +} + +// Decimal DivOperation NotReScale +extern "C" DLLEXPORT void BatchDivDec64Dec64Dec64NotReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal64 result; + for (int i = 0; i < rowCnt; ++i) { + CHECK_DIVIDE_BY_ZERO_CONTINUE(y[i]); + if (isNull[i]) { + x[i] = 1; + continue; + } + DecimalOperations::InternalDecimalDivide(Decimal64(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal64(y[i]).SetScale(yScale), yScale, yPrecision, result, outScale); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + x[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchDivDec64Dec128Dec64NotReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal64 result; + for (int i = 0; i < rowCnt; ++i) { + CHECK_DIVIDE_BY_ZERO_CONTINUE(y[i]); + if (isNull[i]) { + x[i] = 1; + continue; + } + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result, outScale); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + x[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchDivDec128Dec64Dec64NotReScale(int64_t contextPtr, bool *isNull, Decimal128 *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal64 result; + for (int i = 0; i < rowCnt; ++i) { + CHECK_DIVIDE_BY_ZERO_CONTINUE(y[i]); + if (isNull[i]) { + x[i] = 1; + continue; + } + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result, outScale); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + y[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchDivDec64Dec64Dec128NotReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, Decimal128 *output, + int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + CHECK_DIVIDE_BY_ZERO_CONTINUE(y[i]); + if (isNull[i]) { + x[i] = 1; + continue; + } + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result, outScale); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + output[i] = result.ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchDivDec128Dec128Dec128NotReScale(int64_t contextPtr, bool *isNull, Decimal128 *x, + int32_t xPrecision, int32_t xScale, Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + CHECK_DIVIDE_BY_ZERO_CONTINUE(y[i]); + if (isNull[i]) { + x[i] = 1; + continue; + } + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result, outScale); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + x[i] = result.ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchDivDec64Dec128Dec128NotReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + CHECK_DIVIDE_BY_ZERO_CONTINUE(y[i]); + if (isNull[i]) { + x[i] = 1; + continue; + } + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result, outScale); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + y[i] = result.ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchDivDec128Dec64Dec128NotReScale(int64_t contextPtr, bool *isNull, Decimal128 *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + CHECK_DIVIDE_BY_ZERO_CONTINUE(y[i]); + if (isNull[i]) { + x[i] = 1; + continue; + } + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result, outScale); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + x[i] = result.ToDecimal128(); + } +} + +// Decimal Mod Operation NotReScale +extern "C" DLLEXPORT void BatchModDec64Dec64Dec64NotReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal64 result; + for (int i = 0; i < rowCnt; ++i) { + CHECK_DIVIDE_BY_ZERO_CONTINUE(y[i]); + if (isNull[i]) { + x[i] = 1; + continue; + } + DecimalOperations::InternalDecimalMod(Decimal64(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal64(y[i]).SetScale(yScale), yScale, yPrecision, result); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + x[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchModDec64Dec128Dec64NotReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal64 result; + for (int i = 0; i < rowCnt; ++i) { + CHECK_DIVIDE_BY_ZERO_CONTINUE(y[i]); + if (isNull[i]) { + x[i] = 1; + continue; + } + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + x[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchModDec128Dec64Dec64NotReScale(int64_t contextPtr, bool *isNull, Decimal128 *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal64 result; + for (int i = 0; i < rowCnt; ++i) { + CHECK_DIVIDE_BY_ZERO_CONTINUE(y[i]); + if (isNull[i]) { + x[i] = 1; + continue; + } + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + y[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchModDec128Dec64Dec128NotReScale(int64_t contextPtr, bool *isNull, Decimal128 *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + CHECK_DIVIDE_BY_ZERO_CONTINUE(y[i]); + if (isNull[i]) { + x[i] = 1; + continue; + } + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + x[i] = result.ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchModDec128Dec128Dec128NotReScale(int64_t contextPtr, bool *isNull, Decimal128 *x, + int32_t xPrecision, int32_t xScale, Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + CHECK_DIVIDE_BY_ZERO_CONTINUE(y[i]); + if (isNull[i]) { + x[i] = 1; + continue; + } + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + x[i] = result.ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchModDec128Dec128Dec64NotReScale(int64_t contextPtr, bool *isNull, Decimal128 *x, + int32_t xPrecision, int32_t xScale, Decimal128 *y, int32_t yPrecision, int32_t yScale, int64_t *output, + int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + Decimal64 result; + for (int i = 0; i < rowCnt; ++i) { + CHECK_DIVIDE_BY_ZERO_CONTINUE(y[i]); + if (isNull[i]) { + x[i] = 1; + continue; + } + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + output[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchModDec64Dec128Dec128NotReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + CHECK_DIVIDE_BY_ZERO_CONTINUE(y[i]); + if (isNull[i]) { + x[i] = 1; + continue; + } + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + CHECK_OVERFLOW_CONTINUE(result, outPrecision); + y[i] = result.ToDecimal128(); + } +} + + +// add return null +extern "C" DLLEXPORT void BatchAddDec64Dec64Dec64RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + Decimal64 result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalAdd(Decimal64(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal64(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + x[i] = result.GetValue(); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchAddDec64Dec64Dec128RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + int64_t *y, int32_t yPrecision, int32_t yScale, Decimal128 *output, int32_t outPrecision, int32_t outScale, + int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalAdd(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + output[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchAddDec128Dec128Dec128RetNull(bool *isNull, Decimal128 *x, int32_t xPrecision, + int32_t xScale, Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalAdd(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + x[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchAddDec64Dec128Dec128RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalAdd(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + y[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchAddDec128Dec64Dec128RetNull(bool *isNull, Decimal128 *y, int32_t yPrecision, + int32_t yScale, int64_t *x, int32_t xPrecision, int32_t xScale, int32_t outPrecision, int32_t outScale, + int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalAdd(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + y[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + } +} + +// sub ret null +extern "C" DLLEXPORT void BatchSubDec64Dec64Dec64RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + Decimal64 result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalSubtract(Decimal64(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal64(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + x[i] = result.GetValue(); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchSubDec64Dec64Dec128RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + int64_t *y, int32_t yPrecision, int32_t yScale, Decimal128 *output, int32_t outPrecision, int32_t outScale, + int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalSubtract(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + output[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchSubDec128Dec128Dec128RetNull(bool *isNull, Decimal128 *x, int32_t xPrecision, + int32_t xScale, Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalSubtract(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + x[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchSubDec64Dec128Dec128RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalSubtract(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + y[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchSubDec128Dec64Dec128RetNull(bool *isNull, Decimal128 *x, int32_t xPrecision, + int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalSubtract(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + x[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + } +} + +// mul ret null +extern "C" DLLEXPORT void BatchMulDec64Dec64Dec64RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + Decimal64 result; + for (int i = 0; i < rowCnt; ++i) { + if (isNull[i]) { + continue; + } + DecimalOperations::InternalDecimalMultiply(Decimal64(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal64(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + x[i] = result.GetValue(); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchMulDec64Dec64Dec128RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + int64_t *y, int32_t yPrecision, int32_t yScale, Decimal128 *output, int32_t outPrecision, int32_t outScale, + int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + if (isNull[i]) { + continue; + } + DecimalOperations::InternalDecimalMultiply(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + output[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchMulDec128Dec128Dec128RetNull(bool *isNull, Decimal128 *x, int32_t xPrecision, + int32_t xScale, Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + if (isNull[i]) { + continue; + } + result = Decimal128Wrapper(x[i]).MultiplyRoundUp(Decimal128Wrapper(y[i]), xScale + yScale - outScale); + x[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchMulDec64Dec128Dec128RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + if (isNull[i]) { + continue; + } + DecimalOperations::InternalDecimalMultiply(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + y[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + } +} + +extern "C" DLLEXPORT void BatchMulDec128Dec64Dec128RetNull(bool *isNull, Decimal128 *x, int32_t xPrecision, + int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + if (isNull[i]) { + continue; + } + DecimalOperations::InternalDecimalMultiply(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + x[i] = result.ToDecimal128(); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + } +} + +// div ret null +extern "C" DLLEXPORT void BatchDivDec64Dec64Dec64RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + Decimal64 result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalDivide(Decimal64(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal64(y[i]).SetScale(yScale), yScale, yPrecision, result, outScale); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + x[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchDivDec64Dec128Dec64RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + Decimal64 result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result, outScale); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + x[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchDivDec128Dec64Dec64RetNull(bool *isNull, Decimal128 *x, int32_t xPrecision, + int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int32_t rowCnt) +{ + Decimal64 result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result, outScale); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + y[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchDivDec64Dec64Dec128RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + int64_t *y, int32_t yPrecision, int32_t yScale, Decimal128 *output, int32_t outPrecision, int32_t outScale, + int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result, outScale); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + output[i] = result.ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchDivDec128Dec128Dec128RetNull(bool *isNull, Decimal128 *x, int32_t xPrecision, + int32_t xScale, Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result, outScale); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + x[i] = result.ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchDivDec64Dec128Dec128RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result, outScale); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + y[i] = result.ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchDivDec128Dec64Dec128RetNull(bool *isNull, Decimal128 *x, int32_t xPrecision, + int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result, outScale); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + x[i] = result.ToDecimal128(); + } +} + +// mod ret null +extern "C" DLLEXPORT void BatchModDec64Dec64Dec64RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + Decimal64 result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalMod(Decimal64(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal64(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + x[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchModDec64Dec128Dec64RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + Decimal64 result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + x[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchModDec128Dec64Dec64RetNull(bool *isNull, Decimal128 *x, int32_t xPrecision, + int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int32_t rowCnt) +{ + Decimal64 result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + y[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchModDec128Dec64Dec128RetNull(bool *isNull, Decimal128 *x, int32_t xPrecision, + int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + x[i] = result.ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchModDec128Dec128Dec128RetNull(bool *isNull, Decimal128 *x, int32_t xPrecision, + int32_t xScale, Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + x[i] = result.ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchModDec128Dec128Dec64RetNull(bool *isNull, Decimal128 *x, int32_t xPrecision, + int32_t xScale, Decimal128 *y, int32_t yPrecision, int32_t yScale, int64_t *output, int32_t outPrecision, + int32_t outScale, int32_t rowCnt) +{ + Decimal64 result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + output[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchModDec64Dec128Dec128RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + Decimal128Wrapper result; + for (int i = 0; i < rowCnt; ++i) { + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(x[i]).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y[i]).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + y[i] = result.ToDecimal128(); + } +} + + +extern "C" DLLEXPORT void BatchRoundDecimal128RetNull(bool *isNull, Decimal128 *x, int32_t xPrecision, int32_t xScale, + int32_t *round, Decimal128 *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + Decimal128Wrapper input(x[i]); + input.SetScale(xScale); + DecimalOperations::Round(input, outScale, round[i]); + CHECK_OVERFLOW_CONTINUE_NULL(input, outPrecision); + output[i] = input.ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchRoundDecimal64RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + int32_t *round, int64_t *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + Decimal64 input(x[i]); + input.SetScale(xScale); + DecimalOperations::Round(input, outScale, round[i]); + CHECK_OVERFLOW_CONTINUE_NULL(input, outPrecision); + output[i] = input.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchGreatestDecimal64(int64_t contextPtr, int64_t *xValue, int32_t xPrecision, + int32_t xScale, bool *xIsNull, int64_t *yValue, int32_t yPrecision, int32_t yScale, bool *yIsNull, bool *retIsNull, + int64_t *output, int32_t newPrecision, int32_t newScale, int32_t rowCnt) +{ + if (xPrecision == yPrecision && xScale == yScale) { + for (int i = 0; i < rowCnt; i++) { + if (xIsNull[i] && yIsNull[i]) { + retIsNull[i] = true; + continue; + } + auto x = xValue[i]; + auto y = yValue[i]; + if (xIsNull[i] || (!yIsNull[i] && x < y)) { + output[i] = y; + continue; + } + output[i] = x; + } + } else { + for (int i = 0; i < rowCnt; i++) { + if (xIsNull[i] && yIsNull[i]) { + retIsNull[i] = true; + continue; + } + Decimal64 x(xValue[i]); + x.SetScale(xScale); + Decimal64 y(yValue[i]); + y.SetScale(yScale); + if (xIsNull[i] || (!yIsNull[i] && x.Compare(y) < 0)) { + y.ReScale(newScale); + CHECK_OVERFLOW_CONTINUE(y, newPrecision); + output[i] = y.GetValue(); + continue; + } + x.ReScale(newScale); + CHECK_OVERFLOW_CONTINUE(x, newPrecision); + output[i] = x.GetValue(); + } + } +} + +extern "C" DLLEXPORT void BatchGreatestDecimal128(int64_t contextPtr, type::Decimal128 *xValue, int32_t xPrecision, + int32_t xScale, bool *xIsNull, type::Decimal128 *yValue, int32_t yPrecision, int32_t yScale, bool *yIsNull, + bool *retIsNull, type::Decimal128 *output, int32_t newPrecision, int32_t newScale, int32_t rowCnt) +{ + if (xPrecision == yPrecision && xScale == yScale) { + for (int i = 0; i < rowCnt; i++) { + if (xIsNull[i] && yIsNull[i]) { + retIsNull[i] = true; + continue; + } + if (xIsNull[i] || (!yIsNull[i] && xValue[i] < yValue[i])) { + output[i] = yValue[i]; + continue; + } + output[i] = xValue[i]; + } + } else { + for (int i = 0; i < rowCnt; i++) { + if (xIsNull[i] && yIsNull[i]) { + retIsNull[i] = true; + continue; + } + Decimal128Wrapper x(xValue[i]); + x.SetScale(xScale); + Decimal128Wrapper y(yValue[i]); + y.SetScale(yScale); + if (xIsNull[i] || (!yIsNull[i] && x.Compare(y) < 0)) { + y.ReScale(newScale); + CHECK_OVERFLOW_CONTINUE(y, newPrecision); + output[i] = y.ToDecimal128(); + continue; + } + x.ReScale(newScale); + CHECK_OVERFLOW_CONTINUE(x, newPrecision); + output[i] = x.ToDecimal128(); + } + } +} + +extern "C" DLLEXPORT void BatchGreatestDecimal64RetNull(bool *isNull, int64_t *xValue, int32_t xPrecision, + int32_t xScale, bool *xIsNull, int64_t *yValue, int32_t yPrecision, int32_t yScale, bool *yIsNull, bool *retIsNull, + int64_t *output, int32_t newPrecision, int32_t newScale, int32_t rowCnt) +{ + if (xPrecision == yPrecision && xScale == yScale) { + for (int i = 0; i < rowCnt; i++) { + if (yIsNull[i] && xIsNull[i]) { + retIsNull[i] = true; + continue; + } + auto x = xValue[i]; + auto y = yValue[i]; + if (xIsNull[i] || (!yIsNull[i] && x < y)) { + output[i] = y; + continue; + } + output[i] = x; + } + } else { + for (int i = 0; i < rowCnt; i++) { + if (yIsNull[i] && xIsNull[i]) { + retIsNull[i] = true; + continue; + } + Decimal64 x(xValue[i]); + x.SetScale(xScale); + Decimal64 y(yValue[i]); + y.SetScale(yScale); + if (xIsNull[i] || (!yIsNull[i] && x.Compare(y) < 0)) { + y.ReScale(newScale); + CHECK_OVERFLOW_CONTINUE_NULL(y, newPrecision); + output[i] = y.GetValue(); + continue; + } + x.ReScale(newScale); + CHECK_OVERFLOW_CONTINUE_NULL(x, newPrecision); + output[i] = x.GetValue(); + } + } +} + +extern "C" DLLEXPORT void BatchGreatestDecimal128RetNull(bool *isNull, type::Decimal128 *xValue, int32_t xPrecision, + int32_t xScale, bool *xIsNull, type::Decimal128 *yValue, int32_t yPrecision, int32_t yScale, bool *yIsNull, + bool *retIsNull, type::Decimal128 *output, int32_t newPrecision, int32_t newScale, int32_t rowCnt) +{ + if (xPrecision == yPrecision && xScale == yScale) { + for (int i = 0; i < rowCnt; i++) { + if (yIsNull[i] && xIsNull[i]) { + retIsNull[i] = true; + continue; + } + if (xIsNull[i] || (!yIsNull[i] && xValue[i] < yValue[i])) { + output[i] = yValue[i]; + continue; + } + output[i] = xValue[i]; + } + } else { + for (int i = 0; i < rowCnt; i++) { + if (yIsNull[i] && xIsNull[i]) { + retIsNull[i] = true; + continue; + } + Decimal128Wrapper x(xValue[i]); + x.SetScale(xScale); + Decimal128Wrapper y(yValue[i]); + y.SetScale(yScale); + if (xIsNull[i] || (!yIsNull[i] && x.Compare(y) < 0)) { + y.ReScale(newScale); + CHECK_OVERFLOW_CONTINUE_NULL(y, newPrecision); + output[i] = y.ToDecimal128(); + continue; + } + x.ReScale(newScale); + CHECK_OVERFLOW_CONTINUE_NULL(x, newPrecision); + output[i] = x.ToDecimal128(); + } + } +} +} \ No newline at end of file diff --git a/core/src/codegen/batch_functions/batch_decimal_arithmetic_functions.h b/core/src/codegen/batch_functions/batch_decimal_arithmetic_functions.h new file mode 100644 index 0000000..79c9ea7 --- /dev/null +++ b/core/src/codegen/batch_functions/batch_decimal_arithmetic_functions.h @@ -0,0 +1,467 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: batch decimal functions implementation + */ + +#ifndef OMNI_RUNTIME_BATCH_DECIMAL_ARITHMETIC_FUNCTIONS_H +#define OMNI_RUNTIME_BATCH_DECIMAL_ARITHMETIC_FUNCTIONS_H + +#include +#include +#include "type/decimal128.h" +#include "type/decimal_operations.h" + +namespace omniruntime::codegen::function { +#ifdef _WIN32 +#define DLLEXPORT __declspec(dllexport) +#else +#define DLLEXPORT +#endif + +// type::Decimal128 compare +extern "C" DLLEXPORT void BatchDecimal128Compare(type::Decimal128 *x, int32_t xPrecision, int32_t xScale, + type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchLessThanDecimal128(type::Decimal128 *left, int32_t xPrecision, int32_t xScale, + type::Decimal128 *right, int32_t yPrecision, int32_t yScale, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchLessThanEqualDecimal128(type::Decimal128 *left, int32_t xPrecision, int32_t xScale, + type::Decimal128 *right, int32_t yPrecision, int32_t yScale, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchGreaterThanDecimal128(type::Decimal128 *left, int32_t xPrecision, int32_t xScale, + type::Decimal128 *right, int32_t yPrecision, int32_t yScale, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchGreaterThanEqualDecimal128(type::Decimal128 *left, int32_t xPrecision, int32_t xScale, + type::Decimal128 *right, int32_t yPrecision, int32_t yScale, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchEqualDecimal128(type::Decimal128 *left, int32_t xPrecision, int32_t xScale, + type::Decimal128 *right, int32_t yPrecision, int32_t yScale, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchNotEqualDecimal128(type::Decimal128 *left, int32_t xPrecision, int32_t xScale, + type::Decimal128 *right, int32_t yPrecision, int32_t yScale, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchAbsDecimal128(type::Decimal128 *x, int32_t xPrecision, int32_t xScale, bool *isNull, + type::Decimal128 *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchRoundDecimal128(int64_t contextPtr, type::Decimal128 *x, int32_t xPrecision, + int32_t xScale, int32_t *round, bool *isNull, type::Decimal128 *output, int32_t outPrecision, int32_t outScale, + int32_t rowCnt); + +extern "C" DLLEXPORT void BatchRoundDecimal64(int64_t contextPtr, int64_t *x, int32_t xPrecision, int32_t xScale, + int32_t *round, bool *isNull, int64_t *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchRoundDecimal128WithoutRound(int64_t contextPtr, type::Decimal128 *x, int32_t xPrecision, + int32_t xScale, bool *isNull, type::Decimal128 *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchRoundDecimal64WithoutRound(int64_t contextPtr, int64_t *x, int32_t xPrecision, + int32_t xScale, bool *isNull, int64_t *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +// decimal64 compare +extern "C" DLLEXPORT void BatchDecimal64Compare(int64_t *x, int32_t xPrecision, int32_t xScale, int64_t *y, + int32_t yPrecision, int32_t yScale, int32_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchAbsDecimal64(int64_t *x, int32_t xPrecision, int32_t xScale, bool *isNull, + int64_t *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchLessThanDecimal64(int64_t *left, int32_t xPrecision, int32_t xScale, int64_t *right, + int32_t yPrecision, int32_t yScale, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchLessThanEqualDecimal64(int64_t *left, int32_t xPrecision, int32_t xScale, int64_t *right, + int32_t yPrecision, int32_t yScale, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchGreaterThanDecimal64(int64_t *left, int32_t xPrecision, int32_t xScale, int64_t *right, + int32_t yPrecision, int32_t yScale, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchGreaterThanEqualDecimal64(int64_t *left, int32_t xPrecision, int32_t xScale, + int64_t *right, int32_t yPrecision, int32_t yScale, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchEqualDecimal64(int64_t *left, int32_t xPrecision, int32_t xScale, int64_t *right, + int32_t yPrecision, int32_t yScale, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchNotEqualDecimal64(int64_t *left, int32_t xPrecision, int32_t xScale, int64_t *right, + int32_t yPrecision, int32_t yScale, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchUnscaledValue64(int64_t *x, int32_t precision, int32_t scale, bool *isAnyNull, + int64_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchMakeDecimal64(int64_t contextPtr, int64_t *x, bool *isAnyNull, int64_t *output, + int32_t precision, int32_t scale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchMakeDecimal64RetNull(bool *isNull, int64_t *x, int64_t *output, int32_t precision, + int32_t scale, int32_t rowCnt); + +// Decimal Add Operator ReScale +extern "C" DLLEXPORT void BatchAddDec64Dec64Dec64ReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchAddDec64Dec64Dec128ReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, type::Decimal128 *output, + int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchAddDec128Dec128Dec128ReScale(int64_t contextPtr, bool *isNull, type::Decimal128 *x, + int32_t xPrecision, int32_t xScale, type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchAddDec64Dec128Dec128ReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchAddDec128Dec64Dec128ReScale(int64_t contextPtr, bool *isNull, type::Decimal128 *y, + int32_t yPrecision, int32_t yScale, int64_t *x, int32_t xPrecision, int32_t xScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +// Decimal Sub Operation ReScale +extern "C" DLLEXPORT void BatchSubDec64Dec64Dec64ReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchSubDec64Dec64Dec128ReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, type::Decimal128 *output, + int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchSubDec128Dec128Dec128ReScale(int64_t contextPtr, bool *isNull, type::Decimal128 *x, + int32_t xPrecision, int32_t xScale, type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchSubDec64Dec128Dec128ReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchSubDec128Dec64Dec128ReScale(int64_t contextPtr, bool *isNull, type::Decimal128 *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +// Decimal Mul Operation ReScale +extern "C" DLLEXPORT void BatchMulDec64Dec64Dec64ReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchMulDec64Dec64Dec128ReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, type::Decimal128 *output, + int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchMulDec128Dec128Dec128ReScale(int64_t contextPtr, bool *isNull, type::Decimal128 *x, + int32_t xPrecision, int32_t xScale, type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchMulDec64Dec128Dec128ReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchMulDec128Dec64Dec128ReScale(int64_t contextPtr, bool *isNull, type::Decimal128 *y, + int32_t yPrecision, int32_t yScale, int64_t *x, int32_t xPrecision, int32_t xScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +// Decimal Div Operation ReScale +extern "C" DLLEXPORT void BatchDivDec64Dec64Dec64ReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchDivDec64Dec128Dec64ReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchDivDec128Dec64Dec64ReScale(int64_t contextPtr, bool *isNull, type::Decimal128 *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchDivDec64Dec64Dec128ReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, type::Decimal128 *output, + int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchDivDec128Dec128Dec128ReScale(int64_t contextPtr, bool *isNull, type::Decimal128 *x, + int32_t xPrecision, int32_t xScale, type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchDivDec64Dec128Dec128ReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchDivDec128Dec64Dec128ReScale(int64_t contextPtr, bool *isNull, type::Decimal128 *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +// Decimal Mod Operation ReScale +extern "C" DLLEXPORT void BatchModDec64Dec64Dec64ReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchModDec64Dec128Dec64ReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchModDec128Dec64Dec64ReScale(int64_t contextPtr, bool *isNull, type::Decimal128 *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchModDec128Dec64Dec128ReScale(int64_t contextPtr, bool *isNull, type::Decimal128 *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchModDec128Dec128Dec128ReScale(int64_t contextPtr, bool *isNull, type::Decimal128 *x, + int32_t xPrecision, int32_t xScale, type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchModDec128Dec128Dec64ReScale(int64_t contextPtr, bool *isNull, type::Decimal128 *x, + int32_t xPrecision, int32_t xScale, type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int64_t *output, + int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchModDec64Dec128Dec128ReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +// Decimal Add Operator NotReScale +extern "C" DLLEXPORT void BatchAddDec64Dec64Dec64NotReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchAddDec64Dec64Dec128NotReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, type::Decimal128 *output, + int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchAddDec128Dec128Dec128NotReScale(int64_t contextPtr, bool *isNull, type::Decimal128 *x, + int32_t xPrecision, int32_t xScale, type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchAddDec64Dec128Dec128NotReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchAddDec128Dec64Dec128NotReScale(int64_t contextPtr, bool *isNull, type::Decimal128 *y, + int32_t yPrecision, int32_t yScale, int64_t *x, int32_t xPrecision, int32_t xScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +// Decimal Sub Operation NotReScale +extern "C" DLLEXPORT void BatchSubDec64Dec64Dec64NotReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchSubDec64Dec64Dec128NotReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, type::Decimal128 *output, + int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchSubDec128Dec128Dec128NotReScale(int64_t contextPtr, bool *isNull, type::Decimal128 *x, + int32_t xPrecision, int32_t xScale, type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchSubDec64Dec128Dec128NotReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchSubDec128Dec64Dec128NotReScale(int64_t contextPtr, bool *isNull, type::Decimal128 *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +// Decimal Mul Operation NotReScale +extern "C" DLLEXPORT void BatchMulDec64Dec64Dec64NotReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchMulDec64Dec64Dec128NotReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, type::Decimal128 *output, + int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchMulDec128Dec128Dec128NotReScale(int64_t contextPtr, bool *isNull, type::Decimal128 *x, + int32_t xPrecision, int32_t xScale, type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchMulDec64Dec128Dec128NotReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchMulDec128Dec64Dec128NotReScale(int64_t contextPtr, bool *isNull, type::Decimal128 *y, + int32_t yPrecision, int32_t yScale, int64_t *x, int32_t xPrecision, int32_t xScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +// Decimal Div Operation NotReScale +extern "C" DLLEXPORT void BatchDivDec64Dec64Dec64NotReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchDivDec64Dec128Dec64NotReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchDivDec128Dec64Dec64NotReScale(int64_t contextPtr, bool *isNull, type::Decimal128 *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchDivDec64Dec64Dec128NotReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, type::Decimal128 *output, + int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchDivDec128Dec128Dec128NotReScale(int64_t contextPtr, bool *isNull, type::Decimal128 *x, + int32_t xPrecision, int32_t xScale, type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchDivDec64Dec128Dec128NotReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchDivDec128Dec64Dec128NotReScale(int64_t contextPtr, bool *isNull, type::Decimal128 *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +// Decimal Mod Operation NotReScale +extern "C" DLLEXPORT void BatchModDec64Dec64Dec64NotReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchModDec64Dec128Dec64NotReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchModDec128Dec64Dec64NotReScale(int64_t contextPtr, bool *isNull, type::Decimal128 *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchModDec128Dec64Dec128NotReScale(int64_t contextPtr, bool *isNull, type::Decimal128 *x, + int32_t xPrecision, int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchModDec128Dec128Dec128NotReScale(int64_t contextPtr, bool *isNull, type::Decimal128 *x, + int32_t xPrecision, int32_t xScale, type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchModDec128Dec128Dec64NotReScale(int64_t contextPtr, bool *isNull, type::Decimal128 *x, + int32_t xPrecision, int32_t xScale, type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int64_t *output, + int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchModDec64Dec128Dec128NotReScale(int64_t contextPtr, bool *isNull, int64_t *x, + int32_t xPrecision, int32_t xScale, type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +// add ret null +extern "C" DLLEXPORT void BatchAddDec64Dec64Dec64RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchAddDec64Dec64Dec128RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + int64_t *y, int32_t yPrecision, int32_t yScale, type::Decimal128 *output, int32_t outPrecision, int32_t outScale, + int32_t rowCnt); + +extern "C" DLLEXPORT void BatchAddDec128Dec128Dec128RetNull(bool *isNull, type::Decimal128 *x, int32_t xPrecision, + int32_t xScale, type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int32_t rowCnt); + +extern "C" DLLEXPORT void BatchAddDec64Dec128Dec128RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchAddDec128Dec64Dec128RetNull(bool *isNull, type::Decimal128 *y, int32_t yPrecision, + int32_t yScale, int64_t *x, int32_t xPrecision, int32_t xScale, int32_t outPrecision, int32_t outScale, + int32_t rowCnt); + +// sub ret null +extern "C" DLLEXPORT void BatchSubDec64Dec64Dec64RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchSubDec64Dec64Dec128RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + int64_t *y, int32_t yPrecision, int32_t yScale, type::Decimal128 *output, int32_t outPrecision, int32_t outScale, + int32_t rowCnt); + +extern "C" DLLEXPORT void BatchSubDec128Dec128Dec128RetNull(bool *isNull, type::Decimal128 *x, int32_t xPrecision, + int32_t xScale, type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int32_t rowCnt); + +extern "C" DLLEXPORT void BatchSubDec64Dec128Dec128RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchSubDec128Dec64Dec128RetNull(bool *isNull, type::Decimal128 *x, int32_t xPrecision, + int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int32_t rowCnt); + +// mul ret null +extern "C" DLLEXPORT void BatchMulDec64Dec64Dec64RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchMulDec64Dec64Dec128RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + int64_t *y, int32_t yPrecision, int32_t yScale, type::Decimal128 *output, int32_t outPrecision, int32_t outScale, + int32_t rowCnt); + +extern "C" DLLEXPORT void BatchMulDec128Dec128Dec128RetNull(bool *isNull, type::Decimal128 *x, int32_t xPrecision, + int32_t xScale, type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int32_t rowCnt); + +extern "C" DLLEXPORT void BatchMulDec64Dec128Dec128RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchMulDec128Dec64Dec128RetNull(bool *isNull, type::Decimal128 *x, int32_t xPrecision, + int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int32_t rowCnt); + +// div ret null +extern "C" DLLEXPORT void BatchDivDec64Dec64Dec64RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchDivDec64Dec128Dec64RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchDivDec128Dec64Dec64RetNull(bool *isNull, type::Decimal128 *x, int32_t xPrecision, + int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int32_t rowCnt); + +extern "C" DLLEXPORT void BatchDivDec64Dec64Dec128RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + int64_t *y, int32_t yPrecision, int32_t yScale, type::Decimal128 *output, int32_t outPrecision, int32_t outScale, + int32_t rowCnt); + +extern "C" DLLEXPORT void BatchDivDec128Dec128Dec128RetNull(bool *isNull, type::Decimal128 *x, int32_t xPrecision, + int32_t xScale, type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int32_t rowCnt); + +extern "C" DLLEXPORT void BatchDivDec64Dec128Dec128RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchDivDec128Dec64Dec128RetNull(bool *isNull, type::Decimal128 *x, int32_t xPrecision, + int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int32_t rowCnt); + +// mod ret null +extern "C" DLLEXPORT void BatchModDec64Dec64Dec64RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchModDec64Dec128Dec64RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchModDec128Dec64Dec64RetNull(bool *isNull, type::Decimal128 *x, int32_t xPrecision, + int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int32_t rowCnt); + +extern "C" DLLEXPORT void BatchModDec128Dec64Dec128RetNull(bool *isNull, type::Decimal128 *x, int32_t xPrecision, + int32_t xScale, int64_t *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int32_t rowCnt); + +extern "C" DLLEXPORT void BatchModDec128Dec128Dec128RetNull(bool *isNull, type::Decimal128 *x, int32_t xPrecision, + int32_t xScale, type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int32_t rowCnt); + +extern "C" DLLEXPORT void BatchModDec128Dec128Dec64RetNull(bool *isNull, type::Decimal128 *x, int32_t xPrecision, + int32_t xScale, type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int64_t *output, int32_t outPrecision, + int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchModDec64Dec128Dec128RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + type::Decimal128 *y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchRoundDecimal128RetNull(bool *isNull, type::Decimal128 *x, int32_t xPrecision, + int32_t xScale, int32_t *round, type::Decimal128 *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchRoundDecimal64RetNull(bool *isNull, int64_t *x, int32_t xPrecision, int32_t xScale, + int32_t *round, int64_t *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchGreatestDecimal64(int64_t contextPtr, int64_t *xValue, int32_t xPrecision, + int32_t xScale, bool *xIsNull, int64_t *yValue, int32_t yPrecision, int32_t yScale, bool *yIsNull, bool *retIsNull, + int64_t *output, int32_t newPrecision, int32_t newScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchGreatestDecimal128(int64_t contextPtr, type::Decimal128 *xValue, int32_t xPrecision, + int32_t xScale, bool *xIsNull, type::Decimal128 *yValue, int32_t yPrecision, int32_t yScale, bool *yIsNull, + bool *retIsNull, type::Decimal128 *output, int32_t newPrecision, int32_t newScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchGreatestDecimal64RetNull(bool *isNull, int64_t *xValue, int32_t xPrecision, + int32_t xScale, bool *xIsNull, int64_t *yValue, int32_t yPrecision, int32_t yScale, bool *yIsNull, bool *retIsNull, + int64_t *output, int32_t newPrecision, int32_t newScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchGreatestDecimal128RetNull(bool *isNull, type::Decimal128 *xValue, int32_t xPrecision, + int32_t xScale, bool *xIsNull, type::Decimal128 *yValue, int32_t yPrecision, int32_t yScale, bool *yIsNull, + bool *retIsNull, type::Decimal128 *output, int32_t newPrecision, int32_t newScale, int32_t rowCnt); +} + +#endif // OMNI_RUNTIME_BATCH_DECIMAL_ARITHMETIC_FUNCTIONS_H \ No newline at end of file diff --git a/core/src/codegen/batch_functions/batch_decimal_cast_functions.cpp b/core/src/codegen/batch_functions/batch_decimal_cast_functions.cpp new file mode 100644 index 0000000..127e824 --- /dev/null +++ b/core/src/codegen/batch_functions/batch_decimal_cast_functions.cpp @@ -0,0 +1,500 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. + * Description: batch decimal functions implementation + */ + +#include "batch_decimal_cast_functions.h" +#include +#include +#include "codegen/context_helper.h" +#include "type/decimal_operations.h" + +namespace omniruntime::codegen::function { +using namespace omniruntime::type; + +// Round towards "nearest neighbor" unless both neighbors are equidistant, in which case round up. +extern "C" DLLEXPORT void BatchCastDecimal64ToIntHalfUp(int64_t contextPtr, int64_t *x, int32_t precision, + int32_t scale, const bool *isAnyNull, int32_t *output, int32_t rowCnt) +{ + int32_t result = 0; + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + continue; + } + auto status = Decimal128Wrapper(x[i]).SetScale(scale).ToInt(result); + if (status != type::OpStatus::SUCCESS) { + SetError(contextPtr, CastErrorMessage(OMNI_DECIMAL64, OMNI_INT, x[i], OpStatus::SUCCESS, precision, scale)); + continue; + } + output[i] = result; + } +} + +// Round towards zero. +extern "C" DLLEXPORT void BatchCastDecimal64ToIntDown(int64_t contextPtr, int64_t *x, int32_t precision, int32_t scale, + const bool *isAnyNull, int32_t *output, int32_t rowCnt) +{ + int64_t scaledValue; + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + continue; + } + scaledValue = Decimal64(x[i]).SetScale(scale).ReScale(0, RoundingMode::ROUND_FLOOR).GetValue(); + if (scaledValue < INT_MIN || scaledValue > INT_MAX) { + SetError(contextPtr, CastErrorMessage(OMNI_DECIMAL64, OMNI_INT, x[i], OpStatus::SUCCESS, precision, scale)); + continue; + } + output[i] = static_cast(scaledValue); + } +} + +extern "C" DLLEXPORT void BatchCastDecimal64ToLongDown(int64_t *x, int32_t precision, int32_t scale, + const bool *isAnyNull, int64_t *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + continue; + } + output[i] = Decimal64(x[i]).SetScale(scale).ReScale(0, RoundingMode::ROUND_FLOOR).GetValue(); + } +} + +extern "C" DLLEXPORT void BatchCastDecimal64ToLongHalfUp(int64_t *x, int32_t precision, int32_t scale, + const bool *isAnyNull, int64_t *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + continue; + } + output[i] = static_cast(Decimal64(x[i]).SetScale(scale)); + } +} + +extern "C" DLLEXPORT void BatchCastDecimal64ToDoubleDown(const int64_t *x, int32_t precision, int32_t scale, + const bool *isAnyNull, double *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + continue; + } + std::string doubleString = Decimal64(x[i]).SetScale(scale).ToString(); + output[i] = stod(doubleString); + } +} + +extern "C" DLLEXPORT void BatchCastDecimal64ToDoubleHalfUp(const int64_t *x, int32_t precision, int32_t scale, + const bool *isAnyNull, double *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + continue; + } + output[i] = static_cast(Decimal64(x[i]).SetScale(scale)); + } +} + +extern "C" DLLEXPORT void BatchCastIntToDecimal64(int64_t contextPtr, int32_t *x, const bool *isAnyNull, + int64_t *output, int32_t precision, int32_t scale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + continue; + } + Decimal64 result(x[i]); + result.ReScale(scale); + if (result.IsOverflow(precision) != OpStatus::SUCCESS) { + SetError(contextPtr, CastErrorMessage(OMNI_INT, OMNI_DECIMAL64, x[i], OpStatus::SUCCESS, precision, scale)); + continue; + } + output[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchCastLongToDecimal64(int64_t contextPtr, int64_t *x, const bool *isAnyNull, + int64_t *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + continue; + } + Decimal64 result(x[i]); + result.ReScale(outScale); + if (result.IsOverflow(outPrecision) != OpStatus::SUCCESS) { + SetError(contextPtr, + CastErrorMessage(OMNI_LONG, OMNI_DECIMAL64, x[i], OpStatus::SUCCESS, outPrecision, outScale)); + continue; + } + output[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchCastDoubleToDecimal64(int64_t contextPtr, double *x, const bool *isAnyNull, + int64_t *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + continue; + } + Decimal64 result(x[i]); + result.ReScale(outScale); + if (result.IsOverflow(outPrecision) != OpStatus::SUCCESS) { + SetError(contextPtr, + CastErrorMessage(OMNI_DOUBLE, OMNI_DECIMAL64, x[i], OpStatus::SUCCESS, outPrecision, outScale)); + continue; + } + output[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchCastDecimal128ToInt(int64_t contextPtr, Decimal128 *x, int32_t precision, int32_t scale, + const bool *isAnyNull, int32_t *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + continue; + } + int32_t result = 0; + auto status = Decimal128Wrapper(x[i]).SetScale(scale).ToInt(result); + if (status != type::OpStatus::SUCCESS) { + SetError(contextPtr, + CastErrorMessage(OMNI_DECIMAL128, OMNI_INT, x[i].ToInt128(), OpStatus::OP_OVERFLOW, precision, scale)); + continue; + } + output[i] = result; + } +} + +extern "C" DLLEXPORT void BatchCastDecimal128ToLong(int64_t contextPtr, Decimal128 *x, int32_t precision, int32_t scale, + const bool *isAnyNull, int64_t *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + continue; + } + int64_t result = 0; + auto status = Decimal128Wrapper(x[i]).SetScale(scale).ToLong(result); + if (status != type::OpStatus::SUCCESS) { + SetError(contextPtr, + CastErrorMessage(OMNI_DECIMAL128, OMNI_LONG, x[i].ToInt128(), OpStatus::OP_OVERFLOW, precision, scale)); + continue; + } + output[i] = result; + } +} + +extern "C" DLLEXPORT void BatchCastDecimal128ToDoubleDown(Decimal128 *x, int32_t precision, int32_t scale, + const bool *isAnyNull, double *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + continue; + } + output[i] = static_cast(Decimal128Wrapper(x[i])) / DOUBLE_10_POW[scale]; + } +} + +extern "C" DLLEXPORT void BatchCastDecimal128ToDoubleHalfUp(Decimal128 *x, int32_t precision, int32_t scale, + const bool *isAnyNull, double *output, int32_t rowCnt) +{ + std::string doubleString; + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + continue; + } + doubleString = Decimal128Wrapper(x[i]).SetScale(scale).ToString(); + output[i] = stod(doubleString); + } +} + +extern "C" DLLEXPORT void BatchCastIntToDecimal128(int64_t contextPtr, int32_t *x, const bool *isAnyNull, + Decimal128 *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + continue; + } + Decimal128Wrapper result(x[i]); + result.ReScale(outScale); + if (result.IsOverflow(outPrecision) != OpStatus::SUCCESS) { + SetError(contextPtr, + CastErrorMessage(OMNI_INT, OMNI_DECIMAL128, x[i], OpStatus::OP_OVERFLOW, outPrecision, outScale)); + continue; + } + output[i] = result.ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchCastLongToDecimal128(int64_t contextPtr, int64_t *x, const bool *isAnyNull, + Decimal128 *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + continue; + } + Decimal128Wrapper result(x[i]); + result.ReScale(outScale); + if (result.IsOverflow(outPrecision) != OpStatus::SUCCESS) { + SetError(contextPtr, + CastErrorMessage(OMNI_LONG, OMNI_DECIMAL128, x[i], OpStatus::OP_OVERFLOW, outPrecision, outScale)); + continue; + } + output[i] = result.ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchCastDoubleToDecimal128(int64_t contextPtr, double *x, bool *isAnyNull, + Decimal128 *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + continue; + } + Decimal128Wrapper result(x[i]); + result.ReScale(outScale); + if (result.IsOverflow(outPrecision) != OpStatus::SUCCESS) { + SetError(contextPtr, + CastErrorMessage(OMNI_BOOLEAN, OMNI_DECIMAL128, *x, OpStatus::SUCCESS, outPrecision, outScale)); + continue; + } + output[i] = result.ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchCastDecimal64To64(int64_t contextPtr, int64_t *x, int32_t precision, int32_t scale, + const bool *isAnyNull, int64_t *output, int32_t newPrecision, int32_t newScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + continue; + } + Decimal64 result(x[i]); + result.SetScale(scale).ReScale(newScale); + if (result.IsOverflow(newPrecision) != OpStatus::SUCCESS) { + SetError(contextPtr, CastErrorMessage(OMNI_DECIMAL64, OMNI_DECIMAL64, x[i], OpStatus::SUCCESS, precision, + scale, newPrecision, newScale)); + continue; + } + output[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchCastDecimal128To128(int64_t contextPtr, Decimal128 *x, int32_t precision, int32_t scale, + const bool *isAnyNull, Decimal128 *output, int32_t newPrecision, int32_t newScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + continue; + } + Decimal128Wrapper result = Decimal128Wrapper(x[i]).SetScale(scale).ReScale(newScale); + if (result.IsOverflow(newPrecision) != OpStatus::SUCCESS) { + SetError(contextPtr, CastErrorMessage(OMNI_DECIMAL128, OMNI_DECIMAL128, x[i].ToInt128(), OpStatus::SUCCESS, + precision, scale, newPrecision, newScale)); + continue; + } + output[i] = result.ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchCastDecimal64To128(int64_t contextPtr, int64_t *x, int32_t precision, int32_t scale, + const bool *isAnyNull, Decimal128 *output, int32_t newPrecision, int32_t newScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + continue; + } + Decimal128Wrapper result(x[i]); + result.SetScale(scale).ReScale(newScale); + if (result.IsOverflow(newPrecision) != OpStatus::SUCCESS) { + SetError(contextPtr, CastErrorMessage(OMNI_DECIMAL64, OMNI_DECIMAL128, x[i], OpStatus::SUCCESS, precision, + scale, newPrecision, newScale)); + continue; + } + output[i] = result.ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchCastDecimal128To64(int64_t contextPtr, Decimal128 *x, int32_t precision, int32_t scale, + const bool *isAnyNull, int64_t *output, int32_t newPrecision, int32_t newScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + continue; + } + Decimal128Wrapper decimal128 = Decimal128Wrapper(x[i]).SetScale(scale).ReScale(newScale); + Decimal64 result(decimal128); + if (result.IsOverflow(newPrecision) != OpStatus::SUCCESS) { + SetError(contextPtr, CastErrorMessage(OMNI_DECIMAL128, OMNI_DECIMAL64, x[i].ToInt128(), OpStatus::SUCCESS, + precision, scale, newPrecision, newScale)); + continue; + } + output[i] = result.GetValue(); + } +} + +// return null +extern "C" DLLEXPORT void BatchCastDecimal64ToIntRetNull(bool *isNull, int64_t *x, int32_t precision, int32_t scale, + int32_t *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + output[i] = static_cast(Decimal64(x[i]).ReScale(-scale, RoundingMode::ROUND_FLOOR).GetValue()); + } +} + +extern "C" DLLEXPORT void BatchCastDecimal64ToLongRetNull(bool *isNull, int64_t *x, int32_t precision, int32_t scale, + int64_t *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + output[i] = static_cast(Decimal64(x[i]).ReScale(-scale, RoundingMode::ROUND_FLOOR).GetValue()); + } +} + +extern "C" DLLEXPORT void BatchCastDecimal64ToDoubleRetNull(bool *isNull, const int64_t *x, int32_t precision, + int32_t scale, double *output, int32_t rowCnt) +{ + double result; + for (int i = 0; i < rowCnt; ++i) { + std::string doubleString = Decimal64(x[i]).SetScale(scale).ToString(); + ConvertStringToDouble(result, doubleString); + output[i] = result; + } +} + +extern "C" DLLEXPORT void BatchCastIntToDecimal64RetNull(bool *isNull, int32_t *x, int64_t *output, int32_t precision, + int32_t scale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + Decimal64 result(x[i]); + result.ReScale(scale); + CHECK_OVERFLOW_CONTINUE_NULL(result, precision); + output[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchCastLongToDecimal64RetNull(bool *isNull, int64_t *x, int64_t *output, + int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + Decimal64 result(x[i]); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + output[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchCastDoubleToDecimal64RetNull(bool *isNull, double *x, int64_t *output, + int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + Decimal64 result(x[i]); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + output[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchCastDecimal128ToIntRetNull(bool *isNull, Decimal128 *x, int32_t precision, int32_t scale, + int32_t *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + output[i] = static_cast(Decimal128Wrapper(x[i]).ReScale(-scale, RoundingMode::ROUND_FLOOR) + .ToInt128()); + } +} + +extern "C" DLLEXPORT void BatchCastDecimal128ToLongRetNull(bool *isNull, Decimal128 *x, int32_t precision, + int32_t scale, int64_t *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + output[i] = static_cast(Decimal128Wrapper(x[i]).ReScale(-scale, RoundingMode::ROUND_FLOOR) + .ToInt128()); + } +} + +extern "C" DLLEXPORT void BatchCastDecimal128ToDoubleRetNull(bool *isNull, Decimal128 *x, int32_t precision, + int32_t scale, double *output, int32_t rowCnt) +{ + double result; + std::string doubleString; + for (int i = 0; i < rowCnt; ++i) { + doubleString = Decimal128Wrapper(x[i]).SetScale(scale).ToString(); + ConvertStringToDouble(result, doubleString); + output[i] = result; + } +} + +extern "C" DLLEXPORT void BatchCastIntToDecimal128RetNull(bool *isNull, int32_t *x, Decimal128 *output, + int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + Decimal128Wrapper result(x[i]); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + output[i] = result.ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchCastLongToDecimal128RetNull(bool *isNull, int64_t *x, Decimal128 *output, + int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + Decimal128Wrapper result(x[i]); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + output[i] = result.ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchCastDoubleToDecimal128RetNull(bool *isNull, double *x, Decimal128 *output, + int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + Decimal128Wrapper result(x[i]); + result.ReScale(outScale); + CHECK_OVERFLOW_CONTINUE_NULL(result, outPrecision); + output[i] = result.ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchCastDecimal64To64RetNull(bool *isNull, int64_t *x, int32_t precision, int32_t scale, + int64_t *output, int32_t newPrecision, int32_t newScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + Decimal64 result(x[i]); + result.SetScale(scale).ReScale(newScale); + CHECK_OVERFLOW_CONTINUE_NULL(result, newPrecision); + output[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchCastDecimal128To128RetNull(bool *isNull, Decimal128 *x, int32_t precision, int32_t scale, + Decimal128 *output, int32_t newPrecision, int32_t newScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + auto result = Decimal128Wrapper(x[i]).SetScale(scale).ReScale(newScale); + CHECK_OVERFLOW_CONTINUE_NULL(result, newPrecision); + output[i] = result.ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchCastDecimal64To128RetNull(bool *isNull, int64_t *x, int32_t precision, int32_t scale, + Decimal128 *output, int32_t newPrecision, int32_t newScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + Decimal128Wrapper result(x[i]); + result.SetScale(scale).ReScale(newScale); + CHECK_OVERFLOW_CONTINUE_NULL(result, newPrecision); + output[i] = result.ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchCastDecimal128To64RetNull(bool *isNull, Decimal128 *x, int32_t precision, int32_t scale, + int64_t *output, int32_t newPrecision, int32_t newScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + auto decimal128 = Decimal128Wrapper(x[i]).SetScale(scale).ReScale(newScale); + Decimal64 result(decimal128); + CHECK_OVERFLOW_CONTINUE_NULL(result, newPrecision); + output[i] = result.GetValue(); + } +} +} \ No newline at end of file diff --git a/core/src/codegen/batch_functions/batch_decimal_cast_functions.h b/core/src/codegen/batch_functions/batch_decimal_cast_functions.h new file mode 100644 index 0000000..11b093e --- /dev/null +++ b/core/src/codegen/batch_functions/batch_decimal_cast_functions.h @@ -0,0 +1,134 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: batch decimal functions implementation + */ + +#ifndef OMNI_RUNTIME_BATCH_DECIMAL_CAST_FUNCTIONS_H +#define OMNI_RUNTIME_BATCH_DECIMAL_CAST_FUNCTIONS_H + +#include +#include +#include "type/decimal128.h" +#include "type/decimal_operations.h" +#include "type/data_type.h" + +#ifdef _WIN32 +#define DLLEXPORT __declspec(dllexport) +#else +#define DLLEXPORT +#endif + +namespace omniruntime::codegen::function { +using namespace omniruntime::type; + +// Cast Function +extern "C" DLLEXPORT void BatchCastDecimal64To64(int64_t contextPtr, int64_t *x, int32_t precision, int32_t scale, + const bool *isAnyNull, int64_t *output, int32_t newPrecision, int32_t newScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDecimal128To128(int64_t contextPtr, Decimal128 *x, int32_t precision, int32_t scale, + const bool *isAnyNull, Decimal128 *output, int32_t newPrecision, int32_t newScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDecimal64To128(int64_t contextPtr, int64_t *x, int32_t precision, int32_t scale, + const bool *isAnyNull, Decimal128 *output, int32_t newPrecision, int32_t newScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDecimal128To64(int64_t contextPtr, Decimal128 *x, int32_t precision, int32_t scale, + const bool *isAnyNull, int64_t *output, int32_t newPrecision, int32_t newScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastIntToDecimal64(int64_t contextPtr, int32_t *x, const bool *isAnyNull, + int64_t *output, int32_t precision, int32_t scale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastLongToDecimal64(int64_t contextPtr, int64_t *x, const bool *isAnyNull, + int64_t *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDoubleToDecimal64(int64_t contextPtr, double *x, const bool *isAnyNull, + int64_t *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastIntToDecimal128(int64_t contextPtr, int32_t *x, const bool *isAnyNull, + Decimal128 *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastLongToDecimal128(int64_t contextPtr, int64_t *x, const bool *isAnyNull, + Decimal128 *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDoubleToDecimal128(int64_t contextPtr, double *x, bool *isAnyNull, + Decimal128 *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDecimal64ToIntDown(int64_t contextPtr, int64_t *x, int32_t precision, int32_t scale, + const bool *isAnyNull, int32_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDecimal64ToLongDown(int64_t *x, int32_t precision, int32_t scale, + const bool *isAnyNull, int64_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDecimal64ToDoubleDown(const int64_t *x, int32_t precision, int32_t scale, + const bool *isAnyNull, double *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDecimal64ToIntHalfUp(int64_t contextPtr, int64_t *x, int32_t precision, + int32_t scale, const bool *isAnyNull, int32_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDecimal64ToLongHalfUp(int64_t *x, int32_t precision, int32_t scale, + const bool *isAnyNull, int64_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDecimal64ToDoubleHalfUp(const int64_t *x, int32_t precision, int32_t scale, + const bool *isAnyNull, double *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDecimal128ToInt(int64_t contextPtr, Decimal128 *x, int32_t precision, int32_t scale, + const bool *isAnyNull, int32_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDecimal128ToLong(int64_t contextPtr, Decimal128 *x, int32_t precision, int32_t scale, + const bool *isAnyNull, int64_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDecimal128ToDoubleDown(Decimal128 *x, int32_t precision, int32_t scale, + const bool *isAnyNull, double *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDecimal128ToDoubleHalfUp(Decimal128 *x, int32_t precision, int32_t scale, + const bool *isAnyNull, double *output, int32_t rowCnt); + +// Cast Function Return Null +extern "C" DLLEXPORT void BatchCastDecimal64To64RetNull(bool *isNull, int64_t *x, int32_t precision, int32_t scale, + int64_t *output, int32_t newPrecision, int32_t newScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDecimal128To128RetNull(bool *isNull, Decimal128 *x, int32_t precision, int32_t scale, + Decimal128 *output, int32_t newPrecision, int32_t newScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDecimal64To128RetNull(bool *isNull, int64_t *x, int32_t precision, int32_t scale, + Decimal128 *output, int32_t newPrecision, int32_t newScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDecimal128To64RetNull(bool *isNull, Decimal128 *x, int32_t precision, int32_t scale, + int64_t *output, int32_t newPrecision, int32_t newScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastIntToDecimal64RetNull(bool *isNull, int32_t *x, int64_t *output, int32_t precision, + int32_t scale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastLongToDecimal64RetNull(bool *isNull, int64_t *x, int64_t *output, + int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDoubleToDecimal64RetNull(bool *isNull, double *x, int64_t *output, + int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastIntToDecimal128RetNull(bool *isNull, int32_t *x, Decimal128 *output, + int32_t outPrecision, int32_t outScale, int32_t rowCnt); +extern "C" DLLEXPORT void BatchCastLongToDecimal128RetNull(bool *isNull, int64_t *x, Decimal128 *output, + int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDoubleToDecimal128RetNull(bool *isNull, double *x, Decimal128 *output, + int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDecimal64ToIntRetNull(bool *isNull, int64_t *x, int32_t precision, int32_t scale, + int32_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDecimal64ToLongRetNull(bool *isNull, int64_t *x, int32_t precision, int32_t scale, + int64_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDecimal64ToDoubleRetNull(bool *isNull, const int64_t *x, int32_t precision, + int32_t scale, double *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDecimal128ToIntRetNull(bool *isNull, Decimal128 *x, int32_t precision, int32_t scale, + int32_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDecimal128ToLongRetNull(bool *isNull, Decimal128 *x, int32_t precision, + int32_t scale, int64_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDecimal128ToDoubleRetNull(bool *isNull, Decimal128 *x, int32_t precision, + int32_t scale, double *output, int32_t rowCnt); +} + +#endif // OMNI_RUNTIME_BATCH_DECIMAL_CAST_FUNCTIONS_H diff --git a/core/src/codegen/batch_functions/batch_dictionaryfunctions.cpp b/core/src/codegen/batch_functions/batch_dictionaryfunctions.cpp new file mode 100644 index 0000000..c1008dd --- /dev/null +++ b/core/src/codegen/batch_functions/batch_dictionaryfunctions.cpp @@ -0,0 +1,126 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: batch dictionary functions implementation + */ + +#include "batch_dictionaryfunctions.h" +#include "vector/vector.h" +#include "codegen/context_helper.h" + +using namespace omniruntime::vec; + +namespace omniruntime::codegen::function { +extern "C" DLLEXPORT void BatchGetIntFromDictionaryVector(int64_t dictionaryVectorAddr, int32_t *rowIdxArray, + int32_t rowCnt, int32_t *output) +{ + auto dictionaryVectorPtr = reinterpret_cast> *>(dictionaryVectorAddr); + for (int i = 0; i < rowCnt; ++i) { + output[i] = dictionaryVectorPtr->GetValue(rowIdxArray[i]); + } +} + +extern "C" DLLEXPORT void BatchGetLongFromDictionaryVector(int64_t dictionaryVectorAddr, int32_t *rowIdxArray, + int32_t rowCnt, int64_t *output) +{ + auto dictionaryVectorPtr = reinterpret_cast> *>(dictionaryVectorAddr); + for (int i = 0; i < rowCnt; ++i) { + output[i] = dictionaryVectorPtr->GetValue(rowIdxArray[i]); + } +} + +extern "C" DLLEXPORT void BatchGetDoubleFromDictionaryVector(int64_t dictionaryVectorAddr, int32_t *rowIdxArray, + int32_t rowCnt, double *output) +{ + auto dictionaryVectorPtr = reinterpret_cast> *>(dictionaryVectorAddr); + for (int i = 0; i < rowCnt; ++i) { + output[i] = dictionaryVectorPtr->GetValue(rowIdxArray[i]); + } +} + +extern "C" DLLEXPORT void BatchGetBooleanFromDictionaryVector(int64_t dictionaryVectorAddr, int32_t *rowIdxArray, + int32_t rowCnt, bool *output) +{ + auto dictionaryVectorPtr = reinterpret_cast> *>(dictionaryVectorAddr); + for (int i = 0; i < rowCnt; ++i) { + output[i] = dictionaryVectorPtr->GetValue(rowIdxArray[i]); + } +} + +extern "C" DLLEXPORT void BatchGetVarcharFromDictionaryVector(int64_t contextPtr, int64_t dictionaryVectorAddr, + int32_t *rowIdxArray, int32_t rowCnt, uint8_t **str, int32_t *length) +{ + auto dictionaryVectorPtr = reinterpret_cast> *>(dictionaryVectorAddr); + for (int i = 0; i < rowCnt; ++i) { + auto stringView = dictionaryVectorPtr->GetValue(rowIdxArray[i]); + length[i] = stringView.length(); + str[i] = reinterpret_cast(const_cast(stringView.data())); + } +} + +extern "C" DLLEXPORT void BatchGetDecimalFromDictionaryVector(int64_t dictionaryVectorAddr, int32_t *rowIdxArray, + int32_t rowCnt, Decimal128 *output) +{ + auto dictionaryVectorPtr = reinterpret_cast> *>(dictionaryVectorAddr); + for (int i = 0; i < rowCnt; ++i) { + output[i] = dictionaryVectorPtr->GetValue(rowIdxArray[i]); + } +} + +extern "C" DLLEXPORT void BatchGetIntFromVector(int32_t *vector, int32_t *rowIdxArray, int32_t rowCnt, int32_t *output) +{ + for (int i = 0; i < rowCnt; ++i) { + output[i] = vector[rowIdxArray[i]]; + } +} + +extern "C" DLLEXPORT void BatchGetLongFromVector(int64_t *vector, int32_t *rowIdxArray, int32_t rowCnt, int64_t *output) +{ + for (int i = 0; i < rowCnt; ++i) { + output[i] = vector[rowIdxArray[i]]; + } +} + +extern "C" DLLEXPORT void BatchGetDoubleFromVector(double *vector, int32_t *rowIdxArray, int32_t rowCnt, double *output) +{ + for (int i = 0; i < rowCnt; ++i) { + output[i] = vector[rowIdxArray[i]]; + } +} + +extern "C" DLLEXPORT void BatchGetBooleanFromVector(bool *vector, int32_t *rowIdxArray, int32_t rowCnt, bool *output) +{ + for (int i = 0; i < rowCnt; ++i) { + output[i] = vector[rowIdxArray[i]]; + } +} + +extern "C" DLLEXPORT void BatchGetVarcharFromVector(int64_t contextPtr, int32_t *offsetArray, const char *vector, + int32_t *rowIdxArray, int32_t rowCnt, uint8_t **str, int32_t *length) +{ + errno_t err; + char *ret; + for (int i = 0; i < rowCnt; ++i) { + length[i] = offsetArray[rowIdxArray[i] + 1] - offsetArray[rowIdxArray[i]]; + if (length[i] == 0) { + str[i] = (uint8_t *)""; + continue; + } + ret = ArenaAllocatorMalloc(contextPtr, length[i]); + err = memcpy_s(ret, length[i], vector + offsetArray[rowIdxArray[i]], length[i]); + if (err != EOK) { + SetError(contextPtr, "Get string from vector failed"); + str[i] = nullptr; + continue; + } + str[i] = reinterpret_cast(ret); + } +} + +extern "C" DLLEXPORT void BatchGetDecimalFromVector(Decimal128 *vector, int32_t *rowIdxArray, int32_t rowCnt, + Decimal128 *output) +{ + for (int i = 0; i < rowCnt; ++i) { + output[i] = vector[rowIdxArray[i]]; + } +} +} \ No newline at end of file diff --git a/core/src/codegen/batch_functions/batch_dictionaryfunctions.h b/core/src/codegen/batch_functions/batch_dictionaryfunctions.h new file mode 100644 index 0000000..8905c56 --- /dev/null +++ b/core/src/codegen/batch_functions/batch_dictionaryfunctions.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: batch dictionary functions implementation + */ +#ifndef OMNI_RUNTIME_BATCH_DICTIONARYFUNCTIONS_H +#define OMNI_RUNTIME_BATCH_DICTIONARYFUNCTIONS_H + +#include +#include "type/decimal128.h" +using namespace omniruntime::type; + +namespace omniruntime::codegen::function { +#ifdef _WIN32 +#define DLLEXPORT __declspec(dllexport) +#else +#define DLLEXPORT +#endif + +extern "C" DLLEXPORT void BatchGetIntFromDictionaryVector(int64_t dictionaryVectorAddr, int32_t *rowIdxArray, + int32_t rowCnt, int32_t *output); + +extern "C" DLLEXPORT void BatchGetLongFromDictionaryVector(int64_t dictionaryVectorAddr, int32_t *rowIdxArray, + int32_t rowCnt, int64_t *output); + +extern "C" DLLEXPORT void BatchGetDoubleFromDictionaryVector(int64_t dictionaryVectorAddr, int32_t *rowIdxArray, + int32_t rowCnt, double *output); + +extern "C" DLLEXPORT void BatchGetBooleanFromDictionaryVector(int64_t dictionaryVectorAddr, int32_t *rowIdxArray, + int32_t rowCnt, bool *output); + +extern "C" DLLEXPORT void BatchGetVarcharFromDictionaryVector(int64_t contextPtr, int64_t dictionaryVectorAddr, + int32_t *rowIdxArray, int32_t rowCnt, uint8_t **str, int32_t *length); + +extern "C" DLLEXPORT void BatchGetDecimalFromDictionaryVector(int64_t dictionaryVectorAddr, int32_t *rowIdxArray, + int32_t rowCnt, Decimal128 *output); + +extern "C" DLLEXPORT void BatchGetIntFromVector(int32_t *vector, int32_t *rowIdxArray, int32_t rowCnt, int32_t *output); + +extern "C" DLLEXPORT void BatchGetLongFromVector(int64_t *vector, int32_t *rowIdxArray, int32_t rowCnt, + int64_t *output); + +extern "C" DLLEXPORT void BatchGetDoubleFromVector(double *vector, int32_t *rowIdxArray, int32_t rowCnt, + double *output); + +extern "C" DLLEXPORT void BatchGetBooleanFromVector(bool *vector, int32_t *rowIdxArray, int32_t rowCnt, bool *output); + +extern "C" DLLEXPORT void BatchGetVarcharFromVector(int64_t contextPtr, int32_t *offsetArray, const char *vector, + int32_t *rowIdxArray, int32_t rowCnt, uint8_t **str, int32_t *length); + +extern "C" DLLEXPORT void BatchGetDecimalFromVector(Decimal128 *vector, int32_t *rowIdxArray, int32_t rowCnt, + Decimal128 *output); +} + +#endif // OMNI_RUNTIME_BATCH_DICTIONARYFUNCTIONS_H diff --git a/core/src/codegen/batch_functions/batch_mathfunctions.cpp b/core/src/codegen/batch_functions/batch_mathfunctions.cpp new file mode 100644 index 0000000..67044c6 --- /dev/null +++ b/core/src/codegen/batch_functions/batch_mathfunctions.cpp @@ -0,0 +1,404 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: batch math functions implementation + */ +#include "batch_mathfunctions.h" +#include +#include +#include "codegen/context_helper.h" +#include "codegen/functions/mathfunctions.h" +#include "util/config_util.h" +#include "codegen/common_util.h" + +#ifdef _WIN32 +#define DLLEXPORT __declspec(dllexport) +#else +#define DLLEXPORT +#endif + +const double DOUBLE_NAN = (0.0 / 0.0); +const uint64_t DOUBLE_BIT_MASK = ((static_cast(1) << (sizeof(double) * 8 - 1)) - 1); + +namespace omniruntime::codegen::function { +static constexpr char DIVIDE_ZERO_EROR[] = "Divided by zero error!"; + +extern "C" DLLEXPORT void BatchCastInt32ToInt64(int32_t *x, bool *resIsNull, int64_t *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + output[i] = static_cast(x[i]); + } +} + +extern "C" DLLEXPORT void BatchCastInt64ToInt32(int64_t *x, bool *resIsNull, int32_t *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + output[i] = static_cast(x[i]); + } +} + +extern "C" DLLEXPORT void BatchCastInt32ToDouble(int32_t *x, bool *resIsNull, double *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + output[i] = static_cast(x[i]); + } +} + +extern "C" DLLEXPORT void BatchCastInt64ToDouble(int64_t *x, bool *resIsNull, double *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + output[i] = static_cast(x[i]); + } +} + +extern "C" DLLEXPORT void BatchCastDoubleToInt32HalfUp(double *x, bool *resIsNull, int32_t *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + output[i] = static_cast(Round(x[i], 0)); + } +} + +extern "C" DLLEXPORT void BatchCastDoubleToInt64HalfUp(double *x, bool *resIsNull, int64_t *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + output[i] = static_cast(Round(x[i], 0)); + } +} + +extern "C" DLLEXPORT void BatchCastDoubleToInt32Down(double *x, bool *resIsNull, int32_t *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + output[i] = static_cast(x[i]); + } +} + +extern "C" DLLEXPORT void BatchCastDoubleToInt64Down(double *x, bool *resIsNull, int64_t *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + output[i] = static_cast(x[i]); + } +} + + +extern "C" DLLEXPORT void BatchAddDouble(double *left, double *right, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + left[i] = left[i] + right[i]; + } +} + +extern "C" DLLEXPORT void BatchSubtractDouble(double *left, double *right, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + left[i] = left[i] - right[i]; + } +} + +extern "C" DLLEXPORT void BatchMultiplyDouble(double *left, double *right, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + left[i] = left[i] * right[i]; + } +} + +extern "C" DLLEXPORT void BatchDivideDouble(double *left, double *right, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + left[i] = left[i] / right[i]; + } +} + +extern "C" DLLEXPORT void BatchModulusDouble(double *left, double *right, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + left[i] = std::fmod(left[i], right[i]); + } +} + +extern "C" DLLEXPORT void BatchLessThanDouble(double *left, double *right, bool *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + output[i] = (left[i] < right[i]); + } +} + +extern "C" DLLEXPORT void BatchLessThanEqualDouble(double *left, double *right, bool *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + output[i] = (left[i] <= right[i]); + } +} + +extern "C" DLLEXPORT void BatchGreaterThanDouble(double *left, double *right, bool *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + output[i] = (left[i] > right[i]); + } +} + +extern "C" DLLEXPORT void BatchGreaterThanEqualDouble(double *left, double *right, bool *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + output[i] = (left[i] >= right[i]); + } +} + +extern "C" DLLEXPORT void BatchEqualDouble(double *left, double *right, bool *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + output[i] = (std::fabs(left[i] - right[i]) < DBL_EPSILON); + } +} + +extern "C" DLLEXPORT void BatchNotEqualDouble(double *left, double *right, bool *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + output[i] = std::fabs(left[i] - right[i]) >= DBL_EPSILON; + } +} + +extern "C" DLLEXPORT void BatchNormalizeNaNAndZero(double *input, bool *isAnyNull, double *output, int32_t rowCnt) +{ + for (int32_t i = 0; i < rowCnt; i++) { + auto value = input[i]; + if (std::isnan(value)) { + output[i] = DOUBLE_NAN; + continue; + } + union { + uint64_t l; + double d; + } u; + u.d = value; + if (u.l & DOUBLE_BIT_MASK) { + output[i] = value; + } else { + output[i] = 0.0; + } + } +} + +extern "C" DLLEXPORT void BatchPowerDouble(double *base, double *exponent, double *output, int32_t rowCnt) +{ + for (int32_t i = 0; i < rowCnt; i++) { + output[i] = pow(base[i], exponent[i]); + } +} + +extern "C" DLLEXPORT void BatchAddInt64(int64_t *left, int64_t *right, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + left[i] = left[i] + right[i]; + } +} + +extern "C" DLLEXPORT void BatchSubtractInt64(int64_t *left, int64_t *right, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + left[i] = left[i] - right[i]; + } +} + +extern "C" DLLEXPORT void BatchMultiplyInt64(int64_t *left, int64_t *right, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + left[i] = left[i] * right[i]; + } +} + +extern "C" DLLEXPORT void BatchDivideInt64(int64_t contextPtr, int64_t *left, int64_t *right, int32_t rowCnt, + bool *isNull) +{ + for (int i = 0; i < rowCnt; i++) { + if (isNull[i]) { + continue; + } + if (right[i] == 0) { + SetError(contextPtr, DIVIDE_ZERO_EROR); + return; + } + left[i] = left[i] / right[i]; + } +} + +extern "C" DLLEXPORT void BatchModulusInt64(int64_t contextPtr, int64_t *left, int64_t *right, int32_t rowCnt, + bool *isNull) +{ + for (int i = 0; i < rowCnt; i++) { + if (isNull[i]) { + continue; + } + if (right[i] == 0) { + SetError(contextPtr, DIVIDE_ZERO_EROR); + return; + } + left[i] = std::fmod(left[i], right[i]); + } +} + +extern "C" DLLEXPORT void BatchLessThanInt64(int64_t *left, int64_t *right, bool *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + output[i] = left[i] < right[i]; + } +} + +extern "C" DLLEXPORT void BatchLessThanEqualInt64(int64_t *left, int64_t *right, bool *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + output[i] = left[i] <= right[i]; + } +} + +extern "C" DLLEXPORT void BatchGreaterThanInt64(int64_t *left, int64_t *right, bool *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + output[i] = left[i] > right[i]; + } +} + +extern "C" DLLEXPORT void BatchGreaterThanEqualInt64(int64_t *left, int64_t *right, bool *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + output[i] = left[i] >= right[i]; + } +} + +extern "C" DLLEXPORT void BatchEqualInt64(int64_t *left, int64_t *right, bool *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + output[i] = left[i] == right[i]; + } +} + +extern "C" DLLEXPORT void BatchNotEqualInt64(int64_t *left, int64_t *right, bool *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + output[i] = left[i] != right[i]; + } +} + +extern "C" DLLEXPORT void BatchAddInt32(int32_t *left, int32_t *right, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + left[i] = left[i] + right[i]; + } +} + +extern "C" DLLEXPORT void BatchSubtractInt32(int32_t *left, int32_t *right, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + left[i] = left[i] - right[i]; + } +} + +extern "C" DLLEXPORT void BatchMultiplyInt32(int32_t *left, int32_t *right, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + left[i] = left[i] * right[i]; + } +} + +extern "C" DLLEXPORT void BatchDivideInt32(int64_t contextPtr, int32_t *left, int32_t *right, int32_t rowCnt, + bool *isNull) +{ + for (int i = 0; i < rowCnt; i++) { + if (isNull[i]) { + continue; + } + if (right[i] == 0) { + SetError(contextPtr, DIVIDE_ZERO_EROR); + return; + } + left[i] = left[i] / right[i]; + } +} + +extern "C" DLLEXPORT void BatchModulusInt32(int64_t contextPtr, int32_t *left, int32_t *right, int32_t rowCnt, + bool *isNull) +{ + for (int i = 0; i < rowCnt; i++) { + if (isNull[i]) { + continue; + } + if (right[i] == 0) { + SetError(contextPtr, DIVIDE_ZERO_EROR); + return; + } + left[i] = std::fmod(left[i], right[i]); + } +} + +extern "C" DLLEXPORT void BatchLessThanInt32(int32_t *left, int32_t *right, bool *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + output[i] = left[i] < right[i]; + } +} + +extern "C" DLLEXPORT void BatchLessThanEqualInt32(int32_t *left, int32_t *right, bool *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + output[i] = left[i] <= right[i]; + } +} + +extern "C" DLLEXPORT void BatchGreaterThanInt32(int32_t *left, int32_t *right, bool *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + output[i] = left[i] > right[i]; + } +} + +extern "C" DLLEXPORT void BatchGreaterThanEqualInt32(int32_t *left, int32_t *right, bool *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + output[i] = left[i] >= right[i]; + } +} + +extern "C" DLLEXPORT void BatchEqualInt32(int32_t *left, int32_t *right, bool *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + output[i] = left[i] == right[i]; + } +} + +extern "C" DLLEXPORT void BatchNotEqualInt32(int32_t *left, int32_t *right, bool *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + output[i] = left[i] != right[i]; + } +} + +extern "C" DLLEXPORT void BatchEqualBool(bool *left, bool toCmp, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + left[i] = (left[i] == toCmp); + } +} + +extern "C" DLLEXPORT void BatchPmod(int32_t *x, int32_t *y, bool *isAnyNull, int32_t *output, int32_t rowCnt) +{ + int32_t r; + for (int i = 0; i < rowCnt; i++) { + if (y[i] == 0) { + output[i] = 0; + continue; + } + r = x[i] % y[i]; + if (r < 0) { + output[i] = (r + y[i]) % y[i]; + } else { + output[i] = r; + } + } +} + +extern "C" DLLEXPORT void BatchRoundLong(int64_t *num, int32_t *decimals, bool *isAnyNull, int64_t *output, + int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + output[i] = RoundOperator(num[i], decimals[i]); + } +} +} \ No newline at end of file diff --git a/core/src/codegen/batch_functions/batch_mathfunctions.h b/core/src/codegen/batch_functions/batch_mathfunctions.h new file mode 100644 index 0000000..7cebfd0 --- /dev/null +++ b/core/src/codegen/batch_functions/batch_mathfunctions.h @@ -0,0 +1,164 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: batch math functions implementation + */ +#ifndef OMNI_RUNTIME_BATCH_MATHFUNCTIONS_H +#define OMNI_RUNTIME_BATCH_MATHFUNCTIONS_H + +#include +#include + +#ifdef _WIN32 +#define DLLEXPORT __declspec(dllexport) +#else +#define DLLEXPORT +#endif + +namespace omniruntime::codegen::function { +template extern DLLEXPORT void BatchAbs(T *x, bool *resIsNull, T *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + output[i] = std::abs(x[i]); + } +} + +extern "C" DLLEXPORT void BatchCastInt32ToDouble(int32_t *x, bool *resIsNull, double *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastInt64ToDouble(int64_t *x, bool *resIsNull, double *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastInt32ToInt64(int32_t *x, bool *resIsNull, int64_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastInt64ToInt32(int64_t *x, bool *resIsNull, int32_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDoubleToInt32HalfUp(double *x, bool *resIsNull, int32_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDoubleToInt64HalfUp(double *x, bool *resIsNull, int64_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDoubleToInt32Down(double *x, bool *resIsNull, int32_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDoubleToInt64Down(double *x, bool *resIsNull, int64_t *output, int32_t rowCnt); + +// double functions +extern "C" DLLEXPORT void BatchAddDouble(double *left, double *right, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchSubtractDouble(double *left, double *right, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchMultiplyDouble(double *left, double *right, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchDivideDouble(double *left, double *right, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchModulusDouble(double *left, double *right, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchLessThanDouble(double *left, double *right, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchLessThanEqualDouble(double *left, double *right, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchGreaterThanDouble(double *left, double *right, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchGreaterThanEqualDouble(double *left, double *right, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchEqualDouble(double *left, double *right, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchNotEqualDouble(double *left, double *right, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchNormalizeNaNAndZero(double *input, bool *isAnyNull, double *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchPowerDouble(double *base, double *exponent, double *output, int32_t rowCnt); + +// long functions +extern "C" DLLEXPORT void BatchAddInt64(int64_t *left, int64_t *right, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchSubtractInt64(int64_t *left, int64_t *right, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchMultiplyInt64(int64_t *left, int64_t *right, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchDivideInt64(int64_t contextPtr, int64_t *left, int64_t *right, int32_t rowCnt, + bool *isNull); + +extern "C" DLLEXPORT void BatchModulusInt64(int64_t contextPtr, int64_t *left, int64_t *right, int32_t rowCnt, + bool *isNull); + +extern "C" DLLEXPORT void BatchLessThanInt64(int64_t *left, int64_t *right, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchLessThanEqualInt64(int64_t *left, int64_t *right, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchGreaterThanInt64(int64_t *left, int64_t *right, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchGreaterThanEqualInt64(int64_t *left, int64_t *right, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchEqualInt64(int64_t *left, int64_t *right, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchNotEqualInt64(int64_t *left, int64_t *right, bool *output, int32_t rowCnt); + +// int functions +extern "C" DLLEXPORT void BatchAddInt32(int32_t *left, int32_t *right, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchSubtractInt32(int32_t *left, int32_t *right, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchMultiplyInt32(int32_t *left, int32_t *right, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchDivideInt32(int64_t contextPtr, int32_t *left, int32_t *right, int32_t rowCnt, + bool *isNull); + +extern "C" DLLEXPORT void BatchModulusInt32(int64_t contextPtr, int32_t *left, int32_t *right, int32_t rowCnt, + bool *isNull); + +extern "C" DLLEXPORT void BatchLessThanInt32(int32_t *left, int32_t *right, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchLessThanEqualInt32(int32_t *left, int32_t *right, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchGreaterThanInt32(int32_t *left, int32_t *right, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchGreaterThanEqualInt32(int32_t *left, int32_t *right, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchEqualInt32(int32_t *left, int32_t *right, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchNotEqualInt32(int32_t *left, int32_t *right, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchEqualBool(bool *left, bool toCmp, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchPmod(int32_t *x, int32_t *y, bool *isAnyNull, int32_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchRoundLong(int64_t *num, int32_t *decimals, bool *isAnyNull, int64_t *output, + int32_t rowCnt); + +template +extern DLLEXPORT void BatchRound(T *num, int32_t *decimals, bool *isAnyNull, T *output, int32_t rowCnt) +{ + int32_t tenthPower = 10; + double factor = std::pow(tenthPower, decimals[0]); + + for (int i = 0; i < rowCnt; ++i) { + if (std::isnan(num[i]) || std::isinf(num[i])) { + output[i] = num[i]; + continue; + } + + if (num[i] < 0) { + output[i] = -(std::round(-num[i] * factor) / factor); + continue; + } + + output[i] = std::round(num[i] * factor) / factor; + } +} + +template +extern DLLEXPORT void BatchGreatest(T *xValue, bool *xIsNull, T *yValue, bool *yIsNull, bool *retIsNull, T *output, + int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (xIsNull[i] && yIsNull[i]) { + retIsNull[i] = true; + output[i] = xValue[i]; + continue; + } + if (xIsNull[i] || (!yIsNull[i] && yValue[i] > xValue[i])) { + output[i] = yValue[i]; + continue; + } + output[i] = xValue[i]; + } +} +} +#endif // OMNI_RUNTIME_BATCH_MATHFUNCTIONS_H diff --git a/core/src/codegen/batch_functions/batch_murmur3_hash.cpp b/core/src/codegen/batch_functions/batch_murmur3_hash.cpp new file mode 100644 index 0000000..5c2c365 --- /dev/null +++ b/core/src/codegen/batch_functions/batch_murmur3_hash.cpp @@ -0,0 +1,97 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: batch mmh3 functions implementation + */ +#include "codegen/functions/murmur3_hash.h" +#include "type/decimal128_utils.h" +#include "batch_murmur3_hash.h" + +namespace omniruntime::codegen::function { +static const int COMBINE_HASH_VALUE = 31; + +extern "C" DLLEXPORT void BatchMm3Int32(int32_t *val, bool *isValNull, int32_t *seed, bool *isSeedNull, bool *resIsNull, + int32_t *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + seed[i] = isSeedNull[i] ? 0 : seed[i]; + output[i] = static_cast( + HashInt(static_cast(val[i] * !isValNull[i]), static_cast(seed[i]))); + } +} + +extern "C" DLLEXPORT void BatchMm3Int64(int64_t *val, bool *isValNull, int32_t *seed, bool *isSeedNull, bool *resIsNull, + int32_t *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + seed[i] = isSeedNull[i] ? 0 : seed[i]; + output[i] = static_cast( + HashLong(static_cast(val[i] * !isValNull[i]), static_cast(seed[i]))); + } +} + +extern "C" DLLEXPORT void BatchMm3String(uint8_t **val, int32_t *valLen, bool *isValNull, int32_t *seed, + bool *isSeedNull, bool *resIsNull, int32_t *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + seed[i] = isSeedNull[i] ? 0 : seed[i]; + valLen[i] = valLen[i] * !isValNull[i]; + output[i] = static_cast(HashUnsafeBytes(reinterpret_cast(val[i]), + static_cast(valLen[i]), static_cast(seed[i]))); + } +} + +extern "C" DLLEXPORT void BatchMm3Double(double *val, bool *isValNull, int32_t *seed, bool *isSeedNull, bool *resIsNull, + int32_t *output, int32_t rowCnt) +{ + union { + uint64_t lVal; + double dVal; + } uVal = { 0 }; + + for (int i = 0; i < rowCnt; ++i) { + uVal.dVal = val[i] * !isValNull[i]; + seed[i] = isSeedNull[i] ? 0 : seed[i]; + output[i] = static_cast(HashLong(uVal.lVal, static_cast(seed[i]))); + } +} + +extern "C" DLLEXPORT void BatchMm3Decimal64(int64_t *val, int32_t precision, int32_t scale, bool *isValNull, + int32_t *seed, bool *isSeedNull, bool *resIsNull, int32_t *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + seed[i] = isSeedNull[i] ? 0 : seed[i]; + output[i] = static_cast(HashLong(val[i] * !isValNull[i], static_cast(seed[i]))); + } +} + +extern "C" DLLEXPORT void BatchMm3Decimal128(omniruntime::type::Decimal128 *x, int32_t precision, int32_t scale, + bool *isValNull, int32_t *seed, bool *isSeedNull, bool *resIsNull, int32_t *output, int32_t rowCnt) +{ + int32_t byteLen = 0; + for (int i = 0; i < rowCnt; ++i) { + auto bytes = omniruntime::type::Decimal128Utils::Decimal128ToBytes(x[i].HighBits(), x[i].LowBits(), byteLen); + output[i] = static_cast(HashUnsafeBytes(reinterpret_cast(bytes), byteLen, seed[i])); + delete[] bytes; + } +} + +extern "C" DLLEXPORT void BatchMm3Boolean(bool *val, bool *isValNull, int32_t *seed, bool *isSeedNull, bool *resIsNull, + int32_t *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + seed[i] = isSeedNull[i] ? 0 : seed[i]; + output[i] = static_cast( + HashInt(static_cast((val[i] ? 1 : 0) * !isValNull[i]), static_cast(seed[i]))); + } +} + +extern "C" DLLEXPORT void BatchCombineHash(int64_t *prevHashVal, bool *isPrevHashValNull, int64_t *val, bool *isValNull, + bool *resIsNull, int64_t *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + prevHashVal[i] = isPrevHashValNull ? 0 : prevHashVal[i]; + val[i] = isValNull[i] ? 0 : val[i]; + output[i] = COMBINE_HASH_VALUE * prevHashVal[i] + val[i]; + } +} +} \ No newline at end of file diff --git a/core/src/codegen/batch_functions/batch_murmur3_hash.h b/core/src/codegen/batch_functions/batch_murmur3_hash.h new file mode 100644 index 0000000..428b17e --- /dev/null +++ b/core/src/codegen/batch_functions/batch_murmur3_hash.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: batch mmh3 functions implementation + */ + +#ifndef OMNI_RUNTIME_BATCH_MURMUR3_HASH_H +#define OMNI_RUNTIME_BATCH_MURMUR3_HASH_H + +#include "type/decimal128.h" + +namespace omniruntime::codegen::function { +#ifdef _WIN32 +#define DLLEXPORT __declspec(dllexport) +#else +#define DLLEXPORT +#endif + +extern "C" DLLEXPORT void BatchCombineHash(int64_t *prevHashVal, bool *isPrevHashValNull, int64_t *val, bool *isValNull, + bool *resIsNull, int64_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchMm3Int32(int32_t *val, bool *isValNull, int32_t *seed, bool *isSeedNull, bool *resIsNull, + int32_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchMm3Int64(int64_t *val, bool *isValNull, int32_t *seed, bool *isSeedNull, bool *resIsNull, + int32_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchMm3String(uint8_t **val, int32_t *valLen, bool *isValNull, int32_t *seed, + bool *isSeedNull, bool *resIsNull, int32_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchMm3Double(double *val, bool *isValNull, int32_t *seed, bool *isSeedNull, bool *resIsNull, + int32_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchMm3Decimal64(int64_t *val, int32_t precision, int32_t scale, bool *isValNull, + int32_t *seed, bool *isSeedNull, bool *resIsNull, int32_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchMm3Decimal128(omniruntime::type::Decimal128 *x, int32_t precision, int32_t scale, + bool *isValNull, int32_t *seed, bool *isSeedNull, bool *resIsNull, int32_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchMm3Boolean(bool *val, bool *isValNull, int32_t *seed, bool *isSeedNull, bool *resIsNull, + int32_t *output, int32_t rowCnt); +} + +#endif // OMNI_RUNTIME_BATCH_MURMUR3_HASH_H diff --git a/core/src/codegen/batch_functions/batch_stringfunctions.cpp b/core/src/codegen/batch_functions/batch_stringfunctions.cpp new file mode 100644 index 0000000..6d1190b --- /dev/null +++ b/core/src/codegen/batch_functions/batch_stringfunctions.cpp @@ -0,0 +1,1313 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. + * Description: batch string functions implementation + */ +#include "batch_stringfunctions.h" +#include +#include +#include "type/data_operations.h" +#include "type/date32.h" +#include "codegen/functions/md5.h" + +#ifdef _WIN32 +#else +#define DLLEXPORT +#endif + +namespace omniruntime::codegen::function { +extern "C" DLLEXPORT void BatchStrCompare(uint8_t **ap, int32_t *apLen, uint8_t **bp, int32_t *bpLen, int32_t *res, + int32_t rowCnt) +{ + int min = 0, result = 0; + for (int i = 0; i < rowCnt; ++i) { + min = bpLen[i]; + if (apLen[i] < min) { + min = apLen[i]; + } + + result = memcmp(ap[i], bp[i], min); + if (result != 0) { + res[i] = result; + } else { + res[i] = apLen[i] - bpLen[i]; + } + } +} + +extern "C" DLLEXPORT void BatchLessThanStr(uint8_t **ap, int32_t *apLen, uint8_t **bp, int32_t *bpLen, bool *res, + int32_t rowCnt) +{ + auto tmp = new int32_t[rowCnt]; + BatchStrCompare(ap, apLen, bp, bpLen, tmp, rowCnt); + for (int i = 0; i < rowCnt; i++) { + res[i] = (tmp[i] < 0); + } + delete[] tmp; +} + +extern "C" DLLEXPORT void BatchLessThanEqualStr(uint8_t **ap, int32_t *apLen, uint8_t **bp, int32_t *bpLen, bool *res, + int32_t rowCnt) +{ + auto tmp = new int32_t[rowCnt]; + BatchStrCompare(ap, apLen, bp, bpLen, tmp, rowCnt); + for (int i = 0; i < rowCnt; i++) { + res[i] = (tmp[i] <= 0); + } + delete[] tmp; +} + +extern "C" DLLEXPORT void BatchGreaterThanStr(uint8_t **ap, int32_t *apLen, uint8_t **bp, int32_t *bpLen, bool *res, + int32_t rowCnt) +{ + auto tmp = new int32_t[rowCnt]; + BatchStrCompare(ap, apLen, bp, bpLen, tmp, rowCnt); + for (int i = 0; i < rowCnt; i++) { + res[i] = (tmp[i] > 0); + } + delete[] tmp; +} + +extern "C" DLLEXPORT void BatchGreaterThanEqualStr(uint8_t **ap, int32_t *apLen, uint8_t **bp, int32_t *bpLen, + bool *res, int32_t rowCnt) +{ + auto tmp = new int32_t[rowCnt]; + BatchStrCompare(ap, apLen, bp, bpLen, tmp, rowCnt); + for (int i = 0; i < rowCnt; i++) { + res[i] = (tmp[i] >= 0); + } + delete[] tmp; +} + +extern "C" DLLEXPORT void BatchEqualStr(uint8_t **ap, int32_t *apLen, uint8_t **bp, int32_t *bpLen, bool *res, + int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + if (apLen[i] != bpLen[i]) { + res[i] = false; + } else { + res[i] = (memcmp(ap[i], bp[i], apLen[i]) == 0); + } + } +} + +extern "C" DLLEXPORT void BatchNotEqualStr(uint8_t **ap, int32_t *apLen, uint8_t **bp, int32_t *bpLen, bool *res, + int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + if (apLen[i] != bpLen[i]) { + res[i] = true; + } else { + res[i] = (memcmp(ap[i], bp[i], apLen[i]) != 0); + } + } +} + +extern "C" DLLEXPORT void BatchCastStringToDateNotAllowReducePrecison(int64_t contextPtr, uint8_t **str, + int32_t *strLen, bool *isAnyNull, int32_t *output, int32_t rowCnt) +{ + // Date is in the format 1996-02-28 + // Doesn't account for leap seconds or daylight savings + // Should be ok just for dates + int64_t result = 0; + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + output[i] = 0; + continue; + } + std::string s = std::string(reinterpret_cast(str[i]), strLen[i]); + StringUtil::TrimString(s); + if (!regex_match(s, std::regex(R"(\d{4}-\d{2}-\d{2}$)"))) { + SetError(contextPtr, "Only support cast date\'YYYY-MM-DD\' to integer"); + output[i] = 0; + continue; + } + if (Date32::StringToDate32(reinterpret_cast(str[i]), strLen[i], result) != Status::CONVERT_SUCCESS && + !HasError(contextPtr)) { + SetError(contextPtr, "Value cannot be cast to date: " + s); + continue; + } + output[i] = static_cast(result); + } +} + +extern "C" DLLEXPORT void BatchCastStringToDateAllowReducePrecison(int64_t contextPtr, uint8_t **str, int32_t *strLen, + bool *isAnyNull, int32_t *output, int32_t rowCnt) +{ + // Date is in the format 1996-02-28 + // Doesn't account for leap seconds or daylight savings + // Should be ok just for dates + int64_t result = 0; + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + output[i] = 0; + continue; + } + std::string s = std::string(reinterpret_cast(str[i]), strLen[i]); + StringUtil::TrimString(s); + if (Date32::StringToDate32(reinterpret_cast(str[i]), strLen[i], result) != Status::CONVERT_SUCCESS && + !HasError(contextPtr)) { + SetError(contextPtr, "Value cannot be cast to date: " + s); + continue; + } + output[i] = static_cast(result); + } +} + +extern "C" DLLEXPORT void BatchCastIntToString(int64_t contextPtr, int32_t *value, bool *isAnyNull, uint8_t **output, + int32_t *outLen, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + outLen[i] = 0; + output[i] = nullptr; + continue; + } + std::string str = std::to_string(value[i]); + outLen[i] = static_cast(str.size()); + if (outLen[i] <= 0) { + outLen[i] = 0; + output[i] = (uint8_t *)""; + continue; + } + auto ret = ArenaAllocatorMalloc(contextPtr, outLen[i]); + errno_t res = memcpy_s(ret, outLen[i], str.c_str(), outLen[i]); + if (res != EOK) { + SetError(contextPtr, "cast failed"); + output[i] = nullptr; + continue; + } + output[i] = reinterpret_cast(ret); + } +} + +extern "C" DLLEXPORT void BatchCastLongToString(int64_t contextPtr, int64_t *value, bool *isAnyNull, uint8_t **output, + int32_t *outLen, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + outLen[i] = 0; + output[i] = nullptr; + continue; + } + std::string str = std::to_string(value[i]); + outLen[i] = static_cast(strlen(str.c_str())); + if (outLen[i] <= 0) { + outLen[i] = 0; + output[i] = (uint8_t *)""; + continue; + } + auto ret = ArenaAllocatorMalloc(contextPtr, outLen[i]); + errno_t res = memcpy_s(ret, outLen[i], str.c_str(), outLen[i]); + if (res != EOK) { + SetError(contextPtr, "cast failed"); + output[i] = nullptr; + continue; + } + output[i] = reinterpret_cast(ret); + } +} + +extern "C" DLLEXPORT void BatchCastDoubleToString(int64_t contextPtr, double *value, bool *isAnyNull, uint8_t **output, + int32_t *outLen, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + auto ret = ArenaAllocatorMalloc(contextPtr, MAX_DATA_LENGTH); + outLen[i] = static_cast(DoubleToString::DoubleToStringConverter(value[i], ret)); + output[i] = reinterpret_cast(ret); + } +} + +extern "C" DLLEXPORT void BatchCastDecimal64ToString(int64_t contextPtr, int64_t *x, int32_t precision, int32_t scale, + bool *isAnyNull, uint8_t **output, int32_t *outLen, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + outLen[i] = 0; + output[i] = nullptr; + continue; + } + std::string str = Decimal64(x[i]).SetScale(scale).ToString(); + outLen[i] = static_cast(str.size()); + if (outLen[i] <= 0) { + outLen[i] = 0; + output[i] = (uint8_t *)""; + continue; + } + auto ret = ArenaAllocatorMalloc(contextPtr, outLen[i]); + errno_t res = memcpy_s(ret, outLen[i], str.c_str(), outLen[i]); + if (res != EOK) { + SetError(contextPtr, "cast failed"); + output[i] = nullptr; + continue; + } + output[i] = reinterpret_cast(ret); + } +} + +extern "C" DLLEXPORT void BatchCastDecimal128ToString(int64_t contextPtr, Decimal128 *x, int32_t precision, + int32_t scale, bool *isAnyNull, uint8_t **output, int32_t *outLen, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + outLen[i] = 0; + output[i] = nullptr; + continue; + } + std::string stringDecimal = Decimal128Wrapper(x[i]).SetScale(scale).ToString(); + outLen[i] = static_cast(stringDecimal.length()); + if (outLen[i] <= 0) { + outLen[i] = 0; + output[i] = (uint8_t *)""; + continue; + } + auto ret = ArenaAllocatorMalloc(contextPtr, outLen[i]); + errno_t res = memcpy_s(ret, outLen[i], stringDecimal.c_str(), outLen[i]); + if (res != EOK) { + SetError(contextPtr, "cast failed"); + output[i] = nullptr; + continue; + } + output[i] = reinterpret_cast(ret); + } +} + +extern "C" DLLEXPORT void BatchCastStringToDecimal64(int64_t contextPtr, uint8_t **str, int32_t *strLen, + bool *isAnyNull, int64_t *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + output[i] = 0; + continue; + } + std::string s = std::string(reinterpret_cast(str[i]), strLen[i]); + StringUtil::TrimString(s); + if (!regex_match(s, g_decimalRegex)) { + std::ostringstream errorMessage; + errorMessage << "Cannot cast VARCHAR '" << s << "' to DECIMAL(" << outPrecision << ", " << outScale << + "). Value too large."; + SetError(contextPtr, errorMessage.str()); + continue; + } + Decimal64 result(s); + if (result.IsOverflow(outPrecision) != OpStatus::SUCCESS) { + std::ostringstream errorMessage; + errorMessage << "Cannot cast VARCHAR '" << s << "' to DECIMAL(" << outPrecision << ", " << outScale << + "). Value too large."; + SetError(contextPtr, errorMessage.str()); + output[i] = 0; + continue; + } + output[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchCastStringToDecimal128(int64_t contextPtr, uint8_t **str, int32_t *strLen, + bool *isAnyNull, Decimal128 *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + output[i].SetValue(0, 0); + continue; + } + std::string s = std::string(reinterpret_cast(str[i]), strLen[i]); + if (!regex_match(s, g_decimalRegex)) { + std::ostringstream errorMessage; + errorMessage << "Cannot cast VARCHAR '" << s << "' to DECIMAL(" << outPrecision << ", " << outScale << + "). Value too large."; + SetError(contextPtr, errorMessage.str()); + continue; + } + StringUtil::TrimString(s); + Decimal128Wrapper result(s.c_str()); + result.ReScale(outScale); + if (result.IsOverflow(outPrecision) != OpStatus::SUCCESS) { + std::ostringstream errorMessage; + errorMessage << "Cannot cast VARCHAR '" << s << "' to DECIMAL(" << outPrecision << ", " << outScale << + "). Value too large."; + SetError(contextPtr, errorMessage.str()); + continue; + } + output[i] = result.ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchCastStringToDecimal64RoundUp(int64_t contextPtr, uint8_t **str, int32_t *strLen, + bool *isAnyNull, int64_t *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + output[i] = 0; + continue; + } + std::string s = std::string(reinterpret_cast(str[i]), strLen[i]); + StringUtil::TrimString(s); + if (!regex_match(s, g_decimalRegex)) { + std::ostringstream errorMessage; + errorMessage << "Cannot cast VARCHAR '" << s << "' to DECIMAL(" << outPrecision << ", " << outScale << + "). Value too large."; + SetError(contextPtr, errorMessage.str()); + continue; + } + Decimal64 result(s); + if (result.IsOverflow(outPrecision) != OpStatus::SUCCESS) { + std::ostringstream errorMessage; + errorMessage << "Cannot cast VARCHAR '" << s << "' to DECIMAL(" << outPrecision << ", " << outScale << + "). Value too large."; + SetError(contextPtr, errorMessage.str()); + output[i] = 0; + continue; + } + output[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchCastStringToDecimal128RoundUp(int64_t contextPtr, uint8_t **str, int32_t *strLen, + bool *isAnyNull, Decimal128 *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + output[i].SetValue(0, 0); + continue; + } + std::string s = std::string(reinterpret_cast(str[i]), strLen[i]); + if (!regex_match(s, g_decimalRegex)) { + std::ostringstream errorMessage; + errorMessage << "Cannot cast VARCHAR '" << s << "' to DECIMAL(" << outPrecision << ", " << outScale << + "). Value too large."; + SetError(contextPtr, errorMessage.str()); + continue; + } + StringUtil::TrimString(s); + Decimal128Wrapper result(s.c_str()); + result.ReScale(outScale); + if (result.IsOverflow(outPrecision) != OpStatus::SUCCESS) { + std::ostringstream errorMessage; + errorMessage << "Cannot cast VARCHAR '" << s << "' to DECIMAL(" << outPrecision << ", " << outScale << + "). Value too large."; + SetError(contextPtr, errorMessage.str()); + continue; + } + output[i] = result.ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchCastStringToInt(int64_t contextPtr, uint8_t **str, int32_t *strLen, bool *isAnyNull, + int32_t *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + output[i] = 0; + continue; + } + auto chars = reinterpret_cast(str[i]); + Status status = ConvertStringToInteger(output[i], chars, strLen[i]); + if (status != Status::CONVERT_SUCCESS) { + std::string s(chars, strLen[i]); + std::string reason = + status == Status::IS_NOT_A_NUMBER ? "Value is not a number." : "Value too large or too small."; + std::ostringstream errorMessage; + errorMessage << "Cannot cast '" << s << "' to INTEGER. " << reason; + SetError(contextPtr, errorMessage.str()); + output[i] = 0; + continue; + } + } +} + +extern "C" DLLEXPORT void BatchCastStringToLong(int64_t contextPtr, uint8_t **str, int32_t *strLen, bool *isAnyNull, + int64_t *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + output[i] = 0; + continue; + } + + auto chars = reinterpret_cast(str[i]); + Status status = ConvertStringToInteger(output[i], chars, strLen[i]); + if (status != Status::CONVERT_SUCCESS) { + std::string s(chars, strLen[i]); + std::string reason = + status == Status::IS_NOT_A_NUMBER ? "Value is not a number." : "Value too large or too small."; + std::ostringstream errorMessage; + errorMessage << "Cannot cast '" << s << "' to INTEGER. " << reason; + SetError(contextPtr, errorMessage.str()); + output[i] = 0; + continue; + } + } +} + +extern "C" DLLEXPORT void BatchCastStringToDouble(int64_t contextPtr, uint8_t **str, int32_t *strLen, bool *isAnyNull, + double *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + output[i] = 0; + continue; + } + double result; + Status status = ConvertStringToDouble(result, reinterpret_cast(str[i]), strLen[i]); + if (status == Status::IS_NOT_A_NUMBER) { + std::ostringstream errorMessage; + errorMessage << "Cannot cast '" << std::string(reinterpret_cast(str[i]), strLen[i]) + << "' to DOUBLE. Value is not a number."; + SetError(contextPtr, errorMessage.str()); + continue; + } + if (status == Status::CONVERT_OVERFLOW) { + std::ostringstream errorMessage; + errorMessage << "Cannot cast '" << std::string(reinterpret_cast(str[i]), strLen[i]) + << "' to DOUBLE. Value is not a number."; + SetError(contextPtr, errorMessage.str()); + continue; + } + output[i] = result; + } +} + +extern "C" DLLEXPORT void BatchCastStringToDateRetNullNotAllowReducePrecison(bool *isNull, uint8_t **str, + int32_t *strLen, int32_t *output, int32_t rowCnt) +{ + // Date is in the format 1996-02-28 + // Doesn't account for leap seconds or daylight savings + // Should be ok just for dates + int64_t result = 0; + for (int i = 0; i < rowCnt; ++i) { + std::string s = std::string(reinterpret_cast(str[i]), strLen[i]); + StringUtil::TrimString(s); + if (!regex_match(s, std::regex(R"(\d{4}-\d{2}-\d{2}$)"))) { + output[i] = 0; + isNull[i] = true; + continue; + } + if (Date32::StringToDate32(reinterpret_cast(str[i]), strLen[i], result) != Status::CONVERT_SUCCESS) { + output[i] = 0; + isNull[i] = true; + continue; + } + output[i] = static_cast(result); + isNull[i] = false; + } +} + +extern "C" DLLEXPORT void BatchCastStringToDateRetNullAllowReducePrecison(bool *isNull, uint8_t **str, int32_t *strLen, + int32_t *output, int32_t rowCnt) +{ + // Date is in the format 1996-02-28 + // Doesn't account for leap seconds or daylight savings + // Should be ok just for dates + int64_t result = 0; + for (int i = 0; i < rowCnt; ++i) { + std::string s = std::string(reinterpret_cast(str[i]), strLen[i]); + StringUtil::TrimString(s); + if (Date32::StringToDate32(reinterpret_cast(str[i]), strLen[i], result) != Status::CONVERT_SUCCESS) { + output[i] = 0; + isNull[i] = true; + continue; + } + output[i] = static_cast(result); + isNull[i] = false; + } +} + +extern "C" DLLEXPORT void BatchCastIntToStringRetNull(bool *isNull, int64_t contextPtr, int32_t *value, + uint8_t **output, int32_t *outLen, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + std::string str = std::to_string(value[i]); + outLen[i] = static_cast(str.size()); + auto ret = ArenaAllocatorMalloc(contextPtr, outLen[i]); + errno_t res = memcpy_s(ret, outLen[i] + 1, str.c_str(), outLen[i]); + if (res != EOK) { + output[i] = nullptr; + isNull[i] = true; + continue; + } + isNull[i] = false; + output[i] = reinterpret_cast(ret); + } +} + +extern "C" DLLEXPORT void BatchCastLongToStringRetNull(bool *isNull, int64_t contextPtr, int64_t *value, + uint8_t **output, int32_t *outLen, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + std::string str = std::to_string(value[i]); + outLen[i] = static_cast(strlen(str.c_str())); + auto ret = ArenaAllocatorMalloc(contextPtr, outLen[i]); + errno_t res = memcpy_s(ret, outLen[i], str.c_str(), outLen[i]); + if (res != EOK) { + output[i] = nullptr; + isNull[i] = true; + continue; + } + isNull[i] = false; + output[i] = reinterpret_cast(ret); + } +} + +extern "C" DLLEXPORT void BatchCastDoubleToStringRetNull(bool *isNull, int64_t contextPtr, double *value, + uint8_t **output, int32_t *outLen, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + auto ret = ArenaAllocatorMalloc(contextPtr, MAX_DATA_LENGTH); + outLen[i] = static_cast(DoubleToString::DoubleToStringConverter(value[i], ret)); + isNull[i] = false; + output[i] = reinterpret_cast(ret); + } +} + +extern "C" DLLEXPORT void BatchCastDecimal64ToStringRetNull(bool *isNull, int64_t contextPtr, int64_t *x, + int32_t precision, int32_t scale, uint8_t **output, int32_t *outLen, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + std::string str = Decimal64(x[i]).SetScale(scale).ToString(); + outLen[i] = static_cast(str.size()); + if (outLen[i] <= 0) { + outLen[i] = 0; + output[i] = (uint8_t *)""; + continue; + } + auto ret = ArenaAllocatorMalloc(contextPtr, outLen[i]); + errno_t res = memcpy_s(ret, outLen[i], str.c_str(), outLen[i]); + if (res != EOK) { + output[i] = nullptr; + isNull[i] = true; + continue; + } + isNull[i] = false; + output[i] = reinterpret_cast(ret); + } +} + +extern "C" DLLEXPORT void BatchCastDecimal128ToStringRetNull(bool *isNull, int64_t contextPtr, Decimal128 *inputDecimal, + int32_t precision, int32_t scale, uint8_t **output, int32_t *outLen, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + std::string stringDecimal = Decimal128Wrapper(inputDecimal[i]).SetScale(scale).ToString(); + outLen[i] = static_cast(stringDecimal.length()); + if (outLen[i] <= 0) { + outLen[i] = 0; + output[i] = (uint8_t *)""; + continue; + } + auto ret = ArenaAllocatorMalloc(contextPtr, outLen[i]); + errno_t res = memcpy_s(ret, outLen[i], stringDecimal.c_str(), outLen[i]); + if (res != EOK) { + output[i] = nullptr; + isNull[i] = true; + continue; + } + isNull[i] = false; + output[i] = reinterpret_cast(ret); + } +} + +extern "C" DLLEXPORT void BatchCastStringToDecimal64RetNull(bool *isNull, uint8_t **str, int32_t *strLen, + int64_t *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + std::string s = std::string(reinterpret_cast(str[i]), strLen[i]); + StringUtil::TrimString(s); + if (!regex_match(s, g_decimalRegex)) { + output[i] = 0; + isNull[i] = true; + continue; + } + Decimal64 result(s); + result.ReScale(outScale); + if (result.IsOverflow(outPrecision) != OpStatus::SUCCESS) { + output[i] = 0; + isNull[i] = true; + continue; + } + isNull[i] = false; + output[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchCastStringToDecimal128RetNull(bool *isNull, uint8_t **str, int32_t *strLen, + Decimal128 *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + std::string s = std::string(reinterpret_cast(str[i]), strLen[i]); + StringUtil::TrimString(s); + if (!regex_match(s, g_decimalRegex)) { + output[i] = 0; + isNull[i] = true; + continue; + } + Decimal128Wrapper result(s.c_str()); + result.ReScale(outScale); + if (result.IsOverflow(outPrecision) != OpStatus::SUCCESS) { + output[i] = 0; + isNull[i] = true; + continue; + } + output[i] = result.ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchCastStringToDecimal64RoundUpRetNull(bool *isNull, uint8_t **str, int32_t *strLen, + int64_t *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + std::string s = std::string(reinterpret_cast(str[i]), strLen[i]); + StringUtil::TrimString(s); + if (!regex_match(s, g_decimalRegex)) { + output[i] = 0; + isNull[i] = true; + continue; + } + Decimal64 result(s); + result.ReScale(outScale); + if (result.IsOverflow(outPrecision) != OpStatus::SUCCESS) { + output[i] = 0; + isNull[i] = true; + continue; + } + isNull[i] = false; + output[i] = result.GetValue(); + } +} + +extern "C" DLLEXPORT void BatchCastStringToDecimal128RoundUpRetNull(bool *isNull, uint8_t **str, int32_t *strLen, + Decimal128 *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + std::string s = std::string(reinterpret_cast(str[i]), strLen[i]); + StringUtil::TrimString(s); + if (!regex_match(s, g_decimalRegex)) { + output[i] = 0; + isNull[i] = true; + continue; + } + Decimal128Wrapper result(s.c_str()); + result.ReScale(outScale); + if (result.IsOverflow(outPrecision) != OpStatus::SUCCESS) { + output[i] = 0; + isNull[i] = true; + continue; + } + output[i] = result.ToDecimal128(); + } +} + +extern "C" DLLEXPORT void BatchCastStringToIntRetNull(bool *isNull, uint8_t **str, int32_t *strLen, int32_t *output, + int32_t rowCnt) +{ + for (int32_t i = 0; i < rowCnt; ++i) { + Status status = ConvertStringToInteger(output[i], reinterpret_cast(str[i]), + strLen[i]); + isNull[i] = status != Status::CONVERT_SUCCESS; + if (isNull[i]) { + output[i] = 0; + } + } +} + +extern "C" DLLEXPORT void BatchCastStringToLongRetNull(bool *isNull, uint8_t **str, int32_t *strLen, int64_t *output, + int32_t rowCnt) +{ + for (int32_t i = 0; i < rowCnt; ++i) { + Status status = ConvertStringToInteger(output[i], reinterpret_cast(str[i]), + strLen[i]); + isNull[i] = status != Status::CONVERT_SUCCESS; + if (isNull[i]) { + output[i] = 0; + } + } +} + +extern "C" DLLEXPORT void BatchCastStringToDoubleRetNull(bool *isNull, uint8_t **str, int32_t *strLen, double *output, + int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + double result; + Status status = ConvertStringToDouble(result, reinterpret_cast(str[i]), strLen[i]); + if (status != Status::CONVERT_SUCCESS) { + output[i] = 0; + isNull[i] = true; + continue; + } + isNull[i] = false; + output[i] = result; + } +} + +extern "C" DLLEXPORT void BatchToUpperStr(int64_t contextPtr, uint8_t **str, int32_t *strLen, bool *isAnyNull, + uint8_t **output, int32_t *outLen, int32_t rowCnt) +{ + char *ret; + for (int j = 0; j < rowCnt; ++j) { + if (isAnyNull[j]) { + outLen[j] = 0; + output[j] = nullptr; + continue; + } + ret = ArenaAllocatorMalloc(contextPtr, strLen[j]); + for (int i = 0; i < strLen[j]; i++) { + if (*(str[j] + i) >= static_cast('a') && *(str[j] + i) <= static_cast('z')) { + *(ret + i) = *(str[j] + i) - STEP; + } else { + *(ret + i) = *(str[j] + i); + } + } + outLen[j] = strLen[j]; + output[j] = reinterpret_cast(ret); + } +} + +extern "C" DLLEXPORT void BatchToUpperChar(int64_t contextPtr, uint8_t **str, int32_t width, int32_t *strLen, + bool *isAnyNull, uint8_t **output, int32_t *outLen, int32_t rowCnt) +{ + BatchToUpperStr(contextPtr, str, strLen, isAnyNull, output, outLen, rowCnt); +} + +extern "C" DLLEXPORT void BatchToLowerStr(int64_t contextPtr, uint8_t **str, int32_t *strLen, bool *isAnyNull, + uint8_t **output, int32_t *outLen, int32_t rowCnt) +{ + char *ret; + char currItem; + for (int j = 0; j < rowCnt; ++j) { + if (isAnyNull[j]) { + outLen[j] = 0; + output[j] = nullptr; + continue; + } + ret = ArenaAllocatorMalloc(contextPtr, strLen[j]); + for (int32_t i = 0; i < strLen[j]; i++) { + currItem = *(reinterpret_cast(str[j]) + i); + if (currItem >= static_cast('A') && currItem <= static_cast('Z')) { + *(ret + i) = static_cast(currItem + STEP); + } else { + *(ret + i) = currItem; + } + } + outLen[j] = strLen[j]; + output[j] = reinterpret_cast(ret); + } +} + +extern "C" DLLEXPORT void BatchToLowerChar(int64_t contextPtr, uint8_t **str, int32_t width, int32_t *strLen, + bool *isAnyNull, uint8_t **output, int32_t *outLen, int32_t rowCnt) +{ + BatchToLowerStr(contextPtr, str, strLen, isAnyNull, output, outLen, rowCnt); +} + +extern "C" DLLEXPORT void BatchLikeStr(uint8_t **str, int32_t *strLen, uint8_t **regexToMatch, int32_t *regexLen, + bool *isAnyNull, bool *output, int32_t rowCnt) +{ + std::string s; + for (int32_t i = 0; i < rowCnt; i++) { + if (isAnyNull[i]) { + output[i] = false; + continue; + } + s = std::string(reinterpret_cast(str[i]), strLen[i]); + std::string r = std::string(reinterpret_cast(regexToMatch[i]), regexLen[i]); + std::wregex re(StringUtil::ToWideString(r)); + output[i] = regex_match(StringUtil::ToWideString(s), re); + } +} + +extern "C" DLLEXPORT void BatchLikeChar(uint8_t **str, int32_t strWidth, int32_t *strLen, uint8_t **regexToMatch, + int32_t *regexLen, bool *isAnyNull, bool *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + if (isAnyNull[i]) { + output[i] = false; + continue; + } + int32_t paddingCount = + strWidth - omniruntime::Utf8Util::CountCodePoints(reinterpret_cast(str[i]), strLen[i]); + std::string originalStr; + originalStr.reserve(strLen[i] + paddingCount); + originalStr.append(reinterpret_cast(str[i]), strLen[i]); + originalStr.append(paddingCount, ' '); + std::string r = std::string(reinterpret_cast(regexToMatch[i]), regexLen[i]); + std::wregex re(StringUtil::ToWideString(r)); + output[i] = regex_match(StringUtil::ToWideString(originalStr), re); + } +} + +extern "C" DLLEXPORT void BatchConcatStrStr(int64_t contextPtr, uint8_t **ap, int32_t *apLen, uint8_t **bp, + int32_t *bpLen, bool *isAnyNull, uint8_t **output, int32_t *outLen, int32_t rowCnt) +{ + bool hasErr; + for (int i = 0; i < rowCnt; i++) { + if (isAnyNull[i]) { + outLen[i] = 0; + output[i] = nullptr; + continue; + } + hasErr = false; + auto ret = StringUtil::ConcatStrDiffWidths(contextPtr, reinterpret_cast(ap[i]), apLen[i], + reinterpret_cast(bp[i]), bpLen[i], &hasErr, outLen + i); + if (hasErr) { + SetError(contextPtr, CONCAT_ERR_MSG); + continue; + } + output[i] = reinterpret_cast(const_cast(ret)); + } +} + +extern "C" DLLEXPORT void BatchConcatStrStrRetNull(bool *isNull, int64_t contextPtr, uint8_t **ap, int32_t *apLen, + uint8_t **bp, int32_t *bpLen, uint8_t **output, int32_t *outLen, int32_t rowCnt) +{ + for (int32_t i = 0; i < rowCnt; i++) { + auto ret = StringUtil::ConcatStrDiffWidths(contextPtr, reinterpret_cast(ap[i]), apLen[i], + reinterpret_cast(bp[i]), bpLen[i], isNull + i, outLen + i); + output[i] = reinterpret_cast(const_cast(ret)); + } +} + +extern "C" DLLEXPORT void BatchConcatCharChar(int64_t contextPtr, uint8_t **ap, int32_t aWidth, int32_t *apLen, + uint8_t **bp, int32_t bWidth, int32_t *bpLen, bool *isAnyNull, uint8_t **output, int32_t *outLen, int32_t rowCnt) +{ + bool hasErr; + for (int32_t i = 0; i < rowCnt; i++) { + if (isAnyNull[i]) { + outLen[i] = 0; + output[i] = nullptr; + continue; + } + + hasErr = false; + auto ret = StringUtil::ConcatCharDiffWidths(contextPtr, reinterpret_cast(ap[i]), aWidth, apLen[i], + reinterpret_cast(bp[i]), bpLen[i], &hasErr, outLen + i); + if (hasErr) { + SetError(contextPtr, CONCAT_ERR_MSG); + continue; + } + output[i] = reinterpret_cast(const_cast(ret)); + } +} + +extern "C" DLLEXPORT void BatchConcatCharCharRetNull(bool *isNull, int64_t contextPtr, uint8_t **ap, int32_t aWidth, + int32_t *apLen, uint8_t **bp, int32_t bWidth, int32_t *bpLen, uint8_t **output, int32_t *outLen, int32_t rowCnt) +{ + for (int32_t i = 0; i < rowCnt; i++) { + auto ret = StringUtil::ConcatCharDiffWidths(contextPtr, reinterpret_cast(ap[i]), aWidth, apLen[i], + reinterpret_cast(bp[i]), bpLen[i], isNull + i, outLen + i); + output[i] = reinterpret_cast(const_cast(ret)); + } +} + +extern "C" DLLEXPORT void BatchConcatCharStr(int64_t contextPtr, uint8_t **ap, int32_t aWidth, int32_t *apLen, + uint8_t **bp, int32_t *bpLen, bool *isAnyNull, uint8_t **output, int32_t *outLen, int32_t rowCnt) +{ + bool hasErr; + for (int i = 0; i < rowCnt; i++) { + if (isAnyNull[i]) { + outLen[i] = 0; + output[i] = nullptr; + continue; + } + hasErr = false; + auto ret = StringUtil::ConcatCharDiffWidths(contextPtr, reinterpret_cast(ap[i]), aWidth, apLen[i], + reinterpret_cast(bp[i]), bpLen[i], &hasErr, outLen + i); + if (hasErr) { + SetError(contextPtr, CONCAT_ERR_MSG); + continue; + } + output[i] = reinterpret_cast(const_cast(ret)); + } +} + +extern "C" DLLEXPORT void BatchConcatCharStrRetNull(bool *isNull, int64_t contextPtr, uint8_t **ap, int32_t aWidth, + int32_t *apLen, uint8_t **bp, int32_t *bpLen, uint8_t **output, int32_t *outLen, int32_t rowCnt) +{ + for (int32_t i = 0; i < rowCnt; i++) { + auto ret = StringUtil::ConcatCharDiffWidths(contextPtr, reinterpret_cast(ap[i]), aWidth, apLen[i], + reinterpret_cast(bp[i]), bpLen[i], isNull + i, outLen + i); + + output[i] = reinterpret_cast(const_cast(ret)); + } +} + +extern "C" DLLEXPORT void BatchConcatStrChar(int64_t contextPtr, uint8_t **ap, int32_t *apLen, uint8_t **bp, + int32_t bWidth, int32_t *bpLen, bool *isAnyNull, uint8_t **output, int32_t *outLen, int32_t rowCnt) +{ + bool hasErr; + for (int32_t i = 0; i < rowCnt; i++) { + if (isAnyNull[i]) { + outLen[i] = 0; + output[i] = nullptr; + continue; + } + hasErr = false; + auto ret = StringUtil::ConcatStrDiffWidths(contextPtr, reinterpret_cast(ap[i]), apLen[i], + reinterpret_cast(bp[i]), bpLen[i], &hasErr, outLen + i); + if (hasErr) { + SetError(contextPtr, CONCAT_ERR_MSG); + continue; + } + output[i] = reinterpret_cast(const_cast(ret)); + } +} + +extern "C" DLLEXPORT void BatchConcatStrCharRetNull(bool *isNull, int64_t contextPtr, uint8_t **ap, int32_t *apLen, + uint8_t **bp, int32_t bWidth, int32_t *bpLen, uint8_t **output, int32_t *outLen, int32_t rowCnt) +{ + for (int32_t i = 0; i < rowCnt; i++) { + auto ret = StringUtil::ConcatStrDiffWidths(contextPtr, reinterpret_cast(ap[i]), apLen[i], + reinterpret_cast(bp[i]), bpLen[i], isNull + i, outLen + i); + output[i] = reinterpret_cast(const_cast(ret)); + } +} + +extern "C" DLLEXPORT void BatchCastStrWithDiffWidths(int64_t contextPtr, uint8_t **str, int32_t srcWidth, + int32_t *strLen, bool *isAnyNull, uint8_t **output, int32_t *outLen, int32_t dstWidth, int32_t rowCnt) +{ + for (int32_t i = 0; i < rowCnt; i++) { + if (isAnyNull[i]) { + outLen[i] = 0; + output[i] = nullptr; + continue; + } + bool hasErr = false; + const char *ret = StringUtil::CastStrStr(&hasErr, reinterpret_cast(str[i]), srcWidth, strLen[i], + outLen + i, dstWidth); + if (hasErr) { + std::ostringstream errorMessage; + errorMessage << "cast varchar[" << srcWidth << "] to varchar[" << dstWidth << "] failed."; + SetError(contextPtr, errorMessage.str()); + continue; + } + output[i] = reinterpret_cast(const_cast(ret)); + } +} + +extern "C" DLLEXPORT void BatchCastStrWithDiffWidthsRetNull(bool *isNull, int64_t contextPtr, uint8_t **srcStr, + int32_t srcWidth, int32_t *strLen, uint8_t **output, int32_t *outLen, int32_t dstWidth, int32_t rowCnt) +{ + for (int32_t i = 0; i < rowCnt; ++i) { + auto ret = StringUtil::CastStrStr(isNull + i, reinterpret_cast(srcStr[i]), srcWidth, strLen[i], + outLen + i, dstWidth); + output[i] = reinterpret_cast(const_cast(ret)); + } +} + +extern "C" DLLEXPORT void BatchLengthChar(uint8_t **str, const int32_t width, int32_t *strLen, bool *isAnyNull, + int64_t *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + output[i] = width; + } +} + +extern "C" DLLEXPORT void BatchLengthCharReturnInt32(uint8_t **str, const int32_t width, int32_t *strLen, + bool *isAnyNull, int32_t *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + output[i] = width; + } +} + +extern "C" DLLEXPORT void BatchLengthStr(uint8_t **str, int32_t *strLen, bool *isAnyNull, int64_t *output, + int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + output[i] = 0; + continue; + } + output[i] = omniruntime::Utf8Util::CountCodePoints(reinterpret_cast(str[i]), strLen[i]); + } +} + +extern "C" DLLEXPORT void BatchLengthStrReturnInt32(uint8_t **str, int32_t *strLen, bool *isAnyNull, int32_t *output, + int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + output[i] = 0; + continue; + } + output[i] = omniruntime::Utf8Util::CountCodePoints(reinterpret_cast(str[i]), strLen[i]); + } +} + +extern "C" DLLEXPORT void BatchReplaceStrStrStrWithRepNotReplace(int64_t contextPtr, uint8_t **str, int32_t *strLen, + uint8_t **searchStr, int32_t *searchLen, uint8_t **replaceStr, int32_t *replaceLen, bool *isAnyNull, + uint8_t **output, int32_t *outLen, int32_t rowCnt) +{ + ReplaceWithReplaceNotEmpty(contextPtr, str, strLen, searchStr, searchLen, replaceStr, replaceLen, isAnyNull, output, + outLen, rowCnt, [str, strLen, outLen](bool *hasErr, int32_t i) -> uint8_t * { + outLen[i] = strLen[i]; + return str[i]; + }); +} + +extern "C" DLLEXPORT void BatchReplaceStrStrWithoutRepNotReplace(int64_t contextPtr, uint8_t **str, int32_t *strLen, + uint8_t **searchStr, int32_t *searchLen, bool *isAnyNull, uint8_t **output, int32_t *outLen, int32_t rowCnt) +{ + ReplaceWithReplaceEmpty(contextPtr, str, strLen, searchStr, searchLen, isAnyNull, output, outLen, rowCnt, + [str, strLen, outLen](bool *hasErr, int32_t index) -> uint8_t * { + outLen[index] = strLen[index]; + return str[index]; + }); +} + +extern "C" DLLEXPORT void BatchReplaceStrStrStrWithRepReplace(int64_t contextPtr, uint8_t **str, int32_t *strLen, + uint8_t **searchStr, int32_t *searchLen, uint8_t **replaceStr, int32_t *replaceLen, bool *isAnyNull, + uint8_t **output, int32_t *outLen, int32_t rowCnt) +{ + ReplaceWithReplaceNotEmpty(contextPtr, str, strLen, searchStr, searchLen, replaceStr, replaceLen, isAnyNull, output, + outLen, rowCnt, + [contextPtr, str, strLen, replaceStr, replaceLen, outLen](bool *hasErr, int32_t index) -> uint8_t * { + auto result = StringUtil::ReplaceWithSearchEmpty(contextPtr, reinterpret_cast(str[index]), + strLen[index], reinterpret_cast(replaceStr[index]), replaceLen[index], hasErr, + outLen + index); + return reinterpret_cast(const_cast(result)); + }); +} + +extern "C" DLLEXPORT void BatchReplaceStrStrWithoutRepReplace(int64_t contextPtr, uint8_t **str, int32_t *strLen, + uint8_t **searchStr, int32_t *searchLen, bool *isAnyNull, uint8_t **output, int32_t *outLen, int32_t rowCnt) +{ + ReplaceWithReplaceEmpty(contextPtr, str, strLen, searchStr, searchLen, isAnyNull, output, outLen, rowCnt, + [contextPtr, str, strLen, outLen](bool *hasErr, int32_t index) -> uint8_t * { + auto result = StringUtil::ReplaceWithSearchEmpty(contextPtr, reinterpret_cast(str[index]), + strLen[index], reinterpret_cast(EMPTY), 0, hasErr, outLen + index); + return reinterpret_cast(const_cast(result)); + }); +} + +extern "C" DLLEXPORT void BatchInStr(char **srcStrs, int32_t *srcLens, char **subStrs, int32_t *subLens, + bool *isAnyNull, int32_t *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + auto srcLen = srcLens[i]; + auto subLen = subLens[i]; + // currently return 0 if not found that means 1-based + if (isAnyNull[i] || subLen > srcLen) { + output[i] = 0; + continue; + } + if (subLen == 0) { + output[i] = 1; + continue; + } + int32_t tailPos = srcLen - subLen; + int32_t cmpLen = subLen - 1; + auto srcStr = srcStrs[i]; + auto subStr = subStrs[i]; + int32_t result = 0; + int32_t pos = 0; + for (; pos <= tailPos; ++pos) { + if (srcStr[pos] == subStr[0] && memcmp(srcStr + pos + 1, subStr + 1, cmpLen) == 0) { + result = pos + 1; + break; + } + } + output[i] = result; + } +} + +extern "C" DLLEXPORT void BatchStartsWithStr(char **srcStrs, int32_t *srcLens, char **matchStrs, int32_t *matchLens, + bool *isAnyNull, bool *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + auto srcLen = srcLens[i]; + auto matchLen = matchLens[i]; + if (isAnyNull[i] || matchLen > srcLen) { + output[i] = false; + continue; + } + if (matchLen == 0) { + output[i] = true; + continue; + } + output[i] = memcmp(srcStrs[i], matchStrs[i], matchLen) == 0; + } +} + +extern "C" DLLEXPORT void BatchEndsWithStr(char **srcStrs, int32_t *srcLens, char **matchStrs, int32_t *matchLens, + bool *isAnyNull, bool *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + auto srcLen = srcLens[i]; + auto matchLen = matchLens[i]; + if (isAnyNull[i] || matchLen > srcLen) { + output[i] = false; + continue; + } + if (matchLen == 0) { + output[i] = true; + continue; + } + output[i] = memcmp(srcStrs[i] + srcLen - matchLen, matchStrs[i], matchLen) == 0; + } +} + +extern "C" DLLEXPORT void BatchMd5Str(int64_t contextPtr, uint8_t **str, int32_t *strLen, bool *isAnyNull, + uint8_t **output, int32_t *outLen, int32_t rowCnt) +{ + for (int32_t i = 0; i < rowCnt; i++) { + if (isAnyNull[i]) { + outLen[i] = 0; + output[i] = nullptr; + continue; + } + Md5Function md5(reinterpret_cast(str[i]), strLen[i]); + outLen[i] = 32; + char *mdString = ArenaAllocatorMalloc(contextPtr, 32); + md5.FinishHex(mdString); + output[i] = reinterpret_cast(const_cast(mdString)); + } +} + +extern "C" DLLEXPORT void BatchEmptyToNull(char **str, int32_t *strLen, bool *isAnyNull, char **output, int32_t *outLen, + int32_t rowCnt) +{ + for (int32_t i = 0; i < rowCnt; i++) { + if (strLen[i] == 0 || isAnyNull[i]) { + output[i] = nullptr; + outLen[i] = 0; + continue; + } + output[i] = str[i]; + outLen[i] = strLen[i]; + } +} + +extern "C" DLLEXPORT void BatchContainsStr(char **srcStrs, int32_t *srcLens, char **matchStrs, int32_t *matchLens, + bool *isAnyNull, bool *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + auto srcLen = srcLens[i]; + auto matchLen = matchLens[i]; + if (isAnyNull[i] || matchLen > srcLen) { + output[i] = false; + continue; + } + if (matchLen == 0) { + output[i] = true; + continue; + } + output[i] = StringUtil::StrContainsStr(srcStrs[i], srcLen, matchStrs[i], matchLen); + } +} + +extern "C" DLLEXPORT void BatchGreatestStr(uint8_t **xStr, int32_t *xStrLen, bool *xIsNull, uint8_t **yStr, + int32_t *yStrLen, bool *yIsNull, bool *retIsNull, uint8_t **outStr, int32_t *outStrLen, int32_t rowCnt) +{ + int32_t cmpRet; + for (int i = 0; i < rowCnt; ++i) { + if (xIsNull[i] && yIsNull[i]) { + retIsNull[i] = true; + outStr[i] = nullptr; + outStrLen[i] = 0; + continue; + } + if (xIsNull[i]) { + outStr[i] = yStr[i]; + outStrLen[i] = yStrLen[i]; + continue; + } + if (!yIsNull[i]) { + cmpRet = memcmp(xStr[i], yStr[i], std::min(xStrLen[i], yStrLen[i])); + if (cmpRet < 0 || (cmpRet == 0 && yStrLen[i] > xStrLen[i])) { + outStr[i] = yStr[i]; + outStrLen[i] = yStrLen[i]; + continue; + } + } + outStr[i] = xStr[i]; + outStrLen[i] = xStrLen[i]; + } +} + +extern "C" DLLEXPORT void BatchStaticInvokeVarcharTypeWriteSideCheck(int64_t contextPtr, char **str, int32_t *strLen, + int32_t limit, bool *isAnyNull, char **outputStr, int32_t *outputLen, int32_t rowCnt) +{ + for (int32_t i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + outputLen[i] = 0; + outputStr[i] = nullptr; + continue; + } + char *ss = str[i]; + int32_t len = strLen[i]; + int32_t ssLen = StringUtil::NumChars(ss, len); + if (ssLen <= limit) { + outputStr[i] = ss; + outputLen[i] = len; + continue; + } + int32_t numTailSpacesToTrim = ssLen - limit; + int32_t endIdx = len - 1; + int32_t trimTo = len - numTailSpacesToTrim; + while (endIdx >= trimTo && ss[endIdx] == 0x20) { + endIdx--; + } + int32_t outByteNum = endIdx + 1; + if (ssLen > limit) { + std::ostringstream errorMessage; + errorMessage << "Exceeds varchar type length limitation: " << limit; + SetError(contextPtr, errorMessage.str()); + outputLen[i] = 0; + outputStr[i] = nullptr; + continue; + } + auto padded = ArenaAllocatorMalloc(contextPtr, outByteNum); + errno_t res = memcpy_s(padded, outByteNum, ss, outByteNum); + if (res != EOK) { + SetError(contextPtr, "varcharTypeWriteSideCheck failed:memcpy_s error"); + outputLen[i] = 0; + outputStr[i] = nullptr; + continue; + } + padded[outByteNum] = '\0'; + outputLen[i] = outByteNum; + outputStr[i] = padded; + } +} + +extern "C" DLLEXPORT void BatchStaticInvokeCharReadPadding(int64_t contextPtr, char **str, + int32_t *strLen, int32_t limit, bool *isAnyNull, char **outputStr, int32_t *outputLen, int32_t rowCnt) +{ + for (int32_t i = 0; i < rowCnt; ++i) { + if (isAnyNull[i]) { + outputLen[i] = 0; + outputStr[i] = nullptr; + continue; + } else if (strLen[i] == 0) { + outputLen[i] = 0; + outputStr[i] = ""; + continue; + } + char *ss = str[i]; + int32_t len = strLen[i]; + int32_t ssLen = StringUtil::NumChars(ss, len); + if (ssLen >= limit) { + outputStr[i] = ss; + outputLen[i] = len; + continue; + } + int32_t diff = limit - ssLen; + int32_t outByteNum = len + diff + 1; + auto padded = ArenaAllocatorMalloc(contextPtr, outByteNum); + errno_t res = memcpy_s(padded, len, ss, len); + if (res != EOK) { + SetError(contextPtr, "BatchStaticInvokeCharReadPadding failed:memcpy_s error"); + outputLen[i] = 0; + outputStr[i] = nullptr; + continue; + } + res = memset_s(padded + len, diff, ' ', diff); + if (res != EOK) { + SetError(contextPtr, "BatchStaticInvokeCharReadPadding failed:memcpy_s error"); + outputLen[i] = 0; + outputStr[i] = nullptr; + continue; + } + padded[outByteNum] = '\0'; + outputLen[i] = outByteNum - 1; + outputStr[i] = padded; + } +} +} \ No newline at end of file diff --git a/core/src/codegen/batch_functions/batch_stringfunctions.h b/core/src/codegen/batch_functions/batch_stringfunctions.h new file mode 100644 index 0000000..0aa200c --- /dev/null +++ b/core/src/codegen/batch_functions/batch_stringfunctions.h @@ -0,0 +1,430 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. + * Description: batch string functions implementation + */ +#ifndef OMNI_RUNTIME_BATCH_STRINGFUNCTIONS_H +#define OMNI_RUNTIME_BATCH_STRINGFUNCTIONS_H + +#include +#include +#include +#include +#include +#include +#include "util/utf8_util.h" +#include "codegen/context_helper.h" +#include "codegen/string_util.h" +#include "type/decimal128.h" +#include "type/decimal_operations.h" +#include "util/config_util.h" + +#ifdef _WIN32 +#define DLLEXPORT __declspec(dllexport) +#else +#define DLLEXPORT +#endif + +using namespace omniruntime::type; +namespace omniruntime::codegen::function { +// string compare functions +extern "C" DLLEXPORT void BatchStrCompare(uint8_t **ap, int32_t *apLen, uint8_t **bp, int32_t *bpLen, int32_t *res, + int32_t rowCnt); + +extern "C" DLLEXPORT void BatchLessThanStr(uint8_t **ap, int32_t *apLen, uint8_t **bp, int32_t *bpLen, bool *res, + int32_t rowCnt); + +extern "C" DLLEXPORT void BatchLessThanEqualStr(uint8_t **ap, int32_t *apLen, uint8_t **bp, int32_t *bpLen, bool *res, + int32_t rowCnt); + +extern "C" DLLEXPORT void BatchGreaterThanStr(uint8_t **ap, int32_t *apLen, uint8_t **bp, int32_t *bpLen, bool *res, + int32_t rowCnt); + +extern "C" DLLEXPORT void BatchGreaterThanEqualStr(uint8_t **ap, int32_t *apLen, uint8_t **bp, int32_t *bpLen, + bool *res, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchEqualStr(uint8_t **ap, int32_t *apLen, uint8_t **bp, int32_t *bpLen, bool *res, + int32_t rowCnt); + +extern "C" DLLEXPORT void BatchNotEqualStr(uint8_t **ap, int32_t *apLen, uint8_t **bp, int32_t *bpLen, bool *res, + int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastStringToDateAllowReducePrecison(int64_t contextPtr, uint8_t **str, int32_t *strLen, + bool *isAnyNull, int32_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastStringToDateNotAllowReducePrecison(int64_t contextPtr, uint8_t **str, + int32_t *strLen, bool *isAnyNull, int32_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastStringToDateRetNullNotAllowReducePrecison(bool *isNull, uint8_t **str, + int32_t *strLen, int32_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastStringToDateRetNullAllowReducePrecison(bool *isNull, uint8_t **str, int32_t *strLen, + int32_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastIntToString(int64_t contextPtr, int32_t *value, bool *isAnyNull, uint8_t **output, + int32_t *outLen, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastLongToString(int64_t contextPtr, int64_t *value, bool *isAnyNull, uint8_t **output, + int32_t *outLen, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDoubleToString(int64_t contextPtr, double *value, bool *isAnyNull, uint8_t **output, + int32_t *outLen, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDecimal64ToString(int64_t contextPtr, int64_t *x, int32_t precision, int32_t scale, + bool *isAnyNull, uint8_t **output, int32_t *outLen, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDecimal128ToString(int64_t contextPtr, Decimal128 *x, int32_t precision, + int32_t scale, bool *isAnyNull, uint8_t **output, int32_t *outLen, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastIntToStringRetNull(bool *isNull, int64_t contextPtr, int32_t *value, + uint8_t **output, int32_t *outLen, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastLongToStringRetNull(bool *isNull, int64_t contextPtr, int64_t *value, + uint8_t **output, int32_t *outLen, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDoubleToStringRetNull(bool *isNull, int64_t contextPtr, double *value, + uint8_t **output, int32_t *outLen, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDecimal64ToStringRetNull(bool *isNull, int64_t contextPtr, int64_t *x, + int32_t precision, int32_t scale, uint8_t **output, int32_t *outLen, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastDecimal128ToStringRetNull(bool *isNull, int64_t contextPtr, Decimal128 *inputDecimal, + int32_t precision, int32_t scale, uint8_t **output, int32_t *outLen, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastStringToDecimal64(int64_t contextPtr, uint8_t **str, int32_t *strLen, + bool *isAnyNull, int64_t *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastStringToDecimal128(int64_t contextPtr, uint8_t **str, int32_t *strLen, + bool *isAnyNull, Decimal128 *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastStringToDecimal64RoundUp(int64_t contextPtr, uint8_t **str, int32_t *strLen, + bool *isAnyNull, int64_t *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastStringToDecimal128RoundUp(int64_t contextPtr, uint8_t **str, int32_t *strLen, + bool *isAnyNull, Decimal128 *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastStringToInt(int64_t contextPtr, uint8_t **str, int32_t *strLen, bool *isAnyNull, + int32_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastStringToLong(int64_t contextPtr, uint8_t **str, int32_t *strLen, bool *isAnyNull, + int64_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastStringToDouble(int64_t contextPtr, uint8_t **str, int32_t *strLen, bool *isAnyNull, + double *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastStringToDecimal64RetNull(bool *isNull, uint8_t **str, int32_t *strLen, + int64_t *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastStringToDecimal128RetNull(bool *isNull, uint8_t **str, int32_t *strLen, + Decimal128 *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastStringToDecimal64RoundUpRetNull(bool *isNull, uint8_t **str, int32_t *strLen, + int64_t *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastStringToDecimal128RoundUpRetNull(bool *isNull, uint8_t **str, int32_t *strLen, + Decimal128 *output, int32_t outPrecision, int32_t outScale, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastStringToIntRetNull(bool *isNull, uint8_t **str, int32_t *strLen, int32_t *output, + int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastStringToLongRetNull(bool *isNull, uint8_t **str, int32_t *strLen, int64_t *output, + int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastStringToDoubleRetNull(bool *isNull, uint8_t **str, int32_t *strLen, double *output, + int32_t rowCnt); + +extern "C" DLLEXPORT void BatchToUpperStr(int64_t contextPtr, uint8_t **str, int32_t *strLen, bool *isAnyNull, + uint8_t **output, int32_t *outLen, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchToUpperChar(int64_t contextPtr, uint8_t **str, int32_t width, int32_t *strLen, + bool *isAnyNull, uint8_t **output, int32_t *outLen, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchToLowerStr(int64_t contextPtr, uint8_t **str, int32_t *strLen, bool *isAnyNull, + uint8_t **output, int32_t *outLen, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchToLowerChar(int64_t contextPtr, uint8_t **str, int32_t width, int32_t *strLen, + bool *isAnyNull, uint8_t **output, int32_t *outLen, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchLikeStr(uint8_t **str, int32_t *strLen, uint8_t **regexToMatch, int32_t *regexLen, + bool *isAnyNull, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchLikeChar(uint8_t **str, int32_t strWidth, int32_t *strLen, uint8_t **regexToMatch, + int32_t *regexLen, bool *isAnyNull, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchConcatStrStr(int64_t contextPtr, uint8_t **ap, int32_t *apLen, uint8_t **bp, + int32_t *bpLen, bool *isAnyNull, uint8_t **output, int32_t *outLen, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchConcatStrStrRetNull(bool *isNull, int64_t contextPtr, uint8_t **ap, int32_t *apLen, + uint8_t **bp, int32_t *bpLen, uint8_t **output, int32_t *outLen, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchConcatCharChar(int64_t contextPtr, uint8_t **ap, int32_t aWidth, int32_t *apLen, + uint8_t **bp, int32_t bWidth, int32_t *bpLen, bool *isAnyNull, uint8_t **output, int32_t *outLen, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchConcatCharCharRetNull(bool *isNull, int64_t contextPtr, uint8_t **ap, int32_t aWidth, + int32_t *apLen, uint8_t **bp, int32_t bWidth, int32_t *bpLen, uint8_t **output, int32_t *outLen, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchConcatCharStr(int64_t contextPtr, uint8_t **ap, int32_t aWidth, int32_t *apLen, + uint8_t **bp, int32_t *bpLen, bool *isAnyNull, uint8_t **output, int32_t *outLen, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchConcatCharStrRetNull(bool *isNull, int64_t contextPtr, uint8_t **ap, int32_t aWidth, + int32_t *apLen, uint8_t **bp, int32_t *bpLen, uint8_t **output, int32_t *outLen, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchConcatStrChar(int64_t contextPtr, uint8_t **ap, int32_t *apLen, uint8_t **bp, + int32_t bWidth, int32_t *bpLen, bool *isAnyNull, uint8_t **output, int32_t *outLen, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchConcatStrCharRetNull(bool *isNull, int64_t contextPtr, uint8_t **ap, int32_t *apLen, + uint8_t **bp, int32_t bWidth, int32_t *bpLen, uint8_t **output, int32_t *outLen, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastStrWithDiffWidths(int64_t contextPtr, uint8_t **str, int32_t srcWidth, + int32_t *strLen, bool *isAnyNull, uint8_t **output, int32_t *outLen, int32_t dstWidth, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchCastStrWithDiffWidthsRetNull(bool *isNull, int64_t contextPtr, uint8_t **str, + int32_t srcWidth, int32_t *strLen, uint8_t **output, int32_t *outLen, int32_t dstWidth, int32_t rowCnt); + +/** + * If isSupportNegativeIndex is false,the result of substr is "" when start index is negative + * If isSupportNegativeIndex is true,the substr rule is as follows: + * e.g., str="apple", strLength=5, startIndex=-7, subStringLength=3, Result="a". + * If isSupportZeroIndex is false,the result of substr is "" when start index is 0 + * If isSupportZeroIndex is true,it refers to the first element when the start index is 0 + */ +template +extern DLLEXPORT void BatchSubstrVarchar(int64_t contextPtr, uint8_t **str, int32_t *strLen, T *startIdx, T *length, + bool *isAnyNull, uint8_t **output, int32_t *outLen, int32_t rowCnt) +{ + for (int32_t i = 0; i < rowCnt; i++) { + if (isAnyNull[i]) { + outLen[i] = 0; + output[i] = nullptr; + continue; + } + + int64_t endIdx; + int64_t startIndex; + if constexpr (isSupportZeroIndex) { + startIdx[i] = (startIdx[i] == 0) ? 1 : startIdx[i]; + } + int64_t startCodePoint = startIdx[i]; + int64_t lengthCodePoint = length[i]; + int32_t len = strLen[i]; + output[i] = const_cast(EMPTY); + if (startCodePoint == 0 || (lengthCodePoint <= 0) || (len == 0) || startCodePoint > len) { + outLen[i] = 0; + continue; + } + const char *charStr = reinterpret_cast(str[i]); + if (startCodePoint > 0) { + startIndex = omniruntime::Utf8Util::OffsetOfCodePoint(charStr, len, startCodePoint - 1); + if (startIndex < 0) { + // before beginning of string + outLen[i] = 0; + continue; + } + endIdx = omniruntime::Utf8Util::OffsetOfCodePoint(charStr, len, startIndex, lengthCodePoint); + if (endIdx < 0) { + // after end of string + endIdx = len; + } + } else { + // negative start is relative to end of string + int32_t codePoints = omniruntime::Utf8Util::CountCodePoints(charStr, len); + startCodePoint += codePoints; + // before beginning of string + if (startCodePoint < 0) { + if constexpr (!isSupportNegativeIndex) { + outLen[i] = 0; + continue; + } else { + lengthCodePoint += startCodePoint; + startCodePoint = 0; + } + } + startIndex = omniruntime::Utf8Util::OffsetOfCodePoint(charStr, len, startCodePoint); + endIdx = startCodePoint + lengthCodePoint < codePoints ? + omniruntime::Utf8Util::OffsetOfCodePoint(charStr, len, startIndex, lengthCodePoint) : + len; + } + + outLen[i] = endIdx - startIndex; + output[i] = str[i] + startIndex; + } +} + +template +extern DLLEXPORT void BatchSubstrChar(int64_t contextPtr, uint8_t **str, int32_t width, int32_t *strLen, T *startIdx, + T *length, bool *isAnyNull, uint8_t **output, int32_t *outLen, int32_t rowCnt) +{ + BatchSubstrVarchar(contextPtr, str, strLen, startIdx, length, + isAnyNull, output, outLen, rowCnt); +} + +template +extern DLLEXPORT void BatchSubstrVarcharWithStart(int64_t contextPtr, uint8_t **str, int32_t *strLen, T *startIdx, + bool *isAnyNull, uint8_t **output, int32_t *outLen, int32_t rowCnt) +{ + int64_t startIndex; + for (int32_t i = 0; i < rowCnt; i++) { + if (isAnyNull[i]) { + outLen[i] = 0; + output[i] = nullptr; + continue; + } + if constexpr (isSupportZeroIndex) { + startIdx[i] = (startIdx[i] == 0) ? 1 : startIdx[i]; + } + int64_t startCodePoint = startIdx[i]; + int32_t len = strLen[i]; + output[i] = const_cast(EMPTY); + if (startCodePoint == 0 || len == 0 || startCodePoint > len) { + outLen[i] = 0; + continue; + } + + const char *charStr = reinterpret_cast(str[i]); + if (startCodePoint > 0) { + startIndex = omniruntime::Utf8Util::OffsetOfCodePoint(charStr, len, startCodePoint - 1); + if (startIndex < 0) { + outLen[i] = 0; + continue; + } + } else { + // negative start is relative to end of string + int32_t codePoints = omniruntime::Utf8Util::CountCodePoints(charStr, len); + startCodePoint += codePoints; + if (startCodePoint < 0) { + if constexpr (!isSupportNegativeIndex) { + outLen[i] = 0; + continue; + } else { + startCodePoint = 0; + } + } + + startIndex = omniruntime::Utf8Util::OffsetOfCodePoint(charStr, len, startCodePoint); + } + + outLen[i] = len - startIndex; + output[i] = str[i] + startIndex; + } +} + +template +extern DLLEXPORT void BatchSubstrCharWithStart(int64_t contextPtr, uint8_t **str, int32_t width, int32_t *strLen, + T *startIdx, bool *isAnyNull, uint8_t **output, int32_t *outLen, int32_t rowCnt) +{ + BatchSubstrVarcharWithStart(contextPtr, str, strLen, startIdx, + isAnyNull, output, outLen, rowCnt); +} + +extern "C" DLLEXPORT void BatchLengthChar(uint8_t **str, const int32_t width, int32_t *strLen, bool *isAnyNull, + int64_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchLengthCharReturnInt32(uint8_t **str, const int32_t width, int32_t *strLen, + bool *isAnyNull, int32_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchLengthStr(uint8_t **str, int32_t *strLen, bool *isAnyNull, int64_t *output, + int32_t rowCnt); + +extern "C" DLLEXPORT void BatchLengthStrReturnInt32(uint8_t **str, int32_t *strLen, bool *isAnyNull, int32_t *output, + int32_t rowCnt); + +extern "C" DLLEXPORT void BatchReplaceStrStrStrWithRepNotReplace(int64_t contextPtr, uint8_t **str, int32_t *strLen, + uint8_t **searchStr, int32_t *searchLen, uint8_t **replaceStr, int32_t *replaceLen, bool *isAnyNull, + uint8_t **output, int32_t *outLen, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchReplaceStrStrWithoutRepNotReplace(int64_t contextPtr, uint8_t **str, int32_t *strLen, + uint8_t **searchStr, int32_t *searchLen, bool *isAnyNull, uint8_t **output, int32_t *outLen, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchReplaceStrStrStrWithRepReplace(int64_t contextPtr, uint8_t **str, int32_t *strLen, + uint8_t **searchStr, int32_t *searchLen, uint8_t **replaceStr, int32_t *replaceLen, bool *isAnyNull, + uint8_t **output, int32_t *outLen, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchReplaceStrStrWithoutRepReplace(int64_t contextPtr, uint8_t **str, int32_t *strLen, + uint8_t **searchStr, int32_t *searchLen, bool *isAnyNull, uint8_t **output, int32_t *outLen, int32_t rowCnt); + +template +static inline void ReplaceWithReplaceNotEmpty(int64_t contextPtr, uint8_t **str, int32_t *strLen, uint8_t **searchStr, + int32_t *searchLen, uint8_t **replaceStr, int32_t *replaceLen, bool *isAnyNull, uint8_t **output, int32_t *outLen, + int32_t rowCnt, L lambda) +{ + bool hasErr; + uint8_t *ret; + for (int32_t i = 0; i < rowCnt; i++) { + if (isAnyNull[i]) { + outLen[i] = 0; + output[i] = nullptr; + continue; + } + hasErr = false; + if (searchLen[i] == 0) { + ret = lambda(&hasErr, i); + } else { + auto result = StringUtil::ReplaceWithSearchNotEmpty(contextPtr, reinterpret_cast(str[i]), + strLen[i], reinterpret_cast(searchStr[i]), searchLen[i], + reinterpret_cast(replaceStr[i]), replaceLen[i], &hasErr, outLen + i); + ret = reinterpret_cast(const_cast(result)); + } + + if (hasErr) { + SetError(contextPtr, REPLACE_ERR_MSG); + } + output[i] = ret; + } +} + +template +static inline void ReplaceWithReplaceEmpty(int64_t contextPtr, uint8_t **str, int32_t *strLen, uint8_t **searchStr, + int32_t *searchLen, bool *isAnyNull, uint8_t **output, int32_t *outLen, int32_t rowCnt, L lambda) +{ + bool hasErr; + uint8_t *ret; + for (int32_t i = 0; i < rowCnt; i++) { + if (isAnyNull[i]) { + outLen[i] = 0; + output[i] = nullptr; + continue; + } + hasErr = false; + if (searchLen[i] == 0) { + ret = lambda(&hasErr, i); + } else { + auto result = StringUtil::ReplaceWithSearchNotEmpty(contextPtr, reinterpret_cast(str[i]), strLen[i], + reinterpret_cast(searchStr[i]), searchLen[i], reinterpret_cast(EMPTY), 0, &hasErr, + outLen + i); + ret = reinterpret_cast(const_cast(result)); + } + + if (hasErr) { + SetError(contextPtr, REPLACE_ERR_MSG); + } + output[i] = ret; + } +} + +extern "C" DLLEXPORT void BatchInStr(char **srcStrs, int32_t *srcLens, char **subStrs, int32_t *subLens, + bool *isAnyNull, int32_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchStartsWithStr(char **srcStrs, int32_t *srcLens, char **matchStrs, int32_t *matchLens, + bool *isAnyNull, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchEndsWithStr(char **srcStrs, int32_t *srcLens, char **matchStrs, int32_t *matchLens, + bool *isAnyNull, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchMd5Str(int64_t contextPtr, uint8_t **str, int32_t *strLen, bool *isAnyNull, + uint8_t **output, int32_t *outLen, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchEmptyToNull(char **str, int32_t *strLen, bool *isAnyNull, char **output, int32_t *outLen, + int32_t rowCnt); + +extern "C" DLLEXPORT void BatchContainsStr(char **srcStrs, int32_t *srcLens, char **matchStrs, int32_t *matchLens, + bool *isAnyNull, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchGreatestStr(uint8_t **xStr, int32_t *xStrLen, bool *xIsNull, uint8_t **yStr, + int32_t *yStrLen, bool *yIsNull, bool *retIsNull, uint8_t **outStr, int32_t *outStrLen, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchStaticInvokeVarcharTypeWriteSideCheck(int64_t contextPtr, char **str, int32_t *strLen, + int32_t limit, bool *isAnyNull, char **outputStr, int32_t *outputLen, int32_t rowCnt); +} + +extern "C" DLLEXPORT void BatchStaticInvokeCharReadPadding(int64_t contextPtr, char **str, + int32_t *strLen, int32_t limit, bool *isAnyNull, char **outputStr, int32_t *outputLen, int32_t rowCnt); +#endif // OMNI_RUNTIME_BATCH_STRINGFUNCTIONS_H \ No newline at end of file diff --git a/core/src/codegen/batch_functions/batch_utilfunctions.cpp b/core/src/codegen/batch_functions/batch_utilfunctions.cpp new file mode 100644 index 0000000..e3d3d0c --- /dev/null +++ b/core/src/codegen/batch_functions/batch_utilfunctions.cpp @@ -0,0 +1,299 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: batch util functions implementation + */ + +#include "batch_utilfunctions.h" +#include "codegen/context_helper.h" +#include "codegen/functions/stringfunctions.h" + +namespace omniruntime::codegen::function { +extern "C" DLLEXPORT void FillRowIndexArray(int32_t *dataArray, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + dataArray[i] = i; + } +} + +extern "C" DLLEXPORT void FillNull(bool *nullArray, bool isNull, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + nullArray[i] = isNull; + } +} + +extern "C" DLLEXPORT void FillBool(int32_t *dataArray, bool *nullArray, bool literal, bool isNull, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + dataArray[i] = literal; + nullArray[i] = isNull; + } +} + +extern "C" DLLEXPORT void FillInt32(int32_t *dataArray, bool *nullArray, int32_t literal, bool isNull, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + dataArray[i] = literal; + nullArray[i] = isNull; + } +} + +extern "C" DLLEXPORT void FillInt64(int64_t *dataArray, bool *nullArray, int64_t literal, bool isNull, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + dataArray[i] = literal; + nullArray[i] = isNull; + } +} + +extern "C" DLLEXPORT void FillDouble(double *dataArray, bool *nullArray, double literal, bool isNull, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + dataArray[i] = literal; + nullArray[i] = isNull; + } +} + +extern "C" DLLEXPORT void FillDecimal128(Decimal128 *dataArray, bool *nullArray, __int128_t literal, bool isNull, + int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + dataArray[i].SetValue(literal); + nullArray[i] = isNull; + } +} + +extern "C" DLLEXPORT void FillString(int64_t contextPtr, uint8_t **dataArray, bool *nullArray, int32_t *lengthArray, + uint8_t *literal, bool isNull, int32_t length, int32_t rowCnt) +{ + errno_t err; + char *ret; + for (int i = 0; i < rowCnt; i++) { + ret = ArenaAllocatorMalloc(contextPtr, length + 1); + err = memcpy_s(ret, length + 1, literal, length); + if (err != EOK) { + SetError(contextPtr, "Fill string failed"); + dataArray[i] = nullptr; + nullArray[i] = true; + lengthArray[i] = 0; + continue; + } + dataArray[i] = reinterpret_cast(ret); + nullArray[i] = isNull; + lengthArray[i] = length; + } +} + +extern "C" DLLEXPORT void FillLength(int32_t *offsets, int32_t *rowIdxArray, int32_t *lengthArray, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + lengthArray[i] = offsets[rowIdxArray[i] + 1] - offsets[i]; + } +} + +extern "C" DLLEXPORT void FillLengthInFuncExpr(int32_t *lengthArray, int32_t length, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + lengthArray[i] = length; + } +} + +extern "C" DLLEXPORT void CreateNot(bool *val, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + val[i] = !val[i]; + } +} + +extern "C" DLLEXPORT void CreateOr(bool *left, bool *right, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + left[i] = left[i] || right[i]; + } +} + +extern "C" DLLEXPORT void CreateAnd(bool *left, bool *right, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + left[i] = left[i] && right[i]; + } +} + +extern "C" DLLEXPORT void CreateAndNotBool(bool *dataArray, bool *nullArray, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + dataArray[i] = dataArray[i] && (!nullArray[i]); + } +} + +extern "C" DLLEXPORT int32_t CreateAndNot(bool *dataArray, bool *nullArray, int32_t *rowIdxArray, int32_t rowCnt) +{ + int selectedCnt = 0; + for (int i = 0; i < rowCnt; i++) { + if (dataArray[i] && (!nullArray[i])) { + rowIdxArray[selectedCnt] = i; + selectedCnt++; + } + } + return selectedCnt; +} + +extern "C" DLLEXPORT void CreateOrExpr(bool *left, bool *leftNull, bool *right, bool *rightNull, int32_t rowCnt) +{ + bool res; + for (int i = 0; i < rowCnt; ++i) { + res = left[i] || right[i]; + leftNull[i] = (leftNull[i] && rightNull[i]) || (leftNull[i] && !right[i]) || (rightNull[i] && !left[i]); + left[i] = res; + } +} + +extern "C" DLLEXPORT void CreateAndExpr(bool *left, bool *leftNull, bool *right, bool *rightNull, int32_t rowCnt) +{ + bool tmpVal; + for (int i = 0; i < rowCnt; i++) { + tmpVal = left[i] && right[i]; + leftNull[i] = (leftNull[i] && rightNull[i]) || (leftNull[i] && right[i]) || (left[i] && rightNull[i]); + left[i] = tmpVal; + } +} + +extern "C" DLLEXPORT void CopyNull(bool *dataArray, bool *output, int32_t *rowIdxArray, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + dataArray[i] = output[rowIdxArray[i]]; + } +} + +extern "C" DLLEXPORT void CopyBoolean(bool *dataArray, bool *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + dataArray[i] = output[i]; + } +} + +extern "C" DLLEXPORT void CopyInt32(int32_t *dataArray, int32_t *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + dataArray[i] = output[i]; + } +} + +extern "C" DLLEXPORT void CopyInt64(int64_t *dataArray, int64_t *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + dataArray[i] = output[i]; + } +} + +extern "C" DLLEXPORT void CopyDouble(double *dataArray, double *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + dataArray[i] = output[i]; + } +} + +extern "C" DLLEXPORT void CopyDecimal128(Decimal128 *dataArray, Decimal128 *output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + dataArray[i] = output[i]; + } +} + +extern "C" DLLEXPORT void CopyString(uint8_t **dataArray, uint8_t **output, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + dataArray[i] = output[i]; + } +} + +extern "C" DLLEXPORT void IfExprString(bool *ifCond, bool *ifNull, uint8_t **trueValue, bool *trueNull, + int32_t *trueLength, uint8_t **falseValue, bool *falseNull, int32_t *falseLength, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (!(ifCond[i] && !ifNull[i])) { + trueValue[i] = falseValue[i]; + trueNull[i] = falseNull[i]; + trueLength[i] = falseLength[i]; + } + } +} + +extern "C" DLLEXPORT void SwitchExprString(int32_t whenCnt, int64_t *whenClauses, int64_t *whenBools, + int64_t *resultValues, int64_t *resultNulls, int64_t *resultLengths, uint8_t **elseValue, bool *elseNull, + int32_t *elseLength, uint8_t **finalResult, bool *finalNull, int32_t *finalLength, int32_t rowCnt) +{ + std::vector whenValues; + std::vector whenNulls; + + std::vector resValues; + std::vector resNulls; + std::vector resLengths; + + for (int i = 0; i < whenCnt; ++i) { + whenValues.push_back(reinterpret_cast(reinterpret_cast(whenClauses[i]))); + whenNulls.push_back(reinterpret_cast(reinterpret_cast(whenBools[i]))); + resValues.push_back(reinterpret_cast(reinterpret_cast(resultValues[i]))); + resNulls.push_back(reinterpret_cast(reinterpret_cast(resultNulls[i]))); + resLengths.push_back(reinterpret_cast(reinterpret_cast(resultLengths[i]))); + } + + for (int i = 0; i < rowCnt; ++i) { + bool hasSet = false; + for (int j = 0; j < whenCnt; ++j) { + if (whenValues[j][i] && !whenNulls[j][i]) { + finalResult[i] = resValues[j][i]; + finalNull[i] = resNulls[j][i]; + finalLength[i] = resLengths[j][i]; + hasSet = true; + break; + } + } + if (!hasSet) { + finalResult[i] = elseValue[i]; + finalNull[i] = elseNull[i]; + finalLength[i] = elseLength[i]; + } + } +} + +extern "C" DLLEXPORT void CoalesceString(uint8_t **lArray, bool *lIsNull, int32_t *lLength, uint8_t **rArray, + bool *rIsNull, int32_t *rLength, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + if (lIsNull[i]) { + lArray[i] = rArray[i]; + lIsNull[i] = rIsNull[i]; + lLength[i] = rLength[i]; + } + } +} + +extern "C" DLLEXPORT void InExprString(int32_t cmpCnt, int64_t *cmpValues, int64_t *cmpBools, int64_t *cmpLengths, + uint8_t **toCmpValue, bool *toCmpBool, int32_t *toCmpLength, bool *finalResult, bool *finalNull, int32_t rowCnt) +{ + std::vector cmpValuesList; + std::vector cmpNullsList; + std::vector cmpLengthsList; + + for (int32_t i = 0; i < cmpCnt; ++i) { + cmpValuesList.push_back(reinterpret_cast(reinterpret_cast(cmpValues[i]))); + cmpNullsList.push_back(reinterpret_cast(reinterpret_cast(cmpBools[i]))); + cmpLengthsList.push_back(reinterpret_cast(reinterpret_cast(cmpLengths[i]))); + } + + for (int32_t i = 0; i < rowCnt; ++i) { + finalResult[i] = false; + finalNull[i] = false; + for (int32_t j = 0; j < cmpCnt; ++j) { + if (!toCmpBool[i] && !cmpNullsList[j][i]) { + if (StrCompare(reinterpret_cast(toCmpValue[i]), toCmpLength[i], + reinterpret_cast(cmpValuesList[j][i]), cmpLengthsList[j][i]) == 0) { + finalResult[i] = true; + break; + } + } + } + } +} +} diff --git a/core/src/codegen/batch_functions/batch_utilfunctions.h b/core/src/codegen/batch_functions/batch_utilfunctions.h new file mode 100644 index 0000000..cd44c12 --- /dev/null +++ b/core/src/codegen/batch_functions/batch_utilfunctions.h @@ -0,0 +1,172 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: batch util functions implementation + */ + +#ifndef OMNI_RUNTIME_BATCH_UTILFUNCTIONS_H +#define OMNI_RUNTIME_BATCH_UTILFUNCTIONS_H +#include +#include +#include +#include +#include "type/decimal128.h" + +using namespace omniruntime::type; + +#ifdef _WIN32 +#define DLLEXPORT __declspec(dllexport) +#else +#define DLLEXPORT +#endif + +namespace omniruntime::codegen::function { +// make an array like {0, 1, 2, 3, ..., rowCnt - 1} +extern "C" DLLEXPORT void FillRowIndexArray(int32_t *dataArray, int32_t rowCnt); + +// convert literal to an array +extern "C" DLLEXPORT void FillNull(bool *nullArray, bool isNull, int32_t rowCnt); + +extern "C" DLLEXPORT void FillBool(int32_t *dataArray, bool *nullArray, bool literal, bool isNull, int32_t rowCnt); + +extern "C" DLLEXPORT void FillInt32(int32_t *dataArray, bool *nullArray, int32_t literal, bool isNull, int32_t rowCnt); + +extern "C" DLLEXPORT void FillInt64(int64_t *dataArray, bool *nullArray, int64_t literal, bool isNull, int32_t rowCnt); + +extern "C" DLLEXPORT void FillDouble(double *dataArray, bool *nullArray, double literal, bool isNull, int32_t rowCnt); + +extern "C" DLLEXPORT void FillDecimal128(Decimal128 *dataArray, bool *nullArray, __int128_t literal, bool isNull, + int32_t rowCnt); + +extern "C" DLLEXPORT void FillString(int64_t contextPtr, uint8_t **dataArray, bool *nullArray, int32_t *lengthArray, + uint8_t *literal, bool isNull, int32_t length, int32_t rowCnt); + +// fill varchar length array according to offset array +extern "C" DLLEXPORT void FillLength(int32_t *offsets, int32_t *rowIdxArray, int32_t *lengthArray, int32_t rowCnt); + +extern "C" DLLEXPORT void FillLengthInFuncExpr(int32_t *lengthArray, int32_t length, int32_t rowCnt); + +// logical operations for boolean +extern "C" DLLEXPORT int32_t CreateAndNot(bool *dataArray, bool *nullArray, int32_t *rowIdxArray, int32_t rowCnt); + +extern "C" DLLEXPORT void CreateNot(bool *val, int32_t rowCnt); + +extern "C" DLLEXPORT void CreateOr(bool *left, bool *right, int32_t rowCnt); + +extern "C" DLLEXPORT void CreateAnd(bool *left, bool *right, int32_t rowCnt); + +extern "C" DLLEXPORT void CreateAndNotBool(bool *dataArray, bool *nullArray, int32_t rowCnt); + +// process AND/OR expression +extern "C" DLLEXPORT void CreateOrExpr(bool *left, bool *leftNull, bool *right, bool *rightNull, int32_t rowCnt); + +extern "C" DLLEXPORT void CreateAndExpr(bool *left, bool *leftNull, bool *right, bool *rightNull, int32_t rowCnt); + +// copy result to output vector +extern "C" DLLEXPORT void CopyNull(bool *dataArray, bool *output, int32_t *rowIdxArray, int32_t rowCnt); + +extern "C" DLLEXPORT void CopyBoolean(bool *dataArray, bool *output, int32_t rowCnt); + +extern "C" DLLEXPORT void CopyInt32(int32_t *dataArray, int32_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void CopyInt64(int64_t *dataArray, int64_t *output, int32_t rowCnt); + +extern "C" DLLEXPORT void CopyDouble(double *dataArray, double *output, int32_t rowCnt); + +extern "C" DLLEXPORT void CopyDecimal128(Decimal128 *dataArray, Decimal128 *output, int32_t rowCnt); + +extern "C" DLLEXPORT void CopyString(uint8_t **dataArray, uint8_t **output, int32_t rowCnt); + +template extern DLLEXPORT void Coalesce(T *lArray, bool *lIsNull, T *rArray, bool *rIsNull, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (lIsNull[i]) { + lArray[i] = rArray[i]; + lIsNull[i] = rIsNull[i]; + } + } +} + +extern "C" DLLEXPORT void CoalesceString(uint8_t **lArray, bool *lIsNull, int32_t *lLength, uint8_t **rArray, + bool *rIsNull, int32_t *rLength, int32_t rowCnt); + +template +extern DLLEXPORT void IfExpr(bool *ifCond, bool *ifNull, T *trueValue, bool *trueNull, T *falseValue, bool *falseNull, + int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; ++i) { + if (!(ifCond[i] && !ifNull[i])) { + trueValue[i] = falseValue[i]; + trueNull[i] = falseNull[i]; + } + } +} + +extern "C" DLLEXPORT void IfExprString(bool *ifCond, bool *ifNull, uint8_t **trueValue, bool *trueNull, + int32_t *trueLength, uint8_t **falseValue, bool *falseNull, int32_t *falseLength, int32_t rowCnt); + +template +extern DLLEXPORT void SwitchExpr(int32_t whenCnt, int64_t *whenClauses, int64_t *whenBools, int64_t *resultValues, + int64_t *resultNulls, T *elseValue, bool *elseNull, T *finalResult, bool *finalNull, int32_t rowCnt) +{ + std::vector whenValues; + std::vector whenNulls; + std::vector resValues; + std::vector resNulls; + + for (int i = 0; i < whenCnt; ++i) { + whenValues.push_back(reinterpret_cast(reinterpret_cast(whenClauses[i]))); + whenNulls.push_back(reinterpret_cast(reinterpret_cast(whenBools[i]))); + resValues.push_back(reinterpret_cast(reinterpret_cast(resultValues[i]))); + resNulls.push_back(reinterpret_cast(reinterpret_cast(resultNulls[i]))); + } + + for (int i = 0; i < rowCnt; ++i) { + bool hasSet = false; + for (int j = 0; j < whenCnt; ++j) { + if (whenValues[j][i] && !whenNulls[j][i]) { + finalResult[i] = resValues[j][i]; + finalNull[i] = resNulls[j][i]; + hasSet = true; + break; + } + } + if (!hasSet) { + finalResult[i] = elseValue[i]; + finalNull[i] = elseNull[i]; + } + } +} + +extern "C" DLLEXPORT void SwitchExprString(int32_t whenCnt, int64_t *whenClauses, int64_t *whenBools, + int64_t *resultValues, int64_t *resultNulls, int64_t *resultLengths, uint8_t **elseValue, bool *elseNull, + int32_t *elseLength, uint8_t **finalResult, bool *finalNull, int32_t *finalLength, int32_t rowCnt); + +template +extern DLLEXPORT void InExpr(int32_t cmpCnt, int64_t *cmpValues, int64_t *cmpBools, T *toCmpValue, bool *toCmpBool, + bool *finalResult, bool *finalNull, int32_t rowCnt) +{ + std::vector cmpValuesList; + std::vector cmpNullsList; + + for (int i = 0; i < cmpCnt; ++i) { + cmpValuesList.push_back(reinterpret_cast(reinterpret_cast(cmpValues[i]))); + cmpNullsList.push_back(reinterpret_cast(reinterpret_cast(cmpBools[i]))); + } + + for (int i = 0; i < rowCnt; ++i) { + finalResult[i] = false; + finalNull[i] = false; + for (int j = 0; j < cmpCnt; ++j) { + if (!toCmpBool[i] && !cmpNullsList[j][i] && toCmpValue[i] == cmpValuesList[j][i]) { + finalResult[i] = true; + break; + } + } + } +} + +extern "C" DLLEXPORT void InExprString(int32_t cmpCnt, int64_t *cmpValues, int64_t *cmpBools, int64_t *cmpLengths, + uint8_t **toCmpValue, bool *toCmpBool, int32_t *toCmpLength, bool *finalResult, bool *finalNull, int32_t rowCnt); +} + +#endif // OMNI_RUNTIME_BATCH_UTILFUNCTIONS_H diff --git a/core/src/codegen/batch_functions/batch_varcharVectorfunctions.cpp b/core/src/codegen/batch_functions/batch_varcharVectorfunctions.cpp new file mode 100644 index 0000000..c1fe6bf --- /dev/null +++ b/core/src/codegen/batch_functions/batch_varcharVectorfunctions.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: batch varcharVector functions implementation + */ + +#include "batch_varcharVectorfunctions.h" +#include "vector/vector.h" + +using namespace omniruntime::vec; + +namespace omniruntime::codegen::function { +extern "C" DLLEXPORT int32_t BatchWrapVarcharVector(int64_t vectorAddr, uint8_t **data, int32_t *dataLen, + int32_t rowCnt) +{ + auto *varcharVectorPtr = reinterpret_cast> *>(vectorAddr); + for (int i = 0; i < rowCnt; ++i) { + if (data[i] == nullptr) { + varcharVectorPtr->SetNull(i); + } else { + std::string_view strView(reinterpret_cast(data[i]), dataLen[i]); + varcharVectorPtr->SetValue(i, strView); + } + } + return 0; +} + +extern "C" DLLEXPORT void BatchNullArrayToBits(int32_t *dstBits, bool *srcArray, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + BitUtil::SetBit(dstBits, i, srcArray[i]); + } +} + +extern "C" DLLEXPORT void BatchBitsToNullArray(bool *dstArray, int32_t *srcBits, int32_t *rowIdxArray, int32_t rowCnt) +{ + for (int i = 0; i < rowCnt; i++) { + dstArray[i] = BitUtil::IsBitSet(srcBits, rowIdxArray[i]); + } +} +} \ No newline at end of file diff --git a/core/src/codegen/batch_functions/batch_varcharVectorfunctions.h b/core/src/codegen/batch_functions/batch_varcharVectorfunctions.h new file mode 100644 index 0000000..5a6c5cd --- /dev/null +++ b/core/src/codegen/batch_functions/batch_varcharVectorfunctions.h @@ -0,0 +1,25 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: batch varcharVector functions implementation + */ + +#ifndef OMNI_RUNTIME_BATCH_VARCHARVECTORFUNCTIONS_H +#define OMNI_RUNTIME_BATCH_VARCHARVECTORFUNCTIONS_H + +#ifdef _WIN32 +#define DLLEXPORT __declspec(dllexport) +#else +#define DLLEXPORT +#endif + +#include + +namespace omniruntime::codegen::function { +extern "C" DLLEXPORT int32_t BatchWrapVarcharVector(int64_t vectorAddr, uint8_t **data, int32_t *dataLen, + int32_t rowCnt); + +extern "C" DLLEXPORT void BatchNullArrayToBits(int32_t *dstBits, bool *srcArray, int32_t rowCnt); + +extern "C" DLLEXPORT void BatchBitsToNullArray(bool *dstArray, int32_t *srcBits, int32_t *rowIdxArray, int32_t rowCnt); +} +#endif // OMNI_RUNTIME_BATCH_VARCHARVECTORFUNCTIONS_H diff --git a/core/src/codegen/batch_projection_codegen.cpp b/core/src/codegen/batch_projection_codegen.cpp new file mode 100644 index 0000000..a55377c --- /dev/null +++ b/core/src/codegen/batch_projection_codegen.cpp @@ -0,0 +1,131 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: batch projection expression codegen + */ + +#include "batch_projection_codegen.h" + +namespace omniruntime::codegen { +using namespace llvm; +using namespace orc; +using namespace omniruntime::expressions; +using namespace omniruntime::type; +using namespace omniruntime; + +namespace { +const int INPUT_TABLE_INDEX = 0; +const int NUM_ROWS_INDEX = 1; +const int OUTPUT_ADDRESS_INDEX = 2; +const int SELECTED = 3; +const int NUM_SELECTED = 4; +const int BITMAP = 5; +const int OFFSETS_INDEX = 6; +const int NEW_NULL_VALUES_INDEX = 7; +const int OUTPUT_OFFSETS_INDEX = 8; +const int EXECUTION_CONTEXT_IDX = 9; +const int DICTIONARY_VECTORS_IDX = 10; +} +intptr_t BatchProjectionCodeGen::GetFunction() +{ + llvm::Function *func = this->CreateBatchFunction(); + if (func == nullptr) { + return 0; + } + return this->CreateBatchWrapper(*func); +} + +intptr_t BatchProjectionCodeGen::CreateBatchWrapper(llvm::Function &projFunc) +{ + llvm::Function *proj = &projFunc; + + // The args indicates the type of the function parameter list. + std::vector args; + args.push_back(llvmTypes->I64PtrType()); // data address array + args.push_back(llvmTypes->I32Type()); // the num of rows + args.push_back(llvmTypes->I64Type()); // output array address + args.push_back(llvmTypes->I32PtrType()); // selected array + args.push_back(llvmTypes->I32Type()); // the num of selected rows + args.push_back(llvmTypes->I64PtrType()); // bitmap address array + args.push_back(llvmTypes->I64PtrType()); // offset address array + args.push_back(llvmTypes->I32PtrType()); // output null values array + args.push_back(llvmTypes->I32PtrType()); // output offset array + args.push_back(llvmTypes->I64Type()); // execution content address + args.push_back(llvmTypes->I64PtrType()); // dictionary address array + + FunctionType *funcSignature = FunctionType::get(llvmTypes->I32Type(), args, false); + llvm::Function *funcDecl = + llvm::Function::Create(funcSignature, llvm::Function::ExternalLinkage, "WRAPPER_FUNC", modulePtr); + BasicBlock *projectionMain = BasicBlock::Create(*context, "PROJECTION_MAIN", funcDecl); + + // set args names + Argument *input = funcDecl->getArg(INPUT_TABLE_INDEX); + input->setName("INPUT_TABLE"); + Argument *numRows = funcDecl->getArg(NUM_ROWS_INDEX); + numRows->setName("NUM_ROWS"); + Argument *outputAddress = funcDecl->getArg(OUTPUT_ADDRESS_INDEX); + outputAddress->setName("OUTPUT_ADDRESS"); + // Only use these values if filter enabled + Argument *selected = nullptr; + Argument *numSelected = nullptr; + if (filter) { + selected = funcDecl->getArg(SELECTED); + selected->setName("SELECTED_ARRAY"); + numSelected = funcDecl->getArg(NUM_SELECTED); + numSelected->setName("NUM_SELECTED"); + } + Argument *bitmap = funcDecl->getArg(BITMAP); + bitmap->setName("BITMAP"); + Argument *offsets = funcDecl->getArg(OFFSETS_INDEX); + offsets->setName("OFFSETS"); + Argument *nullValuesAddress = funcDecl->getArg(NEW_NULL_VALUES_INDEX); + nullValuesAddress->setName("NULL_VALUES_ADDRESS"); + Argument *outputOffsetsAddress = funcDecl->getArg(OUTPUT_OFFSETS_INDEX); + outputOffsetsAddress->setName("OUTPUT_OFFSETS_ADDRESS"); + Argument *executionContext = funcDecl->getArg(EXECUTION_CONTEXT_IDX); + executionContext->setName("EXECUTION_CONTEXT_ADDRESS"); + Argument *dictionaryVectors = funcDecl->getArg(DICTIONARY_VECTORS_IDX); + dictionaryVectors->setName("DICTIONARY_VECTORS"); + + builder->SetInsertPoint(projectionMain); + Type *outPtrType = llvmTypes->ToPointerType(expr->GetReturnTypeId()); + if (outPtrType == nullptr) { + return 0; + } + Value *outColPtr = builder->CreateIntToPtr(outputAddress, outPtrType); + + AllocaInst *rowIdxArray; + if (filter) { + rowIdxArray = reinterpret_cast(selected); + numRows = numSelected; + } else { + rowIdxArray = builder->CreateAlloca(llvmTypes->I32Type(), numRows, "ROW_IDX_ARRAY"); + CallExternFunction("fill_rowIdx", { OMNI_INT, OMNI_INT }, OMNI_INT, { rowIdxArray, numRows }, nullptr, + "fill_rowIdx"); + } + // generate output array for inner function + AllocaInst *outputLenPtr = builder->CreateAlloca(llvmTypes->I32Type(), numRows, "OUTPUT_LENGTH"); + auto isNullPtr = builder->CreateAlloca(llvmTypes->I1Type(), numRows, "IS_NULL"); + auto resArray = this->GetResultArray(this->expr->GetReturnTypeId(), numRows); + + std::vector projFuncArgs { input, bitmap, offsets, numRows, rowIdxArray, + outputLenPtr, executionContext, dictionaryVectors, isNullPtr, resArray }; + builder->CreateCall(proj, projFuncArgs, "INNER_FUNC"); + + std::vector funcArgs; + if (TypeUtil::IsStringType(expr->GetReturnTypeId())) { + std::vector paramTypes = { OMNI_LONG, OMNI_VARCHAR, OMNI_INT, OMNI_INT }; + funcArgs = { outColPtr, resArray, outputLenPtr, numRows }; + CallExternFunction("batch_WrapVarcharVector", paramTypes, OMNI_INT, funcArgs, nullptr, "copy_varchar_result"); + } else { + funcArgs = { outColPtr, resArray, numRows }; + CallExternFunction("batch_copy", { this->expr->GetReturnTypeId() }, this->expr->GetReturnTypeId(), funcArgs, + nullptr, "copy_result"); + } + + funcArgs = { nullValuesAddress, isNullPtr, numRows }; + CallExternFunction("batch_NullArrayToBits", { OMNI_BOOLEAN }, OMNI_BOOLEAN, funcArgs, nullptr, "copy_null"); + builder->CreateRet(numRows); + OptimizeFunctionsAndModule(); + return Compile(); +} +} \ No newline at end of file diff --git a/core/src/codegen/batch_projection_codegen.h b/core/src/codegen/batch_projection_codegen.h new file mode 100644 index 0000000..d3ecbee --- /dev/null +++ b/core/src/codegen/batch_projection_codegen.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: batch projection expression codegen + */ +#ifndef OMNI_RUNTIME_BATCH_PROJECTION_CODEGEN_H +#define OMNI_RUNTIME_BATCH_PROJECTION_CODEGEN_H + +#include + +#include "batch_expression_codegen.h" + +namespace omniruntime { +namespace codegen { +class BatchProjectionCodeGen : public BatchExpressionCodeGen { +public: + BatchProjectionCodeGen(std::string name, const omniruntime::expressions::Expr &expr, bool filter, + omniruntime::op::OverflowConfig *overflowConfig) + : BatchExpressionCodeGen(std::move(name), expr, overflowConfig), filter(filter) + {} + + ~BatchProjectionCodeGen() override = default; + + intptr_t GetFunction() override; + +private: + intptr_t CreateBatchWrapper(llvm::Function &projFunc); + + bool filter; +}; +} +} +#endif // OMNI_RUNTIME_BATCH_PROJECTION_CODEGEN_H diff --git a/core/src/codegen/bloom_filter.cpp b/core/src/codegen/bloom_filter.cpp new file mode 100644 index 0000000..b2c8b54 --- /dev/null +++ b/core/src/codegen/bloom_filter.cpp @@ -0,0 +1,145 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: BloomFilter operator source file + */ + +#include "bloom_filter.h" +#include "util/type_util.h" +#include "codegen/functions/murmur3_hash.h" + +namespace omniruntime { +namespace op { +using namespace std; +using namespace omniruntime::codegen::function; +using namespace omniruntime::type; + +BloomFilter::BloomFilter(int8_t *in, int32_t versionJava) : version(versionJava) +{ + int32_t versionIn = (reinterpret_cast(in))[0]; // version 4 Bytes + if (version != versionIn) { + throw omniruntime::exception::OmniException("ILLEGAL_INPUT", "wrong version for bloom filter"); + } + + numHashFunctions = (reinterpret_cast(in))[1]; // numHashFunctions 4 Bytes + bits = new BitArray(in + 8); // offset is 8 Bytes +} + +BloomFilter::~BloomFilter() +{ + delete bits; +} + +/* + * @Func : put long data into BloomFilter struct + * @param item : long type data + * @return : Returns true if the bit slot is reversed after insertion, Otherwise, false is returned. + */ +bool BloomFilter::PutLong(int64_t item) +{ + int32_t h1 = Mm3Int64(item, false, 0, false); + int32_t h2 = Mm3Int64(item, false, h1, false); + + uint64_t bitSize = bits->GetBitSize(); + bool bitsChanged = false; + for (int32_t i = 1; i <= numHashFunctions; i++) { + int32_t combinedHash = h1 + (i * h2); + // Flip all the bits if it's negative (guaranteed positive number) + if (combinedHash < 0) { + combinedHash = ~combinedHash; + } + bitsChanged |= bits->Set(combinedHash % bitSize); + } + return bitsChanged; +} + +/* + * @Func : Check whether a item is in the filter. + * @param item : long type data + * @return : Returns true if the item in the filter, Otherwise, false is returned. + */ +bool BloomFilter::MightContainLong(int64_t item) +{ + int32_t h1 = Mm3Int64(item, false, 0, false); + int32_t h2 = Mm3Int64(item, false, h1, false); + + uint64_t bitSize = bits->GetBitSize(); + for (int32_t i = 1; i <= numHashFunctions; i++) { + int32_t combinedHash = h1 + (i * h2); + // Flip all the bits if it's negative (guaranteed positive number) + if (combinedHash < 0) { + combinedHash = ~combinedHash; + } + if (!bits->Get(combinedHash % bitSize)) { + return false; + } + } + return true; +} + +int32_t BloomFilter::GetNumHashFunctions() +{ + return numHashFunctions; +} + +// function implements for class BloomFilterOperatorFactory +BloomFilterOperatorFactory::BloomFilterOperatorFactory(int32_t version) : version(version) {} + +BloomFilterOperatorFactory::~BloomFilterOperatorFactory() = default; + +BloomFilterOperatorFactory *BloomFilterOperatorFactory::CreateBloomFilterOperatorFactory(int32_t version) +{ + auto pOperatorFactory = new BloomFilterOperatorFactory(version); + return pOperatorFactory; +} + +Operator *BloomFilterOperatorFactory::CreateOperator() +{ + auto pSortOperator = new BloomFilterOperator(version); + return pSortOperator; +} + +// function implements for class BloomFilterOperator +BloomFilterOperator::BloomFilterOperator(int32_t version) : version(version) {} + +BloomFilterOperator::~BloomFilterOperator() +{ + VectorHelper::FreeVecBatch(inputVecBatch); + delete bloomFilterAddress; +} + +int32_t BloomFilterOperator::AddInput(VectorBatch *vecBatch) +{ + if (vecBatch == nullptr) { + throw omniruntime::exception::OmniException("ILLEGAL_INPUT", "BloomFilterOperator AddInput can't be nullptr!"); + } + + inputVecBatch = vecBatch; + int32_t vectorCount = inputVecBatch->GetVectorCount(); + if (vectorCount != 1) { + throw omniruntime::exception::OmniException("ILLEGAL_INPUT", "vecBatch col should be 1 for bloom filter"); + } + BaseVector *colVec = inputVecBatch->Get(0); + auto valuesAddress = reinterpret_cast(VectorHelper::UnsafeGetValues(colVec)); + + // init BloomFilter + bloomFilterAddress = new BloomFilter(reinterpret_cast((uintptr_t)valuesAddress), version); + return 0; +} + +int32_t BloomFilterOperator::GetOutput(VectorBatch **blOutPut) +{ + auto outPut = new VectorBatch(1); + auto *col = new Vector(1); + col->SetValue(0, reinterpret_cast(bloomFilterAddress)); + outPut->Append(col); + *blOutPut = outPut; + SetStatus(OMNI_STATUS_FINISHED); + return 0; +} + +OmniStatus BloomFilterOperator::Close() +{ + return OMNI_STATUS_NORMAL; +} +} // end of op +} // end of omniruntime \ No newline at end of file diff --git a/core/src/codegen/bloom_filter.h b/core/src/codegen/bloom_filter.h new file mode 100644 index 0000000..fafca0b --- /dev/null +++ b/core/src/codegen/bloom_filter.h @@ -0,0 +1,73 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: BloomFilter operator header + */ +#ifndef OMNI_RUNTIME_BLOOM_FILTER_H +#define OMNI_RUNTIME_BLOOM_FILTER_H + +#include "util/bit_array.h" +#include "operator/operator_factory.h" + +namespace omniruntime { +namespace op { +class BloomFilter { +public: + explicit BloomFilter(int8_t *in, int32_t versionJava); + + ~BloomFilter(); + + bool PutLong(int64_t item); + + bool MightContainLong(int64_t item); + + int32_t GetNumHashFunctions(); + + BitArray *GetBits() + { + return bits; + } + +private: + int32_t numHashFunctions; + BitArray *bits; + int32_t version; +}; + +class BloomFilterOperatorFactory : public OperatorFactory { +public: + explicit BloomFilterOperatorFactory(int32_t version); + + ~BloomFilterOperatorFactory() override; + + static BloomFilterOperatorFactory *CreateBloomFilterOperatorFactory(int32_t version); + + Operator *CreateOperator() override; + +private: + int32_t version; +}; + +/* + * BloomFilterOperator is only used by spark runtimeFilter Feature, It is used independently and cannot cooperate with + * other operators. AddInput must be IntVector GetOutput must be LongVector + */ +class BloomFilterOperator : public Operator { +public: + explicit BloomFilterOperator(int32_t version); + + ~BloomFilterOperator() override; + + int32_t AddInput(omniruntime::vec::VectorBatch *vecBatch) override; + + int32_t GetOutput(omniruntime::vec::VectorBatch **blOutPut) override; + + OmniStatus Close() override; + +private: + int32_t version; + BloomFilter *bloomFilterAddress; + VectorBatch *inputVecBatch; +}; +} // end of op +} // end of omniruntime +#endif // OMNI_RUNTIME_BLOOM_FILTER_H \ No newline at end of file diff --git a/core/src/codegen/codegen_base.cpp b/core/src/codegen/codegen_base.cpp new file mode 100644 index 0000000..58f2b80 --- /dev/null +++ b/core/src/codegen/codegen_base.cpp @@ -0,0 +1,101 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + * Description: Base codegen generator + */ + +#include "codegen_base.h" + +namespace omniruntime::codegen { +CodegenBase::CodegenBase(std::string name, const omniruntime::expressions::Expr &cpExpr, + omniruntime::op::OverflowConfig *overflowConfig) + : funcName(std::move(name)), expr(&cpExpr), overflowConfig(overflowConfig) +{} + +/** + * Usage example: std::vector values; + * values.push_back(value1); + * values.push_back(value2); + * PrintValues("LLVM DEBUG: %d, %d\n", values); + */ +void CodegenBase::PrintValues(std::string format, const std::vector &values) +{ + // Return a cast to an i8* + auto formatPtr = CreateConstantString(std::move(format)); + std::vector args; + args.push_back(formatPtr); + for (auto v : values) { + args.push_back(v); + } + + builder->CreateCall(codegenContext->print, args, "printfCall"); +} + +std::string CodegenBase::DumpCode() +{ + std::string ir; + llvm::raw_string_ostream stream(ir); + modulePtr->print(stream, nullptr); + std::cout << " Generated code::" << ir; + return ir; +} + +Value *CodegenBase::GetPtrTypeFromInt(omniruntime::type::DataTypeId dataTypeId, Value *elementAddr) +{ + Value *elementPtr = nullptr; + // Convert the column address to array of proper datatype. + switch (dataTypeId) { + case OMNI_BOOLEAN: + elementPtr = builder->CreateIntToPtr(elementAddr, llvmTypes->I1PtrType()); + break; + case OMNI_SHORT: + elementPtr = builder->CreateIntToPtr(elementAddr, llvmTypes->I16PtrType()); + break; + case OMNI_INT: + case OMNI_DATE32: + elementPtr = builder->CreateIntToPtr(elementAddr, llvmTypes->I32PtrType()); + break; + case OMNI_LONG: + case OMNI_TIMESTAMP: + case OMNI_DECIMAL64: + elementPtr = builder->CreateIntToPtr(elementAddr, llvmTypes->I64PtrType()); + break; + case OMNI_DOUBLE: + elementPtr = builder->CreateIntToPtr(elementAddr, llvmTypes->DoublePtrType()); + break; + case OMNI_CHAR: + case OMNI_VARCHAR: + elementPtr = builder->CreateIntToPtr(elementAddr, llvmTypes->I8PtrType()); + break; + case OMNI_DECIMAL128: + elementPtr = builder->CreateIntToPtr(elementAddr, llvmTypes->I128PtrType()); + break; + default: + LLVM_DEBUG_LOG("Unsupported column data type %d", dataTypeId); + elementPtr = builder->CreateIntToPtr(elementAddr, llvmTypes->I64PtrType()); + break; + } + return elementPtr; +} + +llvm::Constant *CodegenBase::CreateConstantString(std::string s) +{ + auto charType = Type::getInt8Ty(*context); + std::vector chars(s.size()); + for (unsigned int i = 0; i < s.size(); i++) { + chars[i] = ConstantInt::get(charType, s[i]); + } + chars.push_back(llvm::ConstantInt::get(charType, 0)); + auto stringType = llvm::ArrayType::get(charType, chars.size()); + + this->numGlobalValues++; + auto globalDeclaration = static_cast( + modulePtr->getOrInsertGlobal("string" + std::to_string(this->numGlobalValues), stringType)); + globalDeclaration->setInitializer(llvm::ConstantArray::get(stringType, chars)); + globalDeclaration->setConstant(true); + globalDeclaration->setLinkage(llvm::GlobalValue::LinkageTypes::PrivateLinkage); + globalDeclaration->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global); + + auto stringPtr = llvm::ConstantExpr::getBitCast(globalDeclaration, charType->getPointerTo()); + return stringPtr; +} +} diff --git a/core/src/codegen/codegen_base.h b/core/src/codegen/codegen_base.h new file mode 100644 index 0000000..ca996c4 --- /dev/null +++ b/core/src/codegen/codegen_base.h @@ -0,0 +1,57 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + * Description: Base codegen generator + */ +#ifndef OMNI_RUNTIME_CODEGEN_BASE_H +#define OMNI_RUNTIME_CODEGEN_BASE_H + +#include +#include +#include +#include +#include +#include +#include + +#include "codegen/llvm_engine.h" +#include "codegen_context.h" +#include "batch_codegen_context.h" +#include "type/decimal128_utils.h" +#include "util/type_util.h" + +namespace omniruntime::codegen { +using CodeGenValuePtr = std::shared_ptr; + +// The base class for the code generator +class CodegenBase : public LLVMEngine { +public: + CodegenBase(std::string name, const omniruntime::expressions::Expr &cpExpr, + omniruntime::op::OverflowConfig *overflowConfig); + + Value *GetPtrTypeFromInt(omniruntime::type::DataTypeId dataTypeId, Value *elementAddr); + +protected: + void PrintValues(std::string format, const std::vector &values); + + std::string DumpCode(); + + llvm::Constant *CreateConstantString(std::string s); + + static CodeGenValuePtr CreateInvalidCodeGenValue() + { + return std::make_shared(nullptr, nullptr, nullptr); + } + + std::string funcName; + const omniruntime::expressions::Expr *expr; + llvm::Function *func = nullptr; + + int numGlobalValues = 0; + std::unique_ptr codegenContext; + std::unique_ptr batchCodegenContext; + CodeGenValuePtr value = nullptr; + omniruntime::op::OverflowConfig *overflowConfig; +}; +} + +#endif diff --git a/core/src/codegen/codegen_context.h b/core/src/codegen/codegen_context.h new file mode 100644 index 0000000..339ea1b --- /dev/null +++ b/core/src/codegen/codegen_context.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description: Expression code generator + */ +#ifndef OMNI_RUNTIME_CODEGEN_CONTEXT_H +#define OMNI_RUNTIME_CODEGEN_CONTEXT_H + +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Value.h" + +namespace omniruntime::codegen { +class CodegenContext { +public: + explicit CodegenContext() + : data(nullptr), + nullBitmap(nullptr), + offsets(nullptr), + rowIdx(nullptr), + executionContext(nullptr), + dictionaryVectors(nullptr), + print(nullptr) + {} + + explicit CodegenContext(llvm::Value *data, llvm::Value *nullBitmap, llvm::Value *offsets, llvm::Value *rowIdx, + llvm::Value *executionContext, llvm::Value *dictionaryVectors) + : data(data), + nullBitmap(nullBitmap), + offsets(offsets), + rowIdx(rowIdx), + executionContext(executionContext), + dictionaryVectors(dictionaryVectors), + print(nullptr) + {} + + ~CodegenContext() = default; + + friend class ExpressionCodeGen; + + friend class SimpleFilterCodeGen; + + friend class CodegenBase; + +private: + llvm::Value *data; + llvm::Value *nullBitmap; + llvm::Value *offsets; + llvm::Value *rowIdx; + llvm::Value *executionContext; + llvm::Value *dictionaryVectors; + llvm::FunctionCallee print; +}; +} + +#endif // OMNI_RUNTIME_CODEGEN_CONTEXT_H diff --git a/core/src/codegen/codegen_value.h b/core/src/codegen/codegen_value.h new file mode 100644 index 0000000..1a1c6d7 --- /dev/null +++ b/core/src/codegen/codegen_value.h @@ -0,0 +1,82 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description: Value object contains reference to data, isNull and varchar length + */ +#ifndef OMNI_RUNTIME_CODEGEN_VALUE_H +#define OMNI_RUNTIME_CODEGEN_VALUE_H + +#include + +namespace omniruntime::codegen { +class CodeGenValue { +public: + explicit CodeGenValue(llvm::Value *data, llvm::Value *isNull, llvm::Value *length = nullptr) + : data(data), isNull(isNull), length(length) + {} + + virtual ~CodeGenValue() = default; + + bool IsValidValue() + { + return this->data != nullptr; + } + + friend class ExpressionCodeGen; + + friend class SimpleFilterCodeGen; + + friend class BatchExpressionCodeGen; + + friend class CodegenBase; + +private: + llvm::Value *data; + llvm::Value *isNull; + llvm::Value *length; +}; + +class DecimalValue : public CodeGenValue { +public: + explicit DecimalValue(llvm::Value *data, llvm::Value *isNull, llvm::Value *precision, llvm::Value *scale) + : CodeGenValue(data, isNull), precision(precision), scale(scale) + {} + + virtual ~DecimalValue() = default; + + llvm::Value *GetPrecision() const + { + return precision; + } + + llvm::Value *GetScale() const + { + return scale; + } + +private: + llvm::Value *precision; + llvm::Value *scale; +}; + +class DecimalSplitValue : public DecimalValue { +public: + explicit DecimalSplitValue(llvm::Value *high, llvm::Value *low, llvm::Value *isNull = nullptr, + llvm::Value *precision = nullptr, llvm::Value *scale = nullptr) + : DecimalValue(nullptr, isNull, precision, scale), high(high), low(low) + {} + virtual ~DecimalSplitValue() = default; + const llvm::Value *GetHigh() + { + return high; + } + const llvm::Value *GetLow() + { + return low; + } + +private: + llvm::Value *high; + llvm::Value *low; +}; +} +#endif // OMNI_RUNTIME_CODEGEN_VALUE_H diff --git a/core/src/codegen/common_util.h b/core/src/codegen/common_util.h new file mode 100644 index 0000000..89d33b6 --- /dev/null +++ b/core/src/codegen/common_util.h @@ -0,0 +1,153 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * Description: timezone util + */ + +#ifndef OMNI_RUNTIME_COMMON_UTIL_H +#define OMNI_RUNTIME_COMMON_UTIL_H + +#include + +namespace omniruntime::codegen::function { +static const int32_t ROUND_LONG_MIN_DECIMALS = -19; +static const int64_t ROUND_OVER_MAX_LONG = -8446744073709551616L; +static const int64_t ROUND_OVER_MIN_LONG = 8446744073709551616L; +static const int64_t ROUND_OVER_MEDIAN_LONG = 5000000000000000000L; + +static int64_t LONG_POWERS_TABLE[] = { + 1L, // 0 / 10^0 + 10L, // 1 / 10^1 + 100L, // 2 / 10^2 + 1000L, // 3 / 10^3 + 10000L, // 4 / 10^4 + 100000L, // 5 / 10^5 + 1000000L, // 6 / 10^6 + 10000000L, // 7 / 10^7 + 100000000L, // 8 / 10^8 + 1000000000L, // 9 / 10^9 + 10000000000L, // 10 / 10^10 + 100000000000L, // 11 / 10^11 + 1000000000000L, // 12 / 10^12 + 10000000000000L, // 13 / 10^13 + 100000000000000L, // 14 / 10^14 + 1000000000000000L, // 15 / 10^15 + 10000000000000000L, // 16 / 10^16 + 100000000000000000L, // 17 / 10^17 + 1000000000000000000L // 18 / 10^18 +}; + +static inline int64_t RoundOperator(int64_t num, int32_t decimals) +{ + if (decimals >= 0) { + return num; + } + if (decimals < ROUND_LONG_MIN_DECIMALS) { + return 0; + } + if (decimals == ROUND_LONG_MIN_DECIMALS) { + if (num < -ROUND_OVER_MEDIAN_LONG) { + return ROUND_OVER_MIN_LONG; + } + if (num > ROUND_OVER_MEDIAN_LONG) { + return ROUND_OVER_MAX_LONG; + } + return 0; + } + int64_t power = LONG_POWERS_TABLE[-decimals]; + int64_t base = (num / power) * power; + int64_t remain = num % power; + int64_t half = 2; + if (std::abs(remain) >= power / half) { + return base + (num < 0 ? -power : power); + } + return base; +} + +static const int32_t ROUND_INT16_MIN_DECIMALS = -4; +static const int16_t ROUND_INT16_MAX = INT16_MAX; +static const int16_t ROUND_INT16_MIN = INT16_MIN; + +static int64_t INT16_POWERS_TABLE[] = { + 1L, // 0 / 10^0 + 10L, // 1 / 10^1 + 100L, // 2 / 10^2 + 1000L, // 3 / 10^3 + 10000L // 4 / 10^4 +}; + +static inline int16_t RoundOperatorInt16(int16_t num, int32_t decimals) +{ + if (decimals >= 0) { + return num; + } + if (decimals < ROUND_INT16_MIN_DECIMALS) { + return 0; + } + int64_t power = INT16_POWERS_TABLE[-decimals]; + int64_t base = (num / power) * power; + int64_t remain = num % power; + int64_t half = 2; + if (std::abs(remain) >= power / half) { + int64_t result = base + (num < 0 ? -power : power); + if (result > ROUND_INT16_MAX) { + return ROUND_INT16_MAX; + } else if (result < ROUND_INT16_MIN) { + return ROUND_INT16_MIN; + } else { + return static_cast(result); + } + } + int64_t result = base; + if (result > ROUND_INT16_MAX) { + return ROUND_INT16_MAX; + } else if (result < ROUND_INT16_MIN) { + return ROUND_INT16_MIN; + } else { + return static_cast(result); + } +} + +static const int32_t ROUND_INT8_MIN_DECIMALS = -2; +static const int8_t ROUND_INT8_MAX = INT8_MAX; +static const int8_t ROUND_INT8_MIN = INT8_MIN; + +static int64_t INT8_POWERS_TABLE[] = { + 1L, // 0 / 10^0 + 10L, // 1 / 10^1 + 100L // 2 / 10^2 +}; + +static inline int8_t RoundOperatorInt8(int8_t num, int32_t decimals) +{ + if (decimals >= 0) { + return num; + } + if (decimals < ROUND_INT8_MIN_DECIMALS) { + return 0; + } + int64_t power = INT8_POWERS_TABLE[-decimals]; + int64_t base = (num / power) * power; + int64_t remain = num % power; + int64_t half = 2; + if (std::abs(remain) >= power / half) { + int64_t result = base + (num < 0 ? -power : power); + if (result > ROUND_INT8_MAX) { + return ROUND_INT8_MAX; + } else if (result < ROUND_INT8_MIN) { + return ROUND_INT8_MIN; + } else { + return static_cast(result); + } + } + int64_t result = base; + if (result > ROUND_INT8_MAX) { + return ROUND_INT8_MAX; + } else if (result < ROUND_INT8_MIN) { + return ROUND_INT8_MIN; + } else { + return static_cast(result); + } +} +} + +#endif // OMNI_RUNTIME_COMMON_UTIL_H diff --git a/core/src/codegen/context_helper.cpp b/core/src/codegen/context_helper.cpp new file mode 100644 index 0000000..81514fa --- /dev/null +++ b/core/src/codegen/context_helper.cpp @@ -0,0 +1,87 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description: registry function implementation + */ +#include "context_helper.h" + +using namespace omniruntime::op; +using namespace omniruntime::type; + +namespace omniruntime::codegen { +#ifdef _WIN32 +#define DLLEXPORT __declspec(dllexport) +#else +#define DLLEXPORT +#endif + +extern "C" DLLEXPORT +{ + char *ArenaAllocatorMalloc(int64_t contextPtr, int32_t size) + { + auto context = reinterpret_cast(contextPtr); + return reinterpret_cast(context->GetArena()->Allocate(size)); + } + + bool ArenaAllocatorReset(int64_t contextPtr) + { + auto context = reinterpret_cast(contextPtr); + context->GetArena()->Reset(); + return true; + } + + bool SetError(int64_t contextPtr, std::string errorMessage) + { + auto context = reinterpret_cast(contextPtr); + if (!context->HasError()) { + context->SetError(errorMessage); + } + return true; + } + + bool HasError(int64_t contextPtr) + { + auto context = reinterpret_cast(contextPtr); + return context->HasError(); + } + + std::string GetDataString(DataTypeId type, int count, ...) + { + va_list v; + va_start(v, count); + std::ostringstream errorMessage; + switch (type) { + case OMNI_CHAR: + case OMNI_VARCHAR: + errorMessage << "VARCHAR"; + break; + case OMNI_BYTE: + errorMessage << "TINYINT"; + break; + case OMNI_SHORT: + errorMessage << "SMALLINT"; + break; + case OMNI_INT: + errorMessage << "INTEGER"; + break; + case OMNI_LONG: + errorMessage << "BIGINT"; + break; + case OMNI_TIMESTAMP: + errorMessage << "TIMESTAMP"; + break; + case OMNI_DOUBLE: + errorMessage << "DOUBLE"; + break; + case OMNI_DECIMAL64: + case OMNI_DECIMAL128: + errorMessage << "DECIMAL(" << va_arg(v, int32_t) << ", " << va_arg(v, int32_t) << ")"; + break; + default: + errorMessage << "No Support data type"; + break; + } + va_end(v); + return errorMessage.str(); + } +} +} \ No newline at end of file diff --git a/core/src/codegen/context_helper.h b/core/src/codegen/context_helper.h new file mode 100644 index 0000000..591aa9d --- /dev/null +++ b/core/src/codegen/context_helper.h @@ -0,0 +1,139 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + * Description: registry function implementation + */ +#ifndef OMNI_RUNTIME_CONTEXT_HELPER_H +#define OMNI_RUNTIME_CONTEXT_HELPER_H + +#include +#include "operator/execution_context.h" +#include "type/data_type.h" +#include "type/decimal_operations.h" +#include "util/config_util.h" +#include "type/data_operations.h" + +#ifdef _WIN32 +#define DLLEXPORT __declspec(dllexport) +#else +#define DLLEXPORT +#endif + +namespace omniruntime::codegen { +#define CHECK_OVERFLOW_RETURN(DECIMAL, PRECISION) \ + do { \ + if ((DECIMAL).IsOverflow((PRECISION)) != OpStatus::SUCCESS) { \ + SetError(contextPtr, DECIMAL_OVERFLOW); \ + return 1; \ + } \ + } while (false) + +#define CHECK_OVERFLOW(DECIMAL, PRECISION) \ + do { \ + if ((DECIMAL).IsOverflow((PRECISION)) != OpStatus::SUCCESS) { \ + SetError(contextPtr, DECIMAL_OVERFLOW); \ + return; \ + } \ + } while (false) + +#define CHECK_DIVIDE_BY_ZERO_RETURN(dividend) \ + do { \ + if ((dividend) == 0) { \ + SetError(contextPtr, DIVIDE_ZERO); \ + return 0; \ + } \ + } while (false) + +#define CHECK_DIVIDE_BY_ZERO(dividend) \ + do { \ + if ((dividend) == 0) { \ + SetError(contextPtr, DIVIDE_ZERO); \ + return; \ + } \ + } while (false) + +#define CHECK_OVERFLOW_RETURN_NULL(DECIMAL, PRECISION) \ + do { \ + if ((DECIMAL).IsOverflow((PRECISION)) != OpStatus::SUCCESS) { \ + *isNull = true; \ + return 0; \ + } \ + } while (false) + +#define CHECK_OVERFLOW_VOID_RETURN_NULL(DECIMAL, PRECISION) \ + do { \ + if ((DECIMAL).IsOverflow((PRECISION)) != OpStatus::SUCCESS) { \ + *isNull = true; \ + return; \ + } \ + } while (false) + +#define CHECK_OVERFLOW_CONTINUE_NULL(DECIMAL, PRECISION) \ + do { \ + if ((DECIMAL).IsOverflow((PRECISION)) != OpStatus::SUCCESS) { \ + isNull[i] = true; \ + continue; \ + } \ + } while (false) + +#define CHECK_OVERFLOW_CONTINUE(DECIMAL, PRECISION) \ + do { \ + if ((DECIMAL).IsOverflow((PRECISION)) != OpStatus::SUCCESS && !HasError(contextPtr)) { \ + SetError(contextPtr, DECIMAL_OVERFLOW); \ + continue; \ + } \ + } while (false) + +#define CHECK_DIVIDE_BY_ZERO_CONTINUE(dividend) \ + do { \ + if ((dividend) == 0) { \ + SetError(contextPtr, DIVIDE_ZERO); \ + continue; \ + } \ + } while (false) + +extern "C" DLLEXPORT +{ + char *ArenaAllocatorMalloc(int64_t contextPtr, int32_t size); + bool ArenaAllocatorReset(int64_t contextPtr); + bool SetError(int64_t contextPtr, std::string errorMessage); + bool HasError(int64_t contextPtr); + std::string GetDataString(type::DataTypeId type, int count, ...); +} + +template +std::string CastErrorMessage(type::DataTypeId from, type::DataTypeId to, T value, type::OpStatus reason, ...) +{ + va_list v; + va_start(v, reason); + std::ostringstream errorMessage; + if (from == type::OMNI_DECIMAL128 || from == type::OMNI_DECIMAL64) { + int32_t precision = va_arg(v, int32_t); + int32_t scale = va_arg(v, int32_t); + errorMessage << "Cannot cast " << GetDataString(from, 2, precision, scale) << " '"; + errorMessage << type::Decimal128Wrapper(value).SetScale(scale).ToString(); + } else { + errorMessage << "Cannot cast " << GetDataString(from, 1) << " '"; + if constexpr (std::is_same_v) { + errorMessage << omniruntime::type::ToString(value); + } else { + errorMessage << value; + } + } + if (to == type::OMNI_DECIMAL128 || to == type::OMNI_DECIMAL64) { + int32_t precision = va_arg(v, int32_t); + int32_t scale = va_arg(v, int32_t); + errorMessage << "' to " << GetDataString(to, 2, precision, scale); + } else { + errorMessage << "' to " << GetDataString(to, 1); + } + if (reason == type::OpStatus::OP_OVERFLOW) { + errorMessage << ". Value too large."; + } + if (reason == type::OpStatus::FAIL) { + errorMessage << ". Value is not a number."; + } + va_end(v); + return errorMessage.str(); +} +} +#endif \ No newline at end of file diff --git a/core/src/codegen/expr_evaluator.cpp b/core/src/codegen/expr_evaluator.cpp new file mode 100644 index 0000000..f898321 --- /dev/null +++ b/core/src/codegen/expr_evaluator.cpp @@ -0,0 +1,579 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + * Description: Expression evaluator + */ +#include "expr_evaluator.h" + +namespace omniruntime::codegen { +int64_t GetRawAddr(const DataTypes &types, int32_t i, BaseVector *colVec) +{ + switch (types.GetIds()[i]) { + case OMNI_INT: + case OMNI_DATE32: + return reinterpret_cast( + unsafe::UnsafeVector::GetRawValues(reinterpret_cast *>(colVec))); + case OMNI_SHORT: + return reinterpret_cast( + unsafe::UnsafeVector::GetRawValues(reinterpret_cast *>(colVec))); + case OMNI_BYTE: + return reinterpret_cast( + unsafe::UnsafeVector::GetRawValues(reinterpret_cast *>(colVec))); + case OMNI_LONG: + case OMNI_TIMESTAMP: + case OMNI_DECIMAL64: + return reinterpret_cast( + unsafe::UnsafeVector::GetRawValues(reinterpret_cast *>(colVec))); + case OMNI_DOUBLE: + return reinterpret_cast( + unsafe::UnsafeVector::GetRawValues(reinterpret_cast *>(colVec))); + case OMNI_BOOLEAN: + return reinterpret_cast( + unsafe::UnsafeVector::GetRawValues(reinterpret_cast *>(colVec))); + case OMNI_DECIMAL128: + return reinterpret_cast( + unsafe::UnsafeVector::GetRawValues(reinterpret_cast *>(colVec))); + case OMNI_VARCHAR: + case OMNI_CHAR: + return reinterpret_cast(unsafe::UnsafeStringVector::GetValues( + reinterpret_cast> *>(colVec))); + default: + LogError("Do not support such vector type %d", types.GetIds()[i]); + return 0; + } +} + +void GetAddr(VectorBatch &vecBatch, intptr_t valueAddrs[], intptr_t nullAddrs[], intptr_t offsetAddrs[], + intptr_t dictionaries[], const DataTypes &types) +{ + intptr_t valuesAddress; + intptr_t dictVecAddress; + int32_t vectorCount = vecBatch.GetVectorCount(); + for (int32_t i = 0; i < vectorCount; i++) { + auto colVec = vecBatch.Get(i); + dictVecAddress = 0; + valuesAddress = 0; + if (colVec->GetEncoding() == OMNI_DICTIONARY) { + dictVecAddress = reinterpret_cast(reinterpret_cast(colVec)); + } else { + valuesAddress = GetRawAddr(types, i, colVec); + } + + // data handling + dictionaries[i] = dictVecAddress; + valueAddrs[i] = valuesAddress; + + // nulls handling + nullAddrs[i] = reinterpret_cast(unsafe::UnsafeBaseVector::GetNulls(colVec)); + + // offsets handling + offsetAddrs[i] = reinterpret_cast(VectorHelper::UnsafeGetOffsetsAddr(colVec)); + } +} + +Filter::Filter(const Expr &expression, const DataTypes &inputDataTypes, OverflowConfig *overflowConfig) +{ +#ifdef DEBUG + std::cout << "String expression in Filter: " << std::endl; + ExprPrinter printExprTree; + expression.Accept(printExprTree); + std::cout << std::endl; +#endif + intptr_t f; + if (!ConfigUtil::IsEnableBatchExprEvaluate()) { + this->codeGen = std::make_unique("filterFunc", expression, overflowConfig); + f = this->codeGen->GetFunction(inputDataTypes); + } else { + this->batchCodeGen = std::make_unique("filterFunc", expression, overflowConfig); + f = this->batchCodeGen->GetFunction(); + } + if (f == 0) { + this->isSupported = false; + this->apply = nullptr; + } else { + this->isSupported = true; + void *function = &f; + this->apply = *static_cast(function); + } +} + +bool Projection::SetLiteralValue(const LiteralExpr *literalExpr) +{ + if (literalExpr->isNull) { + literalVal.isNull = true; + return true; + } + switch (outType->GetId()) { + case OMNI_INT: + case OMNI_DATE32: { + literalVal.value.intVal = literalExpr->intVal; + break; + } + case OMNI_SHORT: { + literalVal.value.shortVal = literalExpr->shortVal; + break; + } + case OMNI_BYTE: { + literalVal.value.byteVal = literalExpr->byteVal; + break; + } + case OMNI_LONG: + case OMNI_DECIMAL64: + case OMNI_TIMESTAMP: { + literalVal.value.longVal = literalExpr->longVal; + break; + } + case OMNI_DOUBLE: { + literalVal.value.doubleVal = literalExpr->doubleVal; + break; + } + case OMNI_BOOLEAN: { + literalVal.value.boolVal = literalExpr->boolVal; + break; + } + case OMNI_DECIMAL128: { + std::string dec128String = *literalExpr->stringVal; + __uint128_t dec128 = Decimal128Utils::StrToUint128_t(dec128String.c_str()); + literalVal.value.decimal128Val.SetValue(static_cast(dec128)); + break; + } + case OMNI_VARCHAR: + case OMNI_CHAR: { + literalVal.value.stringVal = std::string_view(*(literalExpr->stringVal)); + break; + } + default: + LogError("Do not support such vector type %d", outType->GetId()); + return false; + } + return true; +} + +bool Projection::Initialize(bool filter, const DataTypes &inputDataTypes, OverflowConfig *overflowConfig) +{ + // short-circuit logic for column projections + // no need to go through codegen + if (expr->GetType() == ExprType::FIELD_E) { + this->isColumnProjection = true; + this->columnProjectionIndex = static_cast(expr)->colVal; + return true; + } + + // short-circuit logic for literal expression + if (expr->GetType() == ExprType::LITERAL_E) { + this->isConstantProjection = true; + return SetLiteralValue(static_cast(expr)); + } + + intptr_t f; + if (!ConfigUtil::IsEnableBatchExprEvaluate()) { + this->codeGen = std::make_unique("proj_func", *(this->expr), filter, overflowConfig); + f = this->codeGen->GetFunction(inputDataTypes); + } else { + this->batchCodeGen = + std::make_unique("proj_func", *(this->expr), filter, overflowConfig); + f = this->batchCodeGen->GetFunction(); + } + + if (f == 0) { + return false; + } + + void *function = &f; + auto cfunction = static_cast(function); + this->projector = *cfunction; + return true; +} + +Projection::Projection(const Expr &expr, bool filter, DataTypePtr outType, const DataTypes &inputDataTypes, + OverflowConfig *overflowConfig) + : expr(&expr), outType(std::move(outType)), projector(nullptr) +{ +#ifdef DEBUG + std::cout << "Expression in projection:" << std::endl; + ExprPrinter printExprTree; + expr.Accept(printExprTree); + std::cout << std::endl; +#endif + bool initialized = this->Initialize(filter, inputDataTypes, overflowConfig); + if (!initialized) { + this->isSupported = false; + } +} + +/* for supporting cast(null as string) or NULL as col_name */ +bool Projection::NullColumnProjection(ExecutionContext *context, BaseVector *outVec) +{ + auto outNulls = unsafe::UnsafeBaseVector::GetNulls(outVec); + auto outNullsSize = BitUtil::Nbytes(outVec->GetSize()); + auto result = memset_s(outNulls, outNullsSize, -1, outNullsSize); + if (result != EOK) { + std::string errorMessage = "Memset failed, ret " + std::to_string(result) + " destMax " + + std::to_string(outNullsSize) + " count " + std::to_string(outNullsSize); + context->SetError(errorMessage); + return false; + } + return true; +} + +template void Projection::SetConstantValues(T &value, BaseVector *outVec) +{ + auto rowCount = outVec->GetSize(); + if constexpr (std::is_same_v) { + auto outputVector = static_cast> *>(outVec); + for (int32_t i = 0; i < rowCount; i++) { + outputVector->SetValue(i, value); + } + } else { + auto outputVector = static_cast *>(outVec); + for (int32_t i = 0; i < rowCount; i++) { + outputVector->SetValue(i, value); + } + } +} + +bool Projection::ConstantColumnProjection(ExecutionContext *context, BaseVector *outVec) +{ + if (literalVal.isNull) { + return NullColumnProjection(context, outVec); + } + auto outputTypeId = this->outType->GetId(); + switch (outputTypeId) { + case OMNI_INT: + case OMNI_DATE32: + SetConstantValues(literalVal.value.intVal, outVec); + break; + case OMNI_SHORT: + SetConstantValues(literalVal.value.shortVal, outVec); + break; + case OMNI_BYTE: + SetConstantValues(literalVal.value.byteVal, outVec); + break; + case OMNI_LONG: + case OMNI_DECIMAL64: + case OMNI_TIMESTAMP: + SetConstantValues(literalVal.value.longVal, outVec); + break; + case OMNI_DOUBLE: + SetConstantValues(literalVal.value.doubleVal, outVec); + break; + case OMNI_BOOLEAN: + SetConstantValues(literalVal.value.boolVal, outVec); + break; + case OMNI_DECIMAL128: + SetConstantValues(literalVal.value.decimal128Val, outVec); + break; + case OMNI_VARCHAR: + case OMNI_CHAR: + SetConstantValues(literalVal.value.stringVal, outVec); + break; + default: + std::string errorMessage = "Do not support such vector type " + std::to_string(outputTypeId); + context->SetError(errorMessage); + return false; + } + return true; +} + +BaseVector *Projection::Project(VectorBatch *vecBatch, int32_t selectedRows[], int32_t numSelectedRows, + int64_t *valueAddrs, int64_t *nullAddrs, int64_t *offsetAddrs, ExecutionContext *context, + int64_t *dictionaryVectors, const int32_t *typeIds) +{ + if (this->isColumnProjection) { + return ColumnProjectionProxy(vecBatch, selectedRows, numSelectedRows, typeIds); + } else if (this->isConstantProjection) { + BaseVector *outVec = VectorHelper::CreateFlatVector(outType->GetId(), numSelectedRows); + if (!ConstantColumnProjection(context, outVec)) { + delete outVec; + context->GetArena()->Reset(); + return nullptr; + } else { + context->GetArena()->Reset(); + return outVec; + } + } else { + DataTypeId outTypeId = outType->GetId(); + BaseVector *outVec = VectorHelper::CreateFlatVector(outTypeId, numSelectedRows); + if (outTypeId == OMNI_VARCHAR || outTypeId == OMNI_CHAR) { + ProjectHelperVarWidth(*vecBatch, valueAddrs, nullAddrs, offsetAddrs, outVec, numSelectedRows, selectedRows, + context, dictionaryVectors, outTypeId); + } else { + ProjectHelperFixedWidth(*vecBatch, valueAddrs, nullAddrs, offsetAddrs, outVec, numSelectedRows, + selectedRows, context, dictionaryVectors, outTypeId); + } + context->GetArena()->Reset(); + return outVec; + } +} + +template +BaseVector *Projection::ColumnProjectionFlatVectorCopyPositionsHelper(const int32_t *selectedRows, + int32_t numSelectedRows, BaseVector *colVec) const +{ + return reinterpret_cast *>(colVec)->CopyPositions(selectedRows, 0, numSelectedRows); +} + +template +BaseVector *Projection::ColumnProjectionDictionaryVectorCopyPositionsHelper(const int32_t *selectedRows, + int32_t numSelectedRows, BaseVector *colVec) const +{ + return reinterpret_cast> *>(colVec)->CopyPositions(selectedRows, 0, numSelectedRows); +} + +template +BaseVector *Projection::ColumnProjectionFlatVectorSliceHelper(int32_t numSelectedRows, BaseVector *colVec) const +{ + return reinterpret_cast *>(colVec)->Slice(0, numSelectedRows); +} + +template +BaseVector *Projection::ColumnProjectionDictionaryVectorSliceHelper(int32_t numSelectedRows, BaseVector *colVec) const +{ + return reinterpret_cast> *>(colVec)->Slice(0, numSelectedRows); +} + +void Projection::ProjectHelperVarWidth(VectorBatch &vecBatch, int64_t *valueAddrs, int64_t *nullAddrs, + int64_t *offsetAddrs, BaseVector *outVec, int32_t numSelectedRows, int32_t selectedRows[], + ExecutionContext *context, int64_t *dictionaryVectors, DataTypeId &outTypeId) const +{ + this->projector(valueAddrs, vecBatch.GetRowCount(), reinterpret_cast(outVec), selectedRows, + numSelectedRows, nullAddrs, offsetAddrs, + reinterpret_cast(unsafe::UnsafeBaseVector::GetNulls(outVec)), nullptr, + reinterpret_cast(context), dictionaryVectors); +} + +void Projection::ProjectHelperFixedWidth(VectorBatch &vecBatch, int64_t *valueAddrs, int64_t *nullAddrs, + int64_t *offsetAddrs, BaseVector *outVec, int32_t numSelectedRows, int32_t selectedRows[], + ExecutionContext *context, int64_t *dictionaryVectors, DataTypeId &outTypeId) const +{ + auto outValueAddr = reinterpret_cast(VectorHelper::UnsafeGetValues(outVec)); + this->projector(valueAddrs, vecBatch.GetRowCount(), outValueAddr, selectedRows, numSelectedRows, nullAddrs, + offsetAddrs, reinterpret_cast(unsafe::UnsafeBaseVector::GetNulls(outVec)), nullptr, + reinterpret_cast(context), dictionaryVectors); +} + +BaseVector *Projection::Project(VectorBatch *vecBatch, int64_t *valueAddrs, int64_t *nullAddrs, int64_t *offsetAddrs, + ExecutionContext *context, int64_t *dictionaryVectors, const int32_t *typeIds) +{ + return this->Project(vecBatch, nullptr, vecBatch->GetRowCount(), valueAddrs, nullAddrs, offsetAddrs, context, + dictionaryVectors, typeIds); +} + +BaseVector *Projection::ColumnProjectionProxy(VectorBatch *vecBatch, int32_t selectedRows[], int32_t numSelectedRows, + const int32_t *typeIds) const +{ + switch (typeIds[columnProjectionIndex]) { + case OMNI_INT: + case OMNI_DATE32: + return ColumnProjectionHelper(vecBatch, selectedRows, numSelectedRows); + case OMNI_SHORT: + return ColumnProjectionHelper(vecBatch, selectedRows, numSelectedRows); + case OMNI_BYTE: + return ColumnProjectionHelper(vecBatch, selectedRows, numSelectedRows); + case OMNI_LONG: + case OMNI_TIMESTAMP: + case OMNI_DECIMAL64: + return ColumnProjectionHelper(vecBatch, selectedRows, numSelectedRows); + case OMNI_DOUBLE: + return ColumnProjectionHelper(vecBatch, selectedRows, numSelectedRows); + case OMNI_BOOLEAN: + return ColumnProjectionHelper(vecBatch, selectedRows, numSelectedRows); + case OMNI_DECIMAL128: + return ColumnProjectionHelper(vecBatch, selectedRows, numSelectedRows); + case OMNI_VARCHAR: + case OMNI_CHAR: + return ColumnProjectionVarCharVectorHelper(vecBatch, selectedRows, numSelectedRows); + default: + LogError("Do not support such vector type %d", typeIds[columnProjectionIndex]); + return nullptr; + } +} + +template +BaseVector *Projection::ColumnProjectionHelper(VectorBatch *vecBatch, const int32_t *selectedRows, + int32_t numSelectedRows) const +{ + auto colVec = vecBatch->Get(this->columnProjectionIndex); + auto rowCnt = vecBatch->GetRowCount(); + if (numSelectedRows != 0 && numSelectedRows == rowCnt) { + if (colVec->GetEncoding() == OMNI_DICTIONARY) { + return ColumnProjectionDictionaryVectorSliceHelper(numSelectedRows, colVec); + } else { + return ColumnProjectionFlatVectorSliceHelper(numSelectedRows, colVec); + } + } + + if (selectedRows != nullptr && numSelectedRows != 0) { + if (colVec->GetEncoding() == OMNI_DICTIONARY) { + return ColumnProjectionDictionaryVectorCopyPositionsHelper(selectedRows, numSelectedRows, colVec); + } else { + return ColumnProjectionFlatVectorCopyPositionsHelper(selectedRows, numSelectedRows, colVec); + } + } + return nullptr; +} + +template +BaseVector *Projection::ColumnProjectionVarCharVectorHelper(VectorBatch *vecBatch, const int32_t *selectedRows, + int32_t numSelectedRows) const +{ + auto colVec = vecBatch->Get(this->columnProjectionIndex); + auto rowCnt = vecBatch->GetRowCount(); + if (numSelectedRows != 0 && numSelectedRows == rowCnt) { + if (colVec->GetEncoding() == OMNI_DICTIONARY) { + return ColumnProjectionDictionaryVectorSliceHelper(numSelectedRows, colVec); + } else { + return ColumnProjectionFlatVectorSliceHelper>(numSelectedRows, + colVec); + } + } + + if (selectedRows != nullptr && numSelectedRows != 0) { + if (colVec->GetEncoding() == OMNI_DICTIONARY) { + return ColumnProjectionDictionaryVectorCopyPositionsHelper(selectedRows, numSelectedRows, colVec); + } else { + // copyPosition for string vectors is 10 to 30 times slower than creating a dictionary + // so here rather than calling copyPosition with create a dictionary view of original input vector + return VectorHelper::CreateStringDictionary(selectedRows, numSelectedRows, + reinterpret_cast> *>(colVec)); + } + } + return nullptr; +} +ExpressionEvaluator::ExpressionEvaluator(Expr *filterExpression, const std::vector &projectionExprs, + const DataTypes &inputDataTypes, OverflowConfig *ofConfig) + : inputTypes(const_cast(inputDataTypes)) +{ + hasFilter = true; + filterExpr = filterExpression; + for (auto &projectionExpr : projectionExprs) { + projExprs.emplace_back(projectionExpr); + } + overflowConfig = std::make_unique(*ofConfig); + projectVecCount = static_cast(projectionExprs.size()); + + for (int i = 0; isSupportedExpr && i < projectVecCount; ++i) { + outputTypes.emplace_back(projExprs[i]->GetReturnType()); + } +} + +ExpressionEvaluator::ExpressionEvaluator(const std::vector &projectionExprs, const DataTypes &inputDataTypes, + OverflowConfig *ofConfig) + : inputTypes(const_cast(inputDataTypes)) +{ + hasFilter = false; + for (auto &projectionExpr : projectionExprs) { + projExprs.emplace_back(projectionExpr); + } + overflowConfig = std::make_unique(*ofConfig); + projectVecCount = static_cast(projectionExprs.size()); + + for (int i = 0; isSupportedExpr && i < projectVecCount; ++i) { + outputTypes.emplace_back(projExprs[i]->GetReturnType()); + } +} + +bool ExpressionEvaluator::IsSupportedExpr() const +{ + return isSupportedExpr; +} + +VectorBatch *ExpressionEvaluator::Evaluate(VectorBatch *vecBatch, ExecutionContext *context, + AlignedBuffer *selectedRowsBuffer) +{ + const int vectorCount = vecBatch->GetVectorCount(); + intptr_t valueAddrs[vectorCount]; + intptr_t nullAddrs[vectorCount]; + intptr_t offsetAddrs[vectorCount]; + intptr_t dictionaries[vectorCount]; + GetAddr(*vecBatch, valueAddrs, nullAddrs, offsetAddrs, dictionaries, this->inputTypes); + if (hasFilter) { + return ProcessFilterAndProject(vecBatch, context, selectedRowsBuffer, valueAddrs, nullAddrs, offsetAddrs, + dictionaries); + } else { + return ProcessProject(vecBatch, context, valueAddrs, nullAddrs, offsetAddrs, dictionaries); + } +} + +void ExpressionEvaluator::FilterFuncGeneration() +{ + filter = std::make_unique(*filterExpr, GetInputDataTypes(), overflowConfig.get()); + if (!this->filter->IsSupported()) { + this->isSupportedExpr = false; + } + for (auto &projExpr : projExprs) { + auto projection = std::make_unique(*projExpr, true, projExpr->GetReturnType(), GetInputDataTypes(), + overflowConfig.get()); + if (!projection->IsSupported()) { + this->isSupportedExpr = false; + break; + } + projections.emplace_back(move(projection)); + } +} + +void ExpressionEvaluator::ProjectFuncGeneration() +{ + for (auto &projExpr : projExprs) { + auto projection = std::make_unique(*projExpr, false, projExpr->GetReturnType(), GetInputDataTypes(), + overflowConfig.get()); + if (!projection->IsSupported()) { + this->isSupportedExpr = false; + break; + } + projections.emplace_back(move(projection)); + } +} + +VectorBatch *ExpressionEvaluator::ProcessProject(VectorBatch *vecBatch, ExecutionContext *context, intptr_t *valueAddrs, + intptr_t *nullAddrs, intptr_t *offsetAddrs, intptr_t *dictionaries) +{ + auto rowCount = vecBatch->GetRowCount(); + auto projectedVecs = std::make_unique(rowCount); + for (int32_t i = 0; i < projectVecCount; i++) { + BaseVector *outCol = projections[i]->Project(vecBatch, valueAddrs, nullAddrs, offsetAddrs, context, + dictionaries, GetInputDataTypes().GetIds()); + if (context->HasError()) { + context->GetArena()->Reset(); + std::string errorMessage = context->GetError(); + throw OmniException("OPERATOR_RUNTIME_ERROR", errorMessage); + } + projectedVecs->Append(outCol); + } + context->GetArena()->Reset(); + return projectedVecs.release(); +} + +VectorBatch *ExpressionEvaluator::ProcessFilterAndProject(VectorBatch *vecBatch, ExecutionContext *context, + AlignedBuffer *selectedRowsBuffer, intptr_t *valueAddrs, intptr_t *nullAddrs, intptr_t *offsetAddrs, + intptr_t *dictionaries) +{ + const int rowCount = vecBatch->GetRowCount(); + auto selectedRows = selectedRowsBuffer->AllocateReuse(rowCount, false); + int32_t numSelectedRows = filter->GetFilterFunc()(valueAddrs, rowCount, selectedRows, nullAddrs, offsetAddrs, + reinterpret_cast(context), dictionaries); + if (context->HasError()) { + context->GetArena()->Reset(); + std::string errorMessage = context->GetError(); + throw OmniException("OPERATOR_RUNTIME_ERROR", errorMessage); + } + if (numSelectedRows <= 0) { + context->GetArena()->Reset(); + return nullptr; + } + + auto projectedVecs = std::make_unique(numSelectedRows); + for (int32_t i = 0; i < projectVecCount; i++) { + BaseVector *col = projections[i]->Project(vecBatch, selectedRows, numSelectedRows, valueAddrs, nullAddrs, + offsetAddrs, context, dictionaries, GetInputDataTypes().GetIds()); + if (context->HasError()) { + context->GetArena()->Reset(); + std::string errorMessage = context->GetError(); + throw OmniException("OPERATOR_RUNTIME_ERROR", errorMessage); + } + projectedVecs->Append(col); + } + + context->GetArena()->Reset(); + return projectedVecs.release(); +} +} \ No newline at end of file diff --git a/core/src/codegen/expr_evaluator.h b/core/src/codegen/expr_evaluator.h new file mode 100644 index 0000000..0ecb602 --- /dev/null +++ b/core/src/codegen/expr_evaluator.h @@ -0,0 +1,238 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + * Description: Expression evaluator + */ +#ifndef OMNI_RUNTIME_EXPR_EVALUATOR_H +#define OMNI_RUNTIME_EXPR_EVALUATOR_H + +#include "operator/execution_context.h" +#include "codegen/filter_codegen.h" +#include "codegen/batch_filter_codegen.h" +#include "codegen/projection_codegen.h" +#include "codegen/batch_projection_codegen.h" +#include "util/config_util.h" +#include "vector/vector_helper.h" + +namespace omniruntime::codegen { +using namespace omniruntime::expressions; +using namespace omniruntime::op; +using namespace omniruntime::vec; +using namespace omniruntime::type; +using namespace omniruntime::exception; + +using FilterFunc = int32_t (*)(int64_t *, int32_t, int32_t *, int64_t *, int64_t *, int64_t, int64_t *); +using ProjFunc = int32_t (*)(int64_t const *, int32_t, int64_t, int32_t *, int32_t, int64_t const *, int64_t const *, + int32_t *, int32_t *, int64_t, int64_t *); + +typedef struct LiteralValue { + bool isNull = false; + union Value { + bool boolVal; + int8_t byteVal; + int16_t shortVal; + int32_t intVal; + int64_t longVal; + double doubleVal; + Decimal128 decimal128Val; + std::string_view stringVal; + + Value() {} + } value; + LiteralValue() {} +} LiteralValue; + +void GetAddr(VectorBatch &vecBatch, intptr_t valueAddrs[], intptr_t nullAddrs[], intptr_t offsetAddrs[], + intptr_t dictionaries[], const DataTypes &types); + +class Filter { +public: + explicit Filter(const Expr &expression, const DataTypes &inputDataTypes, OverflowConfig *overflowConfig); + + ~Filter() = default; + + FilterFunc GetFilterFunc() + { + return apply; + } + + bool IsSupported() + { + return isSupported; + } + +private: + std::unique_ptr codeGen; + std::unique_ptr batchCodeGen; + bool isSupported; + FilterFunc apply; +}; + +class Projection { +public: + Projection(const Expr &expr, bool filter, DataTypePtr outType, const DataTypes &inputDataTypes, + OverflowConfig *overflowConfig); + + ~Projection() = default; + + void ProjectHelperFixedWidth(VectorBatch &vecBatch, int64_t *valueAddrs, int64_t *nullAddrs, int64_t *offsetAddrs, + BaseVector *outVec, int32_t numSelectedRows, int32_t selectedRows[], ExecutionContext *context, + int64_t *dictionaryVectors, DataTypeId &typeIds) const; + + void ProjectHelperVarWidth(VectorBatch &vecBatch, int64_t *valueAddrs, int64_t *nullAddrs, int64_t *offsetAddrs, + BaseVector *outVec, int32_t numSelectedRows, int32_t selectedRows[], ExecutionContext *context, + int64_t *dictionaryVectors, DataTypeId &typeIds) const; + + BaseVector *Project(VectorBatch *vecBatch, int32_t selectedRows[], int32_t numSelectedRows, int64_t *valueAddrs, + int64_t *nullAddrs, int64_t *offsetAddrs, ExecutionContext *context, int64_t *dictionaryVectors, + const int32_t *typeIds); + + BaseVector *Project(VectorBatch *vecBatch, int64_t *valueAddrs, int64_t *nullAddrs, int64_t *offsetAddrs, + ExecutionContext *context, int64_t *dictionaryVectors, const int32_t *typeIds); + + omniruntime::type::DataType &GetOutputType() const + { + return *(this->outType); + } + + ProjFunc GetProjector() const + { + return projector; + } + + bool IsSupported() + { + return isSupported; + } + + bool IsColumnProjection() const + { + return isColumnProjection; + } + + int GetColumnProjectionIndex() const + { + return columnProjectionIndex; + } + +private: + const omniruntime::expressions::Expr *expr; + std::unique_ptr codeGen { nullptr }; + std::unique_ptr batchCodeGen { nullptr }; + bool isSupported = true; + bool isColumnProjection = false; + int columnProjectionIndex = -1; + bool isConstantProjection = false; + LiteralValue literalVal; + DataTypePtr outType; + ProjFunc projector; + + bool Initialize(bool filter, const DataTypes &inputDataTypes, OverflowConfig *overflowConfig); + BaseVector *ColumnProjectionProxy(VectorBatch *vecBatch, int32_t selectedRows[], int32_t numSelectedRows, + const int32_t *typeIds) const; + + template + BaseVector *ColumnProjectionHelper(VectorBatch *vecBatch, const int32_t selectedRows[], + int32_t numSelectedRows) const; + + template + BaseVector *ColumnProjectionVarCharVectorHelper(VectorBatch *vecBatch, const int32_t *selectedRows, + int32_t numSelectedRows) const; + + template + BaseVector *ColumnProjectionFlatVectorSliceHelper(int32_t numSelectedRows, BaseVector *colVec) const; + + template + BaseVector *ColumnProjectionDictionaryVectorSliceHelper(int32_t numSelectedRows, BaseVector *colVec) const; + + template + BaseVector *ColumnProjectionDictionaryVectorCopyPositionsHelper(const int32_t *selectedRows, + int32_t numSelectedRows, BaseVector *colVec) const; + + template + BaseVector *ColumnProjectionFlatVectorCopyPositionsHelper(const int32_t *selectedRows, int32_t numSelectedRows, + BaseVector *colVec) const; + + bool SetLiteralValue(const LiteralExpr *literalExpr); + + bool NullColumnProjection(ExecutionContext *context, BaseVector *outVec); + + bool ConstantColumnProjection(ExecutionContext *context, BaseVector *outVec); + + template void SetConstantValues(T &value, BaseVector *outVec); +}; + +class ExpressionEvaluator { +public: + ExpressionEvaluator(Expr *filterExpression, const std::vector &projectionExprs, + const DataTypes &inputDataTypes, OverflowConfig *ofConfig); + + ExpressionEvaluator(const std::vector &projectionExprs, const DataTypes &inputDataTypes, + OverflowConfig *ofConfig); + + ~ExpressionEvaluator() + { + if (filterExpr) { + delete filterExpr; + } + for (size_t i = 0; i < projExprs.size(); ++i) { + delete projExprs[i]; + } + projExprs.clear(); + } + + DataTypes &GetInputDataTypes() + { + return inputTypes; + } + + std::vector &GetOutputDataTypes() + { + return outputTypes; + } + + FilterFunc GetFilterFunc() + { + return filter->GetFilterFunc(); + } + + std::vector> &GetProjections() + { + return projections; + } + + int32_t GetProjectVecCount() + { + return projectVecCount; + } + + VectorBatch *Evaluate(VectorBatch *vecBatch, ExecutionContext *context, + AlignedBuffer *selectedRowsBuffer = nullptr); + + bool IsSupportedExpr() const; + + void FilterFuncGeneration(); + + void ProjectFuncGeneration(); + +private: + Expr *filterExpr = nullptr; + std::vector projExprs; + int32_t projectVecCount = 0; + std::unique_ptr overflowConfig; + bool isSupportedExpr = true; + bool hasFilter = false; + DataTypes inputTypes; + std::vector outputTypes; + + std::unique_ptr filter; + std::vector> projections; + + VectorBatch *ProcessFilterAndProject(VectorBatch *vecBatch, ExecutionContext *context, + AlignedBuffer *selectedRowsBuffer, intptr_t *valueAddrs, intptr_t *nullAddrs, intptr_t *offsetAddrs, + intptr_t *dictionaries); + + VectorBatch *ProcessProject(VectorBatch *vecBatch, ExecutionContext *context, intptr_t *valueAddrs, + intptr_t *nullAddrs, intptr_t *offsetAddrs, intptr_t *dictionaries); +}; +} +#endif // OMNI_RUNTIME_EXPR_EVALUATOR_H diff --git a/core/src/codegen/expr_function.cpp b/core/src/codegen/expr_function.cpp new file mode 100644 index 0000000..eacebdb --- /dev/null +++ b/core/src/codegen/expr_function.cpp @@ -0,0 +1,185 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + * Description: Generated Expression Function + */ + +#include "expr_function.h" + +#include + +namespace omniruntime::codegen { +ExprFunction::ExprFunction(std::string funcName, const Expr &e, CodegenBase &codegen, const DataTypes &inputDataTypes) + : funcName(std::move(funcName)), + expr(const_cast(e)), + codegen(codegen), + columnTypes(const_cast(inputDataTypes)) +{ + CreateFunction(); +} + +Argument *ExprFunction::GetColumnArgument(int i) +{ + return llvmFunc->getArg(arguments.size() + i); +} + +Argument *ExprFunction::GetDicArgument(int i) +{ + return llvmFunc->getArg(arguments.size() + columnTypes.GetSize() + i); +} + +Argument *ExprFunction::GetNullArgument(int i) +{ + return llvmFunc->getArg(arguments.size() + columnTypes.GetSize() * 2 + i); +} + +Argument *ExprFunction::GetOffsetArgument(int i) +{ + return llvmFunc->getArg(arguments.size() + columnTypes.GetSize() * 3 + i); +} + +int32_t ExprFunction::GetInputColumnCount() +{ + return columnTypes.GetSize(); +} + +size_t ExprFunction::GetArgumentCount() +{ + return arguments.size(); +} + +std::vector ExprFunction::GetArguments() +{ + std::vector args; + for (const auto &arg : arguments) { + args.push_back(arg.type); + } + + for (int32_t i = 0; i < columnTypes.GetSize(); i++) { + args.push_back(codegen.GetTypes()->ToPointerType(columnTypes.GetType(i)->GetId())); + } + + for (int32_t i = 0; i < columnTypes.GetSize(); i++) { + args.push_back(codegen.GetTypes()->I64Type()); + } + + for (int32_t i = 0; i < columnTypes.GetSize(); i++) { + args.push_back(codegen.GetTypes()->I1PtrType()); + } + + for (int32_t i = 0; i < columnTypes.GetSize(); i++) { + args.push_back(codegen.GetTypes()->I32PtrType()); + } + + return args; +} + +Type *ExprFunction::GetReturnType() +{ + return codegen.GetTypes()->GetFunctionReturnType(expr.GetReturnTypeId()); +} + +llvm::Function *ExprFunction::GetFunction() +{ + return llvmFunc; +} + +std::vector ExprFunction::ToColumnArgs(Value *data) +{ + std::vector result; + for (int32_t i = 0; i < columnTypes.GetSize(); i++) { + auto colAddr = codegen.GetIRBuilder()->CreateGEP(codegen.GetTypes()->I64Type(), data, + codegen.GetTypes()->CreateConstantInt(i), "column_addr_" + itostr(i)); + std::string colName = "column_"; + auto col = + codegen.GetIRBuilder()->CreateLoad(codegen.GetTypes()->I64Type(), colAddr, colName.append(itostr(i))); + auto columnPtr = codegen.GetPtrTypeFromInt(columnTypes.GetType(i)->GetId(), col); + result.push_back(columnPtr); + } + return result; +} + +std::vector ExprFunction::ToDicArgs(Value *dictionary) +{ + std::vector result; + for (int32_t i = 0; i < columnTypes.GetSize(); i++) { + auto dicAddr = codegen.GetIRBuilder()->CreateGEP(codegen.GetTypes()->I64Type(), dictionary, + codegen.GetTypes()->CreateConstantInt(i), "dic_addr_" + itostr(i)); + std::string dicName = "dic_"; + auto dic = + codegen.GetIRBuilder()->CreateLoad(codegen.GetTypes()->I64Type(), dicAddr, dicName.append(itostr(i))); + auto dicPtr = codegen.GetIRBuilder()->CreateIntToPtr(dic, codegen.GetTypes()->I64Type()); + result.push_back(dicPtr); + } + return result; +} + +std::vector ExprFunction::ToNullArgs(Value *bitmap) +{ + std::vector result; + for (int32_t i = 0; i < columnTypes.GetSize(); i++) { + auto bitmapAddr = codegen.GetIRBuilder()->CreateGEP(codegen.GetTypes()->I64Type(), bitmap, + codegen.GetTypes()->CreateConstantInt(i), "bitmap_addr_" + itostr(i)); + std::string bitmapName = "bitmap_"; + auto bitmapValue = + codegen.GetIRBuilder()->CreateLoad(codegen.GetTypes()->I64Type(), bitmapAddr, bitmapName.append(itostr(i))); + auto bitmapPtr = codegen.GetIRBuilder()->CreateIntToPtr(bitmapValue, codegen.GetTypes()->I32PtrType()); + result.push_back(bitmapPtr); + } + return result; +} + +std::vector ExprFunction::ToOffsetArgs(Value *offset) +{ + std::vector result; + for (int32_t i = 0; i < columnTypes.GetSize(); i++) { + auto offsetAddr = codegen.GetIRBuilder()->CreateGEP(codegen.GetTypes()->I64Type(), offset, + codegen.GetTypes()->CreateConstantInt(i), "offset_addr_" + itostr(i)); + std::string offsetName = "offset_"; + auto offsetValue = + codegen.GetIRBuilder()->CreateLoad(codegen.GetTypes()->I64Type(), offsetAddr, offsetName.append(itostr(i))); + auto offsetPtr = codegen.GetIRBuilder()->CreateIntToPtr(offsetValue, codegen.GetTypes()->I32PtrType()); + result.push_back(offsetPtr); + } + return result; +} + +void ExprFunction::CreateFunction() +{ + FunctionType *prototype = FunctionType::get(GetReturnType(), GetArguments(), false); + + llvmFunc = llvm::Function::Create(prototype, llvm::Function::ExternalLinkage, funcName, codegen.GetModule()); + + for (size_t i = 0; i < arguments.size(); i++) { + auto arg = llvmFunc->getArg(i); + arg->setName(arguments.at(i).name); + } + + for (int32_t i = 0; i < columnTypes.GetSize(); i++) { + std::string colName = "column_"; + size_t idx = arguments.size() + i; + auto arg = llvmFunc->getArg(idx); + arg->setName(colName.append(itostr(i))); + } + + for (int32_t i = 0; i < columnTypes.GetSize(); i++) { + std::string dicName = "dic_"; + size_t idx = arguments.size() + columnTypes.GetSize() + i; + auto arg = llvmFunc->getArg(idx); + arg->setName(dicName.append(itostr(i))); + } + + for (int32_t i = 0; i < columnTypes.GetSize(); i++) { + std::string bitmapName = "bitmap_"; + size_t idx = arguments.size() + columnTypes.GetSize() * 2 + i; + auto arg = llvmFunc->getArg(idx); + arg->setName(bitmapName.append(itostr(i))); + } + + for (int32_t i = 0; i < columnTypes.GetSize(); i++) { + std::string offsetName = "offset_"; + size_t idx = arguments.size() + columnTypes.GetSize() * 3 + i; + auto arg = llvmFunc->getArg(idx); + arg->setName(offsetName.append(itostr(i))); + } +} +} diff --git a/core/src/codegen/expr_function.h b/core/src/codegen/expr_function.h new file mode 100644 index 0000000..fb9aeaa --- /dev/null +++ b/core/src/codegen/expr_function.h @@ -0,0 +1,120 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + * Description: Generated Expression Function + */ +#ifndef OMNI_RUNTIME_EXPR_FUNCTION_H +#define OMNI_RUNTIME_EXPR_FUNCTION_H + +#include +#include +#include + +#include "expression/expr_visitor.h" +#include "expression/expressions.h" +#include "codegen/llvm_types.h" +#include "codegen/codegen_base.h" +#include "vector/vector_batch.h" + +namespace omniruntime::codegen { +using namespace llvm; +using namespace omniruntime::expressions; + +struct ExprArgument { + std::string name; + llvm::Type *type; +}; + +/** + * This class encapsulates the generated expression function so that we don't have to expose the details on handling the + * argument index, invoking the function, etc. + * We invoke this function to process every row of data, hence it's performance can dramatically impact the performance + * of expression evaluation + */ +class ExprFunction { +public: + /** + * The ExprFunction should be in a "complete" state once created, e.g. access to this object should + * provide valid and consistent information. It is also reentrant, e.g. all functions can be invoked + * as many times as need without impacting the consistency of the information contained in this class + */ + ExprFunction(std::string funcName, const Expr &e, CodegenBase &codegen, const DataTypes &inputDataTypes); + + Argument *GetColumnArgument(int i); + + Argument *GetDicArgument(int i); + + Argument *GetNullArgument(int i); + + Argument *GetOffsetArgument(int i); + + int32_t GetInputColumnCount(); + + size_t GetArgumentCount(); + + /** + * @return return all arguments to the generated function + */ + std::vector GetArguments(); + + /** + * @return returns the return type of the generated function + */ + Type *GetReturnType(); + + /** + * @return returns the LLVM function + */ + llvm::Function *GetFunction(); + + /** + * Convert the data from each column in the table into individual arguments + * @param table the pointer of data address array + * @return the vector containing data pointer for each column + */ + std::vector ToColumnArgs(Value *data); + + /** + * Convert the dictionary from each column in the table into individual arguments + * @param dictionary the pointer of dictionary address array + * @return the vector containing dictionary pointer for each column + */ + std::vector ToDicArgs(Value *dictionary); + + /** + * Convert the bitmap from each column in the table into individual arguments + * @param bitmap the pointer of bitmap address array + * @return the vector containing bitmap pointer for each column + */ + std::vector ToNullArgs(Value *bitmap); + + /** + * Convert the offset from each column in the table into individual arguments + * @param offset the pointer of offset address array + * @return the vector containing offset pointer for each column + */ + std::vector ToOffsetArgs(Value *offset); + +private: + std::string funcName; + Expr &expr; + CodegenBase &codegen; + llvm::Function *llvmFunc = nullptr; + DataTypes columnTypes; + + /** + * Predefined argument list + * The input columns will be appended to the predefined argument list + */ + std::vector arguments { { "rowIdx", Type::getInt32Ty(*codegen.GetContext()) }, + { "dataLength", Type::getInt32PtrTy(*codegen.GetContext()) }, + { "executionContext", Type::getInt64Ty(*codegen.GetContext()) }, + { "isNullPtr", Type::getInt1PtrTy(*codegen.GetContext()) } }; + + /** + * Method to create prototype of function + */ + void CreateFunction(); +}; +} + +#endif // OMNI_RUNTIME_EXPR_FUNCTION_H diff --git a/core/src/codegen/expr_info_extractor.cpp b/core/src/codegen/expr_info_extractor.cpp new file mode 100644 index 0000000..c2d9990 --- /dev/null +++ b/core/src/codegen/expr_info_extractor.cpp @@ -0,0 +1,80 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description: Extract essential information from the expression tree + */ +#include +#include "expr_info_extractor.h" + +namespace omniruntime::codegen { +using namespace omniruntime::expressions; +using namespace std; + +void ExprInfoExtractor::Visit(const LiteralExpr &e) {} + +void ExprInfoExtractor::Visit(const FieldExpr &e) +{ + this->vectorIndexes.insert(e.colVal); +} + +void ExprInfoExtractor::Visit(const BinaryExpr &e) +{ + e.left->Accept(*this); + e.right->Accept(*this); +} + +void ExprInfoExtractor::Visit(const UnaryExpr &e) +{ + e.exp->Accept(*this); +} + +void ExprInfoExtractor::Visit(const IfExpr &e) +{ + e.condition->Accept(*this); + e.trueExpr->Accept(*this); + e.falseExpr->Accept(*this); +} + +void ExprInfoExtractor::Visit(const InExpr &e) +{ + for (auto arg : e.arguments) { + arg->Accept(*this); + } +} + +void ExprInfoExtractor::Visit(const BetweenExpr &e) +{ + e.value->Accept(*this); + e.lowerBound->Accept(*this); + e.upperBound->Accept(*this); +} + +void ExprInfoExtractor::Visit(const CoalesceExpr &e) +{ + e.value1->Accept(*this); + e.value2->Accept(*this); +} +void ExprInfoExtractor::Visit(const SwitchExpr &e) +{ + e.falseExpr->Accept(*this); + for (auto &when : e.whenClause) { + when.first->Accept(*this); + when.second->Accept(*this); + } +} +void ExprInfoExtractor::Visit(const IsNullExpr &e) +{ + e.value->Accept(*this); +} + +void ExprInfoExtractor::Visit(const FuncExpr &e) +{ + for (auto arg : e.arguments) { + arg->Accept(*this); + } +} + +std::set ExprInfoExtractor::GetVectorIndexes() +{ + return this->vectorIndexes; +} +} diff --git a/core/src/codegen/expr_info_extractor.h b/core/src/codegen/expr_info_extractor.h new file mode 100644 index 0000000..26b6d2d --- /dev/null +++ b/core/src/codegen/expr_info_extractor.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description: Extract essential information from the expression tree + */ +#ifndef __OMNI_RUNTIME_EXPR_INFO_EXTRACTOR__ +#define __OMNI_RUNTIME_EXPR_INFO_EXTRACTOR__ + +#include "expression/expr_visitor.h" +#include "func_registry_string.h" +#include "func_registry_dictionary.h" +#include "func_registry_decimal.h" +#include "func_registry_context.h" +#include "function.h" + +namespace omniruntime::codegen { +class ExprInfoExtractor : public ExprVisitor { +public: + void Visit(const omniruntime::expressions::LiteralExpr &e) override; + + void Visit(const omniruntime::expressions::FieldExpr &e) override; + + void Visit(const omniruntime::expressions::UnaryExpr &e) override; + + void Visit(const omniruntime::expressions::BinaryExpr &e) override; + + void Visit(const omniruntime::expressions::InExpr &e) override; + + void Visit(const omniruntime::expressions::BetweenExpr &e) override; + + void Visit(const omniruntime::expressions::IfExpr &e) override; + + void Visit(const omniruntime::expressions::CoalesceExpr &e) override; + + void Visit(const omniruntime::expressions::IsNullExpr &e) override; + + void Visit(const omniruntime::expressions::FuncExpr &e) override; + + void Visit(const omniruntime::expressions::SwitchExpr &e) override; + + std::set GetVectorIndexes(); + +private: + std::set vectorIndexes; +}; +} +#endif diff --git a/core/src/codegen/expression_codegen.cpp b/core/src/codegen/expression_codegen.cpp new file mode 100644 index 0000000..a0546b1 --- /dev/null +++ b/core/src/codegen/expression_codegen.cpp @@ -0,0 +1,2295 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description: Expression code generator + */ +#include "expression_codegen.h" + +#include +#include +#include +#include + +#include +#include +#include + +#include "expr_info_extractor.h" +#include "codegen_context.h" +#include "function.h" + +namespace omniruntime::codegen { +using namespace llvm; +using namespace orc; +using namespace omniruntime; +using namespace omniruntime::expressions; +using namespace omniruntime::type; +using namespace std; + +namespace { +const int INT32_VALUE = 32; +const int INT64_VALUE = 64; +const int EXPRFUNC_OUT_LENGTH_ARG_INDEX = 1; +const int EXPRFUNC_OUT_NULL_INDEX = 3; +} + +ExpressionCodeGen::ExpressionCodeGen(std::string name, const Expr &cpExpr, op::OverflowConfig *overflowConfig) + : CodegenBase(name, cpExpr, overflowConfig) +{} + +ExpressionCodeGen::~ExpressionCodeGen() +{ + if (rt) { + eoe(rt->remove()); + } +} + +bool ExpressionCodeGen::InitializeCodegenContext(iterator_range args) +{ + this->codegenContext = std::make_unique(); + for (auto &arg : args) { + auto argName = arg.getName().str(); + if (argName == "data") { + codegenContext->data = &arg; + } else if (argName == "nullBitmap") { + codegenContext->nullBitmap = &arg; + } else if (argName == "offsets") { + codegenContext->offsets = &arg; + } else if (argName == "rowIdx") { + codegenContext->rowIdx = &arg; + } else if (argName == "dataLength" || argName == "isNullPtr") { + continue; + } else if (argName == "executionContext") { + codegenContext->executionContext = &arg; + } else if (argName == "dictionaryVectors") { + codegenContext->dictionaryVectors = &arg; + } else if (argName.find("column_") == 0 || argName.find("dic_") == 0 || argName.find("bitmap_") == 0 || + argName.find("offset_") == 0) { + continue; + } else { + LogWarn("Invalid argument %s", argName.c_str()); + return false; + } + } + + codegenContext->print = modulePtr->getOrInsertFunction("printf", + FunctionType::get(IntegerType::getInt32Ty(*context), PointerType::get(Type::getInt8Ty(*context), 0), true)); + + return true; +} + +llvm::Function *ExpressionCodeGen::CreateFunction(const DataTypes &inputDataTypes) +{ + exprFunc = make_shared(funcName, *expr, *this, inputDataTypes); + func = exprFunc->GetFunction(); + + // Fill the function body + BasicBlock *body = BasicBlock::Create(*context, "CREATED_FUNC_BODY", func); + builder->SetInsertPoint(body); + + if (!InitializeCodegenContext(func->args())) { + return nullptr; + } + + auto result = VisitExpr(*expr); + if (result->data == nullptr) { + return nullptr; + } + + // Update final string length of output + if (result->length != nullptr) { + Argument *outputLength = func->getArg(EXPRFUNC_OUT_LENGTH_ARG_INDEX); + Value *lengthGep = builder->CreateGEP(llvmTypes->I32Type(), outputLength, llvmTypes->CreateConstantInt(0), + "OUTPUT_LENGTH_ADDRESS"); + builder->CreateStore(result->length, lengthGep); + } + + // Update final isNull of output + builder->CreateStore(result->isNull, func->getArg(EXPRFUNC_OUT_NULL_INDEX)); + + if (expr->GetReturnTypeId() == DataTypeId::OMNI_VARCHAR) { + result->data = builder->CreatePtrToInt(result->data, llvmTypes->I64Type()); + } + builder->CreateRet(result->data); + verifyFunction(*func); + return func; +} + +CodeGenValuePtr ExpressionCodeGen::VisitExpr(const omniruntime::expressions::Expr &e) +{ + e.Accept(*this); + return this->value; +} + +void ExpressionCodeGen::Visit(const LiteralExpr &lExpr) +{ + this->value.reset(LiteralExprConstantHelper(lExpr)); +} + +void ExpressionCodeGen::Visit(const FieldExpr &fExpr) +{ + Value *rowIdx = this->codegenContext->rowIdx; + Value *length = nullptr; + + // Get dictionary address of this column + Value *dictionaryVectorPtr = exprFunc->GetDicArgument(fExpr.colVal); + Type *dataType = llvmTypes->ToLLVMType(fExpr.GetReturnTypeId()); + auto condition = builder->CreateIsNotNull(dictionaryVectorPtr); + + BasicBlock *trueBlock = BasicBlock::Create(*context, "DICTIONARY_NOT_NULL", func); + BasicBlock *falseBlock = BasicBlock::Create(*context, "DICTIONARY_IS_NULL"); + BasicBlock *mergeBlock = BasicBlock::Create(*context, "ifcont"); + + builder->CreateCondBr(condition, trueBlock, falseBlock); + + // If dictionary vector is present, call DictionaryVector methods + // to get encoded values and length if varchar type + builder->SetInsertPoint(trueBlock); + + AllocaInst *lengthAllocaInst = nullptr; + Value *dictionaryValue = + this->GetDictionaryVectorValue(*(fExpr.GetReturnType()), rowIdx, dictionaryVectorPtr, lengthAllocaInst); + if (dictionaryValue == nullptr) { + this->value = CreateInvalidCodeGenValue(); + return; + } + + Value *dictionaryLength = nullptr; + if (TypeUtil::IsStringType(fExpr.GetReturnTypeId())) { + dictionaryLength = builder->CreateLoad(llvmTypes->I32Type(), lengthAllocaInst, "varchar_length"); + } + + builder->CreateBr(mergeBlock); + trueBlock = builder->GetInsertBlock(); + func->getBasicBlockList().push_back(falseBlock); + + // If dictionary vector is not present, get vector values + // using valuesAddress and length using offsets if varchar type + builder->SetInsertPoint(falseBlock); + // Load the address value. + + // Get data address of this column + Value *columnPtr = exprFunc->GetColumnArgument(fExpr.colVal); + Value *dataValue = nullptr; + if (TypeUtil::IsStringType(fExpr.GetReturnTypeId())) { + Value *offsetPtr = exprFunc->GetOffsetArgument(fExpr.colVal); + auto colOffsetGEP = builder->CreateGEP(llvmTypes->I32Type(), offsetPtr, rowIdx); + Value *startOffset = builder->CreateLoad(llvmTypes->I32Type(), colOffsetGEP); + colOffsetGEP = builder->CreateGEP(llvmTypes->I32Type(), offsetPtr, + builder->CreateAdd(rowIdx, llvmTypes->CreateConstantInt(1))); + Value *endOffset = builder->CreateLoad(llvmTypes->I32Type(), colOffsetGEP); + // Get length for varchar + length = builder->CreateSub(endOffset, startOffset); + // Find the address of the row to be processed. + dataValue = builder->CreateGEP(llvmTypes->I8Type(), columnPtr, startOffset); + } else { + // Find the address of the row to be processed. + auto rowValuePtr = builder->CreateGEP(dataType, columnPtr, rowIdx); + // Value to be processed. + dataValue = builder->CreateLoad(dataType, rowValuePtr); + } + + builder->CreateBr(mergeBlock); + falseBlock = builder->GetInsertBlock(); + + // Get merged data value and length + int32_t numReservedValues = 2; + func->getBasicBlockList().push_back(mergeBlock); + builder->SetInsertPoint(mergeBlock); + + PHINode *phiValue = builder->CreatePHI(dataType, numReservedValues, "iftmp"); + phiValue->addIncoming(dictionaryValue, trueBlock); + phiValue->addIncoming(dataValue, falseBlock); + + // Length is only valid for varchar type + PHINode *phiLength = nullptr; + if (TypeUtil::IsStringType(fExpr.GetReturnTypeId())) { + phiLength = builder->CreatePHI(llvmTypes->I32Type(), numReservedValues, "length"); + phiLength->addIncoming(dictionaryLength, trueBlock); + phiLength->addIncoming(length, falseBlock); + } + + FunctionSignature isBitNullFuncSignature = FunctionSignature("WrapIsBitNull", { OMNI_INT }, OMNI_BOOLEAN); + llvm::Function *isBitNullFunc = + modulePtr->getFunction(FunctionRegistry::LookupFunction(&isBitNullFuncSignature)->GetId()); + // Get bitmap address of this column + Value *bitmapPtr = exprFunc->GetNullArgument(fExpr.colVal); + auto isNullRet = builder->CreateCall(isBitNullFunc, { bitmapPtr, rowIdx }, "wrap_is_bit_null"); + InlineFunctionInfo inlineIsNullFuncInfo; + InlineFunction(*isNullRet, inlineIsNullFuncInfo); + + if (TypeUtil::IsDecimalType(fExpr.GetReturnTypeId())) { + Value *precision = + llvmTypes->CreateConstantInt(static_cast(fExpr.GetReturnType().get())->GetPrecision()); + Value *scale = + llvmTypes->CreateConstantInt(static_cast(fExpr.GetReturnType().get())->GetScale()); + this->value = make_shared(phiValue, isNullRet, precision, scale); + } else { + this->value = make_shared(phiValue, isNullRet, phiLength); + } +} + +void ExpressionCodeGen::Visit(const BinaryExpr &binaryExpr) +{ + auto *bExpr = const_cast(&binaryExpr); + + CodeGenValuePtr left = VisitExpr(*(bExpr->left)); + if (!left->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + CodeGenValuePtr right = VisitExpr(*(bExpr->right)); + if (!right->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + + if (bExpr->op == omniruntime::expressions::Operator::AND) { + this->value = make_shared(builder->CreateAnd(left->data, right->data, "logical_and"), + builder->CreateOr(builder->CreateAnd(left->isNull, right->isNull), builder->CreateOr( + builder->CreateAnd(left->isNull, right->data), builder->CreateAnd(right->isNull, left->data)))); + return; + } + if (bExpr->op == omniruntime::expressions::Operator::OR) { + this->value = make_shared(builder->CreateOr(left->data, right->data, "logical_or"), + builder->CreateOr(builder->CreateAnd(left->isNull, right->isNull), + builder->CreateOr(builder->CreateAnd(left->isNull, builder->CreateNot(right->data)), + builder->CreateAnd(right->isNull, builder->CreateNot(left->data))))); + return; + } + if (bExpr->left->GetReturnTypeId() == OMNI_BYTE) { + Value *nullFlag = builder->CreateAlloca(Type::getInt1Ty(*context), nullptr, "null_flag"); + builder->CreateStore(ConstantInt::get(IntegerType::getInt1Ty(*context), 0), nullFlag); + this->value = make_shared( + this->BinaryExprByteHelper(bExpr, left->data, right->data, left->isNull, right->isNull, nullFlag), + builder->CreateOr(builder->CreateOr(left->isNull, right->isNull), + builder->CreateLoad(llvmTypes->I1Type(), nullFlag))); + return; + } else if (bExpr->left->GetReturnTypeId() == OMNI_SHORT) { + Value *nullFlag = builder->CreateAlloca(Type::getInt1Ty(*context), nullptr, "null_flag"); + builder->CreateStore(ConstantInt::get(IntegerType::getInt1Ty(*context), 0), nullFlag); + this->value = make_shared( + this->BinaryExprShortHelper(bExpr, left->data, right->data, left->isNull, right->isNull, nullFlag), + builder->CreateOr(builder->CreateOr(left->isNull, right->isNull), + builder->CreateLoad(llvmTypes->I1Type(), nullFlag))); + return; + } else if (bExpr->left->GetReturnTypeId() == OMNI_INT || bExpr->left->GetReturnTypeId() == OMNI_DATE32) { + Value *nullFlag = builder->CreateAlloca(Type::getInt1Ty(*context), nullptr, "null_flag"); + builder->CreateStore(ConstantInt::get(IntegerType::getInt1Ty(*context), 0), nullFlag); + this->value = make_shared( + this->BinaryExprIntHelper(bExpr, left->data, right->data, left->isNull, right->isNull, nullFlag), + builder->CreateOr(builder->CreateOr(left->isNull, right->isNull), + builder->CreateLoad(llvmTypes->I1Type(), nullFlag))); + return; + } else if (bExpr->left->GetReturnTypeId() == OMNI_LONG || bExpr->left->GetReturnTypeId() == OMNI_TIMESTAMP) { + Value *nullFlag = builder->CreateAlloca(Type::getInt1Ty(*context), nullptr, "null_flag"); + builder->CreateStore(ConstantInt::get(IntegerType::getInt1Ty(*context), 0), nullFlag); + this->value = make_shared( + this->BinaryExprLongHelper(bExpr, left->data, right->data, left->isNull, right->isNull, nullFlag), + builder->CreateOr(builder->CreateOr(left->isNull, right->isNull), + builder->CreateLoad(llvmTypes->I1Type(), nullFlag))); + return; + } else if (bExpr->left->GetReturnTypeId() == OMNI_DECIMAL64) { + auto decimalLeft = dynamic_cast(*left.get()); + auto decimalRight = dynamic_cast(*right.get()); + if (decimalLeft.GetScale() == decimalRight.GetScale() && + (bExpr->op == omniruntime::expressions::Operator::LT || + bExpr->op == omniruntime::expressions::Operator::LTE || + bExpr->op == omniruntime::expressions::Operator::GT || + bExpr->op == omniruntime::expressions::Operator::GTE || + bExpr->op == omniruntime::expressions::Operator::EQ || + bExpr->op == omniruntime::expressions::Operator::NEQ)) { + auto output = this->BinaryExprLongHelper(bExpr, left->data, right->data, left->isNull, right->isNull); + this->value = + BuildDecimalValue(output, *(bExpr->GetReturnType()), builder->CreateOr(left->isNull, right->isNull)); + return; + } + this->BinaryExprDecimal64Helper(bExpr, decimalLeft, decimalRight, left->isNull, right->isNull); + return; + } else if (bExpr->left->GetReturnTypeId() == OMNI_DOUBLE) { + Value *nullFlag = builder->CreateAlloca(Type::getInt1Ty(*context), nullptr, "null_flag"); + builder->CreateStore(ConstantInt::get(IntegerType::getInt1Ty(*context), 0), nullFlag); + this->value = make_shared( + this->BinaryExprDoubleHelper(bExpr, left->data, right->data, left->isNull, right->isNull, nullFlag), + builder->CreateOr(builder->CreateOr(left->isNull, right->isNull), + builder->CreateLoad(llvmTypes->I1Type(), nullFlag))); + return; + } else if (TypeUtil::IsStringType(bExpr->left->GetReturnTypeId())) { + this->value = make_shared(this->BinaryExprStringHelper(bExpr, left->data, left->length, + right->data, right->length, left->isNull, right->isNull), + builder->CreateOr(left->isNull, right->isNull)); + return; + } else if (bExpr->left->GetReturnTypeId() == OMNI_DECIMAL128) { + this->BinaryExprDecimal128Helper(bExpr, dynamic_cast(*left.get()), + dynamic_cast(*right.get()), left->isNull, right->isNull); + return; + } + LogWarn("Unsupported binary operator %u", static_cast(bExpr->op)); + this->value = CreateInvalidCodeGenValue(); +} + +void ExpressionCodeGen::Visit(const UnaryExpr &uExpr) +{ + auto val = VisitExpr(*(uExpr.exp)); + if (!val->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + + switch (uExpr.op) { + case omniruntime::expressions::Operator::NOT: { + Value *notValue = builder->CreateNot(val->data, "logical_not"); + this->value = make_shared(notValue, val->isNull); + break; + } + default: { + // Ignore the unary operator if it is invalid + this->value = CreateInvalidCodeGenValue(); + break; + } + } +} + +void ExpressionCodeGen::Visit(const SwitchExpr &switchExpr) +{ + Type *switchDataType = llvmTypes->VectorToLLVMType(*(switchExpr.GetReturnType())); + Expr *elseExpr = switchExpr.falseExpr; + std::vector> whenClause = switchExpr.whenClause; + const size_t size = whenClause.size(); + + std::vector condBlockList; + std::vector trueBlockList; + BasicBlock *falseBlock = BasicBlock::Create(*context, "FALSE_BLOCK"); + BasicBlock *mergeBlock = BasicBlock::Create(*context, "ifcont"); + int32_t numReservedValues = 2; + + AllocaInst *resultValuePtr = builder->CreateAlloca(switchDataType, numReservedValues, nullptr, "temp_result_value"); + AllocaInst *resultNullPtr = + builder->CreateAlloca(Type::getInt1Ty(*context), numReservedValues, nullptr, "temp_result_null"); + AllocaInst *resultLengthPtr = + builder->CreateAlloca(Type::getInt32Ty(*context), numReservedValues, nullptr, "temp_result_length"); + + AllocaInst *resultPrecisionPtr = + builder->CreateAlloca(Type::getInt32Ty(*context), numReservedValues, nullptr, "temp_result_precision"); + + AllocaInst *resultScalePtr = + builder->CreateAlloca(Type::getInt32Ty(*context), numReservedValues, nullptr, "temp_result_scale"); + + condBlockList.push_back(BasicBlock::Create(*context, "Condition" + std::to_string(0), func)); + trueBlockList.push_back(BasicBlock::Create(*context, "TRUE_BLOCK" + std::to_string(0), func)); + + for (size_t i = 1; i < size; i++) { // Generate block lists used in the next loop to evaluate conditions + condBlockList.push_back(BasicBlock::Create(*context, "Condition" + std::to_string(i))); + trueBlockList.push_back(BasicBlock::Create(*context, "TRUE_BLOCK" + std::to_string(i), func)); + } + for (size_t i = 0; i < size; i++) { // Evaluate condition in the whenClause + Expr *cond = whenClause[i].first; + Expr *resExpr = whenClause[i].second; + + // If cond evaluates to true, control flow goes to trueBlock, save evTrue to temp value + // Otherwise goes to next Block in the list and keeps evaluating next cond in the whenClause + // If last cond evaluates to false, control flow goes to falseBlock and save evFalse to temp value + if (i == 0) { // Create the entry of the block + builder->CreateBr(condBlockList[i]); + } + if (i > 0) { + func->getBasicBlockList().push_back(condBlockList[i]); + } + + auto elseBranch = falseBlock; + if (i < size - 1) { + elseBranch = condBlockList[i + 1]; + } + builder->SetInsertPoint(condBlockList[i]); + CodeGenValuePtr evCond = VisitExpr(*cond); + if (!evCond->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + builder->CreateCondBr(builder->CreateAnd(builder->CreateNot(evCond->isNull), evCond->data), trueBlockList[i], + elseBranch); + + builder->SetInsertPoint(trueBlockList[i]); + auto evTrue = VisitExpr(*resExpr); + if (!evTrue->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + builder->CreateStore(evTrue->data, resultValuePtr); + builder->CreateStore(evTrue->isNull, resultNullPtr); + if (TypeUtil::IsStringType(switchExpr.GetReturnTypeId())) { + builder->CreateStore(evTrue->length, resultLengthPtr); + } else if (TypeUtil::IsDecimalType(switchExpr.GetReturnTypeId())) { + builder->CreateStore(dynamic_cast(evTrue.get())->GetPrecision(), resultPrecisionPtr); + builder->CreateStore(dynamic_cast(evTrue.get())->GetScale(), resultScalePtr); + } + builder->CreateBr(mergeBlock); + } + + func->getBasicBlockList().push_back(falseBlock); + builder->SetInsertPoint(falseBlock); + auto evFalse = VisitExpr(*elseExpr); + if (!evFalse->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + builder->CreateStore(evFalse->data, resultValuePtr); + builder->CreateStore(evFalse->isNull, resultNullPtr); + if (TypeUtil::IsStringType(switchExpr.GetReturnTypeId())) { + builder->CreateStore(evFalse->length, resultLengthPtr); + } else if (TypeUtil::IsDecimalType(switchExpr.GetReturnTypeId())) { + builder->CreateStore(dynamic_cast(evFalse.get())->GetPrecision(), resultPrecisionPtr); + builder->CreateStore(dynamic_cast(evFalse.get())->GetScale(), resultScalePtr); + } + builder->CreateBr(mergeBlock); + + func->getBasicBlockList().push_back(mergeBlock); + builder->SetInsertPoint(mergeBlock); + if (TypeUtil::IsStringType(switchExpr.GetReturnTypeId())) { + this->value = make_shared(builder->CreateLoad(switchDataType, resultValuePtr), + builder->CreateLoad(llvmTypes->I1Type(), resultNullPtr), + builder->CreateLoad(llvmTypes->I32Type(), resultLengthPtr)); + } else if (TypeUtil::IsDecimalType(switchExpr.GetReturnTypeId())) { + this->value = make_shared(builder->CreateLoad(switchDataType, resultValuePtr), + builder->CreateLoad(llvmTypes->I1Type(), resultNullPtr), + builder->CreateLoad(llvmTypes->I32Type(), resultPrecisionPtr), + builder->CreateLoad(llvmTypes->I32Type(), resultScalePtr)); + } else { + this->value = std::make_shared(builder->CreateLoad(switchDataType, resultValuePtr), + builder->CreateLoad(llvmTypes->I1Type(), resultNullPtr)); + } +} + +void ExpressionCodeGen::Visit(const IfExpr &ifExpr) +{ + Expr *cond = ifExpr.condition; + Expr *ifTrue = ifExpr.trueExpr; + Expr *ifFalse = ifExpr.falseExpr; + + BasicBlock *trueBlock = BasicBlock::Create(*context, "TRUE_BLOCK", func); + BasicBlock *falseBlock = BasicBlock::Create(*context, "FALSE_BLOCK"); + BasicBlock *mergeBlock = BasicBlock::Create(*context, "ifcont"); + + CodeGenValuePtr evCond = VisitExpr(*cond); + if (!evCond->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + + // If cond evaluates to true, control flow goes to trueBlock, returning evTrue + // Otherwise goes to falseBlock and returns evFalse + builder->CreateCondBr(builder->CreateAnd(builder->CreateNot(evCond->isNull), evCond->data), trueBlock, falseBlock); + builder->SetInsertPoint(trueBlock); + auto evTrue = VisitExpr(*ifTrue); + if (!evTrue->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + + builder->CreateBr(mergeBlock); + // Codegen of 'true' can change the current block, update trueBlock for the PHI. + trueBlock = builder->GetInsertBlock(); + + func->getBasicBlockList().push_back(falseBlock); + builder->SetInsertPoint(falseBlock); + auto evFalse = VisitExpr(*ifFalse); + if (!evFalse->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + + builder->CreateBr(mergeBlock); + // Codegen of 'false' can change the current block, update falseBlock for the PHI. + falseBlock = builder->GetInsertBlock(); + int32_t numReservedValues = 2; + // Emit merge block. + Type *phiType = llvmTypes->VectorToLLVMType(*(ifExpr.GetReturnType())); + func->getBasicBlockList().push_back(mergeBlock); + builder->SetInsertPoint(mergeBlock); + PHINode *pn = builder->CreatePHI(phiType, numReservedValues, "iftmp"); + PHINode *phiNull = builder->CreatePHI(evTrue->isNull->getType(), numReservedValues, "iftmpNull"); + + pn->addIncoming(evTrue->data, trueBlock); + pn->addIncoming(evFalse->data, falseBlock); + phiNull->addIncoming(evTrue->isNull, trueBlock); + phiNull->addIncoming(evFalse->isNull, falseBlock); + + PHINode *lengthPhi = nullptr; + if (TypeUtil::IsStringType(ifExpr.GetReturnTypeId())) { + lengthPhi = builder->CreatePHI(Type::getInt32Ty(*context), numReservedValues, "length"); + lengthPhi->addIncoming(evTrue->length, trueBlock); + lengthPhi->addIncoming(evFalse->length, falseBlock); + } + + PHINode *precisionPhi = nullptr; + PHINode *scalePhi = nullptr; + if (TypeUtil::IsDecimalType(ifExpr.GetReturnTypeId())) { + precisionPhi = builder->CreatePHI(Type::getInt32Ty(*context), numReservedValues, "precision"); + auto evTruePrecision = (Value *)dynamic_cast(evTrue.get())->GetPrecision(); + auto evFalsePrecision = (Value *)dynamic_cast(evFalse.get())->GetPrecision(); + precisionPhi->addIncoming(evTruePrecision, trueBlock); + precisionPhi->addIncoming(evFalsePrecision, falseBlock); + + scalePhi = builder->CreatePHI(Type::getInt32Ty(*context), numReservedValues, "scale"); + auto evTrueScale = (Value *)dynamic_cast(evTrue.get())->GetScale(); + auto evFalseScale = (Value *)dynamic_cast(evFalse.get())->GetScale(); + scalePhi->addIncoming(evTrueScale, trueBlock); + scalePhi->addIncoming(evFalseScale, falseBlock); + + this->value = std::make_shared(pn, phiNull, precisionPhi, scalePhi); + return; + } + + this->value = std::make_shared(pn, phiNull, lengthPhi); +} + +void ExpressionCodeGen::Visit(const InExpr &inExpr) +{ + auto size = inExpr.arguments.size(); + CodeGenValuePtr argiValue; + auto inArray = builder->CreateAlloca(Type::getInt1Ty(*context), nullptr, "res"); + builder->CreateStore(llvmTypes->CreateConstantBool(false), inArray); + auto isNull = builder->CreateAlloca(Type::getInt1Ty(*context), nullptr, "res_null"); + builder->CreateStore(llvmTypes->CreateConstantBool(false), isNull); + auto hasnull = builder->CreateAlloca(Type::getInt1Ty(*context), nullptr, "has_null"); + builder->CreateStore(llvmTypes->CreateConstantBool(false), hasnull); + Type *retType = llvmTypes->ToLLVMType(inExpr.GetReturnTypeId()); + + std::vector condBlockList; + BasicBlock *trueBlock = BasicBlock::Create(*context, "TRUE_BLOCK"); + BasicBlock *falseBlock = BasicBlock::Create(*context, "FALSE_BLOCK"); + BasicBlock *mergeBlock = BasicBlock::Create(*context, "MERGE_BLOCK"); + + condBlockList.push_back(nullptr); + for (size_t i = 1; i < size; i++) { + condBlockList.push_back(BasicBlock::Create(*context, "Condition" + std::to_string(i))); + } + + Expr *toCompare = inExpr.arguments[0]; + auto valueToCompare = VisitExpr(*toCompare); + builder->CreateStore(valueToCompare->isNull, isNull); + if (!valueToCompare->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + + for (size_t i = 1; i < size; i++) { + if (AreInvalidDataTypes(toCompare->GetReturnTypeId(), inExpr.arguments[i]->GetReturnTypeId())) { + LogError("Arg 1 and arg %d have different data types", i + 1); + this->value = CreateInvalidCodeGenValue(); + return; + } + + if (i == 1) { + builder->CreateBr(condBlockList[i]); + } + auto elseBranch = falseBlock; + if (i < size - 1) { + elseBranch = condBlockList[i + 1]; + } + + func->getBasicBlockList().push_back(condBlockList[i]); + builder->SetInsertPoint(condBlockList[i]); + + Value *tmpCmpData = llvmTypes->CreateConstantBool(false); + Value *tmpCmpNull = llvmTypes->CreateConstantBool(false); + + argiValue = VisitExpr(*(inExpr.arguments[i])); + if (!argiValue->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + + switch (inExpr.arguments[0]->GetReturnTypeId()) { + case OMNI_INT: + case OMNI_DATE32: + case OMNI_TIMESTAMP: + case OMNI_LONG: { + InExprIntegerHelper(valueToCompare, argiValue, tmpCmpData, tmpCmpNull); + break; + } + case OMNI_DECIMAL64: { + DecimalValue &left = static_cast(*valueToCompare); + DecimalValue &right = static_cast(*argiValue); + if (left.GetScale() == right.GetScale()) { + InExprIntegerHelper(valueToCompare, argiValue, tmpCmpData, tmpCmpNull); + } else { + InExprDecimal64Helper(valueToCompare, argiValue, tmpCmpData, tmpCmpNull, retType); + } + break; + } + case OMNI_DOUBLE: { + InExprDoubleHelper(valueToCompare, argiValue, tmpCmpData, tmpCmpNull); + break; + } + case OMNI_CHAR: + case OMNI_VARCHAR: { + InExprStringHelper(valueToCompare, argiValue, tmpCmpData, tmpCmpNull); + break; + } + case OMNI_DECIMAL128: { + InExprDecimal128Helper(valueToCompare, argiValue, tmpCmpData, tmpCmpNull, retType); + break; + } + default: { + LogWarn("Unsupported data type in IN expr %d", inExpr.arguments[0]->GetReturnTypeId()); + this->value = CreateInvalidCodeGenValue(); + return; + } + } + builder->CreateStore(builder->CreateOr(argiValue->isNull, builder->CreateLoad(llvmTypes->I1Type(), hasnull)), + hasnull); + builder->CreateCondBr(builder->CreateAnd(builder->CreateNot(tmpCmpNull), tmpCmpData), trueBlock, elseBranch); + } + + func->getBasicBlockList().push_back(trueBlock); + builder->SetInsertPoint(trueBlock); + builder->CreateStore(llvmTypes->CreateConstantBool(true), inArray); + builder->CreateBr(mergeBlock); + + func->getBasicBlockList().push_back(falseBlock); + builder->SetInsertPoint(falseBlock); + builder->CreateStore(builder->CreateOr(builder->CreateLoad(llvmTypes->I1Type(), hasnull), + builder->CreateLoad(llvmTypes->I1Type(), isNull)), isNull); + builder->CreateBr(mergeBlock); + + func->getBasicBlockList().push_back(mergeBlock); + builder->SetInsertPoint(mergeBlock); + this->value = std::make_shared(builder->CreateLoad(llvmTypes->I1Type(), inArray), + builder->CreateLoad(llvmTypes->I1Type(), isNull)); +} + +void ExpressionCodeGen::Visit(const BetweenExpr &btExpr) +{ + auto bExpr = const_cast(&btExpr); + + auto val = VisitExpr(*(bExpr->value)); + if (!val->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + + DataTypeId valueTypeId = bExpr->value->GetReturnTypeId(); + if (AreInvalidDataTypes(valueTypeId, bExpr->lowerBound->GetReturnTypeId()) && + AreInvalidDataTypes(valueTypeId, bExpr->upperBound->GetReturnTypeId())) { + LogError("Value, lower bound, and upper bound must have the same type"); + this->value = CreateInvalidCodeGenValue(); + return; + } + + auto valNull = val->isNull; + auto lowerVal = VisitExpr(*(bExpr->lowerBound)); + if (!lowerVal->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + + auto lowerValNull = lowerVal->isNull; + auto upperVal = VisitExpr(*(bExpr->upperBound)); + if (!upperVal->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + + auto upperValNull = upperVal->isNull; + auto isAnyNull = builder->CreateOr(builder->CreateOr(valNull, lowerValNull), upperValNull); + auto isNeitherNull = builder->CreateNot(isAnyNull); + Value *cmpLeft, *cmpRight; + std::pair cmpPair = std::make_pair(&cmpLeft, &cmpRight); + bool supportedType = VisitBetweenExprHelper(*bExpr, val, lowerVal, upperVal, cmpPair); + if (supportedType) { + std::vector andValues; + andValues.push_back(isNeitherNull); + andValues.push_back(cmpLeft); + andValues.push_back(cmpRight); + Value *result = builder->CreateAnd(andValues); + this->value = make_shared(result, isAnyNull); + return; + } + + LogError("Unsupported data type for between %d", valueTypeId); + this->value = CreateInvalidCodeGenValue(); +} + +void ExpressionCodeGen::Visit(const CoalesceExpr &cExpr) +{ + Expr *value1Expr = cExpr.value1; + Expr *value2Expr = cExpr.value2; + CodeGenValuePtr value1 = VisitExpr(*value1Expr); + if (!value1->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + + BasicBlock *isNullBlock = BasicBlock::Create(*context, "coalesceVal1IsNull", func); + BasicBlock *isNotNullBlock = BasicBlock::Create(*context, "coalesceVal1IsNotNull"); + BasicBlock *mergeBlock = BasicBlock::Create(*context, "coalesceCont"); + + // If cond evaluates to true, control flow goes to trueBlock, returning evTrue + // Otherwise goes to falseBlock and returns evFalse + builder->CreateCondBr(value1->isNull, isNullBlock, isNotNullBlock); + + builder->SetInsertPoint(isNullBlock); + auto value2 = VisitExpr(*value2Expr); + if (!value2->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + + builder->CreateBr(mergeBlock); + // Codegen of 'true' can change the current block, update trueBlock for the PHI. + isNullBlock = builder->GetInsertBlock(); + + func->getBasicBlockList().push_back(isNotNullBlock); + builder->SetInsertPoint(isNotNullBlock); + + builder->CreateBr(mergeBlock); + // Codegen of 'false' can change the current block, update falseBlock for the PHI. + isNotNullBlock = builder->GetInsertBlock(); + int32_t numReservedValues = 2; + + // Emit merge block. + Type *phiType = llvmTypes->VectorToLLVMType(*(cExpr.GetReturnType())); + func->getBasicBlockList().push_back(mergeBlock); + builder->SetInsertPoint(mergeBlock); + PHINode *pn = builder->CreatePHI(phiType, numReservedValues, "iftmp"); + PHINode *pnNull = builder->CreatePHI(value1->isNull->getType(), numReservedValues, "iftmp"); + + pn->addIncoming(value1->data, isNotNullBlock); + pn->addIncoming(value2->data, isNullBlock); + pnNull->addIncoming(value1->isNull, isNotNullBlock); + pnNull->addIncoming(value2->isNull, isNullBlock); + + PHINode *lengthPhi = nullptr; + if (TypeUtil::IsStringType(cExpr.GetReturnTypeId())) { + lengthPhi = builder->CreatePHI(Type::getInt32Ty(*context), numReservedValues, "length"); + lengthPhi->addIncoming(value1->length, isNotNullBlock); + lengthPhi->addIncoming(value2->length, isNullBlock); + } + + if (TypeUtil::IsDecimalType(cExpr.GetReturnTypeId())) { + CoalesceExprDecimalHelper(*value1.get(), *value2.get(), *isNotNullBlock, *isNullBlock, *pn, *pnNull); + return; + } + + this->value = make_shared(pn, pnNull, lengthPhi); +} + +void ExpressionCodeGen::Visit(const IsNullExpr &isNullExpr) +{ + Expr *valueExpr = isNullExpr.value; + auto value = VisitExpr(*valueExpr); + if (!value->IsValidValue()) { + this->value = CreateInvalidCodeGenValue(); + return; + } + Value *isNullValue = value->isNull; + + Value *result = builder->CreateICmpEQ(isNullValue, llvmTypes->CreateConstantBool(true), "isNullCompare"); + this->value = make_shared(result, llvmTypes->CreateConstantBool(false)); +} + +template +std::vector ExpressionCodeGen::GetDefaultFunctionArgValues( + const FuncExpr &fExpr, Value **isAnyNull, bool &isInvalidExpr) +{ + std::vector argVals; + CodeGenValuePtr resultPtr; + auto numArgs = fExpr.arguments.size(); + if (fExpr.function->IsExecutionContextSet()) { + argVals.push_back(this->codegenContext->executionContext); + } + for (size_t i = 0; i < numArgs; i++) { + Expr *argN = fExpr.arguments[i]; + resultPtr = VisitExpr(*argN); + if (!resultPtr->IsValidValue()) { + isInvalidExpr = true; + return argVals; + } + argVals.push_back(resultPtr->data); + if constexpr (isNeedVerifyResult) { + *isAnyNull = builder->CreateOr(*isAnyNull, resultPtr->isNull); + } + if ((TypeUtil::IsStringType(fExpr.arguments[i]->GetReturnTypeId()))) { + if (fExpr.arguments[i]->GetReturnTypeId() == OMNI_CHAR) { + argVals.push_back(llvmTypes->CreateConstantInt( + dynamic_cast(fExpr.arguments[i]->GetReturnType().get())->GetWidth())); + } + argVals.push_back(this->value->length); + if (FuncExpr::IsCastStrStr(fExpr)) { + argVals.push_back(llvmTypes->CreateConstantInt( + dynamic_cast(fExpr.arguments[i]->GetReturnType().get())->GetWidth())); + } + } + if (TypeUtil::IsDecimalType(argN->GetReturnTypeId())) { + argVals.push_back(llvmTypes->CreateConstantInt( + dynamic_cast(fExpr.arguments[i]->GetReturnType().get())->GetPrecision())); + argVals.push_back(llvmTypes->CreateConstantInt( + dynamic_cast(fExpr.arguments[i]->GetReturnType().get())->GetScale())); + } + if constexpr (isNeedVerifyVal) { + argVals.push_back(this->value->isNull); + } + } + return argVals; +} + +inline std::vector ExpressionCodeGen::GetDataArgs( + const omniruntime::expressions::FuncExpr &fExpr, + llvm::Value **isAnyNull, + bool &isInvalidExpr) +{ + return GetDefaultFunctionArgValues(fExpr, isAnyNull, isInvalidExpr); +} + +inline std::vector ExpressionCodeGen::GetDataAndNullArgs( + const omniruntime::expressions::FuncExpr &fExpr, + llvm::Value **isAnyNull, + bool &isInvalidExpr) +{ + return GetDefaultFunctionArgValues(fExpr, isAnyNull, isInvalidExpr); +} + +inline std::vector ExpressionCodeGen::GetDataAndNullArgsAndReturnNull( + const omniruntime::expressions::FuncExpr &fExpr, + llvm::Value **isAnyNull, + bool &isInvalidExpr) +{ + return GetDefaultFunctionArgValues(fExpr, isAnyNull, isInvalidExpr); +} + +std::vector ExpressionCodeGen::GetFunctionArgValues(const omniruntime::expressions::FuncExpr &fExpr, + llvm::Value **isAnyNull, bool &isInvalidExpr) +{ + switch (fExpr.function->GetNullableResultType()) { + case INPUT_DATA: + return GetDataArgs(fExpr, isAnyNull, isInvalidExpr); + case INPUT_DATA_AND_NULL: + return GetDataAndNullArgs(fExpr, isAnyNull, isInvalidExpr); + case INPUT_DATA_AND_NULL_AND_RETURN_NULL: + return GetDataAndNullArgsAndReturnNull(fExpr, isAnyNull, isInvalidExpr); + default: + return GetDataArgs(fExpr, isAnyNull, isInvalidExpr); + } +} + +Value *ExpressionCodeGen::CreateHiveUdfArgTypes(const FuncExpr &fExpr) +{ + auto elementSize = static_cast(fExpr.arguments.size()); + auto alloca = builder->CreateAlloca(llvmTypes->I32Type(), llvmTypes->CreateConstantInt(elementSize)); + for (int32_t i = 0; i < elementSize; i++) { + auto ptr = builder->CreateGEP(llvmTypes->I32Type(), alloca, llvmTypes->CreateConstantInt(i)); + builder->CreateStore(llvmTypes->CreateConstantInt(fExpr.arguments[i]->GetReturnTypeId()), ptr); + } + return alloca; +} + +static bool GetValueOffsets(const FuncExpr &fExpr, std::vector &valueOffsets) +{ + int32_t valueSize = 0; + for (auto argExpr : fExpr.arguments) { + valueOffsets.emplace_back(valueSize); + + auto argReturnType = argExpr->GetReturnTypeId(); + switch (argReturnType) { + case OMNI_INT: + case OMNI_DATE32: + valueSize += sizeof(int32_t); + break; + case OMNI_LONG: + case OMNI_TIMESTAMP: + case OMNI_DECIMAL64: + valueSize += sizeof(int64_t); + break; + case OMNI_DOUBLE: + valueSize += sizeof(double); + break; + case OMNI_BOOLEAN: + valueSize += sizeof(bool); + break; + case OMNI_SHORT: + valueSize += sizeof(int16_t); + break; + case OMNI_DECIMAL128: + valueSize += 2 * sizeof(int64_t); + break; + case OMNI_VARCHAR: + case OMNI_CHAR: + valueSize += sizeof(uint8_t *); + break; + default: + LogWarn("Unsupported data type in Data Expr %d", argReturnType); + return false; + } + } + valueOffsets.emplace_back(valueSize); + return true; +} + +std::vector ExpressionCodeGen::GetHiveUdfArgValues(const FuncExpr &fExpr, bool &isInvalid) +{ + std::vector argVals; + std::vector valueOffsets; + if (!GetValueOffsets(fExpr, valueOffsets)) { + isInvalid = true; + return argVals; + } + + // Create array for value, null and length of all arguments + auto argSize = static_cast(fExpr.arguments.size()); + auto valueArray = builder->CreateAlloca(llvmTypes->I8Type(), llvmTypes->CreateConstantInt(valueOffsets[argSize])); + auto nullArray = builder->CreateAlloca(llvmTypes->I8Type(), llvmTypes->CreateConstantInt(argSize)); + auto lengthArray = builder->CreateAlloca(llvmTypes->I32Type(), llvmTypes->CreateConstantInt(argSize)); + + for (int32_t i = 0; i < argSize; i++) { + auto argExpr = fExpr.arguments[i]; + auto argExprResult = VisitExpr(*argExpr); + if (!argExprResult->IsValidValue()) { + isInvalid = true; + return argVals; + } + + // Get pointer for value, null and length + auto valuePtr = + builder->CreateGEP(llvmTypes->I8Type(), valueArray, llvmTypes->CreateConstantInt(valueOffsets[i])); + auto nullPtr = builder->CreateGEP(llvmTypes->I8Type(), nullArray, llvmTypes->CreateConstantInt(i)); + auto lengthPtr = builder->CreateGEP(llvmTypes->I32Type(), lengthArray, llvmTypes->CreateConstantInt(i)); + + builder->CreateStore(argExprResult->data, valuePtr); + builder->CreateStore(argExprResult->isNull, nullPtr); + if (TypeUtil::IsStringType(argExpr->GetReturnTypeId())) { + builder->CreateStore(argExprResult->length, lengthPtr); + } else { + builder->CreateStore(llvmTypes->CreateConstantInt(0), lengthPtr); + } + } + + argVals.emplace_back(valueArray); + argVals.emplace_back(nullArray); + argVals.emplace_back(lengthArray); + + return argVals; +} + +void ExpressionCodeGen::CallHiveUdfFunction(const FuncExpr &fExpr) +{ + std::vector argVals; + argVals.emplace_back(this->codegenContext->executionContext); + argVals.emplace_back(CreateConstantString(fExpr.funcName)); // for udf class name + argVals.emplace_back(CreateHiveUdfArgTypes(fExpr)); // for inputTypes + argVals.emplace_back(llvmTypes->CreateConstantInt(fExpr.GetReturnTypeId())); // for ret type + argVals.emplace_back(llvmTypes->CreateConstantInt(fExpr.arguments.size())); // for vec count + + bool isInvalidExpr = false; + auto inputArgs = GetHiveUdfArgValues(fExpr, isInvalidExpr); + if (isInvalidExpr) { + this->value = CreateInvalidCodeGenValue(); + return; + } + argVals.insert(argVals.end(), inputArgs.begin(), + inputArgs.end()); // for inputValues, inputNulls, inputLength + + Value *outputValuePtr; + Value *outputLenPtr; + Type *ty = llvmTypes->ToLLVMType(fExpr.GetReturnTypeId()); + + if (TypeUtil::IsStringType(fExpr.GetReturnTypeId())) { + auto valueSize = llvmTypes->CreateConstantInt(200); + std::vector paramsVec = { OMNI_LONG, OMNI_INT }; + outputValuePtr = CallExternFunction("ArenaAllocatorMalloc", paramsVec, OMNI_CHAR, + { this->codegenContext->executionContext, valueSize }, nullptr); + outputLenPtr = builder->CreateAlloca(Type::getInt32Ty(*context), nullptr, "outputLength"); + builder->CreateStore(llvmTypes->CreateConstantInt(0), outputLenPtr); + } else { + outputValuePtr = builder->CreateAlloca(ty, nullptr, "outputValue"); + outputLenPtr = llvmTypes->CreateConstantLong(0); + } + argVals.emplace_back(outputValuePtr); + auto outputNullPtr = builder->CreateAlloca(Type::getInt8Ty(*context), nullptr, "outputNull"); + argVals.emplace_back(outputNullPtr); + argVals.emplace_back(outputLenPtr); + + auto signature = FunctionSignature("EvaluateHiveUdfSingle", std::vector {}, OMNI_INT); + auto function = FunctionRegistry::LookupFunction(&signature); + auto f = modulePtr->getFunction(function->GetId()); + if (f) { + auto ret = CreateCall(f, argVals, "call_evaluate_hive_udf"); + InlineFunctionInfo inlineFunctionInfo; + llvm::InlineFunction(*((CallInst *)ret), inlineFunctionInfo); + Value *outputValue = outputValuePtr; + Value *outputLen = nullptr; + if (TypeUtil::IsStringType(fExpr.GetReturnTypeId())) { + outputLen = builder->CreateLoad(llvmTypes->I32Type(), outputLenPtr); + } else { + outputValue = builder->CreateLoad(ty, outputValuePtr); + } + auto outputNull = builder->CreateLoad(llvmTypes->I1Type(), outputNullPtr); + this->value = make_shared(outputValue, outputNull, outputLen); + } else { + LogWarn("Unable to generate udf function : %s", fExpr.funcName.c_str()); + this->value = CreateInvalidCodeGenValue(); + } +} + +// Handles all functions +void ExpressionCodeGen::Visit(const FuncExpr &fExpr) +{ + if (fExpr.functionType == HIVE_UDF) { + CallHiveUdfFunction(fExpr); + return; + } + + if (this->overflowConfig != nullptr && + this->overflowConfig->GetOverflowConfigId() == omniruntime::op::OVERFLOW_CONFIG_NULL) { + auto signature = fExpr.function->GetSignatures()[0]; + if (FunctionRegistry::LookupNullFunction(&signature)) { + FuncExprOverflowNullHelper(fExpr); + return; + } + } + Value *isAnyNull = llvmTypes->CreateConstantBool(false); + auto res = std::find_if(fExpr.arguments.begin(), fExpr.arguments.end(), + [](Expr *exp) { return exp->GetReturnTypeId() == OMNI_DECIMAL128; }); + bool isDecimalFunction = res != fExpr.arguments.end(); + DataTypeId funcRetType = fExpr.GetReturnTypeId(); + bool isInvalidExpr = false; + + auto argVals = GetFunctionArgValues(fExpr, &isAnyNull, isInvalidExpr); + if (isInvalidExpr) { + this->value = CreateInvalidCodeGenValue(); + return; + } + Value *isNull = PushAndGetNullFlag(fExpr, argVals, isAnyNull, true); + Value *ret = nullptr; + Value *outputLen = nullptr; + AllocaInst *outputLenPtr = nullptr; + // Call Decimal IR Generator for decimal functions + if (TypeUtil::IsDecimalType(funcRetType)) { + argVals.push_back( + llvmTypes->CreateConstantInt(dynamic_cast(fExpr.GetReturnType().get())->GetPrecision())); + argVals.push_back( + llvmTypes->CreateConstantInt(dynamic_cast(fExpr.GetReturnType().get())->GetScale())); + auto outputValuePtr = BuildDecimalValue(nullptr, *(fExpr.GetReturnType())); + ret = CallDecimalFunction(fExpr.function->GetId(), llvmTypes->ToLLVMType(funcRetType), argVals); + outputValuePtr->data = ret; + outputValuePtr->isNull = LoadNullFlag(fExpr, isNull); + outputValuePtr->length = outputLen; + this->value = std::move(outputValuePtr); + return; + } else { + if (TypeUtil::IsStringType(funcRetType)) { + if (FuncExpr::IsCastStrStr(fExpr)) { + argVals.push_back(llvmTypes->CreateConstantInt( + dynamic_cast(fExpr.GetReturnType().get())->GetWidth())); + } + outputLenPtr = builder->CreateAlloca(Type::getInt32Ty(*context), nullptr, "output_len"); + builder->CreateStore(llvmTypes->CreateConstantInt(0), outputLenPtr); + argVals.push_back(outputLenPtr); + } + + auto f = modulePtr->getFunction(fExpr.function->GetId()); + if (f) { + ret = isDecimalFunction ? + CallDecimalFunction(fExpr.function->GetId(), llvmTypes->ToLLVMType(funcRetType), argVals) : + CreateCall(f, argVals, fExpr.function->GetId()); + InlineFunctionInfo inlineFunctionInfo; + llvm::InlineFunction(*((CallInst *)ret), inlineFunctionInfo); + outputLen = (outputLenPtr == nullptr) ? nullptr : builder->CreateLoad(llvmTypes->I32Type(), outputLenPtr); + } else { + LogWarn("Unable to generate function : %s", fExpr.funcName.c_str()); + this->value = make_shared(nullptr, nullptr, nullptr); + return; + } + } + this->value = std::make_shared(ret, LoadNullFlag(fExpr, isNull), outputLen); +} + +static std::string ChangeFuncNameToNull(const FuncExpr &fExpr) +{ + auto typeSize = static_cast(fExpr.arguments.size() + 1); + auto originalFuncName = fExpr.function->GetId(); + auto originalFuncChars = originalFuncName.c_str(); + int32_t separatorIdx = 0; + auto pos = static_cast(originalFuncName.length() - 1); + for (; pos >= 0; pos--) { + if (originalFuncChars[pos] == '_') { + separatorIdx++; + if (separatorIdx == typeSize) { + break; + } + } + } + return originalFuncName.insert(pos, "_null"); +} + +std::vector ExpressionCodeGen::GetDataAndOverflowNullArgs( + const omniruntime::expressions::FuncExpr &fExpr, llvm::Value **isAnyNull, bool &isInvalidExpr, + llvm::Value *overflowNull) +{ + std::vector argVals; + auto signature = fExpr.function->GetSignatures()[0]; + if (FunctionRegistry::IsNullExecutionContextSet(&signature)) { + argVals.push_back(this->codegenContext->executionContext); + } + argVals.push_back(overflowNull); + CodeGenValuePtr resultPtr; + auto numArgs = fExpr.arguments.size(); + + for (size_t i = 0; i < numArgs; i++) { + Expr *argN = fExpr.arguments[i]; + resultPtr = VisitExpr(*argN); + if (!resultPtr->IsValidValue()) { + isInvalidExpr = true; + return argVals; + } + argVals.push_back(resultPtr->data); + *isAnyNull = builder->CreateOr(*isAnyNull, resultPtr->isNull); + if ((TypeUtil::IsStringType(fExpr.arguments[i]->GetReturnTypeId()))) { + if (fExpr.arguments[i]->GetReturnTypeId() == OMNI_CHAR) { + argVals.push_back(llvmTypes->CreateConstantInt( + dynamic_cast(fExpr.arguments[i]->GetReturnType().get())->GetWidth())); + } + argVals.push_back(this->value->length); + if (FuncExpr::IsCastStrStr(fExpr)) { + argVals.push_back(llvmTypes->CreateConstantInt( + dynamic_cast(fExpr.arguments[i]->GetReturnType().get())->GetWidth())); + } + } + if (TypeUtil::IsDecimalType(argN->GetReturnTypeId())) { + argVals.push_back(llvmTypes->CreateConstantInt( + dynamic_cast(fExpr.arguments[i]->GetReturnType().get())->GetPrecision())); + argVals.push_back(llvmTypes->CreateConstantInt( + dynamic_cast(fExpr.arguments[i]->GetReturnType().get())->GetScale())); + } + if (fExpr.function->GetNullableResultType() == INPUT_DATA_AND_NULL_AND_RETURN_NULL) { + argVals.push_back(this->value->isNull); + } + } + return argVals; +} + +void ExpressionCodeGen::FuncExprOverflowNullHelper(const FuncExpr &fExpr) +{ + Value *isAnyNull = llvmTypes->CreateConstantBool(false); + auto res = std::find_if(fExpr.arguments.begin(), fExpr.arguments.end(), + [](Expr *exp) { return exp->GetReturnTypeId() == OMNI_DECIMAL128; }); + bool isDecimalFunction = res != fExpr.arguments.end(); + DataTypeId funcRetType = fExpr.GetReturnTypeId(); + bool isInvalidExpr = false; + + AllocaInst *overflowNull = builder->CreateAlloca(Type::getInt1Ty(*context), nullptr, "overflow_null"); + builder->CreateStore(ConstantInt::get(IntegerType::getInt1Ty(*context), 0), overflowNull); + auto argVals = GetDataAndOverflowNullArgs(fExpr, &isAnyNull, isInvalidExpr, overflowNull); + if (isInvalidExpr) { + this->value = CreateInvalidCodeGenValue(); + return; + } + auto isNull = PushAndGetNullFlag(fExpr, argVals, isAnyNull, false); + Value *ret = nullptr; + Value *outputLen = nullptr; + AllocaInst *outputLenPtr = nullptr; + std::string functionName = ChangeFuncNameToNull(fExpr); + + // Call Decimal IR Generator for decimal functions + if (TypeUtil::IsDecimalType(funcRetType)) { + argVals.push_back( + llvmTypes->CreateConstantInt(dynamic_cast(fExpr.GetReturnType().get())->GetPrecision())); + argVals.push_back( + llvmTypes->CreateConstantInt(dynamic_cast(fExpr.GetReturnType().get())->GetScale())); + + auto outputValuePtr = BuildDecimalValue(nullptr, *(fExpr.GetReturnType())); + ret = CallDecimalFunction(functionName, llvmTypes->ToLLVMType(funcRetType), argVals); + outputValuePtr->data = ret; + outputValuePtr->isNull = + builder->CreateOr(LoadNullFlag(fExpr, isNull), builder->CreateLoad(llvmTypes->I1Type(), overflowNull)); + outputValuePtr->length = outputLen; + this->value = std::move(outputValuePtr); + return; + } else { + if (TypeUtil::IsStringType(funcRetType)) { + if (FuncExpr::IsCastStrStr(fExpr)) { + argVals.push_back(llvmTypes->CreateConstantInt( + dynamic_cast(fExpr.GetReturnType().get())->GetWidth())); + } + outputLenPtr = builder->CreateAlloca(Type::getInt32Ty(*context), nullptr, "output_len"); + builder->CreateStore(llvmTypes->CreateConstantInt(0), outputLenPtr); + argVals.push_back(outputLenPtr); + } + auto f = modulePtr->getFunction(functionName); + if (f) { + ret = isDecimalFunction ? CallDecimalFunction(functionName, llvmTypes->ToLLVMType(funcRetType), argVals) : + CreateCall(f, argVals, functionName); + InlineFunctionInfo inlineFunctionInfo; + llvm::InlineFunction(*((CallInst *)ret), inlineFunctionInfo); + outputLen = (outputLenPtr == nullptr) ? nullptr : builder->CreateLoad(llvmTypes->I32Type(), outputLenPtr); + Value *finalNull = + builder->CreateOr(LoadNullFlag(fExpr, isNull), builder->CreateLoad(llvmTypes->I1Type(), overflowNull)); + this->value = std::make_shared(ret, finalNull, outputLen); + return; + } else { + LogError("Unable to generate function : %s", fExpr.funcName.c_str()); + this->value = std::make_shared(nullptr, nullptr, nullptr); + return; + } + } +} + +void ExpressionCodeGen::ExtractVectorIndexes() +{ + ExprInfoExtractor exprInfoExtractor; + this->expr->Accept(exprInfoExtractor); + this->vectorIndexes = exprInfoExtractor.GetVectorIndexes(); +} + +Value *ExpressionCodeGen::StringEqual(Value *lhs, Value *lLen, Value *rhs, Value *rLen, Value *isNull) +{ + BasicBlock *lenEqualBlock; + BasicBlock *lenNotEqualBlock; + BasicBlock *mergeBlock; + Value *lenCond = builder->CreateAnd(builder->CreateICmpEQ(lLen, rLen), builder->CreateNot(isNull)); + lenEqualBlock = BasicBlock::Create(*context, "lenEqualBlock", builder->GetInsertBlock()->getParent()); + lenNotEqualBlock = BasicBlock::Create(*context, "lenNotEqualBlock", builder->GetInsertBlock()->getParent()); + mergeBlock = BasicBlock::Create(*context, "ifcont", builder->GetInsertBlock()->getParent()); + builder->CreateCondBr(lenCond, lenEqualBlock, lenNotEqualBlock); + builder->SetInsertPoint(lenEqualBlock); + + std::vector argVals { lhs, lLen, rhs, rLen }; + auto signature = + FunctionSignature(strEqualStr, std::vector { OMNI_VARCHAR, OMNI_VARCHAR }, OMNI_BOOLEAN); + auto f = modulePtr->getFunction(FunctionRegistry::LookupFunction(&signature)->GetId()); + auto ret = CreateCall(f, argVals, "call_str_eq"); + InlineFunctionInfo inlineFunctionInfo; + llvm::InlineFunction(*ret, inlineFunctionInfo); + + builder->CreateBr(mergeBlock); + + builder->SetInsertPoint(lenNotEqualBlock); + builder->CreateBr(mergeBlock); + + builder->SetInsertPoint(mergeBlock); + + PHINode *phiValue = builder->CreatePHI(llvmTypes->I1Type(), 2, "ifequal"); + phiValue->addIncoming(ret, lenEqualBlock); + phiValue->addIncoming(lenCond, lenNotEqualBlock); + return phiValue; +} + +// Other operations which require externed functions +Value *ExpressionCodeGen::StringCmp(Value *lhs, Value *lLen, Value *rhs, Value *rLen) +{ + // call function + std::vector argVals { lhs, lLen, rhs, rLen }; + auto signature = FunctionSignature(strCompareStr, std::vector { OMNI_VARCHAR, OMNI_VARCHAR }, OMNI_INT); + auto f = modulePtr->getFunction(FunctionRegistry::LookupFunction(&signature)->GetId()); + auto ret = CreateCall(f, argVals, "call_str_cmp"); + InlineFunctionInfo inlineFunctionInfo; + llvm::InlineFunction(*ret, inlineFunctionInfo); + return ret; +} + +void ExpressionCodeGen::BinaryExprNullHelper(const BinaryExpr *binaryExpr, Value *left, Value *right, Value *leftIsNull, + Value *rightIsNull, PHINode **leftPhi, PHINode **rightPhi) +{ + BasicBlock *incomingBlock; + BasicBlock *nullBlock; + BasicBlock *nextInst; + Value *nullCond; + Value *leftZero; + Value *rightOne; + auto op = binaryExpr->op; + + if (op == omniruntime::expressions::Operator::ADD || op == omniruntime::expressions::Operator::SUB || + op == omniruntime::expressions::Operator::MUL || op == omniruntime::expressions::Operator::DIV || + op == omniruntime::expressions::Operator::MOD || op == omniruntime::expressions::Operator::TRY_ADD || + op == omniruntime::expressions::Operator::TRY_SUB || op == omniruntime::expressions::Operator::TRY_MUL || + op == omniruntime::expressions::Operator::TRY_DIV) { + incomingBlock = builder->GetInsertBlock(); + nullBlock = BasicBlock::Create(*context, "nullBlock", builder->GetInsertBlock()->getParent()); + nextInst = BasicBlock::Create(*context, "nextInst", builder->GetInsertBlock()->getParent()); + nullCond = builder->CreateOr(leftIsNull, rightIsNull); + builder->CreateCondBr(nullCond, nullBlock, nextInst); + builder->SetInsertPoint(nullBlock); + switch (binaryExpr->left->GetReturnType()->GetId()) { + case OMNI_INT: + case OMNI_DATE32: + leftZero = llvmTypes->CreateConstantInt(0); + break; + case OMNI_LONG: + case OMNI_DECIMAL64: + leftZero = llvmTypes->CreateConstantLong(0); + break; + case OMNI_DOUBLE: + leftZero = llvmTypes->CreateConstantDouble(0); + break; + case OMNI_DECIMAL128: + leftZero = llvmTypes->CreateConstant128(0); + break; + default: + // Unsupported data-types left as-is + leftZero = left; + break; + } + switch (binaryExpr->right->GetReturnType()->GetId()) { + case OMNI_INT: + case OMNI_DATE32: + rightOne = llvmTypes->CreateConstantInt(1); + break; + case OMNI_LONG: + case OMNI_DECIMAL64: + rightOne = llvmTypes->CreateConstantLong(1); + break; + case OMNI_DOUBLE: + rightOne = llvmTypes->CreateConstantDouble(1); + break; + case OMNI_DECIMAL128: + rightOne = llvmTypes->CreateConstant128(1); + break; + default: + // Unsupported data-types left as-is + rightOne = right; + break; + } + builder->CreateBr(nextInst); + builder->SetInsertPoint(nextInst); + int numberOfPaths = 2; + *leftPhi = builder->CreatePHI(left->getType(), numberOfPaths, "iftmp"); + *rightPhi = builder->CreatePHI(right->getType(), numberOfPaths, "iftmp"); + (*leftPhi)->addIncoming(leftZero, nullBlock); + (*leftPhi)->addIncoming(left, incomingBlock); + (*rightPhi)->addIncoming(rightOne, nullBlock); + (*rightPhi)->addIncoming(right, incomingBlock); + } +} + +// Helper methods to parse binary expressions +llvm::Value *ExpressionCodeGen::BinaryExprByteHelper(const BinaryExpr *binaryExpr, Value *left, Value *right, + Value *leftIsNull, Value *rightIsNull, Value *nullFlag) +{ + PHINode *leftPhi; + PHINode *rightPhi; + Value *isNeitherNull = builder->CreateNot(builder->CreateOr(leftIsNull, rightIsNull)); + std::vector byteParams = { OMNI_BYTE, OMNI_BYTE }; + BinaryExprNullHelper(binaryExpr, left, right, leftIsNull, rightIsNull, &leftPhi, &rightPhi); + switch (binaryExpr->op) { + case omniruntime::expressions::Operator::LT: + return builder->CreateAnd(isNeitherNull, + CallExternFunction("lessThan", byteParams, OMNI_BOOLEAN, { left, right }, nullptr, "relational_lt")); + case omniruntime::expressions::Operator::GT: + return builder->CreateAnd(isNeitherNull, + CallExternFunction("greaterThan", byteParams, OMNI_BOOLEAN, { left, right }, nullptr, "relational_gt")); + case omniruntime::expressions::Operator::LTE: + return builder->CreateAnd(isNeitherNull, CallExternFunction("lessThanEqual", byteParams, OMNI_BOOLEAN, + { left, right }, nullptr, "relational_le")); + case omniruntime::expressions::Operator::GTE: + return builder->CreateAnd(isNeitherNull, CallExternFunction("greaterThanEqual", byteParams, OMNI_BOOLEAN, + { left, right }, nullptr, "relational_ge")); + case omniruntime::expressions::Operator::EQ: + return builder->CreateAnd(isNeitherNull, + CallExternFunction("equal", byteParams, OMNI_BOOLEAN, { left, right }, nullptr, "relational_eq")); + case omniruntime::expressions::Operator::NEQ: + return builder->CreateAnd(isNeitherNull, + CallExternFunction("notEqual", byteParams, OMNI_BOOLEAN, { left, right }, nullptr, "relational_neq")); + case omniruntime::expressions::Operator::ADD: + return CallExternFunction("add", byteParams, OMNI_BYTE, { leftPhi, rightPhi }, nullptr, "arithmetic_add"); + case omniruntime::expressions::Operator::SUB: + return CallExternFunction("subtract", byteParams, OMNI_BYTE, { leftPhi, rightPhi }, nullptr, + "arithmetic_sub"); + case omniruntime::expressions::Operator::MUL: + return CallExternFunction("multiply", byteParams, OMNI_BYTE, { leftPhi, rightPhi }, nullptr, + "arithmetic_mul"); + case omniruntime::expressions::Operator::DIV: + return CallExternFunction("divide", byteParams, OMNI_BYTE, {nullFlag, leftPhi, rightPhi }, + nullptr, "arithmetic_div"); + case omniruntime::expressions::Operator::MOD: + return CallExternFunction("modulus", byteParams, OMNI_BYTE, {nullFlag, leftPhi, rightPhi }, + nullptr, "arithmetic_mod"); + case omniruntime::expressions::Operator::TRY_ADD: + return CallExternFunction("try_add", byteParams, OMNI_BYTE, {nullFlag, leftPhi, rightPhi }, + nullptr, "arithmetic_try_add"); + case omniruntime::expressions::Operator::TRY_SUB: + return CallExternFunction("try_subtract", byteParams, OMNI_BYTE, {nullFlag, leftPhi, rightPhi }, + nullptr, "arithmetic_try_sub"); + case omniruntime::expressions::Operator::TRY_MUL: + return CallExternFunction("try_multiply", byteParams, OMNI_BYTE, {nullFlag, leftPhi, rightPhi }, + nullptr, "arithmetic_try_mul"); + case omniruntime::expressions::Operator::TRY_DIV: + return CallExternFunction("divide", byteParams, OMNI_BYTE, {nullFlag, leftPhi, rightPhi }, + nullptr, "arithmetic_try_div"); + default: { + LogError("Unsupported byte binary operator %u", static_cast(binaryExpr->op)); + return nullptr; + } + } +} + +llvm::Value *ExpressionCodeGen::BinaryExprShortHelper(const BinaryExpr *binaryExpr, Value *left, Value *right, + Value *leftIsNull, Value *rightIsNull, Value *nullFlag) +{ + PHINode *leftPhi; + PHINode *rightPhi; + Value *isNeitherNull = builder->CreateNot(builder->CreateOr(leftIsNull, rightIsNull)); + std::vector shortParams = { OMNI_SHORT, OMNI_SHORT }; + BinaryExprNullHelper(binaryExpr, left, right, leftIsNull, rightIsNull, &leftPhi, &rightPhi); + switch (binaryExpr->op) { + case omniruntime::expressions::Operator::LT: + return builder->CreateAnd(isNeitherNull, + CallExternFunction("lessThan", shortParams, OMNI_BOOLEAN, { left, right }, nullptr, "relational_lt")); + case omniruntime::expressions::Operator::GT: + return builder->CreateAnd(isNeitherNull, + CallExternFunction("greaterThan", shortParams, OMNI_BOOLEAN, { left, right }, nullptr, "relational_gt")); + case omniruntime::expressions::Operator::LTE: + return builder->CreateAnd(isNeitherNull, CallExternFunction("lessThanEqual", shortParams, OMNI_BOOLEAN, + { left, right }, nullptr, "relational_le")); + case omniruntime::expressions::Operator::GTE: + return builder->CreateAnd(isNeitherNull, CallExternFunction("greaterThanEqual", shortParams, OMNI_BOOLEAN, + { left, right }, nullptr, "relational_ge")); + case omniruntime::expressions::Operator::EQ: + return builder->CreateAnd(isNeitherNull, + CallExternFunction("equal", shortParams, OMNI_BOOLEAN, { left, right }, nullptr, "relational_eq")); + case omniruntime::expressions::Operator::NEQ: + return builder->CreateAnd(isNeitherNull, + CallExternFunction("notEqual", shortParams, OMNI_BOOLEAN, { left, right }, nullptr, "relational_neq")); + case omniruntime::expressions::Operator::ADD: + return CallExternFunction("add", shortParams, OMNI_SHORT, { leftPhi, rightPhi }, nullptr, "arithmetic_add"); + case omniruntime::expressions::Operator::SUB: + return CallExternFunction("subtract", shortParams, OMNI_SHORT, { leftPhi, rightPhi }, nullptr, + "arithmetic_sub"); + case omniruntime::expressions::Operator::MUL: + return CallExternFunction("multiply", shortParams, OMNI_SHORT, { leftPhi, rightPhi }, nullptr, + "arithmetic_mul"); + case omniruntime::expressions::Operator::DIV: + return CallExternFunction("divide", shortParams, OMNI_SHORT, {nullFlag, leftPhi, rightPhi }, + nullptr, "arithmetic_div"); + case omniruntime::expressions::Operator::MOD: + return CallExternFunction("modulus", shortParams, OMNI_SHORT, {nullFlag, leftPhi, rightPhi }, + nullptr, "arithmetic_mod"); + case omniruntime::expressions::Operator::TRY_ADD: + return CallExternFunction("try_add", shortParams, OMNI_SHORT, {nullFlag, leftPhi, rightPhi }, + nullptr, "arithmetic_try_add"); + case omniruntime::expressions::Operator::TRY_SUB: + return CallExternFunction("try_subtract", shortParams, OMNI_SHORT, {nullFlag, leftPhi, rightPhi }, + nullptr, "arithmetic_try_sub"); + case omniruntime::expressions::Operator::TRY_MUL: + return CallExternFunction("try_multiply", shortParams, OMNI_SHORT, {nullFlag, leftPhi, rightPhi }, + nullptr, "arithmetic_try_mul"); + case omniruntime::expressions::Operator::TRY_DIV: + return CallExternFunction("divide", shortParams, OMNI_SHORT, {nullFlag, leftPhi, rightPhi }, + nullptr, "arithmetic_try_div"); + default: { + LogError("Unsupported short binary operator %u", static_cast(binaryExpr->op)); + return nullptr; + } + } +} + +llvm::Value *ExpressionCodeGen::BinaryExprIntHelper(const BinaryExpr *binaryExpr, Value *left, Value *right, + Value *leftIsNull, Value *rightIsNull, Value *nullFlag) +{ + PHINode *leftPhi; + PHINode *rightPhi; + Value *isNeitherNull = builder->CreateNot(builder->CreateOr(leftIsNull, rightIsNull)); + std::vector intParams = { OMNI_INT, OMNI_INT }; + BinaryExprNullHelper(binaryExpr, left, right, leftIsNull, rightIsNull, &leftPhi, &rightPhi); + switch (binaryExpr->op) { + case omniruntime::expressions::Operator::LT: + return builder->CreateAnd(isNeitherNull, + CallExternFunction("lessThan", intParams, OMNI_BOOLEAN, { left, right }, nullptr, "relational_lt")); + case omniruntime::expressions::Operator::GT: + return builder->CreateAnd(isNeitherNull, + CallExternFunction("greaterThan", intParams, OMNI_BOOLEAN, { left, right }, nullptr, "relational_gt")); + case omniruntime::expressions::Operator::LTE: + return builder->CreateAnd(isNeitherNull, CallExternFunction("lessThanEqual", intParams, OMNI_BOOLEAN, + { left, right }, nullptr, "relational_le")); + case omniruntime::expressions::Operator::GTE: + return builder->CreateAnd(isNeitherNull, CallExternFunction("greaterThanEqual", intParams, OMNI_BOOLEAN, + { left, right }, nullptr, "relational_ge")); + case omniruntime::expressions::Operator::EQ: + return builder->CreateAnd(isNeitherNull, + CallExternFunction("equal", intParams, OMNI_BOOLEAN, { left, right }, nullptr, "relational_eq")); + case omniruntime::expressions::Operator::NEQ: + return builder->CreateAnd(isNeitherNull, + CallExternFunction("notEqual", intParams, OMNI_BOOLEAN, { left, right }, nullptr, "relational_neq")); + case omniruntime::expressions::Operator::ADD: + return CallExternFunction("add", intParams, OMNI_INT, { leftPhi, rightPhi }, nullptr, "arithmetic_add"); + case omniruntime::expressions::Operator::SUB: + return CallExternFunction("subtract", intParams, OMNI_INT, { leftPhi, rightPhi }, nullptr, + "arithmetic_sub"); + case omniruntime::expressions::Operator::MUL: + return CallExternFunction("multiply", intParams, OMNI_INT, { leftPhi, rightPhi }, nullptr, + "arithmetic_mul"); + case omniruntime::expressions::Operator::DIV: + return CallExternFunction("divide", intParams, OMNI_INT, {nullFlag, leftPhi, rightPhi }, + nullptr, "arithmetic_div"); + case omniruntime::expressions::Operator::MOD: + return CallExternFunction("modulus", intParams, OMNI_INT, {nullFlag, leftPhi, rightPhi }, + nullptr, "arithmetic_mod"); + case omniruntime::expressions::Operator::TRY_ADD: + return CallExternFunction("try_add", intParams, OMNI_INT, {nullFlag, leftPhi, rightPhi }, + nullptr, "arithmetic_try_add"); + case omniruntime::expressions::Operator::TRY_SUB: + return CallExternFunction("try_subtract", intParams, OMNI_INT, {nullFlag, leftPhi, rightPhi }, + nullptr, "arithmetic_try_sub"); + case omniruntime::expressions::Operator::TRY_MUL: + return CallExternFunction("try_multiply", intParams, OMNI_INT, {nullFlag, leftPhi, rightPhi }, + nullptr, "arithmetic_try_mul"); + case omniruntime::expressions::Operator::TRY_DIV: + return CallExternFunction("divide", intParams, OMNI_INT, {nullFlag, leftPhi, rightPhi }, + nullptr, "arithmetic_try_div"); + default: { + LogError("Unsupported int binary operator %u", static_cast(binaryExpr->op)); + return nullptr; + } + } +} + +Value *ExpressionCodeGen::BinaryExprLongHelper(const BinaryExpr *binaryExpr, Value *left, Value *right, + Value *leftIsNull, Value *rightIsNull, Value *nullFlag) +{ + PHINode *leftPhi; + PHINode *rightPhi; + Value *isNeitherNull = builder->CreateNot(builder->CreateOr(leftIsNull, rightIsNull)); + std::vector longParams = { OMNI_LONG, OMNI_LONG }; + BinaryExprNullHelper(binaryExpr, left, right, leftIsNull, rightIsNull, &leftPhi, &rightPhi); + switch (binaryExpr->op) { + case omniruntime::expressions::Operator::LT: + return builder->CreateAnd(isNeitherNull, + CallExternFunction("lessThan", longParams, OMNI_BOOLEAN, { left, right }, nullptr, "lrelational_lt")); + case omniruntime::expressions::Operator::LTE: + return builder->CreateAnd(isNeitherNull, CallExternFunction("lessThanEqual", longParams, OMNI_BOOLEAN, + { left, right }, nullptr, "lrelational_le")); + case omniruntime::expressions::Operator::GT: + return builder->CreateAnd(isNeitherNull, CallExternFunction("greaterThan", longParams, OMNI_BOOLEAN, + { left, right }, nullptr, "lrelational_gt")); + case omniruntime::expressions::Operator::GTE: + return builder->CreateAnd(isNeitherNull, CallExternFunction("greaterThanEqual", longParams, OMNI_BOOLEAN, + { left, right }, nullptr, "lrelational_ge")); + case omniruntime::expressions::Operator::EQ: + return builder->CreateAnd(isNeitherNull, + CallExternFunction("equal", longParams, OMNI_BOOLEAN, { left, right }, nullptr, "larithmetic_eq")); + case omniruntime::expressions::Operator::NEQ: + return builder->CreateAnd(isNeitherNull, + CallExternFunction("notEqual", longParams, OMNI_BOOLEAN, { left, right }, nullptr, "larithmetic_neq")); + case omniruntime::expressions::Operator::ADD: + return CallExternFunction("add", longParams, OMNI_LONG, { leftPhi, rightPhi }, nullptr, "larithmetic_add"); + case omniruntime::expressions::Operator::SUB: + return CallExternFunction("subtract", longParams, OMNI_LONG, { leftPhi, rightPhi }, nullptr, + "larithmetic_sub"); + case omniruntime::expressions::Operator::MUL: + return CallExternFunction("multiply", longParams, OMNI_LONG, { leftPhi, rightPhi }, nullptr, + "larithmetic_mul"); + case omniruntime::expressions::Operator::DIV: + return CallExternFunction("divide", longParams, OMNI_LONG, {nullFlag, leftPhi, rightPhi }, + nullptr, "larithmetic_divide"); + case omniruntime::expressions::Operator::MOD: + return CallExternFunction("modulus", longParams, OMNI_LONG, {nullFlag, leftPhi, rightPhi }, + nullptr, "larithmetic_mod"); + case omniruntime::expressions::Operator::TRY_ADD: + return CallExternFunction("try_add", longParams, OMNI_LONG, {nullFlag, leftPhi, rightPhi }, + nullptr, "larithmetic_try_add"); + case omniruntime::expressions::Operator::TRY_SUB: + return CallExternFunction("try_subtract", longParams, OMNI_LONG, {nullFlag, leftPhi, rightPhi }, + nullptr, "larithmetic_try_sub"); + case omniruntime::expressions::Operator::TRY_MUL: + return CallExternFunction("try_multiply", longParams, OMNI_LONG, {nullFlag, leftPhi, rightPhi }, + nullptr, "larithmetic_try_mul"); + case omniruntime::expressions::Operator::TRY_DIV: + return CallExternFunction("divide", longParams, OMNI_LONG, {nullFlag, leftPhi, rightPhi }, + nullptr, "larithmetic_try_divide"); + default: { + LogWarn("Unsupported long binary operator %u", static_cast(binaryExpr->op)); + return nullptr; + } + } +} + +void ExpressionCodeGen::BinaryExprDecimal64Helper(const BinaryExpr *binaryExpr, DecimalValue &left, DecimalValue &right, + Value *leftIsNull, Value *rightIsNull) +{ + PHINode *leftPhi; + PHINode *rightPhi; + Value *isNeitherNull = builder->CreateNot(builder->CreateOr(leftIsNull, rightIsNull)); + Value *output = nullptr; + auto leftType = binaryExpr->left->GetReturnType(); + auto rightType = binaryExpr->right->GetReturnType(); + auto binaryReturnType = binaryExpr->GetReturnType(); + BinaryExprNullHelper(binaryExpr, left.data, right.data, leftIsNull, rightIsNull, &leftPhi, &rightPhi); + std::vector params { leftType->GetId(), rightType->GetId() }; + std::shared_ptr returnDecimalValue = BuildDecimalValue(nullptr, *binaryReturnType, nullptr); + std::vector argVals { leftPhi, + const_cast(left.GetPrecision()), + const_cast(left.GetScale()), + rightPhi, + const_cast(right.GetPrecision()), + const_cast(right.GetScale()), + const_cast(returnDecimalValue->GetPrecision()), + const_cast(returnDecimalValue->GetScale()) }; + std::vector argValsCmp { + left.data, const_cast(left.GetPrecision()), const_cast(left.GetScale()), + right.data, const_cast(right.GetPrecision()), const_cast(right.GetScale()) + }; + + llvm::Type *returnType = llvmTypes->ToLLVMType(binaryExpr->GetReturnTypeId()); + DataTypeId returnTypeId = binaryExpr->GetReturnTypeId(); + std::string decimal64CmpFuncId = FunctionSignature(decimal64CompareStr, params, OMNI_INT).ToString(); + AllocaInst *overflowNull = builder->CreateAlloca(Type::getInt1Ty(*context), nullptr, "overflow_null"); + builder->CreateStore(ConstantInt::get(IntegerType::getInt1Ty(*context), 0), overflowNull); + + bool isTryExpr = false; + switch (binaryExpr->op) { + case omniruntime::expressions::Operator::LT: + output = builder->CreateAnd(isNeitherNull, builder->CreateICmpSLT( + CallDecimalFunction(decimal64CmpFuncId, returnType, argValsCmp), llvmTypes->CreateConstantInt(0))); + break; + case omniruntime::expressions::Operator::GT: + output = builder->CreateAnd(isNeitherNull, builder->CreateICmpSGT( + CallDecimalFunction(decimal64CmpFuncId, returnType, argValsCmp), llvmTypes->CreateConstantInt(0))); + break; + case omniruntime::expressions::Operator::LTE: + output = builder->CreateAnd(isNeitherNull, builder->CreateICmpSLE( + CallDecimalFunction(decimal64CmpFuncId, returnType, argValsCmp), llvmTypes->CreateConstantInt(0))); + break; + case omniruntime::expressions::Operator::GTE: + output = builder->CreateAnd(isNeitherNull, builder->CreateICmpSGE( + CallDecimalFunction(decimal64CmpFuncId, returnType, argValsCmp), llvmTypes->CreateConstantInt(0))); + break; + case omniruntime::expressions::Operator::EQ: { + output = builder->CreateAnd(isNeitherNull, builder->CreateICmpEQ( + CallDecimalFunction(decimal64CmpFuncId, returnType, argValsCmp), llvmTypes->CreateConstantInt(0))); + break; + } + case omniruntime::expressions::Operator::NEQ: + output = builder->CreateAnd(isNeitherNull, builder->CreateICmpNE( + CallDecimalFunction(decimal64CmpFuncId, returnType, argValsCmp), llvmTypes->CreateConstantInt(0))); + break; + case omniruntime::expressions::Operator::ADD: { + std::string funcId = FunctionSignature(addDec64Str, params, returnTypeId).ToString(this->overflowConfig); + output = CallDecimalFunction(funcId, returnType, argVals, codegenContext->executionContext, + this->overflowConfig, overflowNull); + break; + } + case omniruntime::expressions::Operator::SUB: { + std::string funcId = FunctionSignature(subDec64Str, params, returnTypeId).ToString(this->overflowConfig); + output = CallDecimalFunction(funcId, returnType, argVals, codegenContext->executionContext, + this->overflowConfig, overflowNull); + break; + } + case omniruntime::expressions::Operator::MUL: { + std::string funcId = FunctionSignature(mulDec64Str, params, returnTypeId).ToString(this->overflowConfig); + output = CallDecimalFunction(funcId, returnType, argVals, codegenContext->executionContext, + this->overflowConfig, overflowNull); + break; + } + case omniruntime::expressions::Operator::DIV: { + std::string funcId = FunctionSignature(divDec64Str, params, returnTypeId).ToString(this->overflowConfig); + output = CallDecimalFunction(funcId, returnType, argVals, codegenContext->executionContext, + this->overflowConfig, overflowNull); + break; + } + case omniruntime::expressions::Operator::MOD: { + std::string funcId = FunctionSignature(modDec64Str, params, returnTypeId).ToString(this->overflowConfig); + output = CallDecimalFunction(funcId, returnType, argVals, codegenContext->executionContext, + this->overflowConfig, overflowNull); + break; + } + case omniruntime::expressions::Operator::TRY_ADD:{ + isTryExpr = true; + auto ptr = std::make_unique(omniruntime::op::OverflowConfigId::OVERFLOW_CONFIG_NULL); + std::string funcId = FunctionSignature(tryAddDecimal64FnStr, params, returnTypeId).ToString(); + output= CallDecimalFunction(funcId, returnType, argVals, codegenContext->executionContext, ptr.get(), overflowNull); + break; + } + case omniruntime::expressions::Operator::TRY_SUB: { + isTryExpr = true; + auto ptr = std::make_unique(omniruntime::op::OverflowConfigId::OVERFLOW_CONFIG_NULL); + std::string funcId = FunctionSignature(trySubDecimal64FnStr, params, returnTypeId).ToString(); + + output = CallDecimalFunction(funcId, returnType, argVals, codegenContext->executionContext, ptr.get(), overflowNull); + break; + } + case omniruntime::expressions::Operator::TRY_MUL: { + isTryExpr = true; + auto ptr = std::make_unique(omniruntime::op::OverflowConfigId::OVERFLOW_CONFIG_NULL); + std::string funcId = FunctionSignature(tryMulDecimal64FnStr, params, returnTypeId).ToString(); + output = CallDecimalFunction(funcId, returnType, argVals, codegenContext->executionContext, ptr.get(), overflowNull); + break; + } + case omniruntime::expressions::Operator::TRY_DIV: { + isTryExpr = true; + auto ptr = std::make_unique(omniruntime::op::OverflowConfigId::OVERFLOW_CONFIG_NULL); + std::string funcId = FunctionSignature(tryDivDecimal64FnStr, params, returnTypeId).ToString(); + output = CallDecimalFunction(funcId, returnType, argVals, codegenContext->executionContext, ptr.get(), overflowNull); + break; + } + default: { + LogWarn("Unsupported decimal64 binary operator %u", static_cast(binaryExpr->op)); + output = nullptr; + break; + } + } + CodeGenValuePtr valuePtr = nullptr; + if (TypeUtil::IsDecimalType(binaryExpr->GetReturnTypeId())) { + valuePtr = + BuildDecimalValue(output, *(binaryExpr->GetReturnType()), builder->CreateOr(leftIsNull, rightIsNull)); + } else { + valuePtr = std::make_shared(output, builder->CreateOr(leftIsNull, rightIsNull)); + } + + if (isTryExpr || (overflowConfig != nullptr && + overflowConfig->GetOverflowConfigId() == omniruntime::op::OVERFLOW_CONFIG_NULL)) { + valuePtr->isNull = builder->CreateOr(valuePtr->isNull, builder->CreateLoad(llvmTypes->I1Type(), overflowNull)); + } + this->value = valuePtr; +} + +Value *ExpressionCodeGen::BinaryExprDoubleHelper(const BinaryExpr *binaryExpr, Value *left, Value *right, + Value *leftIsNull, Value *rightIsNull, Value *nullFlag) +{ + PHINode *leftPhi; + PHINode *rightPhi; + Value *isNeitherNull = builder->CreateNot(builder->CreateOr(leftIsNull, rightIsNull)); + std::vector doubleParams = { OMNI_DOUBLE, OMNI_DOUBLE }; + BinaryExprNullHelper(binaryExpr, left, right, leftIsNull, rightIsNull, &leftPhi, &rightPhi); + switch (binaryExpr->op) { + case omniruntime::expressions::Operator::LT: + return builder->CreateAnd(isNeitherNull, + CallExternFunction("lessThan", doubleParams, OMNI_BOOLEAN, { left, right }, nullptr, "frelational_lt")); + case omniruntime::expressions::Operator::LTE: + return builder->CreateAnd(isNeitherNull, CallExternFunction("lessThanEqual", doubleParams, OMNI_BOOLEAN, + { left, right }, nullptr, "frelational_le")); + case omniruntime::expressions::Operator::GT: + return builder->CreateAnd(isNeitherNull, CallExternFunction("greaterThan", doubleParams, OMNI_BOOLEAN, + { left, right }, nullptr, "frelational_gt")); + case omniruntime::expressions::Operator::GTE: + return builder->CreateAnd(isNeitherNull, CallExternFunction("greaterThanEqual", doubleParams, OMNI_BOOLEAN, + { left, right }, nullptr, "frelational_ge")); + case omniruntime::expressions::Operator::EQ: + return builder->CreateAnd(isNeitherNull, + CallExternFunction("equal", doubleParams, OMNI_BOOLEAN, { left, right }, nullptr, "farithmetic_eq")); + case omniruntime::expressions::Operator::NEQ: + return builder->CreateAnd(isNeitherNull, CallExternFunction("notEqual", doubleParams, OMNI_BOOLEAN, + { left, right }, nullptr, "farithmetic_neq")); + case omniruntime::expressions::Operator::ADD: + return CallExternFunction("add", doubleParams, OMNI_DOUBLE, { leftPhi, rightPhi }, nullptr, + "farithmetic_add"); + case omniruntime::expressions::Operator::SUB: + return CallExternFunction("subtract", doubleParams, OMNI_DOUBLE, { leftPhi, rightPhi }, nullptr, + "farithmetic_sub"); + case omniruntime::expressions::Operator::MUL: + return CallExternFunction("multiply", doubleParams, OMNI_DOUBLE, { leftPhi, rightPhi }, nullptr, + "farithmetic_mul"); + case omniruntime::expressions::Operator::DIV: + return CallExternFunction("divide", doubleParams, OMNI_DOUBLE, { nullFlag, leftPhi, rightPhi }, nullptr, + "farithmetic_divide"); + case omniruntime::expressions::Operator::MOD: + return CallExternFunction("modulus", doubleParams, OMNI_DOUBLE, { nullFlag, leftPhi, rightPhi }, nullptr, + "farithmetic_mod"); + case omniruntime::expressions::Operator::TRY_ADD: + return CallExternFunction("add", doubleParams, OMNI_DOUBLE, { leftPhi, rightPhi }, nullptr, + "farithmetic_add"); + case omniruntime::expressions::Operator::TRY_SUB: + return CallExternFunction("subtract", doubleParams, OMNI_DOUBLE, { leftPhi, rightPhi }, nullptr, + "farithmetic_sub"); + case omniruntime::expressions::Operator::TRY_MUL: + return CallExternFunction("multiply", doubleParams, OMNI_DOUBLE, { leftPhi, rightPhi }, nullptr, + "farithmetic_mul"); + case omniruntime::expressions::Operator::TRY_DIV: + return CallExternFunction("divide", doubleParams, OMNI_DOUBLE, { nullFlag, leftPhi, rightPhi }, nullptr, + "farithmetic_divide"); + default: { + LogWarn("Unsupported double binary operator %u", static_cast(binaryExpr->op)); + return nullptr; + } + } +} + +Value *ExpressionCodeGen::BinaryExprStringHelper(const BinaryExpr *binaryExpr, Value *leftVal, Value *leftLen, + Value *rightVal, Value *rightLen, Value *leftIsNull, Value *rightIsNull) +{ + PHINode *leftPhi; + PHINode *rightPhi; + Value *isNull = builder->CreateOr(leftIsNull, rightIsNull); + BinaryExprNullHelper(binaryExpr, leftVal, rightVal, leftIsNull, rightIsNull, &leftPhi, &rightPhi); + switch (binaryExpr->op) { + case omniruntime::expressions::Operator::LT: + return builder->CreateAnd(builder->CreateNot(isNull), builder->CreateICmpSLT( + this->StringCmp(leftVal, leftLen, rightVal, rightLen), llvmTypes->CreateConstantInt(0))); + case omniruntime::expressions::Operator::GT: + return builder->CreateAnd(builder->CreateNot(isNull), builder->CreateICmpSGT( + this->StringCmp(leftVal, leftLen, rightVal, rightLen), llvmTypes->CreateConstantInt(0))); + case omniruntime::expressions::Operator::LTE: + return builder->CreateAnd(builder->CreateNot(isNull), builder->CreateICmpSLE( + this->StringCmp(leftVal, leftLen, rightVal, rightLen), llvmTypes->CreateConstantInt(0))); + case omniruntime::expressions::Operator::GTE: + return builder->CreateAnd(builder->CreateNot(isNull), builder->CreateICmpSGE( + this->StringCmp(leftVal, leftLen, rightVal, rightLen), llvmTypes->CreateConstantInt(0))); + case omniruntime::expressions::Operator::EQ: { + return this->StringEqual(leftVal, leftLen, rightVal, rightLen, isNull); + } + case omniruntime::expressions::Operator::NEQ: { + return builder->CreateNot(this->StringEqual(leftVal, leftLen, rightVal, rightLen, isNull)); + } + + default: { + LogWarn("Unsupported string binary operator %u", static_cast(binaryExpr->op)); + return nullptr; + } + } +} + +void ExpressionCodeGen::BinaryExprDecimal128Helper(const BinaryExpr *binaryExpr, DecimalValue &left, + DecimalValue &right, Value *leftIsNull, Value *rightIsNull) +{ + PHINode *leftPhi; + PHINode *rightPhi; + Value *isNeitherNull = builder->CreateNot(builder->CreateOr(leftIsNull, rightIsNull)); + Value *output = nullptr; + auto leftType = binaryExpr->left->GetReturnType(); + auto rightType = binaryExpr->right->GetReturnType(); + auto binaryReturnType = binaryExpr->GetReturnType(); + BinaryExprNullHelper(binaryExpr, left.data, right.data, leftIsNull, rightIsNull, &leftPhi, &rightPhi); + std::vector params { leftType->GetId(), rightType->GetId() }; + std::shared_ptr returnDecimalValue = BuildDecimalValue(nullptr, *binaryReturnType, nullptr); + std::vector argVals { leftPhi, + const_cast(left.GetPrecision()), + const_cast(left.GetScale()), + rightPhi, + const_cast(right.GetPrecision()), + const_cast(right.GetScale()), + const_cast(returnDecimalValue->GetPrecision()), + const_cast(returnDecimalValue->GetScale()) }; + std::vector argValsCmp { + left.data, const_cast(left.GetPrecision()), const_cast(left.GetScale()), + right.data, const_cast(right.GetPrecision()), const_cast(right.GetScale()) + }; + DataTypeId returnTypeId = binaryExpr->GetReturnTypeId(); + Type *returnType = llvmTypes->ToLLVMType(binaryExpr->GetReturnTypeId()); + std::string decimal128CmpFuncId = FunctionSignature(decimal128CompareStr, params, OMNI_INT).ToString(); + AllocaInst *overflowNull = builder->CreateAlloca(Type::getInt1Ty(*context), nullptr, "overflow_null"); + builder->CreateStore(ConstantInt::get(IntegerType::getInt1Ty(*context), 0), overflowNull); + + bool isTryExpr = false; + switch (binaryExpr->op) { + case omniruntime::expressions::Operator::LT: + output = builder->CreateAnd(isNeitherNull, builder->CreateICmpSLT( + CallDecimalFunction(decimal128CmpFuncId, returnType, argValsCmp), llvmTypes->CreateConstantInt(0))); + break; + case omniruntime::expressions::Operator::GT: + output = builder->CreateAnd(isNeitherNull, builder->CreateICmpSGT( + CallDecimalFunction(decimal128CmpFuncId, returnType, argValsCmp), llvmTypes->CreateConstantInt(0))); + break; + case omniruntime::expressions::Operator::LTE: + output = builder->CreateAnd(isNeitherNull, builder->CreateICmpSLE( + CallDecimalFunction(decimal128CmpFuncId, returnType, argValsCmp), llvmTypes->CreateConstantInt(0))); + break; + case omniruntime::expressions::Operator::GTE: + output = builder->CreateAnd(isNeitherNull, builder->CreateICmpSGE( + CallDecimalFunction(decimal128CmpFuncId, returnType, argValsCmp), llvmTypes->CreateConstantInt(0))); + break; + case omniruntime::expressions::Operator::EQ: { + output = builder->CreateAnd(isNeitherNull, builder->CreateICmpEQ( + CallDecimalFunction(decimal128CmpFuncId, returnType, argValsCmp), llvmTypes->CreateConstantInt(0))); + break; + } + case omniruntime::expressions::Operator::NEQ: + output = builder->CreateAnd(isNeitherNull, builder->CreateICmpNE( + CallDecimalFunction(decimal128CmpFuncId, returnType, argValsCmp), llvmTypes->CreateConstantInt(0))); + break; + case omniruntime::expressions::Operator::ADD: { + std::string funcId = FunctionSignature(addDec128Str, params, returnTypeId).ToString(this->overflowConfig); + output = CallDecimalFunction(funcId, returnType, argVals, codegenContext->executionContext, + this->overflowConfig, overflowNull); + break; + } + case omniruntime::expressions::Operator::SUB: { + std::string funcId = FunctionSignature(subDec128Str, params, returnTypeId).ToString(this->overflowConfig); + output = CallDecimalFunction(funcId, returnType, argVals, codegenContext->executionContext, + this->overflowConfig, overflowNull); + break; + } + case omniruntime::expressions::Operator::MUL: { + std::string funcId = FunctionSignature(mulDec128Str, params, returnTypeId).ToString(this->overflowConfig); + output = CallDecimalFunction(funcId, returnType, argVals, codegenContext->executionContext, + this->overflowConfig, overflowNull); + break; + } + case omniruntime::expressions::Operator::DIV: { + std::string funcId = FunctionSignature(divDec128Str, params, returnTypeId).ToString(this->overflowConfig); + output = CallDecimalFunction(funcId, returnType, argVals, codegenContext->executionContext, + this->overflowConfig, overflowNull); + break; + } + case omniruntime::expressions::Operator::MOD: { + std::string funcId = FunctionSignature(modDec128Str, params, returnTypeId).ToString(this->overflowConfig); + output = CallDecimalFunction(funcId, returnType, argVals, codegenContext->executionContext, + this->overflowConfig, overflowNull); + break; + } + case omniruntime::expressions::Operator::TRY_ADD:{ + isTryExpr = true; + auto ptr = std::make_unique(omniruntime::op::OverflowConfigId::OVERFLOW_CONFIG_NULL); + std::string funcId = FunctionSignature(tryAddDecimal128FnStr, params, returnTypeId).ToString(); + output= CallDecimalFunction(funcId, returnType, argVals, codegenContext->executionContext, ptr.get(), overflowNull); + break; + } + case omniruntime::expressions::Operator::TRY_SUB: { + isTryExpr = true; + auto ptr = std::make_unique(omniruntime::op::OverflowConfigId::OVERFLOW_CONFIG_NULL); + std::string funcId = FunctionSignature(trySubDecimal128FnStr, params, returnTypeId).ToString(); + output = CallDecimalFunction(funcId, returnType, argVals, codegenContext->executionContext,ptr.get(), overflowNull); + break; + } + case omniruntime::expressions::Operator::TRY_MUL: { + isTryExpr = true; + auto ptr = std::make_unique(omniruntime::op::OverflowConfigId::OVERFLOW_CONFIG_NULL); + std::string funcId = FunctionSignature(tryMulDecimal128FnStr, params, returnTypeId).ToString(); + output = CallDecimalFunction(funcId, returnType, argVals, codegenContext->executionContext, ptr.get(), overflowNull); + break; + } + case omniruntime::expressions::Operator::TRY_DIV: { + isTryExpr = true; + auto ptr = std::make_unique(omniruntime::op::OverflowConfigId::OVERFLOW_CONFIG_NULL); + std::string funcId = FunctionSignature(tryDivDecimal128FnStr, params, returnTypeId).ToString(); + output = CallDecimalFunction(funcId, returnType, argVals, codegenContext->executionContext, ptr.get(), overflowNull); + break; + } + default: { + LogWarn("Unsupported decimal128 binary operator %u", static_cast(binaryExpr->op)); + output = nullptr; + break; + } + } + CodeGenValuePtr valuePtr = nullptr; + if (TypeUtil::IsDecimalType(binaryExpr->GetReturnTypeId())) { + valuePtr = + BuildDecimalValue(output, *(binaryExpr->GetReturnType()), builder->CreateOr(leftIsNull, rightIsNull)); + } else { + valuePtr = std::make_shared(output, builder->CreateOr(leftIsNull, rightIsNull)); + } + + if (isTryExpr || (overflowConfig != nullptr && + overflowConfig->GetOverflowConfigId() == omniruntime::op::OVERFLOW_CONFIG_NULL)) { + valuePtr->isNull = builder->CreateOr(valuePtr->isNull, builder->CreateLoad(llvmTypes->I1Type(), overflowNull)); + } + this->value = valuePtr; +} + +CodeGenValue *ExpressionCodeGen::LiteralExprConstantHelper(const LiteralExpr &lExpr) +{ + CodeGenValue *codeGenValue = nullptr; + bool isNullLiteral = lExpr.isNull; + switch (lExpr.GetReturnTypeId()) { + case OMNI_INT: + case OMNI_DATE32: { + codeGenValue = new CodeGenValue(llvmTypes->CreateConstantInt(lExpr.intVal), + llvmTypes->CreateConstantBool(isNullLiteral)); + break; + } + case OMNI_BYTE:{ + codeGenValue = new CodeGenValue(llvmTypes->CreateConstantByte(lExpr.byteVal), + llvmTypes->CreateConstantBool(isNullLiteral)); + break; + } + case OMNI_SHORT:{ + codeGenValue = new CodeGenValue(llvmTypes->CreateConstantShort(lExpr.shortVal), + llvmTypes->CreateConstantBool(isNullLiteral)); + break; + } + case OMNI_TIMESTAMP: + case OMNI_LONG: { + codeGenValue = new CodeGenValue(llvmTypes->CreateConstantLong(lExpr.longVal), + llvmTypes->CreateConstantBool(isNullLiteral)); + break; + } + case OMNI_DOUBLE: { + codeGenValue = new CodeGenValue(llvmTypes->CreateConstantDouble(lExpr.doubleVal), + llvmTypes->CreateConstantBool(isNullLiteral)); + break; + } + case OMNI_CHAR: + case OMNI_VARCHAR: { + Constant *strValConst = CreateConstantString(*(lExpr.stringVal)); + Constant *strLenConst = + ConstantInt::get(*context, APInt(INT32_VALUE, static_cast(lExpr.stringVal->length()))); + codeGenValue = new CodeGenValue(strValConst, llvmTypes->CreateConstantBool(isNullLiteral), strLenConst); + break; + } + case OMNI_BOOLEAN: { + codeGenValue = new CodeGenValue(llvmTypes->CreateConstantBool(lExpr.boolVal), + llvmTypes->CreateConstantBool(isNullLiteral)); + break; + } + case OMNI_DECIMAL64: { + Value *precision = llvmTypes->CreateConstantInt( + static_cast(lExpr.GetReturnType().get())->GetPrecision()); + Value *scale = + llvmTypes->CreateConstantInt(static_cast(lExpr.GetReturnType().get())->GetScale()); + codeGenValue = new DecimalValue(llvmTypes->CreateConstantLong(lExpr.longVal), + llvmTypes->CreateConstantBool(isNullLiteral), precision, scale); + break; + } + case OMNI_DECIMAL128: { + std::string dec128String = isNullLiteral ? "0" : *lExpr.stringVal; + __uint128_t dec128 = Decimal128Utils::StrToUint128_t(dec128String.c_str()); + dec128String = Decimal128Utils::Uint128_tToStr(dec128); + Value *precision = llvmTypes->CreateConstantInt( + static_cast(lExpr.GetReturnType().get())->GetPrecision()); + Value *scale = llvmTypes->CreateConstantInt( + static_cast(lExpr.GetReturnType().get())->GetScale()); + auto const128Val = llvm::ConstantInt::get(llvm::Type::getInt128Ty(*context), dec128String, 10); + codeGenValue = + new DecimalValue(const128Val, llvmTypes->CreateConstantBool(isNullLiteral), precision, scale); + break; + } + case OMNI_NONE: { + codeGenValue = + new CodeGenValue(llvmTypes->CreateConstantInt(lExpr.intVal), llvmTypes->CreateConstantBool(true)); + break; + } + default: { + LogWarn("Unsupported data type in Data Expr %d", lExpr.GetReturnTypeId()); + codeGenValue = + new CodeGenValue(llvmTypes->CreateConstantBool(lExpr.boolVal), llvmTypes->CreateConstantBool(false)); + break; + } + } + return codeGenValue; +} + +bool ExpressionCodeGen::AreInvalidDataTypes(DataTypeId type1, DataTypeId type2) +{ + return type1 != type2 && !(TypeUtil::IsStringType(type1) && TypeUtil::IsStringType(type2)); +} + +void ExpressionCodeGen::InExprIntegerHelper(CodeGenValuePtr &valueToCompare, CodeGenValuePtr &argiValue, + Value *&tmpCmpData, Value *&tmpCmpNull) +{ + tmpCmpData = builder->CreateICmpEQ(valueToCompare->data, argiValue->data); + tmpCmpNull = builder->CreateOr(valueToCompare->isNull, argiValue->isNull); +} + +void ExpressionCodeGen::InExprDecimal64Helper(CodeGenValuePtr &valueToCompare, CodeGenValuePtr &argiValue, + Value *&tmpCmpData, Value *&tmpCmpNull, llvm::Type *retType) +{ + std::vector params { OMNI_DECIMAL64, OMNI_DECIMAL64 }; + std::string funcId = FunctionSignature(decimal64CompareStr, params, OMNI_INT).ToString(); + DecimalValue &left = static_cast(*valueToCompare); + DecimalValue &right = static_cast(*argiValue); + std::vector argValsCmp { + left.data, const_cast(left.GetPrecision()), const_cast(left.GetScale()), + right.data, const_cast(right.GetPrecision()), const_cast(right.GetScale()) + }; + tmpCmpData = + builder->CreateICmpEQ(CallDecimalFunction(funcId, retType, argValsCmp), llvmTypes->CreateConstantInt(0)); + + tmpCmpNull = builder->CreateOr(valueToCompare->isNull, argiValue->isNull); +} + +void ExpressionCodeGen::InExprDecimal128Helper(CodeGenValuePtr &valueToCompare, CodeGenValuePtr &argiValue, + Value *&tmpCmpData, Value *&tmpCmpNull, llvm::Type *retType) +{ + std::vector params { OMNI_DECIMAL128, OMNI_DECIMAL128 }; + std::string funcId = FunctionSignature(decimal128CompareStr, params, OMNI_INT).ToString(); + DecimalValue &left = static_cast(*valueToCompare); + DecimalValue &right = static_cast(*argiValue); + std::vector argValsCmp { + left.data, const_cast(left.GetPrecision()), const_cast(left.GetScale()), + right.data, const_cast(right.GetPrecision()), const_cast(right.GetScale()) + }; + tmpCmpData = + builder->CreateICmpEQ(CallDecimalFunction(funcId, retType, argValsCmp), llvmTypes->CreateConstantInt(0)); + + tmpCmpNull = builder->CreateOr(valueToCompare->isNull, argiValue->isNull); +} + +void ExpressionCodeGen::InExprStringHelper(CodeGenValuePtr &valueToCompare, CodeGenValuePtr &argiValue, + Value *&tmpCmpData, Value *&tmpCmpNull) +{ + tmpCmpNull = builder->CreateOr(valueToCompare->isNull, argiValue->isNull); + tmpCmpData = StringEqual(valueToCompare->data, valueToCompare->length, argiValue->data, value->length, tmpCmpNull); +} + +void ExpressionCodeGen::InExprDoubleHelper(CodeGenValuePtr &valueToCompare, CodeGenValuePtr &argiValue, + Value *&tmpCmpData, Value *&tmpCmpNull) +{ + tmpCmpData = builder->CreateFCmpOEQ(valueToCompare->data, argiValue->data); + tmpCmpNull = builder->CreateOr(valueToCompare->isNull, argiValue->isNull); +} + +bool ExpressionCodeGen::VisitBetweenExprHelper(BetweenExpr &bExpr, const std::shared_ptr &val, + const std::shared_ptr &lowerVal, const std::shared_ptr &upperVal, + std::pair cmpPair) +{ + llvm::Type *retType = llvmTypes->ToLLVMType(bExpr.GetReturnTypeId()); + auto cmpLeft = cmpPair.first; + auto cmpRight = cmpPair.second; + if (bExpr.value->GetReturnTypeId() == OMNI_INT || bExpr.value->GetReturnTypeId() == OMNI_LONG || + bExpr.value->GetReturnTypeId() == OMNI_DATE32 || bExpr.value->GetReturnTypeId() == OMNI_TIMESTAMP) { + *cmpLeft = builder->CreateICmpSLE(lowerVal->data, val->data, "between_cmpleft"); + *cmpRight = builder->CreateICmpSLE(val->data, upperVal->data, "between_cmpright"); + return true; + } else if (bExpr.value->GetReturnTypeId() == OMNI_DOUBLE) { + *cmpLeft = builder->CreateFCmpULE(lowerVal->data, val->data, "between_cmpleft"); + *cmpRight = builder->CreateFCmpULE(val->data, upperVal->data, "between_cmpright"); + return true; + } else if (TypeUtil::IsStringType(bExpr.value->GetReturnTypeId())) { + *cmpLeft = builder->CreateICmpSLE(this->StringCmp(lowerVal->data, lowerVal->length, val->data, val->length), + llvmTypes->CreateConstantInt(0)); + *cmpRight = builder->CreateICmpSLE(this->StringCmp(val->data, val->length, upperVal->data, upperVal->length), + llvmTypes->CreateConstantInt(0)); + return true; + } else if (TypeUtil::IsDecimalType(bExpr.value->GetReturnTypeId())) { + auto retTypeId = bExpr.value->GetReturnTypeId(); + if (retTypeId == OMNI_DECIMAL64) { + std::vector params { OMNI_DECIMAL64, OMNI_DECIMAL64 }; + auto &cmpLower = static_cast(*lowerVal); + auto &cmpVal = static_cast(*val); + auto &cmpUpper = static_cast(*upperVal); + if (cmpVal.GetScale() == cmpLower.GetScale() && cmpVal.GetScale() == cmpUpper.GetScale()) { + *cmpLeft = builder->CreateICmpSLE(lowerVal->data, val->data, "between_cmpleft"); + *cmpRight = builder->CreateICmpSLE(val->data, upperVal->data, "between_cmpright"); + return true; + } + + std::vector argValsCmpLeft { + cmpLower.data, const_cast(cmpLower.GetPrecision()), const_cast(cmpLower.GetScale()), + cmpVal.data, const_cast(cmpVal.GetPrecision()), const_cast(cmpVal.GetScale()) + }; + std::vector argValsCmpRight { + cmpVal.data, const_cast(cmpVal.GetPrecision()), const_cast(cmpVal.GetScale()), + cmpUpper.data, const_cast(cmpUpper.GetPrecision()), const_cast(cmpUpper.GetScale()) + }; + std::string funcId = FunctionSignature(decimal64CompareStr, params, OMNI_INT).ToString(); + + *cmpLeft = builder->CreateICmpSLE(CallDecimalFunction(funcId, retType, argValsCmpLeft), + llvmTypes->CreateConstantInt(0)); + *cmpRight = builder->CreateICmpSLE(CallDecimalFunction(funcId, retType, argValsCmpRight), + llvmTypes->CreateConstantInt(0)); + } else if (retTypeId == OMNI_DECIMAL128) { + std::vector params { OMNI_DECIMAL128, OMNI_DECIMAL128 }; + auto &cmpLower = static_cast(*lowerVal); + auto &cmpVal = static_cast(*val); + auto &cmpUpper = static_cast(*upperVal); + std::vector argValsCmpLeft { + cmpLower.data, const_cast(cmpLower.GetPrecision()), const_cast(cmpLower.GetScale()), + cmpVal.data, const_cast(cmpVal.GetPrecision()), const_cast(cmpVal.GetScale()) + }; + std::vector argValsCmpRight { + cmpVal.data, const_cast(cmpVal.GetPrecision()), const_cast(cmpVal.GetScale()), + cmpUpper.data, const_cast(cmpUpper.GetPrecision()), const_cast(cmpUpper.GetScale()) + }; + std::string funcId = FunctionSignature(decimal128CompareStr, params, OMNI_INT).ToString(); + + *cmpLeft = builder->CreateICmpSLE(CallDecimalFunction(funcId, retType, argValsCmpLeft), + llvmTypes->CreateConstantInt(0)); + *cmpRight = builder->CreateICmpSLE(CallDecimalFunction(funcId, retType, argValsCmpRight), + llvmTypes->CreateConstantInt(0)); + } + return true; + } + return false; +} + +Value *ExpressionCodeGen::GetDictionaryVectorValue(const omniruntime::type::DataType &dataType, Value *rowIdx, + Value *dictionaryVectorPtr, AllocaInst *&lengthAllocaInst) +{ + std::vector paramTypes = { OMNI_LONG, OMNI_INT }; + DataTypeId typeId = dataType.GetId(); + FunctionSignature dictionaryFuncSignature; + switch (typeId) { + case OMNI_BYTE: + dictionaryFuncSignature = FunctionSignature(dictionaryGetByteStr, paramTypes, OMNI_BYTE); + break; + case OMNI_SHORT: + dictionaryFuncSignature = FunctionSignature(dictionaryGetShortStr, paramTypes, OMNI_SHORT); + break; + case OMNI_INT: + case OMNI_DATE32: + dictionaryFuncSignature = FunctionSignature(dictionaryGetIntStr, paramTypes, OMNI_INT); + break; + case OMNI_LONG: + case OMNI_TIMESTAMP: + case OMNI_DECIMAL64: + dictionaryFuncSignature = FunctionSignature(dictionaryGetLongStr, paramTypes, OMNI_LONG); + break; + case OMNI_DECIMAL128: + dictionaryFuncSignature = FunctionSignature(dictionaryGetDecimalStr, paramTypes, OMNI_DECIMAL128); + break; + case OMNI_DOUBLE: + dictionaryFuncSignature = FunctionSignature(dictionaryGetDoubleStr, paramTypes, OMNI_DOUBLE); + break; + case OMNI_BOOLEAN: + dictionaryFuncSignature = FunctionSignature(dictionaryGetBooleanStr, paramTypes, OMNI_BOOLEAN); + break; + case OMNI_CHAR: + case OMNI_VARCHAR: + dictionaryFuncSignature = FunctionSignature(dictionaryGetVarcharStr, paramTypes, OMNI_VARCHAR); + break; + default: + LogWarn("Unsupported dictionary value type: %d", typeId); + return nullptr; + } + auto dictionaryFunc = modulePtr->getFunction(FunctionRegistry::LookupFunction(&dictionaryFuncSignature)->GetId()); + std::vector funcArgs; + funcArgs.push_back(dictionaryVectorPtr); + funcArgs.push_back(rowIdx); + if (TypeUtil::IsStringType(typeId)) { + lengthAllocaInst = builder->CreateAlloca(llvmTypes->I32Type(), nullptr, "varchar_length"); + builder->CreateStore(llvmTypes->CreateConstantInt(0), lengthAllocaInst); + funcArgs.push_back(lengthAllocaInst); + } + Value *result = nullptr; + if (typeId == OMNI_DECIMAL128) { + funcArgs.push_back( + llvmTypes->CreateConstantInt(static_cast(dataType).GetPrecision())); + funcArgs.push_back(llvmTypes->CreateConstantInt(static_cast(dataType).GetScale())); + result = CallDecimalFunction(FunctionRegistry::LookupFunction(&dictionaryFuncSignature)->GetId(), + llvmTypes->ToLLVMType(typeId), funcArgs); + } else { + result = CreateCall(dictionaryFunc, funcArgs, "get_dictionary_value"); + InlineFunctionInfo inlineFunctionInfo; + llvm::InlineFunction(*((CallInst *)result), inlineFunctionInfo); + } + return result; +} + +void ExpressionCodeGen::CoalesceExprDecimalHelper(CodeGenValue &v1, CodeGenValue &v2, BasicBlock &isNotNullBlock, + BasicBlock &isNullBlock, PHINode &pn, PHINode &pnNull) +{ + int32_t numReservedValues = 2; + auto precisionPhi = builder->CreatePHI(Type::getInt32Ty(*context), numReservedValues, "precision"); + auto value1Precision = (Value *)static_cast(v1).GetPrecision(); + auto value2Precision = (Value *)static_cast(v2).GetPrecision(); + precisionPhi->addIncoming(value1Precision, &isNotNullBlock); + precisionPhi->addIncoming(value2Precision, &isNullBlock); + + auto scalePhi = builder->CreatePHI(Type::getInt32Ty(*context), numReservedValues, "scale"); + auto value1Scale = (Value *)static_cast(v1).GetScale(); + auto value2Scale = (Value *)static_cast(v2).GetScale(); + scalePhi->addIncoming(value1Scale, &isNotNullBlock); + scalePhi->addIncoming(value2Scale, &isNullBlock); + + this->value = std::make_shared(&pn, &pnNull, precisionPhi, scalePhi); +} + +Value *ExpressionCodeGen::PushAndGetNullFlag(const FuncExpr &fExpr, std::vector &argVals, + Value *nullFlag, bool needAdd) +{ + if (fExpr.function->GetNullableResultType() == INPUT_DATA_AND_NULL_AND_RETURN_NULL) { + AllocaInst *isNullPtr = builder->CreateAlloca(builder->getInt1Ty(), nullptr, "is_null"); + builder->CreateStore(llvmTypes->CreateConstantBool(false), isNullPtr); + argVals.push_back(isNullPtr); + return isNullPtr; + } + if (needAdd) { + argVals.push_back(nullFlag); + } + return nullFlag; +} + +Value *ExpressionCodeGen::LoadNullFlag(const FuncExpr &fExpr, Value *nullFlag) +{ + if (fExpr.function->GetNullableResultType() == INPUT_DATA_AND_NULL_AND_RETURN_NULL) { + return builder->CreateLoad(builder->getInt1Ty(), nullFlag); + } + return nullFlag; +} +} diff --git a/core/src/codegen/expression_codegen.h b/core/src/codegen/expression_codegen.h new file mode 100644 index 0000000..9f4eac4 --- /dev/null +++ b/core/src/codegen/expression_codegen.h @@ -0,0 +1,219 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description: Expression code generator + */ +#ifndef __EXPRESSION_CODEGEN_H__ +#define __EXPRESSION_CODEGEN_H__ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "codegen_value.h" +#include "codegen_context.h" +#include "expression/expressions.h" +#include "expression/parser/parser.h" +#include "expression/expr_printer.h" +#include "util/debug.h" +#include "llvm_types.h" +#include "llvm_engine.h" +#include "codegen_base.h" +#include "vector/vector_batch.h" +#include "expr_function.h" + +namespace omniruntime::codegen { +using CodeGenValuePtr = std::shared_ptr; + +class ExpressionCodeGen : public ExprVisitor, public CodegenBase { +public: + /** + * Method to initialize a ExpressionCodeGen instance + * @param name ExpressionCodeGen module name + * @param cpExpr the expression to code generation + * @param ofConfig config of overflow + */ + ExpressionCodeGen(std::string name, const Expr &cpExpr, op::OverflowConfig *ofConfig); + + ~ExpressionCodeGen() override; + + /** + * Method to get function of processing expression + * @param inputDataTypes is used to provide data type when preload data + * @return the address of function + */ + virtual intptr_t GetFunction(const DataTypes &inputDataTypes) + { + return 0; + } + + std::set vectorIndexes{}; + +protected: + /** + * Method to get function of processing expression for a single line + * @param inputDataTypes is used to provide data type when preload data + * @return the address of function + */ + virtual llvm::Function *CreateFunction(const DataTypes &inputDataTypes); + + /** + * Visitor methods + * @param e expression to visit + */ + void Visit(const LiteralExpr &e) override; + + void Visit(const FieldExpr &e) override; + + void Visit(const UnaryExpr &e) override; + + void Visit(const BinaryExpr &e) override; + + void Visit(const InExpr &e) override; + + void Visit(const BetweenExpr &e) override; + + void Visit(const IfExpr &e) override; + + void Visit(const CoalesceExpr &e) override; + + void Visit(const IsNullExpr &e) override; + + void Visit(const FuncExpr &e) override; + + void Visit(const SwitchExpr &e) override; + + /** + * Method to get LLVM value ptr of expression + * @param e expression to visit + * @return llvm value ptr of expression + */ + CodeGenValuePtr VisitExpr(const Expr &e); + + void ExtractVectorIndexes(); + + std::vector GetFunctionArgValues(const FuncExpr &fExpr, Value **isAnyNull, bool &isInvalidExpr); + + bool InitializeCodegenContext(iterator_range args); + + Value *GetDictionaryVectorValue(const omniruntime::type::DataType &dataType, Value *rowIdx, + Value *dictionaryVectorPtr, AllocaInst *&lengthAllocaInst); + + // Represents the generated expression function + std::shared_ptr exprFunc; + +private: + template + std::vector GetDefaultFunctionArgValues(const FuncExpr &fExpr, Value **isAnyNull, bool &isInvalidExpr); + + std::vector GetDataArgs(const omniruntime::expressions::FuncExpr &fExpr, llvm::Value **isAnyNull, + bool &isInvalidExpr); + + std::vector GetDataAndNullArgs(const omniruntime::expressions::FuncExpr &fExpr, + llvm::Value **isAnyNull, bool &isInvalidExpr); + + std::vector GetDataAndNullArgsAndReturnNull(const omniruntime::expressions::FuncExpr &fExpr, + llvm::Value **isAnyNull, bool &isInvalidExpr); + + std::vector GetDataAndOverflowNullArgs(const omniruntime::expressions::FuncExpr &fExpr, + llvm::Value **isAnyNull, bool &isInvalidExpr, llvm::Value *overflowNull); + + llvm::Value *CreateHiveUdfArgTypes(const omniruntime::expressions::FuncExpr &fExpr); + + std::vector GetHiveUdfArgValues(const omniruntime::expressions::FuncExpr &fExpr, bool &isInvalid); + + void CallHiveUdfFunction(const omniruntime::expressions::FuncExpr &fExpr); + + void FuncExprOverflowNullHelper(const omniruntime::expressions::FuncExpr &e); + + Value *StringCmp(Value *lhs, Value *lLen, Value *rhs, Value *rLen); + + Value *StringEqual(Value *lhs, Value *lLen, Value *rhs, Value *rLen, Value *isNull); + + void BinaryExprNullHelper(const BinaryExpr *binaryExpr, Value *left, Value *right, Value *leftIsNull, + Value *rightIsNull, PHINode **leftPhi, PHINode **rightPhi); + + llvm::Value *BinaryExprByteHelper(const BinaryExpr *binaryExpr, Value *left, Value *right, Value *leftIsNull, + Value *rightIsNull, Value *nullFlag = nullptr); + + llvm::Value *BinaryExprShortHelper(const BinaryExpr *binaryExpr, Value *left, Value *right, Value *leftIsNull, + Value *rightIsNull, Value *nullFlag = nullptr); + + llvm::Value *BinaryExprIntHelper(const BinaryExpr *binaryExpr, Value *left, Value *right, Value *leftIsNull, + Value *rightIsNull, Value *nullFlag = nullptr); + + Value *BinaryExprLongHelper(const BinaryExpr *binaryExpr, Value *left, Value *right, Value *leftIsNull, + Value *rightIsNull, Value *nullFlag = nullptr); + + void BinaryExprDecimal64Helper(const BinaryExpr *binaryExpr, DecimalValue &left, DecimalValue &right, + Value *leftIsNull, Value *rightIsNull); + + Value *BinaryExprDoubleHelper(const BinaryExpr *binaryExpr, Value *left, Value *right, Value *leftIsNull, + Value *rightIsNull, Value *nullFlag = nullptr); + + Value *BinaryExprStringHelper(const BinaryExpr *binaryExpr, Value *leftVal, Value *leftLen, Value *rightVal, + Value *rightLen, Value *leftIsNull, Value *rightIsNull); + + void BinaryExprDecimal128Helper(const BinaryExpr *binaryExpr, DecimalValue &left, DecimalValue &right, + Value *leftIsNull, Value *rightIsNull); + + CodeGenValue *LiteralExprConstantHelper(const LiteralExpr &lExpr); + + bool AreInvalidDataTypes(DataTypeId type1, DataTypeId type2); + + void InExprIntegerHelper(CodeGenValuePtr &valueToCompare, CodeGenValuePtr &argiValue, Value *&tmpCmpData, + Value *&tmpCmpNull); + + void InExprDecimal64Helper(CodeGenValuePtr &valueToCompare, CodeGenValuePtr &argiValue, Value *&tmpCmpData, + Value *&tmpCmpNull, Type *retType); + + void InExprDoubleHelper(CodeGenValuePtr &valueToCompare, CodeGenValuePtr &argiValue, Value *&tmpCmpData, + Value *&tmpCmpNull); + + void InExprStringHelper(CodeGenValuePtr &valueToCompare, CodeGenValuePtr &argiValue, Value *&tmpCmpData, + Value *&tmpCmpNull); + + void InExprDecimal128Helper(CodeGenValuePtr &valueToCompare, CodeGenValuePtr &argiValue, Value *&tmpCmpData, + Value *&tmpCmpNull, llvm::Type *retType); + + bool VisitBetweenExprHelper(BetweenExpr &bExpr, const std::shared_ptr &val, + const std::shared_ptr &lowerVal, const std::shared_ptr &upperVal, + std::pair cmpPair); + + void CoalesceExprDecimalHelper(CodeGenValue &v1, CodeGenValue &v2, BasicBlock &isNotNullBlock, + BasicBlock &isNullBlock, PHINode &pn, PHINode &pnNull); + + Value *PushAndGetNullFlag(const FuncExpr &fExpr, std::vector &argVals, Value *nullFlag, + bool needAdd); + + Value *LoadNullFlag(const FuncExpr &fExpr, Value *nullFlag); +}; +} + +#endif \ No newline at end of file diff --git a/core/src/codegen/filter_codegen.cpp b/core/src/codegen/filter_codegen.cpp new file mode 100644 index 0000000..4286026 --- /dev/null +++ b/core/src/codegen/filter_codegen.cpp @@ -0,0 +1,175 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description:filter code generation methods + */ +#include "filter_codegen.h" + +namespace omniruntime::codegen { +using namespace llvm; +using namespace orc; +using namespace omniruntime::expressions; + +namespace { +const int ARGS_ARRAY_INDEX = 0; +const int NUM_ROWS_INDEX = 1; +const int RESULTS_INDEX = 2; +const int BITMAP_INDEX = 3; +const int OFFSETS_INDEX = 4; +const int EXECUTION_CONTEXT_IDX = 5; +const int DICTIONARY_VECTORS_IDX = 6; +} + +intptr_t FilterCodeGen::GetFunction(const DataTypes &inputDataTypes) +{ + llvm::Function *func = CreateFunction(inputDataTypes); + if (func == nullptr) { + return 0; + } + return CreateWrapper(); +} + +intptr_t FilterCodeGen::CreateWrapper() +{ + // The args indicates the type of the function parameter list. + std::vector args { + llvmTypes->I64PtrType(), // data address array + llvmTypes->I32Type(), // the num of rows + llvmTypes->I32PtrType(), // output array + llvmTypes->I64PtrType(), // bitmap address array + llvmTypes->I64PtrType(), // offset address array + llvmTypes->I64Type(), // execution content address + llvmTypes->I64PtrType() // dictionary address array + }; + + FunctionType *funcSignature = FunctionType::get(llvmTypes->I32Type(), args, false); + llvm::Function *funcDecl = + llvm::Function::Create(funcSignature, llvm::Function::ExternalLinkage, "WRAPPER_FUNC", modulePtr); + BasicBlock *preLoop = BasicBlock::Create(*context, "PRE_LOOP", funcDecl); + BasicBlock *loopBody = BasicBlock::Create(*context, "LOOP_BODY", funcDecl); + BasicBlock *filterPassed = BasicBlock::Create(*context, "FILTER_PASSED", funcDecl); + BasicBlock *incrementCounter = BasicBlock::Create(*context, "INCREMENT_COUNTER", funcDecl); + BasicBlock *endBlock = BasicBlock::Create(*context, "END_BLOCK", funcDecl); + // preprocessing + Argument *data = funcDecl->getArg(ARGS_ARRAY_INDEX); + data->setName("ARGS_ARRAY"); + + Argument *numRows = funcDecl->getArg(NUM_ROWS_INDEX); + numRows->setName("NUM_ROWS"); + + Argument *resultsArray = funcDecl->getArg(RESULTS_INDEX); + resultsArray->setName("RESULTS"); + + Argument *bitmap = funcDecl->getArg(BITMAP_INDEX); + bitmap->setName("BITMAP"); + + Argument *offsets = funcDecl->getArg(OFFSETS_INDEX); + offsets->setName("OFFSETS"); + + Argument *executionContext = funcDecl->getArg(EXECUTION_CONTEXT_IDX); + executionContext->setName("EXECUTION_CONTEXT_ADDRESS"); + + Argument *dictionaryVectors = funcDecl->getArg(DICTIONARY_VECTORS_IDX); + dictionaryVectors->setName("DICTIONARY_VECTORS"); + + RecordMainFunction(funcDecl); + + Value *zero = llvmTypes->CreateConstantInt(0); + Value *one = llvmTypes->CreateConstantInt(1); + + // filterFuncArgs contains the values of the arguments to the filter function + std::vector filterFuncArgs; + int32_t argsSize = exprFunc->GetArgumentCount() + exprFunc->GetInputColumnCount() * 4; + filterFuncArgs.reserve(argsSize); + + // pre loop body + builder->SetInsertPoint(preLoop); + // Pointer to the current row index to be processed. + AllocaInst *indexStore = builder->CreateAlloca(llvmTypes->I32Type(), nullptr, "INDEX_COUNTER"); + // Initialize row index to 0. + builder->CreateStore(zero, indexStore); + // Value of the current row index to be processed. + Value *curIndexVal; + // Temp value for next row index. + Value *nextIndexVal; + // Pointer to the index of the selected positions array to be filled next. + AllocaInst *selectedIndexStore = builder->CreateAlloca(llvmTypes->I32Type(), nullptr, "SELECTED_INDEX_PTR"); + // Initialize index to 0. + builder->CreateStore(zero, selectedIndexStore); + // Value of the selected positions index. + Value *selectedIndexVal; + // Address of the selected index for writing. + Value *selectedAddress; + // Temp value for next selected index. + Value *nextSelectedIndexVal; + + // Create a int pointer to store data length + AllocaInst *lengthAllocaInst = builder->CreateAlloca(llvmTypes->I32Type(), nullptr, "DATA_LENGTH"); + auto isNullPtr = builder->CreateAlloca(llvmTypes->I1Type(), nullptr, "IS_NULL_PTR"); + + auto columnArgs = exprFunc->ToColumnArgs(data); + auto dicArgs = exprFunc->ToDicArgs(dictionaryVectors); + auto nullArgs = exprFunc->ToNullArgs(bitmap); + auto offsetArgs = exprFunc->ToOffsetArgs(offsets); + + builder->CreateBr(loopBody); + // loop body + builder->SetInsertPoint(loopBody); + // Get the value of the current row index to process. + curIndexVal = builder->CreateLoad(llvmTypes->I32Type(), indexStore, "CUR_INDEX"); + + // initialize data_length to 0; + builder->CreateStore(llvmTypes->CreateConstantInt(0), lengthAllocaInst); + + // initialize isNullPtr to false + builder->CreateStore(llvmTypes->CreateConstantBool(false), isNullPtr); + + filterFuncArgs.push_back(curIndexVal); + filterFuncArgs.push_back(lengthAllocaInst); + filterFuncArgs.push_back(executionContext); + filterFuncArgs.push_back(isNullPtr); + + filterFuncArgs.insert(filterFuncArgs.end(), columnArgs.begin(), columnArgs.end()); + filterFuncArgs.insert(filterFuncArgs.end(), dicArgs.begin(), dicArgs.end()); + filterFuncArgs.insert(filterFuncArgs.end(), nullArgs.begin(), nullArgs.end()); + filterFuncArgs.insert(filterFuncArgs.end(), offsetArgs.begin(), offsetArgs.end()); + + // Get the boolean response for this row from the filter function. + CallInst *ret = builder->CreateCall(func, filterFuncArgs, "ROW_EVAL"); + + ret = static_cast( + builder->CreateAnd(builder->CreateNot(builder->CreateLoad(llvmTypes->I1Type(), isNullPtr)), ret)); + // If true, add row index to selected array, otherwise, process next row. + builder->CreateCondBr(ret, filterPassed, incrementCounter); + + // Add row index to results array + builder->SetInsertPoint(filterPassed); + // Get value of selected index. + selectedIndexVal = builder->CreateLoad(llvmTypes->I32Type(), selectedIndexStore, "SELECTED_INDEX"); + // Get address of selected index. + selectedAddress = builder->CreateGEP(llvmTypes->I32Type(), resultsArray, selectedIndexVal, "SELECTED_ADDRESS"); + // Set the selected value to the current row index. + builder->CreateStore(curIndexVal, selectedAddress); + // Increment the selected index. + nextSelectedIndexVal = builder->CreateAdd(selectedIndexVal, one, "NEXT_SELECTED_INDEX"); + builder->CreateStore(nextSelectedIndexVal, selectedIndexStore); + + // Increment counter and process next row. + builder->CreateBr(incrementCounter); + // Increment loop counter + builder->SetInsertPoint(incrementCounter); + // Increment counter. + nextIndexVal = builder->CreateAdd(curIndexVal, one, "NEXT_INDEX"); + builder->CreateStore(nextIndexVal, indexStore); + // If there are rows remaining, repeat, otherwise, exit. + Value *cond = builder->CreateICmpSLT(nextIndexVal, numRows, "END_LOOP_COND"); + builder->CreateCondBr(cond, loopBody, endBlock); + + // Return results + builder->SetInsertPoint(endBlock); + nextSelectedIndexVal = builder->CreateLoad(llvmTypes->I32Type(), selectedIndexStore); + builder->CreateRet(nextSelectedIndexVal); + OptimizeFunctionsAndModule(); + + return Compile(); +} +} \ No newline at end of file diff --git a/core/src/codegen/filter_codegen.h b/core/src/codegen/filter_codegen.h new file mode 100644 index 0000000..7f1e064 --- /dev/null +++ b/core/src/codegen/filter_codegen.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description: filter code generation methods + */ +#ifndef FILTER_CODEGEN_H +#define FILTER_CODEGEN_H + +#include "expression_codegen.h" + +namespace omniruntime { +namespace codegen { +class FilterCodeGen : public ExpressionCodeGen { +public: + /** + * Method to initialize a FilterCodeGen instance + * @param name FilterCodeGen module name + * @param expression the filter expression to code generation + * @param overflowConfig config of overflow + */ + FilterCodeGen(std::string name, const omniruntime::expressions::Expr &expression, + omniruntime::op::OverflowConfig *overflowConfig) + : ExpressionCodeGen(std::move(name), expression, overflowConfig) + {} + + ~FilterCodeGen() override = default; + + /** + * Method to get function of processing filter expression + * @param inputDataTypes is used to provide data type when preload data + * @return the address of function + */ + intptr_t GetFunction(const DataTypes &inputDataTypes) override; + +private: + /** + * Method to generate function by using LLVM API which processes filter expression line by line + * @return the address of function + */ + intptr_t CreateWrapper(); +}; +} +} +#endif \ No newline at end of file diff --git a/core/src/codegen/func_registry.cpp b/core/src/codegen/func_registry.cpp new file mode 100644 index 0000000..fe29977 --- /dev/null +++ b/core/src/codegen/func_registry.cpp @@ -0,0 +1,267 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2024. All rights reserved. + * Description: registry function + */ +#include "func_registry.h" +#include +#include "util/debug.h" +#include "util/config_util.h" + +namespace omniruntime::codegen { +using namespace std; + +const std::string INVALID_HIVE_UDF = ""; + +vector FunctionRegistry::registeredBatchFunctions = InitializeBatchFunc(); +vector FunctionRegistry::registeredRowFunctions = InitializeRowFunc(); +FunctionMapPtr FunctionRegistry::functionRegistry; +FunctionMapPtr FunctionRegistry::functionNullRegistry; +HiveUdfMapPtr FunctionRegistry::hiveUdfMap; +std::once_flag FunctionRegistry::initHiveUdfMap; + +vector> FunctionRegistry::GetRowFunctionRegistries() +{ + vector> functionRegistries; + functionRegistries.push_back(make_unique()); + functionRegistries.push_back(make_unique()); + functionRegistries.push_back(make_unique()); + functionRegistries.push_back(make_unique()); + functionRegistries.push_back(make_unique()); + functionRegistries.push_back(make_unique()); + functionRegistries.push_back(make_unique()); + functionRegistries.push_back(make_unique()); + functionRegistries.push_back(make_unique()); + functionRegistries.push_back(make_unique()); + + auto policy = GetProperties().GetPolicy(); + if (policy->GetRoundingRule() == RoundingRule::HALF_UP) { + functionRegistries.push_back(make_unique()); + functionRegistries.push_back(make_unique()); + } else { + functionRegistries.push_back(make_unique()); + functionRegistries.push_back(make_unique()); + } + + if (policy->GetCheckReScaleRule() == CheckReScaleRule::NOT_CHECK_RESCALE) { + functionRegistries.push_back(make_unique()); + } else { + functionRegistries.push_back(make_unique()); + } + + if (policy->GetEmptySearchStrReplaceRule() == EmptySearchStrReplaceRule::REPLACE) { + functionRegistries.push_back(make_unique()); + } else { + functionRegistries.push_back(make_unique()); + } + + if (policy->GetStringToDateFormatRule() == StringToDateFormatRule::NOT_ALLOW_REDUCED_PRECISION) { + functionRegistries.push_back(make_unique()); + } else { + functionRegistries.push_back(make_unique()); + } + + if (policy->GetStringToDecimalRule() == StringToDecimalRule::OVERFLOW_AS_ROUND_UP) { + functionRegistries.push_back(make_unique()); + } else { + functionRegistries.push_back(make_unique()); + } + + if (policy->GetNegativeStartIndexOutOfBoundsRule() == NegativeStartIndexOutOfBoundsRule::INTERCEPT_FROM_BEYOND && + policy->GetZeroStartIndexSupportRule() == ZeroStartIndexSupportRule::IS_SUPPORT) { + functionRegistries.push_back(make_unique()); + } else if (policy->GetNegativeStartIndexOutOfBoundsRule() == NegativeStartIndexOutOfBoundsRule::EMPTY_STRING && + policy->GetZeroStartIndexSupportRule() == ZeroStartIndexSupportRule::IS_SUPPORT) { + functionRegistries.push_back(make_unique()); + } else if (policy->GetNegativeStartIndexOutOfBoundsRule() == + NegativeStartIndexOutOfBoundsRule::INTERCEPT_FROM_BEYOND && + policy->GetZeroStartIndexSupportRule() == ZeroStartIndexSupportRule::IS_NOT_SUPPORT) { + functionRegistries.push_back(make_unique()); + } else if (policy->GetNegativeStartIndexOutOfBoundsRule() == NegativeStartIndexOutOfBoundsRule::EMPTY_STRING && + policy->GetZeroStartIndexSupportRule() == ZeroStartIndexSupportRule::IS_NOT_SUPPORT) { + functionRegistries.push_back(make_unique()); + } + + return functionRegistries; +} + +vector> FunctionRegistry::GetBatchFunctionRegistries() +{ + vector> functionRegistries; + functionRegistries.push_back(make_unique()); + functionRegistries.push_back(make_unique()); + functionRegistries.push_back(make_unique()); + functionRegistries.push_back(make_unique()); + functionRegistries.push_back(make_unique()); + functionRegistries.push_back(make_unique()); + functionRegistries.push_back(make_unique()); + functionRegistries.push_back(make_unique()); + + auto policy = GetProperties().GetPolicy(); + if (policy->GetRoundingRule() == RoundingRule::HALF_UP) { + functionRegistries.push_back(make_unique()); + functionRegistries.push_back(make_unique()); + } else { + functionRegistries.push_back(make_unique()); + functionRegistries.push_back(make_unique()); + } + + if (policy->GetCheckReScaleRule() == CheckReScaleRule::NOT_CHECK_RESCALE) { + functionRegistries.push_back(make_unique()); + } else { + functionRegistries.push_back(make_unique()); + } + + if (policy->GetStringToDateFormatRule() == StringToDateFormatRule::NOT_ALLOW_REDUCED_PRECISION) { + functionRegistries.push_back(make_unique()); + } else { + functionRegistries.push_back(make_unique()); + } + + if (policy->GetEmptySearchStrReplaceRule() == EmptySearchStrReplaceRule::REPLACE) { + functionRegistries.push_back(make_unique()); + } else { + functionRegistries.push_back(make_unique()); + } + + if (policy->GetStringToDecimalRule() == StringToDecimalRule::OVERFLOW_AS_ROUND_UP) { + functionRegistries.push_back(make_unique()); + } else { + functionRegistries.push_back(make_unique()); + } + + if (policy->GetNegativeStartIndexOutOfBoundsRule() == NegativeStartIndexOutOfBoundsRule::INTERCEPT_FROM_BEYOND && + policy->GetZeroStartIndexSupportRule() == ZeroStartIndexSupportRule::IS_SUPPORT) { + functionRegistries.push_back(make_unique()); + } else if (policy->GetNegativeStartIndexOutOfBoundsRule() == NegativeStartIndexOutOfBoundsRule::EMPTY_STRING && + policy->GetZeroStartIndexSupportRule() == ZeroStartIndexSupportRule::IS_SUPPORT) { + functionRegistries.push_back(make_unique()); + } else if (policy->GetNegativeStartIndexOutOfBoundsRule() == + NegativeStartIndexOutOfBoundsRule::INTERCEPT_FROM_BEYOND && + policy->GetZeroStartIndexSupportRule() == ZeroStartIndexSupportRule::IS_NOT_SUPPORT) { + functionRegistries.push_back(make_unique()); + } else if (policy->GetNegativeStartIndexOutOfBoundsRule() == NegativeStartIndexOutOfBoundsRule::EMPTY_STRING && + policy->GetZeroStartIndexSupportRule() == ZeroStartIndexSupportRule::IS_NOT_SUPPORT) { + functionRegistries.push_back(make_unique()); + } + + return functionRegistries; +} + +std::vector FunctionRegistry::InitializeRowFunc() +{ + hiveUdfMap = std::make_unique>(); + + std::vector allFunctions; + functionRegistry = + std::make_unique>(); + functionNullRegistry = + std::make_unique>(); + + auto registries = GetRowFunctionRegistries(); + for (auto const & registry : registries) { + auto functions = registry->GetFunctions(); + allFunctions.insert(std::end(allFunctions), functions.begin(), functions.end()); + } + for (auto &function : allFunctions) { + for (auto &signature : function.GetSignatures()) { + if (functionRegistry->find(&signature) != functionRegistry->end()) { + LogWarn("Trying to register functions with same signature: %s", signature.ToString().c_str()); + } + functionRegistry->insert(std::make_pair(&signature, &function)); + if (function.GetNullableResultType() == INPUT_DATA_AND_OVERFLOW_NULL || + (function.GetNullableResultType() == INPUT_DATA_AND_NULL_AND_RETURN_NULL && + signature.GetName().find("_null") != string::npos)) { + functionNullRegistry->insert(std::make_pair(&signature, &function)); + } + } + } + + return allFunctions; +} + +std::vector FunctionRegistry::InitializeBatchFunc() +{ + hiveUdfMap = std::make_unique>(); + + std::vector allFunctions; + functionRegistry = + std::make_unique>(); + functionNullRegistry = + std::make_unique>(); + + auto registries = GetBatchFunctionRegistries(); + for (auto const & registry : registries) { + auto functions = registry->GetFunctions(); + allFunctions.insert(std::end(allFunctions), functions.begin(), functions.end()); + } + for (auto &function : allFunctions) { + for (auto &signature : function.GetSignatures()) { + if (functionRegistry->find(&signature) != functionRegistry->end()) { + LogWarn("Trying to register functions with same signature: %s", signature.ToString().c_str()); + } + functionRegistry->insert(std::make_pair(&signature, &function)); + if (function.GetNullableResultType() == INPUT_DATA_AND_OVERFLOW_NULL || + (function.GetNullableResultType() == INPUT_DATA_AND_NULL_AND_RETURN_NULL && + signature.GetName().find("_null") != string::npos)) { + functionNullRegistry->insert(std::make_pair(&signature, &function)); + } + } + } + + return allFunctions; +} + +FunctionRegistry::~FunctionRegistry() = default; + +const Function *FunctionRegistry::LookupFunction(FunctionSignature *signature) +{ + auto result = functionRegistry->find(signature); + if (result == functionRegistry->end()) { + return nullptr; + } + return result->second; +} + +bool FunctionRegistry::LookupNullFunction(FunctionSignature *signature) +{ + auto signatureNull = FunctionSignature(signature->GetName() + "_null", signature->GetParams(), + signature->GetReturnType(), signature->GetFunctionAddress()); + auto result = functionNullRegistry->find(&signatureNull); + return result != functionNullRegistry->end(); +} + +// Some functions such as CastDecimal128ToStringRetNull(), it needs both contextPtr and overflowConfig as parameters. +// The purpose of the below function is to find this functions. +bool FunctionRegistry::IsNullExecutionContextSet(FunctionSignature *signature) +{ + auto signatureNull = FunctionSignature(signature->GetName() + "_null", signature->GetParams(), + signature->GetReturnType(), signature->GetFunctionAddress()); + auto result = functionNullRegistry->find(&signatureNull); + return result->second->IsExecutionContextSet(); +} + +void FunctionRegistry::InitHiveUdfMap() +{ + HiveUdfRegistry::GenerateHiveUdfMap(*(hiveUdfMap.get())); +} + +const std::string &FunctionRegistry::LookupHiveUdf(const std::string &udfName) +{ + std::call_once(initHiveUdfMap, InitHiveUdfMap); + auto result = hiveUdfMap->find(udfName); + if (result == hiveUdfMap->end()) { + return INVALID_HIVE_UDF; + } + return result->second; +} + +std::vector &FunctionRegistry::GetRowFunctions() +{ + return registeredRowFunctions; +} + +std::vector &FunctionRegistry::GetBatchFunctions() +{ + return registeredBatchFunctions; +} +} diff --git a/core/src/codegen/func_registry.h b/core/src/codegen/func_registry.h new file mode 100644 index 0000000..9884f53 --- /dev/null +++ b/core/src/codegen/func_registry.h @@ -0,0 +1,84 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description: registry function + */ +#ifndef OMNI_RUNTIME_FUNC_REGISTRY_H +#define OMNI_RUNTIME_FUNC_REGISTRY_H + +#include +#include +#include + +#include "function.h" +#include "func_registry_context.h" +#include "func_registry_decimal.h" +#include "func_registry_dictionary.h" +#include "func_registry_math.h" +#include "func_registry_hash.h" +#include "func_registry_might_contain.h" +#include "func_registry_string.h" +#include "func_registry_varchar_vector.h" +#include "func_registry_hive_udf.h" +#include "func_registry_datetime.h" + +#include "batch_func_registry_decimal.h" +#include "batch_func_registry_dictionary.h" +#include "batch_func_registry_math.h" +#include "batch_func_registry_hash.h" +#include "batch_func_registry_string.h" +#include "batch_func_registry_varchar_vector.h" +#include "batch_func_registry_util.h" +#include "batch_func_registry_datetime.h" + +namespace omniruntime::codegen { +struct Hash { + std::size_t operator () (const FunctionSignature *signature) const + { + return signature->HashCode(); + } +}; +struct Equals { + bool operator () (const FunctionSignature *s1, const FunctionSignature *s2) const + { + return *s1 == *s2; + } +}; + +using FunctionMapPtr = std::unique_ptr>; +using HiveUdfMapPtr = std::unique_ptr>; + +class FunctionRegistry { +public: + ~FunctionRegistry(); + + static const Function *LookupFunction(FunctionSignature *signature); + + static bool LookupNullFunction(FunctionSignature *signature); + + static bool IsNullExecutionContextSet(FunctionSignature *signature); + + static const std::string &LookupHiveUdf(const std::string &udfName); + + static std::vector> GetRowFunctionRegistries(); + + static std::vector> GetBatchFunctionRegistries(); + + static std::vector &GetRowFunctions(); + + static std::vector &GetBatchFunctions(); + + static void InitHiveUdfMap(); + +private: + static std::vector registeredRowFunctions; + static std::vector registeredBatchFunctions; + static FunctionMapPtr functionRegistry; + static FunctionMapPtr functionNullRegistry; + static HiveUdfMapPtr hiveUdfMap; + static std::once_flag initHiveUdfMap; + + static std::vector InitializeRowFunc(); + static std::vector InitializeBatchFunc(); +}; +} +#endif // OMNI_RUNTIME_FUNC_REGISTRY_H diff --git a/core/src/codegen/func_registry_base.h b/core/src/codegen/func_registry_base.h new file mode 100644 index 0000000..0c00822 --- /dev/null +++ b/core/src/codegen/func_registry_base.h @@ -0,0 +1,18 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description: registry external function + */ +#ifndef OMNI_RUNTIME_FUNC_REGISTRY_BASE_H +#define OMNI_RUNTIME_FUNC_REGISTRY_BASE_H + +#include + +namespace omniruntime::codegen { +class BaseFunctionRegistry { +public: + virtual std::vector GetFunctions() = 0; + virtual ~BaseFunctionRegistry() = default; +}; +} + +#endif // OMNI_RUNTIME_FUNC_REGISTRY_BASE_H \ No newline at end of file diff --git a/core/src/codegen/func_registry_context.cpp b/core/src/codegen/func_registry_context.cpp new file mode 100644 index 0000000..e54ab43 --- /dev/null +++ b/core/src/codegen/func_registry_context.cpp @@ -0,0 +1,17 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description: Context Helper Functions Registry + */ +#include "func_registry_context.h" +#include "context_helper.h" + +namespace omniruntime::codegen { +using namespace omniruntime::type; + +std::vector ContextFunctionRegistry::GetFunctions() +{ + std::vector contextFnRegistry = { Function(reinterpret_cast(ArenaAllocatorMalloc), + "ArenaAllocatorMalloc", {}, { OMNI_LONG, OMNI_INT }, OMNI_CHAR) }; + return contextFnRegistry; +} +} \ No newline at end of file diff --git a/core/src/codegen/func_registry_context.h b/core/src/codegen/func_registry_context.h new file mode 100644 index 0000000..01dfbec --- /dev/null +++ b/core/src/codegen/func_registry_context.h @@ -0,0 +1,17 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description: Context Helper Function Registry + */ +#ifndef OMNI_RUNTIME_FUNC_REGISTRY_CONTEXT_H +#define OMNI_RUNTIME_FUNC_REGISTRY_CONTEXT_H +#include "function.h" +#include "func_registry_base.h" + +namespace omniruntime::codegen { +class ContextFunctionRegistry : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; +} + +#endif // OMNI_RUNTIME_FUNC_REGISTRY_CONTEXT_H diff --git a/core/src/codegen/func_registry_datetime.cpp b/core/src/codegen/func_registry_datetime.cpp new file mode 100644 index 0000000..dddac3a --- /dev/null +++ b/core/src/codegen/func_registry_datetime.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + * Description: Date Time Function Registry + */ + +#include "func_registry_datetime.h" +#include "functions/datetime_functions.h" + +namespace omniruntime::codegen { +using namespace omniruntime::type; +using namespace omniruntime::codegen::function; + +std::vector DateTimeFunctionRegistry::GetFunctions() +{ + std::vector dateTimeFnRegistry = { + Function(reinterpret_cast(UnixTimestampFromStr), "unix_timestamp", {}, + { OMNI_VARCHAR, OMNI_VARCHAR, OMNI_VARCHAR, OMNI_VARCHAR }, OMNI_LONG, INPUT_DATA_AND_NULL_AND_RETURN_NULL), + Function(reinterpret_cast(UnixTimestampFromDate), "unix_timestamp", {}, + { OMNI_DATE32, OMNI_VARCHAR, OMNI_VARCHAR, OMNI_VARCHAR }, OMNI_LONG, INPUT_DATA), + Function(Function(reinterpret_cast(FromUnixTime), "from_unixtime", {}, + { OMNI_LONG, OMNI_VARCHAR, OMNI_VARCHAR }, OMNI_VARCHAR, INPUT_DATA, true)), + Function(reinterpret_cast(FromUnixTimeRetNull), "from_unixtime_null", {}, + { OMNI_LONG, OMNI_VARCHAR, OMNI_VARCHAR }, OMNI_VARCHAR, INPUT_DATA_AND_OVERFLOW_NULL, true), + Function(reinterpret_cast(DateTrunc), "trunc_date", {}, { OMNI_DATE32, OMNI_VARCHAR }, + OMNI_DATE32, INPUT_DATA, true), + Function(reinterpret_cast(DateTruncRetNull), "trunc_date_null", {}, { OMNI_DATE32, OMNI_VARCHAR }, + OMNI_DATE32, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(DateAdd), "date_add", {}, {OMNI_DATE32, OMNI_INT}, OMNI_DATE32, INPUT_DATA) + }; + + return dateTimeFnRegistry; +} +} diff --git a/core/src/codegen/func_registry_datetime.h b/core/src/codegen/func_registry_datetime.h new file mode 100644 index 0000000..7169b40 --- /dev/null +++ b/core/src/codegen/func_registry_datetime.h @@ -0,0 +1,19 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: Date Time Function Registry + */ + +#ifndef OMNI_RUNTIME_FUNC_REGISTRY_DATETIME_H +#define OMNI_RUNTIME_FUNC_REGISTRY_DATETIME_H + +#include "function.h" +#include "func_registry_base.h" + +namespace omniruntime::codegen { +class DateTimeFunctionRegistry : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; +} + +#endif // OMNI_RUNTIME_FUNC_REGISTRY_DATETIME_H diff --git a/core/src/codegen/func_registry_decimal.cpp b/core/src/codegen/func_registry_decimal.cpp new file mode 100644 index 0000000..ae1f860 --- /dev/null +++ b/core/src/codegen/func_registry_decimal.cpp @@ -0,0 +1,702 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description: Decimal Function Registry + */ +#include "func_registry_decimal.h" +#include "functions/decimal_arithmetic_functions.h" +#include "functions/decimal_cast_functions.h" + +namespace omniruntime::codegen { +using namespace omniruntime::type; +using namespace codegen::function; + +const std::string DecimalCastFnStr() +{ + const std::string decimalCastFnStr = "CAST"; + return decimalCastFnStr; +} + +const std::string DecimalCastNullFnStr() +{ + const std::string decimalCastNullFnStr = "CAST_null"; + return decimalCastNullFnStr; +} + +const std::string DecimalGreatestFnStr() +{ + const std::string decimalGreatestFnStr = "Greatest"; + return decimalGreatestFnStr; +} + +const std::string DecimalGreatestNullFnStr() +{ + const std::string decimalGreatestNullFnStr = "Greatest_null"; + return decimalGreatestNullFnStr; +} + +const std::string DecimalAbsFnStr() +{ + const std::string decimalAbsFnStr = "abs"; + return decimalAbsFnStr; +} + +const std::string MakeDecimalFnStr() +{ + const std::string makeDecimalFnStr = "MakeDecimal"; + return makeDecimalFnStr; +} + +const std::string MakeDecimalNullFnStr() +{ + const std::string makeDecimalNullFnStr = "MakeDecimal_null"; + return makeDecimalNullFnStr; +} + +const std::string DecimalRoundFnStr() +{ + const std::string decimalRoundFnStr = "round"; + return decimalRoundFnStr; +} + +const std::string RoundNullFnStr() +{ + const std::string roundNullFnStr = "round_null"; + return roundNullFnStr; +} + +const std::string UnscaledValueFnStr() +{ + const std::string unscaledValueFnStr = "UnscaledValue"; + return unscaledValueFnStr; +} + +const std::string Decimal64CompareFnStr() +{ + const std::string decimal64CompareFnStr = "Decimal64Compare"; + return decimal64CompareFnStr; +} + +const std::string Decimal128CompareFnStr() +{ + const std::string decimal128CompareFnStr = "Decimal128Compare"; + return decimal128CompareFnStr; +} + +const std::string AddDecimal128FnStr() +{ + const std::string addDecimal128FnStr = "Add_decimal128"; + return addDecimal128FnStr; +} + +const std::string SubDecimal128FnStr() +{ + const std::string subDecimal128FnStr = "Sub_decimal128"; + return subDecimal128FnStr; +} + +const std::string MulDecimal128FnStr() +{ + const std::string mulDecimal128FnStr = "Mul_decimal128"; + return mulDecimal128FnStr; +} + +const std::string DivDecimal128FnStr() +{ + const std::string divDecimal128FnStr = "Div_decimal128"; + return divDecimal128FnStr; +} + +const std::string ModDecimal128FnStr() +{ + const std::string modDecimal128FnStr = "Mod_decimal128"; + return modDecimal128FnStr; +} + +const std::string AddDecimal64FnStr() +{ + const std::string addDecimal64FnStr = "Add_decimal64"; + return addDecimal64FnStr; +} + +const std::string SubDecimal64FnStr() +{ + const std::string subDecimal64FnStr = "Sub_decimal64"; + return subDecimal64FnStr; +} + +const std::string MulDecimal64FnStr() +{ + const std::string mulDecimal64FnStr = "Mul_decimal64"; + return mulDecimal64FnStr; +} + +const std::string DivDecimal64FnStr() +{ + const std::string divDecimal64FnStr = "Div_decimal64"; + return divDecimal64FnStr; +} + +const std::string ModDecimal64FnStr() +{ + const std::string modDecimal64FnStr = "Mod_decimal64"; + return modDecimal64FnStr; +} + +const std::string AddDecimal128NullFnStr() +{ + const std::string addDecimal128NullFnStr = "Add_decimal128_null"; + return addDecimal128NullFnStr; +} + +const std::string SubDecimal128NullFnStr() +{ + const std::string subDecimal128NullFnStr = "Sub_decimal128_null"; + return subDecimal128NullFnStr; +} + +const std::string MulDecimal128NullFnStr() +{ + const std::string mulDecimal128NullFnStr = "Mul_decimal128_null"; + return mulDecimal128NullFnStr; +} + +const std::string DivDecimal128NullFnStr() +{ + const std::string divDecimal128NullFnStr = "Div_decimal128_null"; + return divDecimal128NullFnStr; +} + +const std::string ModDecimal128NullFnStr() +{ + const std::string modDecimal128NullFnStr = "Mod_decimal128_null"; + return modDecimal128NullFnStr; +} + +const std::string AddDecimal64NullFnStr() +{ + const std::string addDecimal64NullFnStr = "Add_decimal64_null"; + return addDecimal64NullFnStr; +} + +const std::string SubDecimal64NullFnStr() +{ + const std::string subDecimal64NullFnStr = "Sub_decimal64_null"; + return subDecimal64NullFnStr; +} + +const std::string MulDecimal64NullFnStr() +{ + const std::string mulDecimal64NullFnStr = "Mul_decimal64_null"; + return mulDecimal64NullFnStr; +} + +const std::string DivDecimal64NullFnStr() +{ + const std::string divDecimal64NullFnStr = "Div_decimal64_null"; + return divDecimal64NullFnStr; +} + +const std::string ModDecimal64NullFnStr() +{ + const std::string modDecimal64NullFnStr = "Mod_decimal64_null"; + return modDecimal64NullFnStr; +} + +const std::string TryAddDecimal64FnStr() +{ + return "Try_add_decimal64"; +} + +const std::string TryAddDecimal128FnStr() +{ + return "Try_add_decimal128"; +} + +const std::string TrySubDecimal64FnStr() +{ + return "Try_sub_decimal64"; +} + +const std::string TrySubDecimal128FnStr() +{ + return "Try_sub_decimal128"; +} + +const std::string TryMulDecimal64FnStr() +{ + return "Try_mul_decimal64"; +} + +const std::string TryMulDecimal128FnStr() +{ + return "Try_mul_decimal128"; +} + +const std::string TryDivDecimal64FnStr() +{ + return "Try_div_decimal64"; +} + +const std::string TryDivDecimal128FnStr() +{ + return "Try_div_decimal128"; +} + +std::vector DecimalFunctionRegistry::GetFunctions() +{ + std::vector paramTypes128 = { OMNI_DECIMAL128, OMNI_DECIMAL128 }; + std::vector paramTypes64 = { OMNI_DECIMAL64, OMNI_DECIMAL64 }; + std::vector paramTypes64Op128 = { OMNI_DECIMAL64, OMNI_DECIMAL128 }; + std::vector paramTypes128Op64 = { OMNI_DECIMAL128, OMNI_DECIMAL64 }; + DataTypeId retType128 = OMNI_DECIMAL128; + DataTypeId retType64 = OMNI_DECIMAL64; + + static std::vector decimalFnRegistry = { + Function(reinterpret_cast(CastDecimal64To64), DecimalCastFnStr(), {}, { OMNI_DECIMAL64 }, + OMNI_DECIMAL64, INPUT_DATA, true), + Function(reinterpret_cast(CastDecimal128To128), DecimalCastFnStr(), {}, { OMNI_DECIMAL128 }, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(CastDecimal64To128), DecimalCastFnStr(), {}, { OMNI_DECIMAL64 }, + OMNI_DECIMAL128, INPUT_DATA, true), + Function(reinterpret_cast(CastDecimal128To64), DecimalCastFnStr(), {}, { OMNI_DECIMAL128 }, + OMNI_DECIMAL64, INPUT_DATA, true), + + Function(reinterpret_cast(CastIntToDecimal64), DecimalCastFnStr(), {}, { OMNI_INT }, OMNI_DECIMAL64, + INPUT_DATA, true), + Function(reinterpret_cast(CastInt16ToDecimal64), DecimalCastFnStr(), {}, { OMNI_SHORT }, OMNI_DECIMAL64, + INPUT_DATA, true), + Function(reinterpret_cast(CastInt8ToDecimal64), DecimalCastFnStr(), {}, { OMNI_BYTE }, OMNI_DECIMAL64, + INPUT_DATA, true), + Function(reinterpret_cast(CastLongToDecimal64), DecimalCastFnStr(), {}, { OMNI_LONG }, OMNI_DECIMAL64, + INPUT_DATA, true), + Function(reinterpret_cast(CastDoubleToDecimal64), DecimalCastFnStr(), {}, { OMNI_DOUBLE }, + OMNI_DECIMAL64, INPUT_DATA, true), + + Function(reinterpret_cast(CastIntToDecimal128), DecimalCastFnStr(), {}, { OMNI_INT }, OMNI_DECIMAL128, + INPUT_DATA, true), + Function(reinterpret_cast(CastInt16ToDecimal128), DecimalCastFnStr(), {}, { OMNI_SHORT }, OMNI_DECIMAL128, + INPUT_DATA, true), + Function(reinterpret_cast(CastInt8ToDecimal128), DecimalCastFnStr(), {}, { OMNI_BYTE }, OMNI_DECIMAL128, + INPUT_DATA, true), + Function(reinterpret_cast(CastLongToDecimal128), DecimalCastFnStr(), {}, { OMNI_LONG }, OMNI_DECIMAL128, + INPUT_DATA, true), + Function(reinterpret_cast(CastDoubleToDecimal128), DecimalCastFnStr(), {}, { OMNI_DOUBLE }, + OMNI_DECIMAL128, INPUT_DATA, true), + + Function(reinterpret_cast(AbsDecimal128), DecimalAbsFnStr(), {}, { OMNI_DECIMAL128 }, retType128, + INPUT_DATA), + Function(reinterpret_cast(AbsDecimal64), DecimalAbsFnStr(), {}, { OMNI_DECIMAL64 }, OMNI_DECIMAL64, + INPUT_DATA), + + Function(reinterpret_cast(RoundDecimal128), DecimalRoundFnStr(), {}, { OMNI_DECIMAL128, OMNI_INT }, + retType128, INPUT_DATA, true), + Function(reinterpret_cast(RoundDecimal64), DecimalRoundFnStr(), {}, { OMNI_DECIMAL64, OMNI_INT }, + retType64, INPUT_DATA, true), + Function(reinterpret_cast(RoundDecimal128WithoutRound), DecimalRoundFnStr(), {}, { OMNI_DECIMAL128 }, + retType128, INPUT_DATA, true), + Function(reinterpret_cast(RoundDecimal64WithoutRound), DecimalRoundFnStr(), {}, { OMNI_DECIMAL64 }, + retType64, INPUT_DATA, true), + + Function(reinterpret_cast(Decimal64Compare), Decimal64CompareFnStr(), {}, paramTypes64, OMNI_INT), + Function(reinterpret_cast(Decimal128Compare), Decimal128CompareFnStr(), {}, paramTypes128, OMNI_INT), + + // Return Null + Function(reinterpret_cast(RoundDecimal128RetNull), RoundNullFnStr(), {}, { OMNI_DECIMAL128, OMNI_INT }, + retType128, INPUT_DATA), + Function(reinterpret_cast(RoundDecimal64RetNull), RoundNullFnStr(), {}, { OMNI_DECIMAL64, OMNI_INT }, + retType64, INPUT_DATA), + + Function(reinterpret_cast(AddDec64Dec64Dec64RetNull), AddDecimal64NullFnStr(), {}, paramTypes64, + retType64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(AddDec64Dec64Dec128RetNull), AddDecimal64NullFnStr(), {}, paramTypes64, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(AddDec128Dec128Dec128RetNull), AddDecimal128NullFnStr(), {}, paramTypes128, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(AddDec64Dec128Dec128RetNull), AddDecimal64NullFnStr(), {}, paramTypes64Op128, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(AddDec128Dec64Dec128RetNull), AddDecimal128NullFnStr(), {}, paramTypes128Op64, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(AddDec64Dec64Dec64RetNull), TryAddDecimal64FnStr(), {}, paramTypes64, + retType64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(AddDec64Dec64Dec128RetNull), TryAddDecimal64FnStr(), {}, paramTypes64, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(AddDec64Dec128Dec128RetNull), TryAddDecimal64FnStr(), {}, paramTypes64Op128, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(AddDec128Dec128Dec128RetNull), TryAddDecimal128FnStr(), {}, paramTypes128, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(AddDec128Dec64Dec128RetNull), TryAddDecimal128FnStr(), {}, paramTypes128Op64, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + + Function(reinterpret_cast(SubDec64Dec64Dec64RetNull), SubDecimal64NullFnStr(), {}, paramTypes64, + retType64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(SubDec64Dec64Dec128RetNull), SubDecimal64NullFnStr(), {}, paramTypes64, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(SubDec128Dec128Dec128RetNull), SubDecimal128NullFnStr(), {}, paramTypes128, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(SubDec64Dec128Dec128RetNull), SubDecimal64NullFnStr(), {}, paramTypes64Op128, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(SubDec128Dec64Dec128RetNull), SubDecimal128NullFnStr(), {}, paramTypes128Op64, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(SubDec64Dec64Dec64RetNull), TrySubDecimal64FnStr(), {}, paramTypes64, + retType64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(SubDec64Dec64Dec128RetNull), TrySubDecimal64FnStr(), {}, paramTypes64, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(SubDec64Dec128Dec128RetNull), TrySubDecimal64FnStr(), {}, paramTypes64Op128, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(SubDec128Dec128Dec128RetNull), TrySubDecimal128FnStr(), {}, paramTypes128, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(SubDec128Dec64Dec128RetNull), TrySubDecimal128FnStr(), {}, paramTypes128Op64, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + + Function(reinterpret_cast(MulDec64Dec64Dec64RetNull), MulDecimal64NullFnStr(), {}, paramTypes64, + retType64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(MulDec64Dec64Dec128RetNull), MulDecimal64NullFnStr(), {}, paramTypes64, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(MulDec128Dec128Dec128RetNull), MulDecimal128NullFnStr(), {}, paramTypes128, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(MulDec64Dec128Dec128RetNull), MulDecimal64NullFnStr(), {}, paramTypes64Op128, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(MulDec128Dec64Dec128RetNull), MulDecimal128NullFnStr(), {}, paramTypes128Op64, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(MulDec64Dec64Dec64RetNull), TryMulDecimal64FnStr(), {}, paramTypes64, + retType64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(MulDec64Dec64Dec128RetNull), TryMulDecimal64FnStr(), {}, paramTypes64, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(MulDec64Dec128Dec128RetNull), TryMulDecimal64FnStr(), {}, paramTypes64Op128, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(MulDec128Dec128Dec128RetNull), TryMulDecimal128FnStr(), {}, paramTypes128, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(MulDec128Dec64Dec128RetNull), TryMulDecimal128FnStr(), {}, paramTypes128Op64, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + + Function(reinterpret_cast(DivDec64Dec64Dec64RetNull), DivDecimal64NullFnStr(), {}, paramTypes64, + retType64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(DivDec64Dec128Dec64RetNull), DivDecimal64NullFnStr(), {}, paramTypes64Op128, + retType64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(DivDec128Dec64Dec64RetNull), DivDecimal128NullFnStr(), {}, paramTypes128Op64, + retType64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(DivDec64Dec64Dec128RetNull), DivDecimal64NullFnStr(), {}, paramTypes64, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(DivDec128Dec128Dec128RetNull), DivDecimal128NullFnStr(), {}, paramTypes128, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(DivDec64Dec128Dec128RetNull), DivDecimal64NullFnStr(), {}, paramTypes64Op128, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(DivDec128Dec64Dec128RetNull), DivDecimal128NullFnStr(), {}, paramTypes128Op64, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(DivDec64Dec64Dec64RetNull), TryDivDecimal64FnStr(), {}, paramTypes64, + retType64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(DivDec64Dec128Dec64RetNull), TryDivDecimal64FnStr(), {}, paramTypes64Op128, + retType64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(DivDec64Dec64Dec128RetNull), TryDivDecimal64FnStr(), {}, paramTypes64, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(DivDec64Dec128Dec128RetNull), TryDivDecimal64FnStr(), {}, paramTypes64Op128, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(DivDec128Dec128Dec128RetNull), TryDivDecimal128FnStr(), {}, paramTypes128, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(DivDec128Dec64Dec64RetNull), TryDivDecimal128FnStr(), {}, paramTypes128Op64, + retType64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(DivDec128Dec64Dec128RetNull), TryDivDecimal128FnStr(), {}, paramTypes128Op64, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + + Function(reinterpret_cast(ModDec64Dec64Dec64RetNull), ModDecimal64NullFnStr(), {}, paramTypes64, + retType64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(ModDec64Dec128Dec64RetNull), ModDecimal64NullFnStr(), {}, paramTypes64Op128, + retType64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(ModDec128Dec64Dec64RetNull), ModDecimal128NullFnStr(), {}, paramTypes128Op64, + retType64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(ModDec128Dec64Dec128RetNull), ModDecimal128NullFnStr(), {}, paramTypes128Op64, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(ModDec128Dec128Dec128RetNull), ModDecimal128NullFnStr(), {}, paramTypes128, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(ModDec128Dec128Dec64RetNull), ModDecimal128NullFnStr(), {}, paramTypes128, + retType64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(ModDec64Dec128Dec128RetNull), ModDecimal64NullFnStr(), {}, paramTypes64Op128, + retType128, INPUT_DATA_AND_OVERFLOW_NULL), + + Function(reinterpret_cast(CastDecimal64To64RetNull), DecimalCastNullFnStr(), {}, { OMNI_DECIMAL64 }, + OMNI_DECIMAL64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(CastDecimal128To128RetNull), DecimalCastNullFnStr(), {}, { OMNI_DECIMAL128 }, + OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(CastDecimal64To128RetNull), DecimalCastNullFnStr(), {}, { OMNI_DECIMAL64 }, + OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(CastDecimal128To64RetNull), DecimalCastNullFnStr(), {}, { OMNI_DECIMAL128 }, + OMNI_DECIMAL64, INPUT_DATA_AND_OVERFLOW_NULL), + + Function(reinterpret_cast(CastIntToDecimal64RetNull), DecimalCastNullFnStr(), {}, { OMNI_INT }, + OMNI_DECIMAL64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(CastInt16ToDecimal64RetNull), DecimalCastNullFnStr(), {}, { OMNI_SHORT }, + OMNI_DECIMAL64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(CastInt8ToDecimal64RetNull), DecimalCastNullFnStr(), {}, { OMNI_BYTE }, + OMNI_DECIMAL64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(CastLongToDecimal64RetNull), DecimalCastNullFnStr(), {}, { OMNI_LONG }, + OMNI_DECIMAL64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(CastDoubleToDecimal64RetNull), DecimalCastNullFnStr(), {}, { OMNI_DOUBLE }, + OMNI_DECIMAL64, INPUT_DATA_AND_OVERFLOW_NULL), + + Function(reinterpret_cast(CastIntToDecimal128RetNull), DecimalCastNullFnStr(), {}, { OMNI_INT }, + OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(CastInt16ToDecimal128RetNull), DecimalCastNullFnStr(), {}, { OMNI_SHORT }, + OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(CastInt8ToDecimal128RetNull), DecimalCastNullFnStr(), {}, { OMNI_BYTE }, + OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(CastLongToDecimal128RetNull), DecimalCastNullFnStr(), {}, { OMNI_LONG }, + OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(CastDoubleToDecimal128RetNull), DecimalCastNullFnStr(), {}, { OMNI_DOUBLE }, + OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + + Function(reinterpret_cast(CastDecimal64ToIntRetNull), DecimalCastNullFnStr(), {}, { OMNI_DECIMAL64 }, + OMNI_INT, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(CastDecimal64ToInt16RetNull), DecimalCastNullFnStr(), {}, { OMNI_DECIMAL64 }, + OMNI_SHORT, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(CastDecimal64ToInt8RetNull), DecimalCastNullFnStr(), {}, { OMNI_DECIMAL64 }, + OMNI_BYTE, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(CastDecimal64ToLongRetNull), DecimalCastNullFnStr(), {}, { OMNI_DECIMAL64 }, + OMNI_LONG, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(CastDecimal64ToDoubleRetNull), DecimalCastNullFnStr(), {}, { OMNI_DECIMAL64 }, + OMNI_DOUBLE, INPUT_DATA_AND_OVERFLOW_NULL), + + Function(reinterpret_cast(CastDecimal128ToIntRetNull), DecimalCastNullFnStr(), {}, { OMNI_DECIMAL128 }, + OMNI_INT, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(CastDecimal128ToInt16RetNull), DecimalCastNullFnStr(), {}, { OMNI_DECIMAL128 }, + OMNI_SHORT, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(CastDecimal128ToInt8RetNull), DecimalCastNullFnStr(), {}, { OMNI_DECIMAL128 }, + OMNI_BYTE, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(CastDecimal128ToLongRetNull), DecimalCastNullFnStr(), {}, { OMNI_DECIMAL128 }, + OMNI_LONG, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(CastDecimal128ToDoubleRetNull), DecimalCastNullFnStr(), {}, + { OMNI_DECIMAL128 }, OMNI_DOUBLE, INPUT_DATA_AND_OVERFLOW_NULL), + + // UnscaledValue + Function(reinterpret_cast(UnscaledValue64), UnscaledValueFnStr(), {}, { OMNI_DECIMAL64 }, OMNI_LONG, + INPUT_DATA), + // MakeDecimal + Function(reinterpret_cast(MakeDecimal64), MakeDecimalFnStr(), {}, { OMNI_LONG }, OMNI_DECIMAL64, + INPUT_DATA, true), + Function(reinterpret_cast(MakeDecimal64RetNull), MakeDecimalNullFnStr(), {}, { OMNI_LONG }, + OMNI_DECIMAL64, INPUT_DATA_AND_OVERFLOW_NULL), + // DecimalGreatest + Function(reinterpret_cast(GreatestDecimal64), DecimalGreatestFnStr(), {}, paramTypes64, retType64, + INPUT_DATA_AND_NULL_AND_RETURN_NULL, true), + Function(reinterpret_cast(GreatestDecimal128), DecimalGreatestFnStr(), {}, paramTypes128, retType128, + INPUT_DATA_AND_NULL_AND_RETURN_NULL, true), + Function(reinterpret_cast(GreatestDecimal64RetNull), DecimalGreatestNullFnStr(), {}, paramTypes64, + retType64, INPUT_DATA_AND_NULL_AND_RETURN_NULL), + Function(reinterpret_cast(GreatestDecimal128RetNull), DecimalGreatestNullFnStr(), {}, paramTypes128, + retType128, INPUT_DATA_AND_NULL_AND_RETURN_NULL), + }; + + return decimalFnRegistry; +} + +std::vector DecimalFunctionRegistryDown::GetFunctions() +{ + static std::vector decimalFnRegistry = { + Function(reinterpret_cast(CastDecimal64ToLongDown), DecimalCastFnStr(), {}, { OMNI_DECIMAL64 }, + OMNI_LONG, INPUT_DATA), + Function(reinterpret_cast(CastDecimal64ToIntDown), DecimalCastFnStr(), {}, { OMNI_DECIMAL64 }, OMNI_INT, + INPUT_DATA, true), + Function(reinterpret_cast(CastDecimal64ToInt16Down), DecimalCastFnStr(), {}, { OMNI_DECIMAL64 }, OMNI_SHORT, + INPUT_DATA, true), + Function(reinterpret_cast(CastDecimal64ToInt8Down), DecimalCastFnStr(), {}, { OMNI_DECIMAL64 }, OMNI_BYTE, + INPUT_DATA, true), + Function(reinterpret_cast(CastDecimal64ToDoubleDown), DecimalCastFnStr(), {}, { OMNI_DECIMAL64 }, + OMNI_DOUBLE, INPUT_DATA), + + Function(reinterpret_cast(CastDecimal128ToLongDown), DecimalCastFnStr(), {}, { OMNI_DECIMAL128 }, + OMNI_LONG, INPUT_DATA, true), + Function(reinterpret_cast(CastDecimal128ToIntDown), DecimalCastFnStr(), {}, { OMNI_DECIMAL128 }, + OMNI_INT, INPUT_DATA, true), + Function(reinterpret_cast(CastDecimal128ToInt16Down), DecimalCastFnStr(), {}, { OMNI_DECIMAL128 }, + OMNI_SHORT, INPUT_DATA, true), + Function(reinterpret_cast(CastDecimal128ToInt8Down), DecimalCastFnStr(), {}, { OMNI_DECIMAL128 }, + OMNI_BYTE, INPUT_DATA, true), + Function(reinterpret_cast(CastDecimal128ToDoubleDown), DecimalCastFnStr(), {}, { OMNI_DECIMAL128 }, + OMNI_DOUBLE, INPUT_DATA), + }; + + return decimalFnRegistry; +} + +std::vector DecimalFunctionRegistryHalfUp::GetFunctions() +{ + static std::vector decimalFnRegistry = { + Function(reinterpret_cast(CastDecimal64ToLongHalfUp), DecimalCastFnStr(), {}, { OMNI_DECIMAL64 }, + OMNI_LONG, INPUT_DATA), + Function(reinterpret_cast(CastDecimal64ToIntHalfUp), DecimalCastFnStr(), {}, { OMNI_DECIMAL64 }, + OMNI_INT, INPUT_DATA, true), + Function(reinterpret_cast(CastDecimal64ToInt16HalfUp), DecimalCastFnStr(), {}, { OMNI_DECIMAL64 }, + OMNI_SHORT, INPUT_DATA, true), + Function(reinterpret_cast(CastDecimal64ToInt8HalfUp), DecimalCastFnStr(), {}, { OMNI_DECIMAL64 }, + OMNI_BYTE, INPUT_DATA, true), + Function(reinterpret_cast(CastDecimal64ToDoubleHalfUp), DecimalCastFnStr(), {}, { OMNI_DECIMAL64 }, + OMNI_DOUBLE, INPUT_DATA), + + Function(reinterpret_cast(CastDecimal128ToLongHalfUp), DecimalCastFnStr(), {}, { OMNI_DECIMAL128 }, + OMNI_LONG, INPUT_DATA, true), + Function(reinterpret_cast(CastDecimal128ToIntHalfUp), DecimalCastFnStr(), {}, { OMNI_DECIMAL128 }, + OMNI_INT, INPUT_DATA, true), + Function(reinterpret_cast(CastDecimal128ToInt16HalfUp), DecimalCastFnStr(), {}, { OMNI_DECIMAL128 }, + OMNI_SHORT, INPUT_DATA, true), + Function(reinterpret_cast(CastDecimal128ToInt8HalfUp), DecimalCastFnStr(), {}, { OMNI_DECIMAL128 }, + OMNI_BYTE, INPUT_DATA, true), + Function(reinterpret_cast(CastDecimal128ToDoubleHalfUp), DecimalCastFnStr(), {}, { OMNI_DECIMAL128 }, + OMNI_DOUBLE, INPUT_DATA), + }; + + return decimalFnRegistry; +} + +std::vector DecimalFunctionRegistryNotReScale::GetFunctions() +{ + std::vector paramTypes128 = { OMNI_DECIMAL128, OMNI_DECIMAL128 }; + std::vector paramTypes64 = { OMNI_DECIMAL64, OMNI_DECIMAL64 }; + std::vector paramTypes64Op128 = { OMNI_DECIMAL64, OMNI_DECIMAL128 }; + std::vector paramTypes128Op64 = { OMNI_DECIMAL128, OMNI_DECIMAL64 }; + DataTypeId retType128 = OMNI_DECIMAL128; + DataTypeId retType64 = OMNI_DECIMAL64; + + static std::vector decimalFnRegistry = { + Function(reinterpret_cast(AddDec64Dec64Dec64NotReScale), AddDecimal64FnStr(), {}, paramTypes64, + retType64, INPUT_DATA, true), + Function(reinterpret_cast(AddDec64Dec64Dec128NotReScale), AddDecimal64FnStr(), {}, paramTypes64, + retType128, INPUT_DATA, true), + Function(reinterpret_cast(AddDec128Dec128Dec128NotReScale), AddDecimal128FnStr(), {}, paramTypes128, + retType128, INPUT_DATA, true), + Function(reinterpret_cast(AddDec64Dec128Dec128NotReScale), AddDecimal64FnStr(), {}, paramTypes64Op128, + retType128, INPUT_DATA, true), + Function(reinterpret_cast(AddDec128Dec64Dec128NotReScale), AddDecimal128FnStr(), {}, paramTypes128Op64, + retType128, INPUT_DATA, true), + + Function(reinterpret_cast(SubDec64Dec64Dec64NotReScale), SubDecimal64FnStr(), {}, paramTypes64, + retType64, INPUT_DATA, true), + Function(reinterpret_cast(SubDec64Dec64Dec128NotReScale), SubDecimal64FnStr(), {}, paramTypes64, + retType128, INPUT_DATA, true), + Function(reinterpret_cast(SubDec128Dec128Dec128NotReScale), SubDecimal128FnStr(), {}, paramTypes128, + retType128, INPUT_DATA, true), + Function(reinterpret_cast(SubDec64Dec128Dec128NotReScale), SubDecimal64FnStr(), {}, paramTypes64Op128, + retType128, INPUT_DATA, true), + Function(reinterpret_cast(SubDec128Dec64Dec128NotReScale), SubDecimal128FnStr(), {}, paramTypes128Op64, + retType128, INPUT_DATA, true), + + Function(reinterpret_cast(MulDec64Dec64Dec64NotReScale), MulDecimal64FnStr(), {}, paramTypes64, + retType64, INPUT_DATA, true), + Function(reinterpret_cast(MulDec64Dec64Dec128NotReScale), MulDecimal64FnStr(), {}, paramTypes64, + retType128, INPUT_DATA, true), + Function(reinterpret_cast(MulDec128Dec128Dec128NotReScale), MulDecimal128FnStr(), {}, paramTypes128, + retType128, INPUT_DATA, true), + Function(reinterpret_cast(MulDec64Dec128Dec128NotReScale), MulDecimal64FnStr(), {}, paramTypes64Op128, + retType128, INPUT_DATA, true), + Function(reinterpret_cast(MulDec128Dec64Dec128NotReScale), MulDecimal128FnStr(), {}, paramTypes128Op64, + retType128, INPUT_DATA, true), + + Function(reinterpret_cast(DivDec64Dec64Dec64NotReScale), DivDecimal64FnStr(), {}, paramTypes64, + retType64, INPUT_DATA, true), + Function(reinterpret_cast(DivDec64Dec128Dec64NotReScale), DivDecimal64FnStr(), {}, paramTypes64Op128, + retType64, INPUT_DATA, true), + Function(reinterpret_cast(DivDec128Dec64Dec64NotReScale), DivDecimal128FnStr(), {}, paramTypes128Op64, + retType64, INPUT_DATA, true), + Function(reinterpret_cast(DivDec64Dec64Dec128NotReScale), DivDecimal64FnStr(), {}, paramTypes64, + retType128, INPUT_DATA, true), + Function(reinterpret_cast(DivDec128Dec128Dec128NotReScale), DivDecimal128FnStr(), {}, paramTypes128, + retType128, INPUT_DATA, true), + Function(reinterpret_cast(DivDec64Dec128Dec128NotReScale), DivDecimal64FnStr(), {}, paramTypes64Op128, + retType128, INPUT_DATA, true), + Function(reinterpret_cast(DivDec128Dec64Dec128NotReScale), DivDecimal128FnStr(), {}, paramTypes128Op64, + retType128, INPUT_DATA, true), + + Function(reinterpret_cast(ModDec64Dec64Dec64NotReScale), ModDecimal64FnStr(), {}, paramTypes64, + retType64, INPUT_DATA, true), + Function(reinterpret_cast(ModDec64Dec128Dec64NotReScale), ModDecimal64FnStr(), {}, paramTypes64Op128, + retType64, INPUT_DATA, true), + Function(reinterpret_cast(ModDec128Dec64Dec64NotReScale), ModDecimal128FnStr(), {}, paramTypes128Op64, + retType64, INPUT_DATA, true), + Function(reinterpret_cast(ModDec128Dec64Dec128ReScale), ModDecimal128FnStr(), {}, paramTypes128Op64, + retType128, INPUT_DATA, true), + Function(reinterpret_cast(ModDec128Dec128Dec128NotReScale), ModDecimal128FnStr(), {}, paramTypes128, + retType128, INPUT_DATA, true), + Function(reinterpret_cast(ModDec128Dec128Dec64NotReScale), ModDecimal128FnStr(), {}, paramTypes128, + retType64, INPUT_DATA, true), + Function(reinterpret_cast(ModDec64Dec128Dec128NotReScale), ModDecimal64FnStr(), {}, paramTypes64Op128, + retType128, INPUT_DATA, true), + }; + + return decimalFnRegistry; +} + +std::vector DecimalFunctionRegistryReScale::GetFunctions() +{ + std::vector paramTypes128 = { OMNI_DECIMAL128, OMNI_DECIMAL128 }; + std::vector paramTypes64 = { OMNI_DECIMAL64, OMNI_DECIMAL64 }; + std::vector paramTypes64Op128 = { OMNI_DECIMAL64, OMNI_DECIMAL128 }; + std::vector paramTypes128Op64 = { OMNI_DECIMAL128, OMNI_DECIMAL64 }; + DataTypeId retType128 = OMNI_DECIMAL128; + DataTypeId retType64 = OMNI_DECIMAL64; + + static std::vector decimalFnRegistry = { + Function(reinterpret_cast(AddDec64Dec64Dec64ReScale), AddDecimal64FnStr(), {}, paramTypes64, retType64, + INPUT_DATA, true), + Function(reinterpret_cast(AddDec64Dec64Dec128ReScale), AddDecimal64FnStr(), {}, paramTypes64, + retType128, INPUT_DATA, true), + Function(reinterpret_cast(AddDec128Dec128Dec128ReScale), AddDecimal128FnStr(), {}, paramTypes128, + retType128, INPUT_DATA, true), + Function(reinterpret_cast(AddDec64Dec128Dec128ReScale), AddDecimal64FnStr(), {}, paramTypes64Op128, + retType128, INPUT_DATA, true), + Function(reinterpret_cast(AddDec128Dec64Dec128ReScale), AddDecimal128FnStr(), {}, paramTypes128Op64, + retType128, INPUT_DATA, true), + + Function(reinterpret_cast(SubDec64Dec64Dec64ReScale), SubDecimal64FnStr(), {}, paramTypes64, retType64, + INPUT_DATA, true), + Function(reinterpret_cast(SubDec64Dec64Dec128ReScale), SubDecimal64FnStr(), {}, paramTypes64, + retType128, INPUT_DATA, true), + Function(reinterpret_cast(SubDec128Dec128Dec128ReScale), SubDecimal128FnStr(), {}, paramTypes128, + retType128, INPUT_DATA, true), + Function(reinterpret_cast(SubDec64Dec128Dec128ReScale), SubDecimal64FnStr(), {}, paramTypes64Op128, + retType128, INPUT_DATA, true), + Function(reinterpret_cast(SubDec128Dec64Dec128ReScale), SubDecimal128FnStr(), {}, paramTypes128Op64, + retType128, INPUT_DATA, true), + + Function(reinterpret_cast(MulDec64Dec64Dec64ReScale), MulDecimal64FnStr(), {}, paramTypes64, retType64, + INPUT_DATA, true), + Function(reinterpret_cast(MulDec64Dec64Dec128ReScale), MulDecimal64FnStr(), {}, paramTypes64, + retType128, INPUT_DATA, true), + Function(reinterpret_cast(MulDec128Dec128Dec128ReScale), MulDecimal128FnStr(), {}, paramTypes128, + retType128, INPUT_DATA, true), + Function(reinterpret_cast(MulDec64Dec128Dec128ReScale), MulDecimal64FnStr(), {}, paramTypes64Op128, + retType128, INPUT_DATA, true), + Function(reinterpret_cast(MulDec128Dec64Dec128ReScale), MulDecimal128FnStr(), {}, paramTypes128Op64, + retType128, INPUT_DATA, true), + + Function(reinterpret_cast(DivDec64Dec64Dec64ReScale), DivDecimal64FnStr(), {}, paramTypes64, retType64, + INPUT_DATA, true), + Function(reinterpret_cast(DivDec64Dec128Dec64ReScale), DivDecimal64FnStr(), {}, paramTypes64Op128, + retType64, INPUT_DATA, true), + Function(reinterpret_cast(DivDec128Dec64Dec64ReScale), DivDecimal128FnStr(), {}, paramTypes128Op64, + retType64, INPUT_DATA, true), + Function(reinterpret_cast(DivDec64Dec64Dec128ReScale), DivDecimal64FnStr(), {}, paramTypes64, + retType128, INPUT_DATA, true), + Function(reinterpret_cast(DivDec128Dec128Dec128ReScale), DivDecimal128FnStr(), {}, paramTypes128, + retType128, INPUT_DATA, true), + Function(reinterpret_cast(DivDec64Dec128Dec128ReScale), DivDecimal64FnStr(), {}, paramTypes64Op128, + retType128, INPUT_DATA, true), + Function(reinterpret_cast(DivDec128Dec64Dec128ReScale), DivDecimal128FnStr(), {}, paramTypes128Op64, + retType128, INPUT_DATA, true), + + Function(reinterpret_cast(ModDec64Dec64Dec64ReScale), ModDecimal64FnStr(), {}, paramTypes64, retType64, + INPUT_DATA, true), + Function(reinterpret_cast(ModDec64Dec128Dec64ReScale), ModDecimal64FnStr(), {}, paramTypes64Op128, + retType64, INPUT_DATA, true), + Function(reinterpret_cast(ModDec128Dec64Dec64ReScale), ModDecimal128FnStr(), {}, paramTypes128Op64, + retType64, INPUT_DATA, true), + Function(reinterpret_cast(ModDec128Dec64Dec128ReScale), ModDecimal128FnStr(), {}, paramTypes128Op64, + retType128, INPUT_DATA, true), + Function(reinterpret_cast(ModDec128Dec128Dec128ReScale), ModDecimal128FnStr(), {}, paramTypes128, + retType128, INPUT_DATA, true), + Function(reinterpret_cast(ModDec128Dec128Dec64ReScale), ModDecimal128FnStr(), {}, paramTypes128, + retType64, INPUT_DATA, true), + Function(reinterpret_cast(ModDec64Dec128Dec128ReScale), ModDecimal64FnStr(), {}, paramTypes64Op128, + retType128, INPUT_DATA, true), + }; + + return decimalFnRegistry; +} +} diff --git a/core/src/codegen/func_registry_decimal.h b/core/src/codegen/func_registry_decimal.h new file mode 100644 index 0000000..df15131 --- /dev/null +++ b/core/src/codegen/func_registry_decimal.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description: Decimal Function Registry + */ +#ifndef OMNI_RUNTIME_FUNC_REGISTRY_DECIMAL_H +#define OMNI_RUNTIME_FUNC_REGISTRY_DECIMAL_H +#include "function.h" +#include "func_registry_base.h" + +// functions called directly from codegen +const std::string decimal128CompareStr = "Decimal128Compare"; +const std::string decimal64CompareStr = "Decimal64Compare"; + +const std::string addDec128Str = "Add_decimal128"; +const std::string subDec128Str = "Sub_decimal128"; +const std::string mulDec128Str = "Mul_decimal128"; +const std::string divDec128Str = "Div_decimal128"; +const std::string modDec128Str = "Mod_decimal128"; + +const std::string addDec64Str = "Add_decimal64"; +const std::string subDec64Str = "Sub_decimal64"; +const std::string mulDec64Str = "Mul_decimal64"; +const std::string divDec64Str = "Div_decimal64"; +const std::string modDec64Str = "Mod_decimal64"; + +constexpr const char* tryAddDecimal64FnStr = "Try_add_decimal64"; +constexpr const char* tryAddDecimal128FnStr = "Try_add_decimal128"; +constexpr const char* trySubDecimal64FnStr = "Try_sub_decimal64"; +constexpr const char* trySubDecimal128FnStr = "Try_sub_decimal128"; +constexpr const char* tryMulDecimal64FnStr = "Try_mul_decimal64"; +constexpr const char* tryMulDecimal128FnStr = "Try_mul_decimal128"; +constexpr const char* tryDivDecimal64FnStr = "Try_div_decimal64"; +constexpr const char* tryDivDecimal128FnStr = "Try_div_decimal128"; + +namespace omniruntime::codegen { +class DecimalFunctionRegistry : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; + +class DecimalFunctionRegistryDown : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; + +class DecimalFunctionRegistryHalfUp : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; + +class DecimalFunctionRegistryNotReScale : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; + +class DecimalFunctionRegistryReScale : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; +} + +#endif // OMNI_RUNTIME_FUNC_REGISTRY_DECIMAL_H diff --git a/core/src/codegen/func_registry_dictionary.cpp b/core/src/codegen/func_registry_dictionary.cpp new file mode 100644 index 0000000..4f1ae82 --- /dev/null +++ b/core/src/codegen/func_registry_dictionary.cpp @@ -0,0 +1,32 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description: Dictionary Functions Registry + */ +#include "func_registry_dictionary.h" +#include "functions/dictionaryfunctions.h" + +namespace omniruntime::codegen { +using namespace omniruntime::type; +using namespace omniruntime::codegen::function; + +std::vector DictionaryFunctionRegistry::GetFunctions() +{ + std::vector paramTypes = { OMNI_LONG, OMNI_INT }; + std::vector getStringParamTypes = { OMNI_LONG, OMNI_INT, OMNI_INT }; + std::vector dictionaryFnRegistry = { + Function(reinterpret_cast(GetIntFromDictionaryVector), "DictionaryGetInt", {}, paramTypes, OMNI_INT), + Function(reinterpret_cast(GetByteFromDictionaryVector), "DictionaryGetByte", {}, paramTypes, OMNI_BYTE), + Function(reinterpret_cast(GetShortFromDictionaryVector), "DictionaryGetShort", {}, paramTypes, OMNI_SHORT), + Function(reinterpret_cast(GetLongFromDictionaryVector), "DictionaryGetLong", {}, paramTypes, OMNI_LONG), + Function(reinterpret_cast(GetDoubleFromDictionaryVector), "DictionaryGetDouble", {}, paramTypes, + OMNI_DOUBLE), + Function(reinterpret_cast(GetBooleanFromDictionaryVector), "DictionaryGetBoolean", {}, paramTypes, + OMNI_BOOLEAN), + Function(reinterpret_cast(GetVarcharFromDictionaryVector), "DictionaryGetVarchar", {}, paramTypes, + OMNI_VARCHAR), + Function(reinterpret_cast(GetDecimalFromDictionaryVector), "DictionaryGetDecimal", {}, paramTypes, + OMNI_DECIMAL128) + }; + return dictionaryFnRegistry; +} +} diff --git a/core/src/codegen/func_registry_dictionary.h b/core/src/codegen/func_registry_dictionary.h new file mode 100644 index 0000000..1399af8 --- /dev/null +++ b/core/src/codegen/func_registry_dictionary.h @@ -0,0 +1,27 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description: Dictionary Functions Registry + */ +#ifndef OMNI_RUNTIME_FUNC_REGISTRY_DICTIONARY_H +#define OMNI_RUNTIME_FUNC_REGISTRY_DICTIONARY_H +#include "function.h" +#include "func_registry_base.h" + +// functions called directly from codegen +const std::string dictionaryGetByteStr = "DictionaryGetByte"; +const std::string dictionaryGetShortStr = "DictionaryGetShort"; +const std::string dictionaryGetIntStr = "DictionaryGetInt"; +const std::string dictionaryGetLongStr = "DictionaryGetLong"; +const std::string dictionaryGetDoubleStr = "DictionaryGetDouble"; +const std::string dictionaryGetBooleanStr = "DictionaryGetBoolean"; +const std::string dictionaryGetVarcharStr = "DictionaryGetVarchar"; +const std::string dictionaryGetDecimalStr = "DictionaryGetDecimal"; + +namespace omniruntime::codegen { +class DictionaryFunctionRegistry : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; +} + +#endif // OMNI_RUNTIME_FUNC_REGISTRY_DICTIONARY_H diff --git a/core/src/codegen/func_registry_hash.cpp b/core/src/codegen/func_registry_hash.cpp new file mode 100644 index 0000000..359fd42 --- /dev/null +++ b/core/src/codegen/func_registry_hash.cpp @@ -0,0 +1,65 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description: Murmur3 Hash Functions Registry + */ + +#include "func_registry_hash.h" +#include "functions/murmur3_hash.h" +#include "functions/xxhash64_hash.h" + +namespace omniruntime::codegen { +using namespace omniruntime::type; +using namespace codegen::function; + +std::vector HashFunctionRegistry::GetFunctions() +{ + DataTypeId retTypeInt = OMNI_INT; + DataTypeId retTypeLong = OMNI_LONG; + std::string mm3FnStr = "mm3hash"; + std::string xxH64FnStr = "xxhash64"; + std::vector hashRegistry = { // insert native function for combine hash math function + Function(reinterpret_cast(CombineHash), "combine_hash", {}, { OMNI_LONG, OMNI_LONG }, retTypeLong, + INPUT_DATA_AND_NULL), + Function(reinterpret_cast(Mm3Int32), mm3FnStr, {}, { OMNI_INT, OMNI_INT }, retTypeInt, + INPUT_DATA_AND_NULL), + Function(reinterpret_cast(Mm3Int32), mm3FnStr, {}, { OMNI_DATE32, OMNI_INT }, retTypeInt, + INPUT_DATA_AND_NULL), + Function(reinterpret_cast(Mm3Int64), mm3FnStr, {}, { OMNI_LONG, OMNI_INT }, retTypeInt, + INPUT_DATA_AND_NULL), + Function(reinterpret_cast(Mm3Int64), mm3FnStr, {}, { OMNI_TIMESTAMP, OMNI_INT }, retTypeInt, + INPUT_DATA_AND_NULL), + Function(reinterpret_cast(Mm3Double), mm3FnStr, {}, { OMNI_DOUBLE, OMNI_INT }, retTypeInt, + INPUT_DATA_AND_NULL), + Function(reinterpret_cast(Mm3String), mm3FnStr, {}, { OMNI_VARCHAR, OMNI_INT }, retTypeInt, + INPUT_DATA_AND_NULL), + Function(reinterpret_cast(Mm3Decimal64), mm3FnStr, {}, { OMNI_DECIMAL64, OMNI_INT }, retTypeInt, + INPUT_DATA_AND_NULL), + Function(reinterpret_cast(Mm3Decimal128), mm3FnStr, {}, { OMNI_DECIMAL128, OMNI_INT }, retTypeInt, + INPUT_DATA_AND_NULL), + Function(reinterpret_cast(Mm3Boolean), mm3FnStr, {}, { OMNI_BOOLEAN, OMNI_INT }, retTypeInt, + INPUT_DATA_AND_NULL), + Function(reinterpret_cast(XxH64Int16), xxH64FnStr, {}, { OMNI_SHORT, OMNI_LONG }, retTypeLong, + INPUT_DATA_AND_NULL), + Function(reinterpret_cast(XxH64Int32), xxH64FnStr, {}, { OMNI_INT, OMNI_LONG }, retTypeLong, + INPUT_DATA_AND_NULL), + Function(reinterpret_cast(XxH64Int32), xxH64FnStr, {}, { OMNI_DATE32, OMNI_LONG }, retTypeLong, + INPUT_DATA_AND_NULL), + Function(reinterpret_cast(XxH64Int64), xxH64FnStr, {}, { OMNI_LONG, OMNI_LONG }, retTypeLong, + INPUT_DATA_AND_NULL), + Function(reinterpret_cast(XxH64Int64), xxH64FnStr, {}, { OMNI_TIMESTAMP, OMNI_LONG }, retTypeLong, + INPUT_DATA_AND_NULL), + Function(reinterpret_cast(XxH64Double), xxH64FnStr, {}, { OMNI_DOUBLE, OMNI_LONG }, retTypeLong, + INPUT_DATA_AND_NULL), + Function(reinterpret_cast(XxH64String), xxH64FnStr, {}, { OMNI_VARCHAR, OMNI_LONG }, retTypeLong, + INPUT_DATA_AND_NULL), + Function(reinterpret_cast(XxH64Decimal64), xxH64FnStr, {}, { OMNI_DECIMAL64, OMNI_LONG }, retTypeLong, + INPUT_DATA_AND_NULL), + Function(reinterpret_cast(XxH64Decimal128), xxH64FnStr, {}, { OMNI_DECIMAL128, OMNI_LONG }, retTypeLong, + INPUT_DATA_AND_NULL), + Function(reinterpret_cast(XxH64Boolean), xxH64FnStr, {}, { OMNI_BOOLEAN, OMNI_LONG }, retTypeLong, + INPUT_DATA_AND_NULL) + }; + + return hashRegistry; +} +} diff --git a/core/src/codegen/func_registry_hash.h b/core/src/codegen/func_registry_hash.h new file mode 100644 index 0000000..4eb61f8 --- /dev/null +++ b/core/src/codegen/func_registry_hash.h @@ -0,0 +1,17 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description: Murmur3 Hash Functions Registry + */ +#ifndef OMNI_RUNTIME_FUNC_REGISTRY_HASH_H +#define OMNI_RUNTIME_FUNC_REGISTRY_HASH_H +#include "function.h" +#include "func_registry_base.h" + +namespace omniruntime::codegen { +class HashFunctionRegistry : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; +} + +#endif // OMNI_RUNTIME_FUNC_REGISTRY_HASH_H diff --git a/core/src/codegen/func_registry_hive_udf.cpp b/core/src/codegen/func_registry_hive_udf.cpp new file mode 100644 index 0000000..7ca87b6 --- /dev/null +++ b/core/src/codegen/func_registry_hive_udf.cpp @@ -0,0 +1,68 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: registry hive udf. + */ +#include +#include +#include "util/debug.h" +#include "util/config_util.h" +#include "functions/udffunctions.h" +#include "func_registry_hive_udf.h" + +namespace omniruntime::codegen { +using namespace omniruntime::type; +using namespace codegen::function; + +std::vector HiveUdfRegistry::GetFunctions() +{ + std::vector hiveUdfFunctions = { Function(reinterpret_cast(EvaluateHiveUdfSingle), + "EvaluateHiveUdfSingle", {}, std::vector {}, OMNI_INT), + Function(reinterpret_cast(EvaluateHiveUdfBatch), "EvaluateHiveUdfBatch", {}, std::vector {}, + OMNI_INT) }; + return hiveUdfFunctions; +} + +static void Trim(std::string &value) +{ + value.erase(0, value.find_first_not_of(' ')); + value.erase(value.find_last_not_of(' ') + 1); +} + +void HiveUdfRegistry::GenerateHiveUdfMap(std::unordered_map &hiveUdfMap) +{ + std::string propertyFile = ConfigUtil::GetHiveUdfPropertyFilePath(); + if (propertyFile.empty()) { + LogWarn("No hive udf properties file."); + return; + } + Trim(propertyFile); + auto realPathRes = realpath(propertyFile.c_str(), nullptr); + if (realPathRes == nullptr) { + LogWarn("realpath failed."); + return; + } + + // the property file has been normalized in ConfigUtil + std::ifstream file(realPathRes); + if (!file.good()) { + LogWarn("%s does not exist.", realPathRes); + return; + } + + std::string s; + while (getline(file, s)) { + Trim(s); + auto pos = s.find(' '); + if (pos == std::string::npos) { + continue; + } + std::string udfName = s.substr(0, pos); + std::string udfClass = s.substr(pos + 1); + Trim(udfName); + Trim(udfClass); + std::transform(udfName.begin(), udfName.end(), udfName.begin(), ::tolower); + hiveUdfMap.insert(std::make_pair(udfName, udfClass)); + } + file.close(); +} +} \ No newline at end of file diff --git a/core/src/codegen/func_registry_hive_udf.h b/core/src/codegen/func_registry_hive_udf.h new file mode 100644 index 0000000..3994606 --- /dev/null +++ b/core/src/codegen/func_registry_hive_udf.h @@ -0,0 +1,19 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: registry hive udf. + */ +#ifndef OMNI_RUNTIME_FUNC_REGISTRY_HIVE_UDF_H +#define OMNI_RUNTIME_FUNC_REGISTRY_HIVE_UDF_H + +#include "function.h" +#include "func_registry_base.h" + +namespace omniruntime::codegen { +class HiveUdfRegistry : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; + + static void GenerateHiveUdfMap(std::unordered_map &hiveUdfMap); +}; +} +#endif // OMNI_RUNTIME_FUNC_REGISTRY_HIVE_UDF_H diff --git a/core/src/codegen/func_registry_math.cpp b/core/src/codegen/func_registry_math.cpp new file mode 100644 index 0000000..b8ae9ba --- /dev/null +++ b/core/src/codegen/func_registry_math.cpp @@ -0,0 +1,330 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description: Math Functions Registry + */ +#include "func_registry_math.h" +#include "functions/mathfunctions.h" +namespace omniruntime::codegen { +using namespace omniruntime::type; +using namespace omniruntime::codegen::function; + +const std::string AbsFnStr() +{ + const std::string absFnStr = "abs"; + return absFnStr; +} + +const std::string RoundFnStr() +{ + const std::string roundFnStr = "round"; + return roundFnStr; +} + +const std::string AddFnStr() +{ + const std::string addFnStr = "add"; + return addFnStr; +} + +const std::string SubtractFnStr() +{ + const std::string subtractFnStr = "subtract"; + return subtractFnStr; +} + +const std::string MultiplyFnStr() +{ + const std::string multiplyFnStr = "multiply"; + return multiplyFnStr; +} + +const std::string DivideFnStr() +{ + const std::string divideFnStr = "divide"; + return divideFnStr; +} + +const std::string ModulusFnStr() +{ + const std::string modulusFnStr = "modulus"; + return modulusFnStr; +} + +const std::string LessThanFnStr() +{ + const std::string lessThanFnStr = "lessThan"; + return lessThanFnStr; +} + +const std::string LessThanEqualFnStr() +{ + const std::string lessThanEqualFnStr = "lessThanEqual"; + return lessThanEqualFnStr; +} + +const std::string GreaterThanFnStr() +{ + const std::string greaterThanFnStr = "greaterThan"; + return greaterThanFnStr; +} + +const std::string GreaterThanEqualFnStr() +{ + const std::string greaterThanEqualFnStr = "greaterThanEqual"; + return greaterThanEqualFnStr; +} + +const std::string EqualFnStr() +{ + const std::string equalFnStr = "equal"; + return equalFnStr; +} + +const std::string NotEqualFnStr() +{ + const std::string notEqualFnStr = "notEqual"; + return notEqualFnStr; +} + +const std::string MathCastFnStr() +{ + const std::string mathCastFnStr = "CAST"; + return mathCastFnStr; +} + +const std::string PmodFnStr() +{ + const std::string pmodFnStr = "pmod"; + return pmodFnStr; +} + +const std::string NormalizeNaNAndZeroFnStr() +{ + const std::string normalizeNaNAndZeroFnStr = "NormalizeNaNAndZero"; + return normalizeNaNAndZeroFnStr; +} + +const std::string GreatestFnStr() +{ + const std::string greatestFnStr = "Greatest"; + return greatestFnStr; +} + +const std::string PowerFnStr() +{ + const std::string powerFnStr = "power"; + return powerFnStr; +} + +const std::string TryAddFnStr() +{ + const std::string addFnStr = "try_add"; + return addFnStr; +} + +const std::string TrySubtractFnStr() +{ + const std::string subtractFnStr = "try_subtract"; + return subtractFnStr; +} + +const std::string TryMultiplyFnStr() +{ + const std::string multiplyFnStr = "try_multiply"; + return multiplyFnStr; +} + +const std::string TryDivideFnStr() +{ + const std::string tryDivideFnStr = "try_divide"; + return tryDivideFnStr; +} + +std::vector MathFunctionRegistry::GetFunctions() +{ + const std::vector doubleParams = { OMNI_DOUBLE, OMNI_DOUBLE }; + const std::vector longParams = { OMNI_LONG, OMNI_LONG }; + const std::vector intParams = { OMNI_INT, OMNI_INT }; + const std::vector shortParams = { OMNI_SHORT, OMNI_SHORT }; + const std::vector byteParams = { OMNI_BYTE, OMNI_BYTE} ; + + std::vector mathFnRegistry = { + // insert native functions for each absolute math function + Function(reinterpret_cast(CastInt32ToInt16), MathCastFnStr(), {}, { OMNI_INT }, OMNI_SHORT, INPUT_DATA), + Function(reinterpret_cast(CastInt32ToInt8), MathCastFnStr(), {}, { OMNI_INT }, OMNI_BYTE, INPUT_DATA), + Function(reinterpret_cast(CastInt64ToInt16), MathCastFnStr(), {}, { type::OMNI_LONG }, OMNI_SHORT, INPUT_DATA), + Function(reinterpret_cast(CastInt64ToInt8), MathCastFnStr(), {}, { type::OMNI_LONG }, OMNI_BYTE, INPUT_DATA), + Function(reinterpret_cast(Abs), AbsFnStr(), {}, { OMNI_INT }, OMNI_INT, INPUT_DATA), + Function(reinterpret_cast(Abs), AbsFnStr(), {}, { OMNI_LONG }, OMNI_LONG, INPUT_DATA), + Function(reinterpret_cast(Abs), AbsFnStr(), {}, { OMNI_DOUBLE }, OMNI_DOUBLE, INPUT_DATA), + + // insert native functions for each cast math function + Function(reinterpret_cast(CastInt32ToDouble), MathCastFnStr(), {}, { OMNI_INT }, OMNI_DOUBLE, + INPUT_DATA), + Function(reinterpret_cast(CastInt64ToDouble), MathCastFnStr(), {}, { OMNI_LONG }, OMNI_DOUBLE, + INPUT_DATA), + Function(reinterpret_cast(CastInt32ToInt64), MathCastFnStr(), {}, { OMNI_INT }, OMNI_LONG, INPUT_DATA), + Function(reinterpret_cast(CastInt64ToInt32), MathCastFnStr(), {}, { OMNI_LONG }, OMNI_INT, INPUT_DATA), + Function(reinterpret_cast(CastInt16ToInt32), MathCastFnStr(), {}, { OMNI_SHORT }, OMNI_INT, INPUT_DATA), + Function(reinterpret_cast(CastInt8ToInt32), MathCastFnStr(), {}, { OMNI_BYTE }, OMNI_INT, INPUT_DATA), + + Function(reinterpret_cast(CastInt16ToInt64), MathCastFnStr(), {}, { OMNI_SHORT }, OMNI_LONG, INPUT_DATA), + Function(reinterpret_cast(CastInt8ToInt64), MathCastFnStr(), {}, { OMNI_BYTE }, OMNI_LONG, INPUT_DATA), + + Function(reinterpret_cast(CastInt16ToDouble), MathCastFnStr(), {}, { OMNI_SHORT }, OMNI_DOUBLE, INPUT_DATA), + Function(reinterpret_cast(CastInt8ToDouble), MathCastFnStr(), {}, { OMNI_BYTE }, OMNI_DOUBLE, INPUT_DATA), + + // insert native function for each double operations + Function(reinterpret_cast(AddDouble), AddFnStr(), {}, doubleParams, OMNI_DOUBLE, INPUT_DATA), + Function(reinterpret_cast(SubtractDouble), SubtractFnStr(), {}, doubleParams, OMNI_DOUBLE, INPUT_DATA), + Function(reinterpret_cast(MultiplyDouble), MultiplyFnStr(), {}, doubleParams, OMNI_DOUBLE, INPUT_DATA), + Function(reinterpret_cast(DivideDouble), DivideFnStr(), {}, doubleParams, OMNI_DOUBLE, INPUT_DATA), + Function(reinterpret_cast(ModulusDouble), ModulusFnStr(), {}, doubleParams, OMNI_DOUBLE, INPUT_DATA), + Function(reinterpret_cast(LessThanDouble), LessThanFnStr(), {}, doubleParams, OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(LessThanEqualDouble), LessThanEqualFnStr(), {}, doubleParams, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(GreaterThanDouble), GreaterThanFnStr(), {}, doubleParams, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(GreaterThanEqualDouble), GreaterThanEqualFnStr(), {}, doubleParams, + OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(EqualDouble), EqualFnStr(), {}, doubleParams, OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(NotEqualDouble), NotEqualFnStr(), {}, doubleParams, OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(NormalizeNaNAndZero), NormalizeNaNAndZeroFnStr(), {}, { OMNI_DOUBLE }, + OMNI_DOUBLE, INPUT_DATA), + Function(reinterpret_cast(PowerDouble), PowerFnStr(), {}, doubleParams, OMNI_DOUBLE, INPUT_DATA), + + // insert native function for each long operations + Function(reinterpret_cast(AddInt64), AddFnStr(), {}, longParams, OMNI_LONG, INPUT_DATA), + Function(reinterpret_cast(SubtractInt64), SubtractFnStr(), {}, longParams, OMNI_LONG, INPUT_DATA), + Function(reinterpret_cast(MultiplyInt64), MultiplyFnStr(), {}, longParams, OMNI_LONG, INPUT_DATA), + Function(reinterpret_cast(DivideInt64), DivideFnStr(), {}, longParams, OMNI_LONG, INPUT_DATA), + Function(reinterpret_cast(ModulusInt64), ModulusFnStr(), {}, longParams, OMNI_LONG, INPUT_DATA), + Function(reinterpret_cast(AddInt64RetNull), TryAddFnStr(), {}, longParams, OMNI_LONG, INPUT_DATA), + Function(reinterpret_cast(SubtractInt64RetNull), TrySubtractFnStr(), {}, longParams, OMNI_LONG, INPUT_DATA), + Function(reinterpret_cast(MultiplyInt64RetNull), TryMultiplyFnStr(), {}, longParams, OMNI_LONG, INPUT_DATA), + Function(reinterpret_cast(DivideInt64), TryDivideFnStr(), {}, longParams, OMNI_LONG, INPUT_DATA), + Function(reinterpret_cast(LessThanInt64), LessThanFnStr(), {}, longParams, OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(LessThanEqualInt64), LessThanEqualFnStr(), {}, longParams, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(GreaterThanInt64), GreaterThanFnStr(), {}, longParams, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(GreaterThanEqualInt64), GreaterThanEqualFnStr(), {}, longParams, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(EqualInt64), EqualFnStr(), {}, longParams, OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(NotEqualInt64), NotEqualFnStr(), {}, longParams, OMNI_BOOLEAN, INPUT_DATA), + + // insert native function for each int operations + Function(reinterpret_cast(AddInt32), AddFnStr(), {}, intParams, OMNI_INT, INPUT_DATA), + Function(reinterpret_cast(SubtractInt32), SubtractFnStr(), {}, intParams, OMNI_INT, INPUT_DATA), + Function(reinterpret_cast(MultiplyInt32), MultiplyFnStr(), {}, intParams, OMNI_INT, INPUT_DATA), + Function(reinterpret_cast(DivideInt32), DivideFnStr(), {}, intParams, OMNI_INT, INPUT_DATA), + Function(reinterpret_cast(ModulusInt32), ModulusFnStr(), {}, intParams, OMNI_INT, INPUT_DATA), + Function(reinterpret_cast(AddInt32RetNull), TryAddFnStr(), {}, intParams, OMNI_INT, INPUT_DATA), + Function(reinterpret_cast(SubtractInt32RetNull), TrySubtractFnStr(), {}, intParams, OMNI_INT, INPUT_DATA), + Function(reinterpret_cast(MultiplyInt32RetNull), TryMultiplyFnStr(), {}, intParams, OMNI_INT, INPUT_DATA), + Function(reinterpret_cast(DivideInt32), TryDivideFnStr(), {}, intParams, OMNI_INT, INPUT_DATA), + Function(reinterpret_cast(LessThanInt32), LessThanFnStr(), {}, intParams, OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(LessThanEqualInt32), LessThanEqualFnStr(), {}, intParams, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(GreaterThanInt32), GreaterThanFnStr(), {}, intParams, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(GreaterThanEqualInt32), GreaterThanEqualFnStr(), {}, intParams, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(EqualInt32), EqualFnStr(), {}, intParams, OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(NotEqualInt32), NotEqualFnStr(), {}, intParams, OMNI_BOOLEAN, INPUT_DATA), + + // insert pmod function for project operator support + Function(reinterpret_cast(Pmod), PmodFnStr(), {}, intParams, OMNI_INT, INPUT_DATA), + // insert native functions for each round math function + Function(reinterpret_cast(Round), RoundFnStr(), {}, intParams, OMNI_INT, INPUT_DATA), + Function(reinterpret_cast(RoundLong), RoundFnStr(), {}, { OMNI_LONG, OMNI_INT }, OMNI_LONG, + INPUT_DATA), + Function(reinterpret_cast(Round), RoundFnStr(), {}, { OMNI_DOUBLE, OMNI_INT }, OMNI_DOUBLE, + INPUT_DATA), + Function(reinterpret_cast(Greatest), GreatestFnStr(), {}, { OMNI_INT, OMNI_INT }, OMNI_INT, + INPUT_DATA_AND_NULL_AND_RETURN_NULL), + Function(reinterpret_cast(Greatest), GreatestFnStr(), {}, { OMNI_LONG, OMNI_LONG }, OMNI_LONG, + INPUT_DATA_AND_NULL_AND_RETURN_NULL), + Function(reinterpret_cast(Greatest), GreatestFnStr(), {}, { OMNI_BOOLEAN, OMNI_BOOLEAN }, + OMNI_BOOLEAN, INPUT_DATA_AND_NULL_AND_RETURN_NULL), + Function(reinterpret_cast(Greatest), GreatestFnStr(), {}, { OMNI_DOUBLE, OMNI_DOUBLE }, + OMNI_DOUBLE, INPUT_DATA_AND_NULL_AND_RETURN_NULL), + + // insert native function for each short operations + Function(reinterpret_cast(AddInt16), AddFnStr(), {}, shortParams, OMNI_SHORT, INPUT_DATA), + Function(reinterpret_cast(SubtractInt16), SubtractFnStr(), {}, shortParams, OMNI_SHORT, INPUT_DATA), + Function(reinterpret_cast(MultiplyInt16), MultiplyFnStr(), {}, shortParams, OMNI_SHORT, INPUT_DATA), + Function(reinterpret_cast(DivideInt16), DivideFnStr(), {}, shortParams, OMNI_SHORT, INPUT_DATA), + Function(reinterpret_cast(ModulusInt16), ModulusFnStr(), {}, shortParams, OMNI_SHORT, INPUT_DATA), + Function(reinterpret_cast(AddInt16RetNull), TryAddFnStr(), {}, shortParams, OMNI_SHORT, INPUT_DATA), + Function(reinterpret_cast(SubtractInt16RetNull), TrySubtractFnStr(), {}, shortParams, OMNI_SHORT, INPUT_DATA), + Function(reinterpret_cast(MultiplyInt16RetNull), TryMultiplyFnStr(), {}, shortParams, OMNI_SHORT, INPUT_DATA), + Function(reinterpret_cast(DivideInt16), TryDivideFnStr(), {}, shortParams, OMNI_SHORT, INPUT_DATA), + Function(reinterpret_cast(LessThanInt16), LessThanFnStr(), {}, shortParams, OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(LessThanEqualInt16), LessThanEqualFnStr(), {}, shortParams, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(GreaterThanInt16), GreaterThanFnStr(), {}, shortParams, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(GreaterThanEqualInt16), GreaterThanEqualFnStr(), {}, shortParams, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(EqualInt16), EqualFnStr(), {}, shortParams, OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(NotEqualInt16), NotEqualFnStr(), {}, shortParams, OMNI_BOOLEAN, INPUT_DATA), + + // insert native function for each byte operations + Function(reinterpret_cast(AddInt8), AddFnStr(), {}, byteParams, OMNI_BYTE, INPUT_DATA), + Function(reinterpret_cast(SubtractInt8), SubtractFnStr(), {}, byteParams, OMNI_BYTE, INPUT_DATA), + Function(reinterpret_cast(MultiplyInt8), MultiplyFnStr(), {}, byteParams, OMNI_BYTE, INPUT_DATA), + Function(reinterpret_cast(DivideInt8), DivideFnStr(), {}, byteParams, OMNI_BYTE, INPUT_DATA), + Function(reinterpret_cast(ModulusInt8), ModulusFnStr(), {}, byteParams, OMNI_BYTE, INPUT_DATA), + Function(reinterpret_cast(AddInt8RetNull), TryAddFnStr(), {}, byteParams, OMNI_BYTE, INPUT_DATA), + Function(reinterpret_cast(SubtractInt8RetNull), TrySubtractFnStr(), {}, byteParams, OMNI_BYTE, INPUT_DATA), + Function(reinterpret_cast(MultiplyInt8RetNull), TryMultiplyFnStr(), {}, byteParams, OMNI_BYTE, INPUT_DATA), + Function(reinterpret_cast(DivideInt8), TryDivideFnStr(), {}, byteParams, OMNI_BYTE, INPUT_DATA), + Function(reinterpret_cast(LessThanInt8), LessThanFnStr(), {}, byteParams, OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(LessThanEqualInt8), LessThanEqualFnStr(), {}, byteParams, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(GreaterThanInt8), GreaterThanFnStr(), {}, byteParams, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(GreaterThanEqualInt8), GreaterThanEqualFnStr(), {}, byteParams, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(EqualInt8), EqualFnStr(), {}, byteParams, OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(NotEqualInt8), NotEqualFnStr(), {}, byteParams, OMNI_BOOLEAN, INPUT_DATA), + }; + + return mathFnRegistry; +} + +std::vector MathFunctionRegistryHalfUp::GetFunctions() +{ + std::vector mathFnRegistry = { + // insert native functions for each absolute math function + Function(reinterpret_cast(CastDoubleToInt64HalfUp), MathCastFnStr(), {}, { OMNI_DOUBLE }, OMNI_LONG, + INPUT_DATA), + Function(reinterpret_cast(CastDoubleToInt32HalfUp), MathCastFnStr(), {}, { OMNI_DOUBLE }, OMNI_INT, + INPUT_DATA), + Function(reinterpret_cast(CastDoubleToInt16HalfUp), MathCastFnStr(), {}, { OMNI_DOUBLE }, OMNI_SHORT, + INPUT_DATA), + Function(reinterpret_cast(CastDoubleToInt8HalfUp), MathCastFnStr(), {}, { OMNI_DOUBLE }, OMNI_BYTE, + INPUT_DATA), + }; + + return mathFnRegistry; +} + +std::vector MathFunctionRegistryDown::GetFunctions() +{ + std::vector mathFnRegistry = { + // insert native functions for each absolute math function + Function(reinterpret_cast(CastDoubleToInt64Down), MathCastFnStr(), {}, { OMNI_DOUBLE }, OMNI_LONG, + INPUT_DATA), + Function(reinterpret_cast(CastDoubleToInt32Down), MathCastFnStr(), {}, { OMNI_DOUBLE }, OMNI_INT, + INPUT_DATA), + Function(reinterpret_cast(CastDoubleToInt16Down), MathCastFnStr(), {}, { OMNI_DOUBLE }, OMNI_SHORT, + INPUT_DATA), + Function(reinterpret_cast(CastDoubleToInt8Down), MathCastFnStr(), {}, { OMNI_DOUBLE }, OMNI_BYTE, + INPUT_DATA), + }; + + return mathFnRegistry; +} +} \ No newline at end of file diff --git a/core/src/codegen/func_registry_math.h b/core/src/codegen/func_registry_math.h new file mode 100644 index 0000000..02c2bb4 --- /dev/null +++ b/core/src/codegen/func_registry_math.h @@ -0,0 +1,27 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description: Math Functions Registry + */ +#ifndef OMNI_RUNTIME_FUNC_REGISTRY_MATH_H +#define OMNI_RUNTIME_FUNC_REGISTRY_MATH_H +#include "function.h" +#include "func_registry_base.h" + +namespace omniruntime::codegen { +class MathFunctionRegistry : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; + +class MathFunctionRegistryHalfUp : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; + +class MathFunctionRegistryDown : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; +} + +#endif // OMNI_RUNTIME_FUNC_REGISTRY_MATH_H diff --git a/core/src/codegen/func_registry_might_contain.cpp b/core/src/codegen/func_registry_might_contain.cpp new file mode 100644 index 0000000..5100121 --- /dev/null +++ b/core/src/codegen/func_registry_might_contain.cpp @@ -0,0 +1,23 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: MightContain Function Registry + */ +#include "func_registry_might_contain.h" +#include "functions/mightcontain.h" + +namespace omniruntime::codegen { +using namespace omniruntime::type; +using namespace omniruntime::codegen::function; + +std::vector MightContainFunctionRegistry::GetFunctions() +{ + DataTypeId retTypeBoolean = OMNI_BOOLEAN; + std::string mightContainFnStr = "might_contain"; + std::vector mightContainRegistry = { // insert native function for might contain function + Function(reinterpret_cast(MightContain), mightContainFnStr, {}, { OMNI_LONG, OMNI_LONG }, + retTypeBoolean, INPUT_DATA) + }; + + return mightContainRegistry; +} +} \ No newline at end of file diff --git a/core/src/codegen/func_registry_might_contain.h b/core/src/codegen/func_registry_might_contain.h new file mode 100644 index 0000000..2e537fe --- /dev/null +++ b/core/src/codegen/func_registry_might_contain.h @@ -0,0 +1,17 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: MightContain Function Registry + */ +#ifndef OMNI_RUNTIME_FUNC_REGISTRY_MIGHT_CONTAIN_H +#define OMNI_RUNTIME_FUNC_REGISTRY_MIGHT_CONTAIN_H +#include "function.h" +#include "func_registry_base.h" + +namespace omniruntime::codegen { +class MightContainFunctionRegistry : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; +} + +#endif // OMNI_RUNTIME_FUNC_REGISTRY_MIGHT_CONTAIN_H \ No newline at end of file diff --git a/core/src/codegen/func_registry_string.cpp b/core/src/codegen/func_registry_string.cpp new file mode 100644 index 0000000..b6fad2a --- /dev/null +++ b/core/src/codegen/func_registry_string.cpp @@ -0,0 +1,478 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2024. All rights reserved. + * Description: String Function Registry + */ +#include "func_registry_string.h" +#include "functions/stringfunctions.h" + +namespace omniruntime::codegen { +using namespace omniruntime::type; +using namespace codegen::function; + +const std::string StrEqualFnStr() +{ + const std::string compareFnStr = "strequal"; + return compareFnStr; +} + +const std::string ConcatFnStr() +{ + const std::string concatFnStr = "concat"; + return concatFnStr; +} + +const std::string ConcatWsFnStr() +{ + const std::string concatWsFnStr = "concat_ws"; + return concatWsFnStr; +} + +const std::string LikeFnStr() +{ + const std::string likeFnStr = "LIKE"; + return likeFnStr; +} + +const std::string CastFnStr() +{ + const std::string castFnStr = "CAST"; + return castFnStr; +} + +const std::string LowerFnStr() +{ + const std::string lowerFnStr = "lower"; + return lowerFnStr; +} + +const std::string UpperFnStr() +{ + const std::string upperFnStr = "upper"; + return upperFnStr; +} + +const std::string CompareFnStr() +{ + const std::string compareFnStr = "compare"; + return compareFnStr; +} + +const std::string LengthFnStr() +{ + const std::string lengthFnStr = "length"; + return lengthFnStr; +} + +const std::string CastNullFnStr() +{ + const std::string castNullFnStr = "CAST_null"; + return castNullFnStr; +} + +const std::string ConcatNullFnStr() +{ + const std::string concatNullFnStr = "concat_null"; + return concatNullFnStr; +} + +const std::string ReplaceFnStr() +{ + const std::string replaceFnStr = "replace"; + return replaceFnStr; +} + +const std::string EmptyToNullStr() +{ + const std::string empty2nullFnStr = "empty2null"; + return empty2nullFnStr; +} + +const std::string SubstrFnStr() +{ + const std::string substrFnStr = "substr"; + return substrFnStr; +} + +const std::string InStrFnStr() +{ + const std::string instrFnStr = "instr"; + return instrFnStr; +} + +const std::string StartsWithFnStr() +{ + const std::string startsWithFnStr = "StartsWith"; + return startsWithFnStr; +} + +const std::string EndsWithFnStr() +{ + const std::string endsWithFnStr = "EndsWith"; + return endsWithFnStr; +} + +const std::string RLikeFnStr() +{ + const std::string rLikeFnStr = "RLike"; + return rLikeFnStr; +} + +const std::string Md5FnStr() +{ + const std::string md5FnStr = "Md5"; + return md5FnStr; +} + +const std::string ContainsFnStr() +{ + const std::string containsFnStr = "Contains"; + return containsFnStr; +} + +const std::string GreatestStrFnStr() +{ + const std::string greatestStrFnStr = "Greatest"; + return greatestStrFnStr; +} + +const std::string StaticInvokeVarcharTypeWriteSideCheckFnStr() +{ + const std::string staticInvokeVarcharTypeWriteSideCheckFnStr = "StaticInvokeVarcharTypeWriteSideCheck"; + return staticInvokeVarcharTypeWriteSideCheckFnStr; +} + +const std::string StaticInvokeCharReadPaddingFnStr() +{ + const std::string staticInvokeCharReadPaddingFnStr = "StaticInvokeCharReadPadding"; + return staticInvokeCharReadPaddingFnStr; +} + +std::vector StringFunctionRegistry::GetFunctions() +{ + std::vector stringFnRegistry = { // concat functions + Function(reinterpret_cast(ConcatStrStr), ConcatFnStr(), {}, { OMNI_VARCHAR, OMNI_VARCHAR }, + OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(ConcatCharChar), ConcatFnStr(), {}, { OMNI_CHAR, OMNI_CHAR }, OMNI_CHAR, + INPUT_DATA, true), + Function(reinterpret_cast(ConcatCharStr), ConcatFnStr(), {}, { OMNI_CHAR, OMNI_VARCHAR }, OMNI_CHAR, + INPUT_DATA, true), + Function(reinterpret_cast(ConcatStrChar), ConcatFnStr(), {}, { OMNI_VARCHAR, OMNI_CHAR }, OMNI_CHAR, + INPUT_DATA, true), + Function(reinterpret_cast(ConcatWsStr), ConcatWsFnStr(), {}, { OMNI_VARCHAR, OMNI_VARCHAR }, + OMNI_VARCHAR, INPUT_DATA, true), + + Function(reinterpret_cast(LikeStr), LikeFnStr(), {}, { OMNI_VARCHAR, OMNI_VARCHAR }, OMNI_BOOLEAN, + INPUT_DATA), + Function(reinterpret_cast(LikeChar), LikeFnStr(), {}, { OMNI_CHAR, OMNI_VARCHAR }, OMNI_BOOLEAN, + INPUT_DATA), + + Function(reinterpret_cast(ToUpperStr), UpperFnStr(), {}, { OMNI_VARCHAR }, OMNI_VARCHAR, INPUT_DATA, + true), + Function(reinterpret_cast(ToUpperChar), UpperFnStr(), {}, { OMNI_CHAR }, OMNI_CHAR, INPUT_DATA, true), + Function(reinterpret_cast(ToLowerStr), LowerFnStr(), {}, { OMNI_VARCHAR }, OMNI_VARCHAR, INPUT_DATA, + true), + Function(reinterpret_cast(ToLowerChar), LowerFnStr(), {}, { OMNI_CHAR }, OMNI_CHAR, INPUT_DATA, true), + + Function(reinterpret_cast(StrCompare), CompareFnStr(), {}, { OMNI_VARCHAR, OMNI_VARCHAR }, OMNI_INT), + + Function(reinterpret_cast(StrEquals), StrEqualFnStr(), {}, { OMNI_VARCHAR, OMNI_VARCHAR }, + OMNI_BOOLEAN), + + Function(reinterpret_cast(CastIntToString), CastFnStr(), {}, { OMNI_INT }, OMNI_VARCHAR, INPUT_DATA, + true), + Function(reinterpret_cast(CastInt16ToString), CastFnStr(), {}, { OMNI_SHORT }, OMNI_VARCHAR, INPUT_DATA, + true), + Function(reinterpret_cast(CastInt8ToString), CastFnStr(), {}, { OMNI_BYTE }, OMNI_VARCHAR, INPUT_DATA, + true), + Function(reinterpret_cast(CastLongToString), CastFnStr(), {}, { OMNI_LONG }, OMNI_VARCHAR, INPUT_DATA, + true), + Function(reinterpret_cast(CastDoubleToString), CastFnStr(), {}, { OMNI_DOUBLE }, OMNI_VARCHAR, + INPUT_DATA, true), + Function(reinterpret_cast(CastDecimal64ToString), CastFnStr(), {}, { OMNI_DECIMAL64 }, OMNI_VARCHAR, + INPUT_DATA, true), + Function(reinterpret_cast(CastDecimal128ToString), CastFnStr(), {}, { OMNI_DECIMAL128 }, OMNI_VARCHAR, + INPUT_DATA, true), + Function(reinterpret_cast(CastDateToString), CastFnStr(), {}, { OMNI_DATE32 }, OMNI_VARCHAR, INPUT_DATA, + true), + + Function(reinterpret_cast(CastStringToByte), CastFnStr(), {}, { OMNI_VARCHAR }, OMNI_BYTE, INPUT_DATA, + true), + Function(reinterpret_cast(CastStringToShort), CastFnStr(), {}, { OMNI_VARCHAR }, OMNI_SHORT, INPUT_DATA, + true), + Function(reinterpret_cast(CastStringToInt), CastFnStr(), {}, { OMNI_VARCHAR }, OMNI_INT, INPUT_DATA, + true), + Function(reinterpret_cast(CastStringToLong), CastFnStr(), {}, { OMNI_VARCHAR }, OMNI_LONG, INPUT_DATA, + true), + Function(reinterpret_cast(CastStringToDouble), CastFnStr(), {}, { OMNI_VARCHAR }, OMNI_DOUBLE, + INPUT_DATA, true), + Function(reinterpret_cast(CastStrWithDiffWidths), CastFnStr(), {}, { OMNI_VARCHAR }, OMNI_VARCHAR, + INPUT_DATA, true), + + // length functions + Function(reinterpret_cast(LengthChar), LengthFnStr(), {}, { OMNI_CHAR }, OMNI_LONG, INPUT_DATA), + Function(reinterpret_cast(LengthStr), LengthFnStr(), {}, { OMNI_VARCHAR }, OMNI_LONG, INPUT_DATA), + + // replace functions + Function(reinterpret_cast(LengthCharReturnInt32), LengthFnStr(), {}, { OMNI_CHAR }, OMNI_INT, + INPUT_DATA), + Function(reinterpret_cast(LengthStrReturnInt32), LengthFnStr(), {}, { OMNI_VARCHAR }, OMNI_INT, + INPUT_DATA), + + Function(reinterpret_cast(ConcatStrStrRetNull), ConcatNullFnStr(), {}, { OMNI_VARCHAR, OMNI_VARCHAR }, + OMNI_CHAR, INPUT_DATA_AND_OVERFLOW_NULL, true), + Function(reinterpret_cast(ConcatCharCharRetNull), ConcatNullFnStr(), {}, { OMNI_CHAR, OMNI_CHAR }, + OMNI_CHAR, INPUT_DATA_AND_OVERFLOW_NULL, true), + Function(reinterpret_cast(ConcatCharStrRetNull), ConcatNullFnStr(), {}, { OMNI_CHAR, OMNI_VARCHAR }, + OMNI_CHAR, INPUT_DATA_AND_OVERFLOW_NULL, true), + Function(reinterpret_cast(ConcatStrCharRetNull), ConcatNullFnStr(), {}, { OMNI_VARCHAR, OMNI_CHAR }, + OMNI_CHAR, INPUT_DATA_AND_OVERFLOW_NULL, true), + + Function(reinterpret_cast(CastIntToStringRetNull), CastNullFnStr(), {}, { OMNI_INT }, OMNI_VARCHAR, + INPUT_DATA_AND_OVERFLOW_NULL, true), + Function(reinterpret_cast(CastInt16ToStringRetNull), CastNullFnStr(), {}, { OMNI_SHORT }, OMNI_VARCHAR, + INPUT_DATA_AND_OVERFLOW_NULL, true), + Function(reinterpret_cast(CastInt8ToStringRetNull), CastNullFnStr(), {}, { OMNI_BYTE }, OMNI_VARCHAR, + INPUT_DATA_AND_OVERFLOW_NULL, true), + Function(reinterpret_cast(CastLongToStringRetNull), CastNullFnStr(), {}, { OMNI_LONG }, OMNI_VARCHAR, + INPUT_DATA_AND_OVERFLOW_NULL, true), + Function(reinterpret_cast(CastDoubleToStringRetNull), CastNullFnStr(), {}, { OMNI_DOUBLE }, + OMNI_VARCHAR, INPUT_DATA_AND_OVERFLOW_NULL, true), + Function(reinterpret_cast(CastDecimal64ToStringRetNull), CastNullFnStr(), {}, { OMNI_DECIMAL64 }, + OMNI_VARCHAR, INPUT_DATA_AND_OVERFLOW_NULL, true), + Function(reinterpret_cast(CastDecimal128ToStringRetNull), CastNullFnStr(), {}, { OMNI_DECIMAL128 }, + OMNI_VARCHAR, INPUT_DATA_AND_OVERFLOW_NULL, true), + Function(reinterpret_cast(CastDateToStringRetNull), CastNullFnStr(), {}, { OMNI_DATE32 }, OMNI_VARCHAR, + INPUT_DATA_AND_OVERFLOW_NULL, true), + + Function(reinterpret_cast(CastStringToByteRetNull), CastNullFnStr(), {}, { OMNI_VARCHAR }, OMNI_BYTE, + INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(CastStringToShortRetNull), CastNullFnStr(), {}, { OMNI_VARCHAR }, OMNI_SHORT, + INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(CastStringToIntRetNull), CastNullFnStr(), {}, { OMNI_VARCHAR }, OMNI_INT, + INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(CastStringToLongRetNull), CastNullFnStr(), {}, { OMNI_VARCHAR }, OMNI_LONG, + INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(CastStringToDoubleRetNull), CastNullFnStr(), {}, { OMNI_VARCHAR }, + OMNI_DOUBLE, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(CastStrWithDiffWidthsRetNull), CastNullFnStr(), {}, { OMNI_VARCHAR }, + OMNI_VARCHAR, INPUT_DATA_AND_OVERFLOW_NULL, true), + + Function(reinterpret_cast(InStr), InStrFnStr(), {}, { OMNI_VARCHAR, OMNI_VARCHAR }, OMNI_INT, + INPUT_DATA), + + // like functions + Function(reinterpret_cast(StartsWithStr), StartsWithFnStr(), {}, { OMNI_VARCHAR, OMNI_VARCHAR }, + OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(EndsWithStr), EndsWithFnStr(), {}, { OMNI_VARCHAR, OMNI_VARCHAR }, + OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(RegexMatch), RLikeFnStr(), {}, {OMNI_VARCHAR, OMNI_VARCHAR}, + OMNI_BOOLEAN, INPUT_DATA), + + Function(reinterpret_cast(Md5Str), Md5FnStr(), {}, { OMNI_VARCHAR }, OMNI_VARCHAR, + INPUT_DATA, true), + Function(reinterpret_cast(ContainsStr), ContainsFnStr(), {}, {OMNI_VARCHAR, OMNI_VARCHAR}, + OMNI_BOOLEAN, INPUT_DATA), + Function(reinterpret_cast(GreatestStr), GreatestStrFnStr(), {}, {OMNI_VARCHAR, OMNI_VARCHAR}, + OMNI_VARCHAR, INPUT_DATA_AND_NULL_AND_RETURN_NULL), + Function(reinterpret_cast(EmptyToNull), EmptyToNullStr(), {}, { OMNI_VARCHAR }, OMNI_VARCHAR, + INPUT_DATA, false), + Function(reinterpret_cast(StaticInvokeVarcharTypeWriteSideCheck), + StaticInvokeVarcharTypeWriteSideCheckFnStr(), {}, { OMNI_VARCHAR, OMNI_INT }, + OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(StaticInvokeCharReadPadding), StaticInvokeCharReadPaddingFnStr(), {}, + {OMNI_VARCHAR, OMNI_INT}, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(SubstringIndex), "substring_index", {}, + {OMNI_VARCHAR, OMNI_VARCHAR, OMNI_INT}, OMNI_VARCHAR, INPUT_DATA, true) + }; + + return stringFnRegistry; +} + +std::vector StringFunctionRegistryNotAllowReducePrecison::GetFunctions() +{ + std::vector stringFnRegistry = { + Function(reinterpret_cast(CastStringToDateNotAllowReducePrecison), CastFnStr(), {}, { OMNI_VARCHAR }, + OMNI_DATE32, INPUT_DATA, true), + Function(reinterpret_cast(CastStringToDateRetNullNotAllowReducePrecison), CastNullFnStr(), {}, + { OMNI_VARCHAR }, OMNI_DATE32, INPUT_DATA_AND_OVERFLOW_NULL), + }; + + return stringFnRegistry; +} + +std::vector StringFunctionRegistryAllowReducePrecison::GetFunctions() +{ + std::vector stringFnRegistry = { + Function(reinterpret_cast(CastStringToDateAllowReducePrecison), CastFnStr(), {}, { OMNI_VARCHAR }, + OMNI_DATE32, INPUT_DATA, true), + Function(reinterpret_cast(CastStringToDateRetNullAllowReducePrecison), CastNullFnStr(), {}, + { OMNI_VARCHAR }, OMNI_DATE32, INPUT_DATA_AND_OVERFLOW_NULL), + }; + + return stringFnRegistry; +} + +std::vector StringFunctionRegistryNotReplace::GetFunctions() +{ + std::vector stringFnRegistry = { + Function(reinterpret_cast(ReplaceStrStrStrWithRepNotReplace), ReplaceFnStr(), {}, + { OMNI_VARCHAR, OMNI_VARCHAR, OMNI_VARCHAR }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(ReplaceStrStrWithoutRepNotReplace), ReplaceFnStr(), {}, + { OMNI_VARCHAR, OMNI_VARCHAR }, OMNI_VARCHAR, INPUT_DATA, true), + }; + + return stringFnRegistry; +} + +std::vector StringFunctionRegistryReplace::GetFunctions() +{ + std::vector stringFnRegistry = { + Function(reinterpret_cast(ReplaceStrStrStrWithRepReplace), ReplaceFnStr(), {}, + { OMNI_VARCHAR, OMNI_VARCHAR, OMNI_VARCHAR }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(ReplaceStrStrWithoutRepReplace), ReplaceFnStr(), {}, + { OMNI_VARCHAR, OMNI_VARCHAR }, OMNI_VARCHAR, INPUT_DATA, true), + }; + + return stringFnRegistry; +} + +std::vector StringFunctionRegistrySupportNegativeAndZeroIndex::GetFunctions() +{ + std::vector stringFnRegistry = { + // substr functions + Function(reinterpret_cast(SubstrVarchar), SubstrFnStr(), {}, + { OMNI_VARCHAR, OMNI_INT, OMNI_INT }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(SubstrChar), SubstrFnStr(), {}, + { OMNI_CHAR, OMNI_INT, OMNI_INT }, OMNI_CHAR, INPUT_DATA, true), + Function(reinterpret_cast(SubstrVarchar), SubstrFnStr(), {}, + { OMNI_VARCHAR, OMNI_LONG, OMNI_LONG }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(SubstrChar), SubstrFnStr(), {}, + { OMNI_CHAR, OMNI_LONG, OMNI_LONG }, OMNI_CHAR, INPUT_DATA, true), + + // substr with start index functions + Function(reinterpret_cast(SubstrVarcharWithStart), SubstrFnStr(), {}, + { OMNI_VARCHAR, OMNI_INT }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(SubstrCharWithStart), SubstrFnStr(), {}, + { OMNI_CHAR, OMNI_INT }, OMNI_CHAR, INPUT_DATA, true), + Function(reinterpret_cast(SubstrVarcharWithStart), SubstrFnStr(), {}, + { OMNI_VARCHAR, OMNI_LONG }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(SubstrCharWithStart), SubstrFnStr(), {}, + { OMNI_CHAR, OMNI_LONG }, OMNI_CHAR, INPUT_DATA, true), + }; + + return stringFnRegistry; +} + +std::vector StringFunctionRegistrySupportNotNegativeAndZeroIndex::GetFunctions() +{ + std::vector stringFnRegistry = { + // substr functions + Function(reinterpret_cast(SubstrVarchar), SubstrFnStr(), {}, + { OMNI_VARCHAR, OMNI_INT, OMNI_INT }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(SubstrChar), SubstrFnStr(), {}, + { OMNI_CHAR, OMNI_INT, OMNI_INT }, OMNI_CHAR, INPUT_DATA, true), + Function(reinterpret_cast(SubstrVarchar), SubstrFnStr(), {}, + { OMNI_VARCHAR, OMNI_LONG, OMNI_LONG }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(SubstrChar), SubstrFnStr(), {}, + { OMNI_CHAR, OMNI_LONG, OMNI_LONG }, OMNI_CHAR, INPUT_DATA, true), + + // substr with start index functions + Function(reinterpret_cast(SubstrVarcharWithStart), SubstrFnStr(), {}, + { OMNI_VARCHAR, OMNI_INT }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(SubstrCharWithStart), SubstrFnStr(), {}, + { OMNI_CHAR, OMNI_INT }, OMNI_CHAR, INPUT_DATA, true), + Function(reinterpret_cast(SubstrVarcharWithStart), SubstrFnStr(), {}, + { OMNI_VARCHAR, OMNI_LONG }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(SubstrCharWithStart), SubstrFnStr(), {}, + { OMNI_CHAR, OMNI_LONG }, OMNI_CHAR, INPUT_DATA, true), + }; + + return stringFnRegistry; +} + +std::vector StringFunctionRegistrySupportNegativeAndNotZeroIndex::GetFunctions() +{ + std::vector stringFnRegistry = { + // substr functions + Function(reinterpret_cast(SubstrVarchar), SubstrFnStr(), {}, + { OMNI_VARCHAR, OMNI_INT, OMNI_INT }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(SubstrChar), SubstrFnStr(), {}, + { OMNI_CHAR, OMNI_INT, OMNI_INT }, OMNI_CHAR, INPUT_DATA, true), + Function(reinterpret_cast(SubstrVarchar), SubstrFnStr(), {}, + { OMNI_VARCHAR, OMNI_LONG, OMNI_LONG }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(SubstrChar), SubstrFnStr(), {}, + { OMNI_CHAR, OMNI_LONG, OMNI_LONG }, OMNI_CHAR, INPUT_DATA, true), + + // substr with start index functions + Function(reinterpret_cast(SubstrVarcharWithStart), SubstrFnStr(), {}, + { OMNI_VARCHAR, OMNI_INT }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(SubstrCharWithStart), SubstrFnStr(), {}, + { OMNI_CHAR, OMNI_INT }, OMNI_CHAR, INPUT_DATA, true), + Function(reinterpret_cast(SubstrVarcharWithStart), SubstrFnStr(), {}, + { OMNI_VARCHAR, OMNI_LONG }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(SubstrCharWithStart), SubstrFnStr(), {}, + { OMNI_CHAR, OMNI_LONG }, OMNI_CHAR, INPUT_DATA, true), + }; + + return stringFnRegistry; +} + +std::vector StringFunctionRegistrySupportNotNegativeAndNotZeroIndex::GetFunctions() +{ + std::vector stringFnRegistry = { + // substr functions + Function(reinterpret_cast(SubstrVarchar), SubstrFnStr(), {}, + { OMNI_VARCHAR, OMNI_INT, OMNI_INT }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(SubstrChar), SubstrFnStr(), {}, + { OMNI_CHAR, OMNI_INT, OMNI_INT }, OMNI_CHAR, INPUT_DATA, true), + Function(reinterpret_cast(SubstrVarchar), SubstrFnStr(), {}, + { OMNI_VARCHAR, OMNI_LONG, OMNI_LONG }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(SubstrChar), SubstrFnStr(), {}, + { OMNI_CHAR, OMNI_LONG, OMNI_LONG }, OMNI_CHAR, INPUT_DATA, true), + + // substr with start index functions + Function(reinterpret_cast(SubstrVarcharWithStart), SubstrFnStr(), {}, + { OMNI_VARCHAR, OMNI_INT }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(SubstrCharWithStart), SubstrFnStr(), {}, + { OMNI_CHAR, OMNI_INT }, OMNI_CHAR, INPUT_DATA, true), + Function(reinterpret_cast(SubstrVarcharWithStart), SubstrFnStr(), {}, + { OMNI_VARCHAR, OMNI_LONG }, OMNI_VARCHAR, INPUT_DATA, true), + Function(reinterpret_cast(SubstrCharWithStart), SubstrFnStr(), {}, + { OMNI_CHAR, OMNI_LONG }, OMNI_CHAR, INPUT_DATA, true), + }; + + return stringFnRegistry; +} + +std::vector StringToDecimalFunctionRegistryAllowRoundUp::GetFunctions() +{ + std::vector stringFnRegistry = { + Function(reinterpret_cast(CastStringToDecimal64RoundUpRetNull), CastNullFnStr(), {}, {OMNI_VARCHAR}, + OMNI_DECIMAL64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(CastStringToDecimal128RoundUpRetNull), CastNullFnStr(), {}, {OMNI_VARCHAR}, + OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(CastStringToDecimal64RoundUp), CastFnStr(), {}, {OMNI_VARCHAR}, + OMNI_DECIMAL64, + INPUT_DATA, true), + Function(reinterpret_cast(CastStringToDecimal128RoundUp), CastFnStr(), {}, {OMNI_VARCHAR}, + OMNI_DECIMAL128, + INPUT_DATA, true) + }; + return stringFnRegistry; +} + +std::vector StringToDecimalFunctionRegistry::GetFunctions() +{ + std::vector stringFnRegistry = { + Function(reinterpret_cast(CastStringToDecimal64RetNull), CastNullFnStr(), {}, {OMNI_VARCHAR}, + OMNI_DECIMAL64, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(CastStringToDecimal128RetNull), CastNullFnStr(), {}, {OMNI_VARCHAR}, + OMNI_DECIMAL128, INPUT_DATA_AND_OVERFLOW_NULL), + Function(reinterpret_cast(CastStringToDecimal64), CastFnStr(), {}, {OMNI_VARCHAR}, OMNI_DECIMAL64, + INPUT_DATA, true), + Function(reinterpret_cast(CastStringToDecimal128), CastFnStr(), {}, {OMNI_VARCHAR}, OMNI_DECIMAL128, + INPUT_DATA, true) + }; + return stringFnRegistry; +} +} diff --git a/core/src/codegen/func_registry_string.h b/core/src/codegen/func_registry_string.h new file mode 100644 index 0000000..d7bdaf4 --- /dev/null +++ b/core/src/codegen/func_registry_string.h @@ -0,0 +1,72 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2024. All rights reserved. + * Description: String Function Registry + */ +#ifndef OMNI_RUNTIME_FUNC_REGISTRY_STRING_H +#define OMNI_RUNTIME_FUNC_REGISTRY_STRING_H +#include "function.h" +#include "func_registry_base.h" +#include "util/type_util.h" + +// functions called directly from codegen +const std::string strCompareStr = "compare"; +const std::string strEqualStr = "strequal"; + +namespace omniruntime::codegen { +class StringFunctionRegistry : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; + +class StringFunctionRegistryNotAllowReducePrecison : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; + +class StringToDecimalFunctionRegistryAllowRoundUp : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; + +class StringToDecimalFunctionRegistry : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; + +class StringFunctionRegistryAllowReducePrecison : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; + +class StringFunctionRegistryNotReplace : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; + +class StringFunctionRegistryReplace : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; + +class StringFunctionRegistrySupportNegativeAndZeroIndex : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; + +class StringFunctionRegistrySupportNotNegativeAndZeroIndex : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; + +class StringFunctionRegistrySupportNegativeAndNotZeroIndex : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; + +class StringFunctionRegistrySupportNotNegativeAndNotZeroIndex : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; +} + +#endif // OMNI_RUNTIME_FUNC_REGISTRY_STRING_H diff --git a/core/src/codegen/func_registry_varchar_vector.cpp b/core/src/codegen/func_registry_varchar_vector.cpp new file mode 100644 index 0000000..c4ce6ae --- /dev/null +++ b/core/src/codegen/func_registry_varchar_vector.cpp @@ -0,0 +1,21 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description: Varchar Vector Functions Registry + */ +#include "func_registry_varchar_vector.h" +#include "functions/varcharVectorfunctions.h" + +namespace omniruntime::codegen { +using namespace omniruntime::type; +using namespace codegen::function; + +std::vector VarcharVectorFunctionRegistry::GetFunctions() +{ + std::vector paramTypes = { OMNI_LONG, OMNI_INT, OMNI_VARCHAR }; + std::vector varcharVectorFnRegistry = { Function(reinterpret_cast(WrapVarcharVector), + "WrapVarcharVector", {}, paramTypes, OMNI_INT), + Function(reinterpret_cast(WrapSetBitNull), "WrapSetBitNull", {}, { OMNI_INT }, OMNI_BOOLEAN), + Function(reinterpret_cast(WrapIsBitNull), "WrapIsBitNull", {}, { OMNI_INT }, OMNI_BOOLEAN) }; + return varcharVectorFnRegistry; +} +} diff --git a/core/src/codegen/func_registry_varchar_vector.h b/core/src/codegen/func_registry_varchar_vector.h new file mode 100644 index 0000000..aede533 --- /dev/null +++ b/core/src/codegen/func_registry_varchar_vector.h @@ -0,0 +1,20 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description: Varchar Vector Functions Registry + */ +#ifndef OMNI_RUNTIME_FUNC_REGISTRY_VARCHAR_VECTOR_H +#define OMNI_RUNTIME_FUNC_REGISTRY_VARCHAR_VECTOR_H +#include "function.h" +#include "func_registry_base.h" + +// functions called directly from codegen +const std::string WrapVarcharVectorStr = "WrapVarcharVector"; + +namespace omniruntime::codegen { +class VarcharVectorFunctionRegistry : public BaseFunctionRegistry { +public: + std::vector GetFunctions() override; +}; +} + +#endif // OMNI_RUNTIME_FUNC_REGISTRY_VARCHAR_VECTOR_H diff --git a/core/src/codegen/func_signature.cpp b/core/src/codegen/func_signature.cpp new file mode 100644 index 0000000..2f19f68 --- /dev/null +++ b/core/src/codegen/func_signature.cpp @@ -0,0 +1,106 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description: + */ +#include +#include "util/type_util.h" +#include "func_signature.h" + +namespace omniruntime::codegen { +using namespace omniruntime::type; + +FunctionSignature::FunctionSignature() = default; + +FunctionSignature::FunctionSignature(const std::string &name, std::vector params, + const omniruntime::type::DataTypeId &returnType, void *address) + : funcName(name), paramTypes(std::move(params)), retType(returnType), funcAddress(address) +{} + +// Copy constructor +FunctionSignature::FunctionSignature(const FunctionSignature &fs) + : funcName(fs.funcName), paramTypes(fs.paramTypes), retType(fs.retType), funcAddress(fs.funcAddress) +{} + +FunctionSignature::~FunctionSignature() = default; + +std::string FunctionSignature::GetName() const +{ + return this->funcName; +} + +const std::vector &FunctionSignature::GetParams() const +{ + return this->paramTypes; +} + +DataTypeId FunctionSignature::GetReturnType() const +{ + return this->retType; +} + +void *FunctionSignature::GetFunctionAddress() const +{ + return this->funcAddress; +} + +FunctionSignature &FunctionSignature::operator = (FunctionSignature other) +{ + std::swap(funcName, other.funcName); + std::swap(paramTypes, other.paramTypes); + std::swap(retType, other.retType); + std::swap(funcAddress, other.funcAddress); + return *this; +} + +bool FunctionSignature::operator == (const FunctionSignature &other) const +{ + if (this->funcName != other.funcName || this->retType != other.retType || + this->paramTypes.size() != other.paramTypes.size()) { + return false; + } + + for (uint32_t i = 0; i < this->paramTypes.size(); i++) { + if (this->paramTypes.at(i) != other.paramTypes.at(i)) { + return false; + } + } + return true; +} + +size_t FunctionSignature::HashCode() const +{ + auto hashName = std::hash {}(this->funcName); + auto hashReturnType = std::hash {}(static_cast(this->retType)); + auto combinedHash = hashName ^ (hashReturnType << 1); + for (auto param : this->paramTypes) { + auto hashParamType = std::hash {}(static_cast(param)); + combinedHash = hashParamType ^ (combinedHash << 1); + } + return combinedHash; +} + +std::string FunctionSignature::ToString() const +{ + auto result = this->funcName; + for (auto const & param : this->paramTypes) { + result += "_"; + result += TypeUtil::TypeToString(param); + } + result = result + "_" + TypeUtil::TypeToString(this->retType); + return result; +} + +std::string FunctionSignature::ToString(omniruntime::op::OverflowConfig *overflowConfig) const +{ + auto result = this->funcName; + if (overflowConfig != nullptr && overflowConfig->GetOverflowConfigId() == omniruntime::op::OVERFLOW_CONFIG_NULL) { + result += "_null"; + } + for (auto const & param : this->paramTypes) { + result += "_"; + result += TypeUtil::TypeToString(param); + } + result = result + "_" + TypeUtil::TypeToString(this->retType); + return result; +} +} diff --git a/core/src/codegen/func_signature.h b/core/src/codegen/func_signature.h new file mode 100644 index 0000000..1534f71 --- /dev/null +++ b/core/src/codegen/func_signature.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description: + */ +#ifndef __FUNC_SIGNATURE_H__ +#define __FUNC_SIGNATURE_H__ +#include +#include +#include +#include +#include +#include "operator/config/operator_config.h" + +namespace omniruntime::codegen { +class FunctionSignature { +public: + FunctionSignature(); + FunctionSignature(const std::string &name, std::vector params, + const omniruntime::type::DataTypeId &returnType, void *address = nullptr); + FunctionSignature(const FunctionSignature &fs); + FunctionSignature &operator = (FunctionSignature other); + bool operator == (const FunctionSignature &other) const; + ~FunctionSignature(); + std::string GetName() const; + const std::vector &GetParams() const; + omniruntime::type::DataTypeId GetReturnType() const; + void *GetFunctionAddress() const; + size_t HashCode() const; + std::string ToString() const; + std::string ToString(omniruntime::op::OverflowConfig *overflowConfig) const; + +private: + std::string funcName; + std::vector paramTypes {}; + omniruntime::type::DataTypeId retType; + void *funcAddress = nullptr; +}; +} + +#endif \ No newline at end of file diff --git a/core/src/codegen/function.cpp b/core/src/codegen/function.cpp new file mode 100644 index 0000000..b16dd5f --- /dev/null +++ b/core/src/codegen/function.cpp @@ -0,0 +1,62 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description: Maps a function in function expression to a precompiled function + */ + +#include "function.h" + +using namespace omniruntime::type; + +namespace omniruntime::codegen { +Function::Function(void *address, const std::string &name, const std::vector &aliases, + const std::vector ¶mTypes, const DataTypeId &retType, NullableResultType nullableResultType, + bool setExecutionContext) +{ + this->address = address; + this->nullableResultType = nullableResultType; + this->isExecContextSet = setExecutionContext; + // create function sig to register for codegen + this->signatures.emplace_back(name, paramTypes, retType, address); + // create function sigs for different functions calls in omni-runtime + for (auto &alias : aliases) { + this->signatures.emplace_back(alias, paramTypes, retType, address); + } +} + +Function::~Function() = default; + +const std::vector &Function::GetSignatures() const +{ + return this->signatures; +} + +std::string Function::GetId() const +{ + return this->signatures.at(0).ToString(); +} + +DataTypeId Function::GetReturnType() const +{ + return this->signatures.at(0).GetReturnType(); +} + +const std::vector &Function::GetParamTypes() const +{ + return this->signatures.at(0).GetParams(); +} + +const void *Function::GetAddress() const +{ + return this->address; +} + +const NullableResultType Function::GetNullableResultType() const +{ + return this->nullableResultType; +} + +bool Function::IsExecutionContextSet() const +{ + return this->isExecContextSet; +} +} diff --git a/core/src/codegen/function.h b/core/src/codegen/function.h new file mode 100644 index 0000000..ef160bb --- /dev/null +++ b/core/src/codegen/function.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description: OmniRuntime Function Header + */ +#ifndef OMNI_RUNTIME_FUNCTION_H +#define OMNI_RUNTIME_FUNCTION_H + +#include "func_signature.h" +#include "util/type_util.h" + +namespace omniruntime::codegen { +enum NullableResultType { + INPUT_DATA, + INPUT_DATA_AND_NULL, + INPUT_DATA_AND_OVERFLOW_NULL, + INPUT_DATA_AND_NULL_AND_RETURN_NULL, + DEFAULT +}; + +class Function { +public: + Function() = default; + + /** + * Constructs an omni-runtime Function object that contains the functionality and attributes of an omni-runtime + * function + * + * @param name function name + * @param address contains a void pointer of the function + * @param aliases allows to specify multiple names for the same function + * @param paramTypes vector of datatypes of arguments - VARCHAR AND CHAR are expanded to their corresponding + * function signature equivalents to contain value and length for VARCHAR and value, length and width for CHAR + * @param retType data type of return value + * @param setExecutionContext if true - pass the execution context to func signature as a param, + * it will always be the first parameter in your function, default to false + */ + Function(void *address, const std::string &name, const std::vector &aliases, + const std::vector ¶mTypes, const omniruntime::type::DataTypeId &retType, + NullableResultType = DEFAULT, bool setExecutionContext = false); + + // Copy constructor + Function &operator = (Function other) + { + std::swap(signatures, other.signatures); + return *this; + } + + ~Function(); + const std::vector &GetSignatures() const; + omniruntime::type::DataTypeId GetReturnType() const; + const std::vector &GetParamTypes() const; + std::string GetId() const; + const void *GetAddress() const; + const NullableResultType GetNullableResultType() const; + bool IsExecutionContextSet() const; + +private: + void *address; + // signatures corresponding to that function + std::vector signatures = {}; + NullableResultType nullableResultType; + bool isExecContextSet = false; +}; +} + +#endif // OMNI_RUNTIME_FUNCTION_H diff --git a/core/src/codegen/functions/README.md b/core/src/codegen/functions/README.md new file mode 100644 index 0000000..44add3f --- /dev/null +++ b/core/src/codegen/functions/README.md @@ -0,0 +1,117 @@ +# Adding new external functions +To add new functions to omni-runtime, follow the steps below. You will need to modify: `externalfunctions.h`, `externalfunctions.cpp`, `func_registry.cpp`. + +## Steps to create function +1. If needed, create new `.h` and `.cpp` file for function types such as ``, otherwise put new functions in the `externalfunctions.h` and `externalfunctions.cpp` files. +2. Add the function declaration of the + function in the `externalfunctions.h` file. + ex: + ```c++ + extern DELLEXPORT int32_t add1(int32_t x); + ``` + Parameter types can be `int32_t`, `int64_t`, `double`, `boolean`, `string` or `decimal`. + * For variable length `string` parameter aka `VARCHAR`, it's passed in as `char*`(pointer to the data) and `int32_t`(length of the string) + * For fixed length `string` parameter aka `CHAR(width)`, it's passed in as `char*`(pointer to the data) `int32_t`(width) and `int32_t`(length of the string) + * For `decimal` 128bit parameter, it's passed in as `int64_t`(high 64 bit) and `int64_t`(low 64 bit) values. + * If you need to allocate memory for returning `string` values, you can also pass in a `int64_t` in the beginning of parameter list as the pointer address to an `ExecutionContext` object, and use this object to allocate new memory for better performance and void memory leak + + Return type can also be `int32_t`, `int64_t`, `double`, `boolean`, `string` or `decimal`. + * For `string` return type, the function return type should be `char*`, but a pointer to the return string length will also be passed in at the end of the param list + * For `decimal` 128bit return type, the pointers to the high bits and low bits must be passed in at the end of the param list. + +3. Write function in C++ in the `externalfunctions.cpp` file(If you want to use template for your functions you can put your implementation in header). + ex: + ```c++ + extern DLLEXPORT int32_t add1(int32_t x) { + return x + 1; + } + ``` + +6. Register your functions in Function Registry: + You can either register the functions in `external_func_registry.cpp` or create your own registry + + * If adding to `external_func_registry`, you only need to add your function in the `GetFunctions()` method. + * If creating your own registry, you need to implement the `BaseFunctionRegistry` interface in `func_registry_base.h` and add all your functions in `GetFunctions()` method. + + ```c++ + vector ExternalFunctionRegistry::GetFunctions() + { + std::vector externalFunctionRegistry = { + Function(reinterpret_cast(Increment), "Increment", {}, {OMNI_INT}, + OMNI_INT), + Function(reinterpret_cast(Increment), "Increment", {}, {OMNI_LONG}, + OMNI_LONG), + }; + return externalFunctionRegistry; + } + ``` + + The return types and parameter types in function signature registered can only be the data types, currently supporting: + * OMNI_INT + * OMNI_LONG + * OMNI_DOUBLE + * OMNI_BOOLEAN + * OMNI_VARCHAR + * OMNI_CHAR + * OMNI_DECIMAL64 + * OMNI_DECIMAL128 + + +7. Finally, if you are adding a new function registry, register it in the `FunctionRegistry` class in `func_registry` by adding it to the registries list in `GetFunctionRegistries()` method: + + ```c++ + vector> FunctionRegistry::GetFunctionRegistries() + { + vector> functionRegistries; + // Other registries... + // External functions + functionRegistries.push_back(make_unique()); + // Put your registry here + + return functionRegistries; + } + ``` + +## Exception handling + +When registering function, set `setExecutionContext` to `true`: +```c++ +Function(reinterpret_cast(Increment), "Increment", {}, {OMNI_INT}, OMNI_INT, true) +``` + +In your function whenever you need to throw an error or exception, set error message in the execution context by using helper function `SetError` in `context_helper.h`: +```c++ +#include "context_helper.h" + +// Make sure you have the contextPtr as the first arg in your function +extern "C" DLLEXPORT int64_t DivDec64Ret64(int64_t contextPtr, int64_t x, int64_t y) +{ + if (y == 0) { + char message[] = "Divided by zero error!"; + SetError(contextPtr, message, sizeof(message)/sizeof(char)); + return 0; + } + return round(double(x)/y); +} +``` + +`Filter` and `Projection` operators will throw `OMNI_EXCEPTION` which will be caught at JNI layer and be returned to engine side. + +# Adding new functions +## OmniRuntime Function Class +``` +Function(void *address, string &fnName, vector &aliases, vector ¶mTypes, DataType &retType, bool setExecutionContext); +``` +Constructs a omni-runtime `Function` object that contains the functionality and attributes of an omni-runtime function including function signature to facilitate registration and `funcID` to uniquely identify built-in or external functions in the `functions` dir. + +- `fnName` denotes the name the function will be referenced by the query. `substr`, `LIKE`, `abs`, etc. are examples of fnNames for base functions. +- `generateFuncID` is used to generate a `funcID` that is used to identify the corresponding function from the `functions` dir that the omniruntime `Function` refers to. +- `aliases` allows us to specify multiple names for the same function +- `address` is the void function pointer +- `paramTypes` is a vector of data types of arguments - `VARCHAR` and `CHAR` are expanded to their corresponding function signature equivalent types to contain value, length for VARCHAR and value, length and width for `CHAR` +- `retType` is the data type of the return value +- `setExecutionContext` if true - pass the execution context to func signature as a param, it will always be the first parameter in your function, default to false + +## Function Registry +- Instead of a single function registry class, each xxxfunctions.cpp file has a corresponding xxx_func_registry.cpp that appends the omniruntime `Function` to the single static vector `functionRegistry`. +- `LookupFunction` returns the omniruntime `Function` corresponding to the `funcID` provided. diff --git a/core/src/codegen/functions/datetime_functions.cpp b/core/src/codegen/functions/datetime_functions.cpp new file mode 100644 index 0000000..def73b3 --- /dev/null +++ b/core/src/codegen/functions/datetime_functions.cpp @@ -0,0 +1,128 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + * Description: date time functions implementation + */ + +#include "datetime_functions.h" +#include "codegen/context_helper.h" +#include "type/date32.h" +#include "codegen/time_util.h" +#include + +namespace omniruntime::codegen::function { +extern "C" DLLEXPORT int64_t UnixTimestampFromStr(const char *timeStr, int32_t timeLen, bool isNullTimeStr, + const char *fmtStr, int32_t fmtLen, bool isNullFmtStr, const char *tzStr, int32_t tzLen, bool isNullTzStr, + const char *policyStr, int32_t policyLen, bool isNullPolStr, bool *retIsNull) +{ + if (isNullTimeStr || isNullFmtStr || fmtLen == 0 || timeLen == 0) { + *retIsNull = true; + return 0; + } + std::string timeStr1(timeStr, timeLen); + std::string fmtStr1(fmtStr, fmtLen); + std::string fmtOmniTimeStr = toOmniTimeFormat(fmtStr1); + int fmtOmniTimeStrLen = fmtOmniTimeStr.length(); + if (!TimeUtil::IsTimeValid(timeStr, timeLen, fmtOmniTimeStr.c_str(), + fmtOmniTimeStrLen, policyStr)) { + *retIsNull = true; + return 0; + } + setenv("TZ", TimeZoneUtil::GetTZ(tzStr), 1); + tzset(); + struct tm timeInfo = { 0 }; + strptime(timeStr1.c_str(), fmtOmniTimeStr.c_str(), &timeInfo); + time_t timeStamp = mktime(&timeInfo); + if (TimeZoneUtil::JudgeDSTByUnixTimestampFromStr(tzStr, tzLen, &timeInfo, timeStr, timeLen, fmtStr, fmtLen)) { + timeStamp -= type::SECOND_OF_HOUR; + } + return timeStamp; +} + +extern "C" DLLEXPORT int64_t UnixTimestampFromDate(int32_t date, const char *fmtStr, int32_t fmtLen, + const char *tzStr, int32_t tzLen, const char *policyStr, int32_t policyLen, bool isNull) +{ + if (isNull) { + return 0; + } + setenv("TZ", TimeZoneUtil::GetTZ(tzStr), 1); + tzset(); + time_t desiredTime = type::SECOND_OF_DAY * date; + struct tm ltm; + localtime_r(&desiredTime, <m); + time_t result = desiredTime - ltm.tm_gmtoff; + result += TimeZoneUtil::AdjustDSTByUnixTimestampFromDate(tzStr, tzLen, <m, desiredTime) * type::SECOND_OF_HOUR; + return static_cast(result); +} + +extern "C" DLLEXPORT char *FromUnixTime(int64_t contextPtr, bool *isNull, int64_t timestamp, const char *fmtStr, + int32_t fmtLen, const char *tzStr, int32_t tzLen, int32_t *outLen) +{ + time_t timeStampVal = timestamp; + setenv("TZ", TimeZoneUtil::GetTZ(tzStr), 1); + tzset(); + struct tm ltm; + localtime_r(&timeStampVal, <m); + if (!TimeZoneUtil::JudgeDSTByFromUnixTime(tzStr, tzLen, <m)) { + timeStampVal -= type::SECOND_OF_HOUR; + localtime_r(&timeStampVal, <m); + } + int32_t resultLen = fmtLen + 3; + auto result = ArenaAllocatorMalloc(contextPtr, resultLen); + std::string fmtStr1(fmtStr, fmtLen); + std::string fmtOmniTimeStr = toOmniTimeFormat(fmtStr1); + auto ret = strftime(result, resultLen, fmtOmniTimeStr.c_str(), <m); + *isNull = static_cast(ret) == 0; + *outLen = ret; + return result; +} + +std::string toOmniTimeFormat(const std::string &format) +{ + std::string result = format; + const std::pair replacements[] = { + {"yyyy", "%Y"}, {"MM", "%m"}, {"dd", "%d"}, + {"HH", "%H"}, {"mm", "%M"}, {"ss", "%S"}}; + for (const auto &[from, to] : replacements) { + size_t pos = 0; + while ((pos = result.find(from, pos)) != std::string::npos) { + result.replace(pos, from.length(), to); + pos += to.length(); + } + } + return result; +} + +extern "C" DLLEXPORT char *FromUnixTimeRetNull(int64_t contextPtr, bool *isNull, int64_t timestamp, const char *fmtStr, + int32_t fmtLen, const char *tzStr, int32_t tzLen, int32_t *outLen) +{ + return FromUnixTime(contextPtr, isNull, timestamp, fmtStr, fmtLen, tzStr, tzLen, outLen); +} + +extern "C" DLLEXPORT int32_t DateTrunc(int64_t contextPtr, int32_t days, const char *levelStr, int32_t len) +{ + type::DateTruncMode level = type::Date32::ParseTruncLevel(std::string(levelStr, len)); + int32_t result; + if (type::Date32::TruncDate(days, level, result) != type::Status::CONVERT_SUCCESS) { + std::ostringstream errorMessage; + errorMessage << "The level is not supported yet: " << std::string(levelStr, len); + SetError(contextPtr, errorMessage.str()); + return 0; + } + return result; +} + +extern "C" DLLEXPORT int32_t DateTruncRetNull(bool *isNull, int32_t days, const char *levelStr, int32_t len) +{ + type::DateTruncMode level = type::Date32::ParseTruncLevel(std::string(levelStr, len)); + int32_t result; + if (type::Date32::TruncDate(days, level, result) != type::Status::CONVERT_SUCCESS) { + *isNull = true; + } + return result; +} + +extern "C" DLLEXPORT int32_t DateAdd(int32_t right, int32_t left) +{ + return right + left; +} +} diff --git a/core/src/codegen/functions/datetime_functions.h b/core/src/codegen/functions/datetime_functions.h new file mode 100644 index 0000000..6a8777f --- /dev/null +++ b/core/src/codegen/functions/datetime_functions.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: date time functions implementation + */ + +#ifndef OMNI_RUNTIME_DATETIME_FUNCTIONS_H +#define OMNI_RUNTIME_DATETIME_FUNCTIONS_H + +#include +#include +// All extern functions go here temporarily +#ifdef _WIN32 +#define DLLEXPORT __declspec(dllexport) +#else +#define DLLEXPORT +#endif + +namespace omniruntime::codegen::function { +extern "C" DLLEXPORT int64_t UnixTimestampFromStr(const char *timeStr, int32_t timeLen, bool isNullTimeStr, + const char *fmtStr, int32_t fmtLen, bool isNullFmtStr, const char *tzStr, int32_t tzLen, + bool isNullTzStr, const char *policyStr, int32_t policyLen, bool isNullPolStr, bool *retIsNull); + +extern "C" DLLEXPORT int64_t UnixTimestampFromDate(int32_t date, const char *fmtStr, int32_t fmtLen, + const char *tzStr, int32_t tzLen, const char *policyStr, int32_t policyLen, bool isNull); + +extern "C" DLLEXPORT char *FromUnixTime(int64_t contextPtr, bool *isNull, int64_t timestamp, const char *fmtStr, + int32_t fmtLen, const char *tzStr, int32_t tzLen, int32_t *outLen); + +extern "C" DLLEXPORT char *FromUnixTimeRetNull(int64_t contextPtr, bool *isNull, int64_t timestamp, const char *fmtStr, + int32_t fmtLen, const char *tzStr, int32_t tzLen, int32_t *outLen); + +extern "C" DLLEXPORT int32_t DateTrunc(int64_t contextPtr, int32_t days, const char *levelStr, int32_t len); + +extern "C" DLLEXPORT int32_t DateTruncRetNull(bool *isNull, int32_t days, const char *levelStr, int32_t len); + +extern "C" DLLEXPORT int32_t DateAdd(int32_t right, int32_t left); + +std::string toOmniTimeFormat(const std::string& format); +} +#endif // OMNI_RUNTIME_DATETIME_FUNCTIONS_H diff --git a/core/src/codegen/functions/decimal_arithmetic_functions.cpp b/core/src/codegen/functions/decimal_arithmetic_functions.cpp new file mode 100644 index 0000000..8ecf9fb --- /dev/null +++ b/core/src/codegen/functions/decimal_arithmetic_functions.cpp @@ -0,0 +1,1415 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description: registry math function name + */ +#include "decimal_arithmetic_functions.h" + +using namespace omniruntime::type; + +namespace omniruntime::codegen::function { +const std::string DECIMAL_OVERFLOW { "Decimal overflow" }; /* NOLINT */ +const std::string DIVIDE_ZERO { "Division by zero" }; /* NOLINT */ + +// decimal128 arithmetical functions +extern "C" DLLEXPORT int32_t Decimal128Compare(int64_t xHigh, uint64_t xLow, int32_t xPrecision, int32_t xScale, + int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, bool isNull) +{ + int128_t xValue = Decimal128(xHigh, xLow).ToInt128(); + int128_t yValue = Decimal128(yHigh, yLow).ToInt128(); + if (xScale == yScale) { + if (xValue == yValue) { + return 0; + } else { + return xValue > yValue ? 1 : -1; + } + } + + Decimal128Wrapper x(xValue); + Decimal128Wrapper y(yValue); + return x.SetScale(xScale).Compare(y.SetScale(yScale)); +} + +extern "C" DLLEXPORT void AbsDecimal128(int64_t xHigh, uint64_t xLow, int32_t xPrecision, int32_t xScale, bool isNull, + int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result(xHigh, xLow); + result.Abs(); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +// decimal64 arithmetical functions +extern "C" DLLEXPORT int32_t Decimal64Compare(int64_t x, int32_t xPrecision, int32_t xScale, int64_t y, + int32_t yPrecision, int32_t yScale, bool isNull) +{ + Decimal64 left(x); + Decimal64 right(y); + return left.SetScale(xScale).Compare(right.SetScale(yScale)); +} + +extern "C" DLLEXPORT int64_t AbsDecimal64(int64_t x, int32_t xPrecision, int32_t xScale, bool isNull, + int32_t outPrecision, int32_t outScale) +{ + return std::abs(x); +} + +// Decimal AddOperator ReScale +extern "C" DLLEXPORT int64_t AddDec64Dec64Dec64ReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale) +{ + Decimal64 result; + DecimalOperations::InternalDecimalAdd(Decimal64(x).SetScale(xScale), xScale, xPrecision, + Decimal64(y).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_RETURN(result, outPrecision); + return result.GetValue(); +} + +extern "C" DLLEXPORT void AddDec64Dec64Dec128ReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, + uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalAdd(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void AddDec128Dec128Dec128ReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, + int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalAdd(Decimal128Wrapper(xHigh, xLow).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(yHigh, yLow).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void AddDec64Dec128Dec128ReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalAdd(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(yHigh, yLow).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void AddDec128Dec64Dec128ReScale(int64_t contextPtr, int64_t yHigh, uint64_t yLow, + int32_t yPrecision, int32_t yScale, int64_t x, int32_t xPrecision, int32_t xScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalAdd(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(yHigh, yLow).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + + +// Decimal SubOperator ReScale +extern "C" DLLEXPORT int64_t SubDec64Dec64Dec64ReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale) +{ + Decimal64 result; + DecimalOperations::InternalDecimalSubtract(Decimal64(x).SetScale(xScale), xScale, xPrecision, + Decimal64(y).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_RETURN(result, outPrecision); + return result.GetValue(); +} + +extern "C" DLLEXPORT void SubDec64Dec64Dec128ReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, + uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalSubtract(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void SubDec128Dec128Dec128ReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, + int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalSubtract(Decimal128Wrapper(xHigh, xLow).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(yHigh, yLow).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void SubDec64Dec128Dec128ReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalSubtract(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(yHigh, yLow).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void SubDec128Dec64Dec128ReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalSubtract(Decimal128Wrapper(xHigh, xLow).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +// Decimal MulOperator ReScale +extern "C" DLLEXPORT int64_t MulDec64Dec64Dec64ReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale) +{ + Decimal64 result; + DecimalOperations::InternalDecimalMultiply(Decimal64(x).SetScale(xScale), xScale, xPrecision, + Decimal64(y).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_RETURN(result, outPrecision); + return result.GetValue(); +} + +extern "C" DLLEXPORT void MulDec64Dec64Dec128ReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, + uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalMultiply(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void MulDec128Dec128Dec128ReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, + int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result = + Decimal128Wrapper(xHigh, xLow).MultiplyRoundUp(Decimal128Wrapper(yHigh, yLow), xScale + yScale - outScale); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void MulDec64Dec128Dec128ReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalMultiply(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(yHigh, yLow).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void MulDec128Dec64Dec128ReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalMultiply(Decimal128Wrapper(xHigh, xLow), xScale, xPrecision, + Decimal128Wrapper(y).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +// Decimal DivOperation ReScale +extern "C" DLLEXPORT int64_t DivDec64Dec64Dec64ReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale) +{ + CHECK_DIVIDE_BY_ZERO_RETURN(y); + Decimal64 result; + DecimalOperations::InternalDecimalDivide(Decimal64(x).SetScale(xScale), xScale, xPrecision, + Decimal64(y).SetScale(yScale), yScale, yPrecision, result, outScale); + result.ReScale(outScale); + CHECK_OVERFLOW_RETURN(result, outPrecision); + return result.GetValue(); +} + +extern "C" DLLEXPORT int64_t DivDec64Dec128Dec64ReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t yHigh, int64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale) +{ + Decimal128Wrapper divisor(yHigh, yLow); + CHECK_DIVIDE_BY_ZERO_RETURN(divisor); + Decimal128Wrapper result; + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + divisor.SetScale(yScale), yScale, yPrecision, result, outScale); + result.ReScale(outScale); + CHECK_OVERFLOW_RETURN(result, outPrecision); + return result.HighBits() < 0 ? -static_cast(result.LowBits()) : static_cast(result.LowBits()); +} + +extern "C" DLLEXPORT int64_t DivDec128Dec64Dec64ReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale) +{ + Decimal128Wrapper divisor(y); + CHECK_DIVIDE_BY_ZERO_RETURN(divisor); + Decimal128Wrapper result; + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(xHigh, xLow).SetScale(xScale), xScale, xPrecision, + divisor.SetScale(yScale), yScale, yPrecision, result, outScale); + result.ReScale(outScale); + CHECK_OVERFLOW_RETURN(result, outPrecision); + return result.HighBits() < 0 ? -static_cast(result.LowBits()) : static_cast(result.LowBits()); +} + +extern "C" DLLEXPORT void DivDec64Dec64Dec128ReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, + uint64_t *outLowPtr) +{ + Decimal128Wrapper divisor(y); + CHECK_DIVIDE_BY_ZERO(divisor); + Decimal128Wrapper result; + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + divisor.SetScale(yScale), yScale, yPrecision, result, outScale); + result.ReScale(outScale); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void DivDec128Dec128Dec128ReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, + int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper divisor(yHigh, yLow); + CHECK_DIVIDE_BY_ZERO(divisor); + Decimal128Wrapper result; + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(xHigh, xLow).SetScale(xScale), xScale, xPrecision, + divisor.SetScale(yScale), yScale, yPrecision, result, outScale); + result.ReScale(outScale); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void DivDec64Dec128Dec128ReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper divisor(yHigh, yLow); + CHECK_DIVIDE_BY_ZERO(divisor); + Decimal128Wrapper result; + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + divisor.SetScale(yScale), yScale, yPrecision, result, outScale); + result.ReScale(outScale); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void DivDec128Dec64Dec128ReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper divisor(y); + CHECK_DIVIDE_BY_ZERO(divisor); + Decimal128Wrapper result; + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(xHigh, xLow).SetScale(xScale), xScale, xPrecision, + divisor.SetScale(yScale), yScale, yPrecision, result, outScale); + result.ReScale(outScale); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +// Decimal ModOperation ReScale +extern "C" DLLEXPORT int64_t ModDec64Dec64Dec64ReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale) +{ + CHECK_DIVIDE_BY_ZERO_RETURN(y); + Decimal64 result; + DecimalOperations::InternalDecimalMod(Decimal64(x).SetScale(xScale), xScale, xPrecision, + Decimal64(y).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_RETURN(result, outPrecision); + return result.GetValue(); +} + +extern "C" DLLEXPORT int64_t ModDec64Dec128Dec64ReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale) +{ + Decimal128Wrapper divisor(yHigh, yLow); + CHECK_DIVIDE_BY_ZERO_RETURN(divisor); + Decimal64 result; + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + divisor.SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_RETURN(result, outPrecision); + return result.GetValue(); +} + +extern "C" DLLEXPORT int64_t ModDec128Dec64Dec64ReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale) +{ + Decimal128Wrapper divisor(y); + CHECK_DIVIDE_BY_ZERO_RETURN(divisor); + Decimal64 result; + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(xHigh, xLow).SetScale(xScale), xScale, xPrecision, + divisor.SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_RETURN(result, outPrecision); + return result.GetValue(); +} + +extern "C" DLLEXPORT void ModDec128Dec64Dec128ReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper divisor(y); + CHECK_DIVIDE_BY_ZERO(divisor); + Decimal128Wrapper result; + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(xHigh, xLow).SetScale(xScale), xScale, xPrecision, + divisor.SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void ModDec128Dec128Dec128ReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, + int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper divisor(yHigh, yLow); + CHECK_DIVIDE_BY_ZERO(divisor); + Decimal128Wrapper result; + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(xHigh, xLow).SetScale(xScale), xScale, xPrecision, + divisor.SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT int64_t ModDec128Dec128Dec64ReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, + int32_t outPrecision, int32_t outScale) +{ + Decimal128Wrapper divisor(yHigh, yLow); + CHECK_DIVIDE_BY_ZERO_RETURN(divisor); + Decimal128Wrapper result; + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(xHigh, xLow).SetScale(xScale), xScale, xPrecision, + divisor.SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_RETURN(result, outPrecision); + + return result.HighBits() < 0 ? -static_cast(result.LowBits()) : static_cast(result.LowBits()); +} + +extern "C" DLLEXPORT void ModDec64Dec128Dec128ReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper divisor(yHigh, yLow); + CHECK_DIVIDE_BY_ZERO(divisor); + Decimal128Wrapper result; + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + divisor.SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +// return null +extern "C" DLLEXPORT int64_t AddDec64Dec64Dec64RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale) +{ + Decimal64 result; + DecimalOperations::InternalDecimalAdd(Decimal64(x).SetScale(xScale), xScale, xPrecision, + Decimal64(y).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_RETURN_NULL(result, outPrecision); + return result.GetValue(); +} + +extern "C" DLLEXPORT void AddDec64Dec64Dec128RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, + uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalAdd(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_VOID_RETURN_NULL(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void AddDec128Dec128Dec128RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t xPrecision, + int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalAdd(Decimal128Wrapper(xHigh, xLow).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(yHigh, yLow).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_VOID_RETURN_NULL(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void AddDec64Dec128Dec128RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalAdd(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(yHigh, yLow).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_VOID_RETURN_NULL(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void AddDec128Dec64Dec128RetNull(bool *isNull, int64_t yHigh, uint64_t yLow, int32_t yPrecision, + int32_t yScale, int64_t x, int32_t xPrecision, int32_t xScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalAdd(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(yHigh, yLow).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_VOID_RETURN_NULL(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + + +// Decimal AddOperator NotReScale +extern "C" DLLEXPORT int64_t AddDec64Dec64Dec64NotReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale) +{ + Decimal64 result; + DecimalOperations::InternalDecimalAdd(Decimal64(x).SetScale(xScale), xScale, xPrecision, + Decimal64(y).SetScale(yScale), yScale, yPrecision, result); + CHECK_OVERFLOW_RETURN(result, outPrecision); + return result.GetValue(); +} + +extern "C" DLLEXPORT void AddDec64Dec64Dec128NotReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalAdd(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y).SetScale(yScale), yScale, yPrecision, result); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void AddDec128Dec128Dec128NotReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, + int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalAdd(Decimal128Wrapper(xHigh, xLow).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(yHigh, yLow).SetScale(yScale), yScale, yPrecision, result); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void AddDec64Dec128Dec128NotReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalAdd(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(yHigh, yLow).SetScale(yScale), yScale, yPrecision, result); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void AddDec128Dec64Dec128NotReScale(int64_t contextPtr, int64_t yHigh, uint64_t yLow, + int32_t yPrecision, int32_t yScale, int64_t x, int32_t xPrecision, int32_t xScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalAdd(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(yHigh, yLow).SetScale(yScale), yScale, yPrecision, result); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + + +// Decimal SubOperator NotReScale +extern "C" DLLEXPORT int64_t SubDec64Dec64Dec64NotReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale) +{ + Decimal64 result; + DecimalOperations::InternalDecimalSubtract(Decimal64(x).SetScale(xScale), xScale, xPrecision, + Decimal64(y).SetScale(yScale), yScale, yPrecision, result); + CHECK_OVERFLOW_RETURN(result, outPrecision); + return result.GetValue(); +} + +extern "C" DLLEXPORT void SubDec64Dec64Dec128NotReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalSubtract(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y).SetScale(yScale), yScale, yPrecision, result); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void SubDec128Dec128Dec128NotReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, + int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalSubtract(Decimal128Wrapper(xHigh, xLow).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(yHigh, yLow).SetScale(yScale), yScale, yPrecision, result); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void SubDec64Dec128Dec128NotReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalSubtract(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(yHigh, yLow).SetScale(yScale), yScale, yPrecision, result); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void SubDec128Dec64Dec128NotReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalSubtract(Decimal128Wrapper(xHigh, xLow).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y).SetScale(yScale), yScale, yPrecision, result); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +// Decimal MulOperator NotReScale +extern "C" DLLEXPORT int64_t MulDec64Dec64Dec64NotReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale) +{ + Decimal64 result; + DecimalOperations::InternalDecimalMultiply(Decimal64(x).SetScale(xScale), xScale, xPrecision, + Decimal64(y).SetScale(yScale), yScale, yPrecision, result); + CHECK_OVERFLOW_RETURN(result, outPrecision); + return result.GetValue(); +} + +extern "C" DLLEXPORT void MulDec64Dec64Dec128NotReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalMultiply(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y).SetScale(yScale), yScale, yPrecision, result); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void MulDec128Dec128Dec128NotReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, + int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalMultiply(Decimal128Wrapper(xHigh, xLow).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(yHigh, yLow).SetScale(yScale), yScale, yPrecision, result); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void MulDec64Dec128Dec128NotReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalMultiply(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(yHigh, yLow).SetScale(yScale), yScale, yPrecision, result); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void MulDec128Dec64Dec128NotReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalMultiply(Decimal128Wrapper(xHigh, xLow), xScale, xPrecision, + Decimal128Wrapper(y).SetScale(yScale), yScale, yPrecision, result); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +// Decimal DivOperation NotReScale +extern "C" DLLEXPORT int64_t DivDec64Dec64Dec64NotReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale) +{ + CHECK_DIVIDE_BY_ZERO_RETURN(y); + Decimal64 result; + DecimalOperations::InternalDecimalDivide(Decimal64(x).SetScale(xScale), xScale, xPrecision, + Decimal64(y).SetScale(yScale), yScale, yPrecision, result, outScale); + CHECK_OVERFLOW_RETURN(result, outPrecision); + return result.GetValue(); +} + +extern "C" DLLEXPORT int64_t DivDec64Dec128Dec64NotReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t yHigh, int64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale) +{ + Decimal128Wrapper divisor(yHigh, yLow); + CHECK_DIVIDE_BY_ZERO_RETURN(divisor); + Decimal128Wrapper result; + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + divisor.SetScale(yScale), yScale, yPrecision, result, outScale); + CHECK_OVERFLOW_RETURN(result, outPrecision); + return result.HighBits() < 0 ? -static_cast(result.LowBits()) : static_cast(result.LowBits()); +} + +extern "C" DLLEXPORT int64_t DivDec128Dec64Dec64NotReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale) +{ + Decimal128Wrapper divisor(y); + CHECK_DIVIDE_BY_ZERO_RETURN(divisor); + Decimal128Wrapper result; + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(xHigh, xLow).SetScale(xScale), xScale, xPrecision, + divisor.SetScale(yScale), yScale, yPrecision, result, outScale); + CHECK_OVERFLOW_RETURN(result, outPrecision); + return result.HighBits() < 0 ? -static_cast(result.LowBits()) : static_cast(result.LowBits()); +} + +extern "C" DLLEXPORT void DivDec64Dec64Dec128NotReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper divisor(y); + CHECK_DIVIDE_BY_ZERO(divisor); + Decimal128Wrapper result; + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + divisor.SetScale(yScale), yScale, yPrecision, result, outScale); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void DivDec128Dec128Dec128NotReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, + int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper divisor(yHigh, yLow); + CHECK_DIVIDE_BY_ZERO(divisor); + Decimal128Wrapper result; + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(xHigh, xLow).SetScale(xScale), xScale, xPrecision, + divisor.SetScale(yScale), yScale, yPrecision, result, outScale); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void DivDec64Dec128Dec128NotReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper divisor(yHigh, yLow); + CHECK_DIVIDE_BY_ZERO(divisor); + Decimal128Wrapper result; + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + divisor.SetScale(yScale), yScale, yPrecision, result, outScale); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void DivDec128Dec64Dec128NotReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper divisor(y); + CHECK_DIVIDE_BY_ZERO(divisor); + Decimal128Wrapper result; + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(xHigh, xLow).SetScale(xScale), xScale, xPrecision, + divisor.SetScale(yScale), yScale, yPrecision, result, outScale); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +// Decimal ModOperation NotReScale +extern "C" DLLEXPORT int64_t ModDec64Dec64Dec64NotReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale) +{ + CHECK_DIVIDE_BY_ZERO_RETURN(y); + Decimal64 result; + DecimalOperations::InternalDecimalMod(Decimal64(x).SetScale(xScale), xScale, xPrecision, + Decimal64(y).SetScale(yScale), yScale, yPrecision, result); + CHECK_OVERFLOW_RETURN(result, outPrecision); + return result.GetValue(); +} + +extern "C" DLLEXPORT int64_t ModDec64Dec128Dec64NotReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale) +{ + Decimal128Wrapper divisor(yHigh, yLow); + CHECK_DIVIDE_BY_ZERO_RETURN(divisor); + Decimal64 result; + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + divisor.SetScale(yScale), yScale, yPrecision, result); + CHECK_OVERFLOW_RETURN(result, outPrecision); + return result.GetValue(); +} + +extern "C" DLLEXPORT int64_t ModDec128Dec64Dec64NotReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale) +{ + Decimal128Wrapper divisor(y); + CHECK_DIVIDE_BY_ZERO_RETURN(divisor); + Decimal64 result; + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(xHigh, xLow).SetScale(xScale), xScale, xPrecision, + divisor.SetScale(yScale), yScale, yPrecision, result); + CHECK_OVERFLOW_RETURN(result, outPrecision); + return result.GetValue(); +} + +extern "C" DLLEXPORT void ModDec128Dec64Dec128NotReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper divisor(y); + CHECK_DIVIDE_BY_ZERO(divisor); + Decimal128Wrapper result; + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(xHigh, xLow).SetScale(xScale), xScale, xPrecision, + divisor.SetScale(yScale), yScale, yPrecision, result); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void ModDec128Dec128Dec128NotReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, + int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper divisor(yHigh, yLow); + CHECK_DIVIDE_BY_ZERO(divisor); + Decimal128Wrapper result; + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(xHigh, xLow).SetScale(xScale), xScale, xPrecision, + divisor.SetScale(yScale), yScale, yPrecision, result); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT int64_t ModDec128Dec128Dec64NotReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, + int32_t outPrecision, int32_t outScale) +{ + Decimal128Wrapper divisor(yHigh, yLow); + CHECK_DIVIDE_BY_ZERO_RETURN(divisor); + Decimal128Wrapper result; + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(xHigh, xLow).SetScale(xScale), xScale, xPrecision, + divisor.SetScale(yScale), yScale, yPrecision, result); + CHECK_OVERFLOW_RETURN(result, outPrecision); + + return result.HighBits() < 0 ? -static_cast(result.LowBits()) : static_cast(result.LowBits()); +} + +extern "C" DLLEXPORT void ModDec64Dec128Dec128NotReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper divisor(yHigh, yLow); + CHECK_DIVIDE_BY_ZERO(divisor); + Decimal128Wrapper result; + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + divisor.SetScale(yScale), yScale, yPrecision, result); + CHECK_OVERFLOW(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +// Decimal SubOperator +extern "C" DLLEXPORT int64_t SubDec64Dec64Dec64RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale) +{ + Decimal64 result; + DecimalOperations::InternalDecimalSubtract(Decimal64(x).SetScale(xScale), xScale, xPrecision, + Decimal64(y).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_RETURN_NULL(result, outPrecision); + return result.GetValue(); +} + +extern "C" DLLEXPORT void SubDec64Dec64Dec128RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, + uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalSubtract(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_VOID_RETURN_NULL(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void SubDec128Dec128Dec128RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t xPrecision, + int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalSubtract(Decimal128Wrapper(xHigh, xLow).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(yHigh, yLow).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_VOID_RETURN_NULL(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void SubDec64Dec128Dec128RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalSubtract(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(yHigh, yLow).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_VOID_RETURN_NULL(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void SubDec128Dec64Dec128RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalSubtract(Decimal128Wrapper(xHigh, xLow).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_VOID_RETURN_NULL(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +// Decimal MulOperator +extern "C" DLLEXPORT int64_t MulDec64Dec64Dec64RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale) +{ + Decimal64 result; + DecimalOperations::InternalDecimalMultiply(Decimal64(x).SetScale(xScale), xScale, xPrecision, + Decimal64(y).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_RETURN_NULL(result, outPrecision); + return result.GetValue(); +} + +extern "C" DLLEXPORT void MulDec64Dec64Dec128RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, + uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalMultiply(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_VOID_RETURN_NULL(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void MulDec128Dec128Dec128RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t xPrecision, + int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result = + Decimal128Wrapper(xHigh, xLow).MultiplyRoundUp(Decimal128Wrapper(yHigh, yLow), xScale + yScale - outScale); + CHECK_OVERFLOW_VOID_RETURN_NULL(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void MulDec64Dec128Dec128RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalMultiply(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(yHigh, yLow).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_VOID_RETURN_NULL(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void MulDec128Dec64Dec128RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalMultiply(Decimal128Wrapper(xHigh, xLow).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_VOID_RETURN_NULL(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +// Decimal DivOperation +extern "C" DLLEXPORT int64_t DivDec64Dec64Dec64RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale) +{ + Decimal64 result; + DecimalOperations::InternalDecimalDivide(Decimal64(x).SetScale(xScale), xScale, xPrecision, + Decimal64(y).SetScale(yScale), yScale, yPrecision, result, outScale); + result.ReScale(outScale); + CHECK_OVERFLOW_RETURN_NULL(result, outPrecision); + return result.GetValue(); +} + +extern "C" DLLEXPORT int64_t DivDec64Dec128Dec64RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t yHigh, int64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale) +{ + Decimal64 result; + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(yHigh, yLow).SetScale(yScale), yScale, yPrecision, result, outScale); + result.ReScale(outScale); + CHECK_OVERFLOW_RETURN_NULL(result, outPrecision); + return result.GetValue(); +} + +extern "C" DLLEXPORT int64_t DivDec128Dec64Dec64RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale) +{ + Decimal64 result; + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(xHigh, xLow).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y).SetScale(yScale), yScale, yPrecision, result, outScale); + result.ReScale(outScale); + CHECK_OVERFLOW_RETURN_NULL(result, outPrecision); + return result.GetValue(); +} + +extern "C" DLLEXPORT void DivDec64Dec64Dec128RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, + uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y).SetScale(yScale), yScale, yPrecision, result, outScale); + result.ReScale(outScale); + CHECK_OVERFLOW_VOID_RETURN_NULL(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void DivDec128Dec128Dec128RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t xPrecision, + int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(xHigh, xLow).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(yHigh, yLow).SetScale(yScale), yScale, yPrecision, result, outScale); + result.ReScale(outScale); + CHECK_OVERFLOW_VOID_RETURN_NULL(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void DivDec64Dec128Dec128RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(yHigh, yLow).SetScale(yScale), yScale, yPrecision, result, outScale); + result.ReScale(outScale); + CHECK_OVERFLOW_VOID_RETURN_NULL(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void DivDec128Dec64Dec128RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalDivide(Decimal128Wrapper(xHigh, xLow).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y).SetScale(yScale), yScale, yPrecision, result, outScale); + result.ReScale(outScale); + CHECK_OVERFLOW_VOID_RETURN_NULL(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +// Decimal ModOperation +extern "C" DLLEXPORT int64_t ModDec64Dec64Dec64RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale) +{ + Decimal64 result; + DecimalOperations::InternalDecimalMod(Decimal64(x).SetScale(xScale), xScale, xPrecision, + Decimal64(y).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_RETURN_NULL(result, outPrecision); + return result.GetValue(); +} + +extern "C" DLLEXPORT int64_t ModDec64Dec128Dec64RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t yHigh, int64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale) +{ + Decimal64 result; + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(yHigh, yLow).SetScale(yScale).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + return result.GetValue(); +} + +extern "C" DLLEXPORT int64_t ModDec128Dec64Dec64RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale) +{ + Decimal64 result; + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(xHigh, xLow).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y).SetScale(yScale).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + if (result.IsOverflow(outPrecision) != OpStatus::SUCCESS) { + *isNull = true; + return 0; + } + return result.GetValue(); +} + +extern "C" DLLEXPORT void ModDec128Dec64Dec128RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(xHigh, xLow).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(y).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_VOID_RETURN_NULL(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void ModDec128Dec128Dec128RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t xPrecision, + int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(xHigh, xLow).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(yHigh, yLow).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_VOID_RETURN_NULL(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT int64_t ModDec128Dec128Dec64RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t xPrecision, + int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale) +{ + Decimal64 result; + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(xHigh, xLow).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(yHigh, yLow).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_RETURN_NULL(result, outPrecision); + return result.GetValue(); +} + +extern "C" DLLEXPORT void ModDec64Dec128Dec128RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result; + DecimalOperations::InternalDecimalMod(Decimal128Wrapper(x).SetScale(xScale), xScale, xPrecision, + Decimal128Wrapper(yHigh, yLow).SetScale(yScale), yScale, yPrecision, result); + result.ReScale(outScale); + CHECK_OVERFLOW_VOID_RETURN_NULL(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT int64_t UnscaledValue64(int64_t x, int32_t precision, int32_t scale, bool isNull) +{ + return x; +} + +extern "C" DLLEXPORT int64_t MakeDecimal64(int64_t contextPtr, int64_t x, bool isNull, int32_t precision, int32_t scale) +{ + if (DecimalOperations::IsUnscaledLongOverflow(x, precision, scale)) { + std::ostringstream errorMessage; + errorMessage << "Unscaled value " << x << " out of Decimal(" << precision << ", " << scale << ") range"; + SetError(contextPtr, errorMessage.str()); + return 0; + } + return x; +} + +extern "C" DLLEXPORT int64_t MakeDecimal64RetNull(bool *isNull, int64_t x, int32_t precision, int32_t scale) +{ + if (DecimalOperations::IsUnscaledLongOverflow(x, precision, scale)) { + *isNull = true; + return 0; + } + *isNull = false; + return x; +} + +extern "C" DLLEXPORT void RoundDecimal128(int64_t contextPtr, int64_t xHigh, uint64_t xLow, int32_t xPrecision, + int32_t xScale, int32_t round, bool isNull, int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, + uint64_t *outLowPtr) +{ + if (isNull) { + *outHighPtr = 0; + *outLowPtr = 0; + return; + } + Decimal128Wrapper input(xHigh, xLow); + input.SetScale(xScale); + DecimalOperations::Round(input, outScale, round); + CHECK_OVERFLOW(input, outPrecision); + *outHighPtr = input.HighBits(); + *outLowPtr = input.LowBits(); +} + +extern "C" DLLEXPORT void RoundDecimal128WithoutRound(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, bool isNull, int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, + uint64_t *outLowPtr) +{ + if (isNull) { + *outHighPtr = 0; + *outLowPtr = 0; + return; + } + Decimal128Wrapper input(xHigh, xLow); + input.SetScale(xScale); + DecimalOperations::Round(input, outScale, 0); + CHECK_OVERFLOW(input, outPrecision); + *outHighPtr = input.HighBits(); + *outLowPtr = input.LowBits(); +} + +extern "C" DLLEXPORT int64_t RoundDecimal64(int64_t contextPtr, int64_t x, int32_t xPrecision, int32_t xScale, + int32_t round, bool isNull, int32_t outPrecision, int32_t outScale) +{ + if (isNull) { + return 0; + } + Decimal64 input(x); + input.SetScale(xScale); + DecimalOperations::Round(input, outScale, round); + CHECK_OVERFLOW_RETURN(input, outPrecision); + return input.GetValue(); +} + +extern "C" DLLEXPORT int64_t RoundDecimal64WithoutRound(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, bool isNull, int32_t outPrecision, int32_t outScale) +{ + if (isNull) { + return 0; + } + Decimal64 input(x); + input.SetScale(xScale); + DecimalOperations::Round(input, outScale, 0); + CHECK_OVERFLOW_RETURN(input, outPrecision); + return input.GetValue(); +} + +extern "C" DLLEXPORT void RoundDecimal128RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t xPrecision, + int32_t xScale, int32_t round, int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper input(xHigh, xLow); + input.SetScale(xScale); + DecimalOperations::Round(input, outScale, 0); + CHECK_OVERFLOW_VOID_RETURN_NULL(input, outPrecision); + *outHighPtr = input.HighBits(); + *outLowPtr = input.LowBits(); +} + +extern "C" DLLEXPORT int64_t RoundDecimal64RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int32_t round, int32_t outPrecision, int32_t outScale) +{ + Decimal64 input(x); + input.SetScale(xScale); + DecimalOperations::Round(input, outScale, round); + CHECK_OVERFLOW_RETURN_NULL(input, outPrecision); + return input.GetValue(); +} + +extern "C" DLLEXPORT int64_t GreatestDecimal64(int64_t contextPtr, int64_t xValue, int32_t xPrecision, int32_t xScale, + bool xIsNull, int64_t yValue, int32_t yPrecision, int32_t yScale, bool yIsNull, bool *retIsNull, + int32_t newPrecision, int32_t newScale) +{ + if (xIsNull && yIsNull) { + *retIsNull = true; + return 0; + } + if (xPrecision == yPrecision && xScale == yScale) { + if (xIsNull || (!yIsNull && xValue < yValue)) { + return yValue; + } + return xValue; + } + Decimal64 x(xValue); + x.SetScale(xScale); + Decimal64 y(yValue); + y.SetScale(yScale); + if (xIsNull || (!yIsNull && x.Compare(y) < 0)) { + y.ReScale(newScale); + CHECK_OVERFLOW_RETURN(y, newPrecision); + return y.GetValue(); + } + x.ReScale(newScale); + CHECK_OVERFLOW_RETURN(x, newPrecision); + return x.GetValue(); +} + +extern "C" DLLEXPORT void GreatestDecimal128(int64_t contextPtr, int64_t xHigh, uint64_t xLow, int32_t xPrecision, + int32_t xScale, bool xIsNull, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, bool yIsNull, + bool *retIsNull, int32_t newPrecision, int32_t newScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + if (xIsNull && yIsNull) { + *retIsNull = true; + *outHighPtr = 0; + *outLowPtr = 0; + return; + } + if (xPrecision == yPrecision && xScale == yScale) { + if (xIsNull || (!yIsNull && Decimal128(xHigh, xLow) < Decimal128(yHigh, yLow))) { + *outHighPtr = yHigh; + *outLowPtr = yLow; + return; + } + *outHighPtr = xHigh; + *outLowPtr = xLow; + return; + } + Decimal128Wrapper x(xHigh, xLow); + x.SetScale(xScale); + Decimal128Wrapper y(yHigh, yLow); + y.SetScale(yScale); + if (xIsNull || (!yIsNull && x.Compare(y) < 0)) { + y.ReScale(newScale); + CHECK_OVERFLOW(y, newPrecision); + *outHighPtr = y.HighBits(); + *outLowPtr = y.LowBits(); + return; + } + x.ReScale(newScale); + CHECK_OVERFLOW(x, newPrecision); + *outHighPtr = x.HighBits(); + *outLowPtr = x.LowBits(); +} + +extern "C" DLLEXPORT int64_t GreatestDecimal64RetNull(bool *isNull, int64_t xValue, int32_t xPrecision, int32_t xScale, + bool xIsNull, int64_t yValue, int32_t yPrecision, int32_t yScale, bool yIsNull, bool *retIsNull, + int32_t newPrecision, int32_t newScale) +{ + if (xIsNull && yIsNull) { + *retIsNull = true; + return 0; + } + if (xPrecision == yPrecision && xScale == yScale) { + if (xIsNull || (!yIsNull && xValue < yValue)) { + return yValue; + } + return xValue; + } + Decimal64 x(xValue); + x.SetScale(xScale); + Decimal64 y(yValue); + y.SetScale(yScale); + if (xIsNull || (!yIsNull && x.Compare(y) < 0)) { + y.ReScale(newScale); + CHECK_OVERFLOW_RETURN_NULL(y, newPrecision); + return y.GetValue(); + } + x.ReScale(newScale); + CHECK_OVERFLOW_RETURN_NULL(x, newPrecision); + return x.GetValue(); +} + +extern "C" DLLEXPORT void GreatestDecimal128RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t xPrecision, + int32_t xScale, bool xIsNull, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, bool yIsNull, + bool *retIsNull, int32_t newPrecision, int32_t newScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + if (xIsNull && yIsNull) { + *retIsNull = true; + *outHighPtr = 0; + *outLowPtr = 0; + return; + } + if (xPrecision == yPrecision && xScale == yScale) { + if (xIsNull || (!yIsNull && Decimal128(xHigh, xLow) < Decimal128(yHigh, yLow))) { + *outHighPtr = yHigh; + *outLowPtr = yLow; + return; + } + *outHighPtr = xHigh; + *outLowPtr = xLow; + return; + } + Decimal128Wrapper x(xHigh, xLow); + x.SetScale(xScale); + Decimal128Wrapper y(yHigh, yLow); + y.SetScale(yScale); + if (xIsNull || (!yIsNull && x.Compare(y) < 0)) { + y.ReScale(newScale); + CHECK_OVERFLOW_VOID_RETURN_NULL(y, newPrecision); + *outHighPtr = y.HighBits(); + *outLowPtr = y.LowBits(); + return; + } + x.ReScale(newScale); + CHECK_OVERFLOW_VOID_RETURN_NULL(x, newPrecision); + *outHighPtr = x.HighBits(); + *outLowPtr = x.LowBits(); +} +} \ No newline at end of file diff --git a/core/src/codegen/functions/decimal_arithmetic_functions.h b/core/src/codegen/functions/decimal_arithmetic_functions.h new file mode 100644 index 0000000..8343d2b --- /dev/null +++ b/core/src/codegen/functions/decimal_arithmetic_functions.h @@ -0,0 +1,426 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description: registry math function name + */ +#ifndef OMNI_RUNTIME_DECIMAL_ARITHMETIC_FUNCTIONS_H +#define OMNI_RUNTIME_DECIMAL_ARITHMETIC_FUNCTIONS_H + +#include +#include +#include +#include +#include "type/decimal128.h" +#include "codegen/context_helper.h" +#include "type/decimal_operations.h" +#include "util/config_util.h" +#include "type/data_type.h" + +namespace omniruntime::codegen::function { +// All extern functions go here temporarily +#ifdef _WIN32 +#define DLLEXPORT __declspec(dllexport) +#else +#define DLLEXPORT +#endif + +extern "C" DLLEXPORT int32_t Decimal128Compare(int64_t xHigh, uint64_t xLow, int32_t xPrecision, int32_t xScale, + int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, bool isNull); + +extern "C" DLLEXPORT void AbsDecimal128(int64_t xHigh, uint64_t xLow, int32_t xPrecision, int32_t xScale, bool isNull, + int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT int32_t Decimal64Compare(int64_t x, int32_t xPrecision, int32_t xScale, int64_t y, + int32_t yPrecision, int32_t yScale, bool isNull); + +extern "C" DLLEXPORT int64_t AbsDecimal64(int64_t x, int32_t xPrecision, int32_t xScale, bool isNull, + int32_t outPrecision, int32_t outScale); + +// Decimal AddOperator ReScale +extern "C" DLLEXPORT int64_t AddDec64Dec64Dec64ReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale); + +extern "C" DLLEXPORT void AddDec64Dec64Dec128ReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, + uint64_t *outLowPtr); + +extern "C" DLLEXPORT void AddDec128Dec128Dec128ReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, + int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void AddDec64Dec128Dec128ReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void AddDec128Dec64Dec128ReScale(int64_t contextPtr, int64_t yHigh, uint64_t yLow, + int32_t yPrecision, int32_t yScale, int64_t x, int32_t xPrecision, int32_t xScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +// Decimal SubOperation ReScale +extern "C" DLLEXPORT int64_t SubDec64Dec64Dec64ReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale); + +extern "C" DLLEXPORT void SubDec64Dec64Dec128ReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, + uint64_t *outLowPtr); + +extern "C" DLLEXPORT void SubDec128Dec128Dec128ReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, + int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void SubDec64Dec128Dec128ReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void SubDec128Dec64Dec128ReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +// Decimal MulOperation ReScale +extern "C" DLLEXPORT int64_t MulDec64Dec64Dec64ReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale); + +extern "C" DLLEXPORT void MulDec64Dec64Dec128ReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, + uint64_t *outLowPtr); + +extern "C" DLLEXPORT void MulDec128Dec128Dec128ReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, + int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void MulDec64Dec128Dec128ReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void MulDec128Dec64Dec128ReScale(int64_t contextPtr, int64_t yHigh, uint64_t yLow, + int32_t yPrecision, int32_t yScale, int64_t x, int32_t xPrecision, int32_t xScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +// Decimal DivOperation ReScale +extern "C" DLLEXPORT int64_t DivDec64Dec64Dec64ReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale); + +extern "C" DLLEXPORT int64_t DivDec64Dec128Dec64ReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t yHigh, int64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale); + +extern "C" DLLEXPORT int64_t DivDec128Dec64Dec64ReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale); + +extern "C" DLLEXPORT void DivDec64Dec64Dec128ReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, + uint64_t *outLowPtr); + +extern "C" DLLEXPORT void DivDec128Dec128Dec128ReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, + int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void DivDec64Dec128Dec128ReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void DivDec128Dec64Dec128ReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +// Decimal ModOperation ReScale +extern "C" DLLEXPORT int64_t ModDec64Dec64Dec64ReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale); + +extern "C" DLLEXPORT int64_t ModDec64Dec128Dec64ReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale); + +extern "C" DLLEXPORT int64_t ModDec128Dec64Dec64ReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale); + +extern "C" DLLEXPORT void ModDec128Dec64Dec128ReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void ModDec128Dec128Dec128ReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, + int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT int64_t ModDec128Dec128Dec64ReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, + int32_t outPrecision, int32_t outScale); + +extern "C" DLLEXPORT void ModDec64Dec128Dec128ReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr); + +// Decimal AddOperator NotReScale +extern "C" DLLEXPORT int64_t AddDec64Dec64Dec64NotReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale); + +extern "C" DLLEXPORT void AddDec64Dec64Dec128NotReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void AddDec128Dec128Dec128NotReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, + int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void AddDec64Dec128Dec128NotReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void AddDec128Dec64Dec128NotReScale(int64_t contextPtr, int64_t yHigh, uint64_t yLow, + int32_t yPrecision, int32_t yScale, int64_t x, int32_t xPrecision, int32_t xScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +// Decimal SubOperation NotReScale +extern "C" DLLEXPORT int64_t SubDec64Dec64Dec64NotReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale); + +extern "C" DLLEXPORT void SubDec64Dec64Dec128NotReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void SubDec128Dec128Dec128NotReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, + int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void SubDec64Dec128Dec128NotReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void SubDec128Dec64Dec128NotReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +// Decimal MulOperation NotReScale +extern "C" DLLEXPORT int64_t MulDec64Dec64Dec64NotReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale); + +extern "C" DLLEXPORT void MulDec64Dec64Dec128NotReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void MulDec128Dec128Dec128NotReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, + int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void MulDec64Dec128Dec128NotReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void MulDec128Dec64Dec128NotReScale(int64_t contextPtr, int64_t yHigh, uint64_t yLow, + int32_t yPrecision, int32_t yScale, int64_t x, int32_t xPrecision, int32_t xScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +// Decimal DivOperation NotReScale +extern "C" DLLEXPORT int64_t DivDec64Dec64Dec64NotReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale); + +extern "C" DLLEXPORT int64_t DivDec64Dec128Dec64NotReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t yHigh, int64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale); + +extern "C" DLLEXPORT int64_t DivDec128Dec64Dec64NotReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale); + +extern "C" DLLEXPORT void DivDec64Dec64Dec128NotReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void DivDec128Dec128Dec128NotReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, + int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void DivDec64Dec128Dec128NotReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void DivDec128Dec64Dec128NotReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +// Decimal ModOperation NotReScale +extern "C" DLLEXPORT int64_t ModDec64Dec64Dec64NotReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale); + +extern "C" DLLEXPORT int64_t ModDec64Dec128Dec64NotReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale); + +extern "C" DLLEXPORT int64_t ModDec128Dec64Dec64NotReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale); + +extern "C" DLLEXPORT void ModDec128Dec64Dec128NotReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void ModDec128Dec128Dec128NotReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, + int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT int64_t ModDec128Dec128Dec64NotReScale(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, + int32_t outPrecision, int32_t outScale); + +extern "C" DLLEXPORT void ModDec64Dec128Dec128NotReScale(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +// Return Null +extern "C" DLLEXPORT int64_t AddDec64Dec64Dec64RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale); + +extern "C" DLLEXPORT void AddDec64Dec64Dec128RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, + uint64_t *outLowPtr); + +extern "C" DLLEXPORT void AddDec128Dec128Dec128RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t xPrecision, + int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void AddDec64Dec128Dec128RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void AddDec128Dec64Dec128RetNull(bool *isNull, int64_t yHigh, uint64_t yLow, int32_t yPrecision, + int32_t yScale, int64_t x, int32_t xPrecision, int32_t xScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr); + +// Decimal SubOperator +extern "C" DLLEXPORT int64_t SubDec64Dec64Dec64RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale); + +extern "C" DLLEXPORT void SubDec64Dec64Dec128RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, + uint64_t *outLowPtr); + +extern "C" DLLEXPORT void SubDec128Dec128Dec128RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t xPrecision, + int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void SubDec64Dec128Dec128RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void SubDec128Dec64Dec128RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr); + +// Decimal MulOperator +extern "C" DLLEXPORT int64_t MulDec64Dec64Dec64RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale); + +extern "C" DLLEXPORT void MulDec64Dec64Dec128RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, + uint64_t *outLowPtr); + +extern "C" DLLEXPORT void MulDec128Dec128Dec128RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t xPrecision, + int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void MulDec64Dec128Dec128RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void MulDec128Dec64Dec128RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr); + +// Decimal DivOperation +extern "C" DLLEXPORT int64_t DivDec64Dec64Dec64RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale); + +extern "C" DLLEXPORT int64_t DivDec64Dec128Dec64RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t yHigh, int64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale); + +extern "C" DLLEXPORT int64_t DivDec128Dec64Dec64RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale); + +extern "C" DLLEXPORT void DivDec64Dec64Dec128RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, + uint64_t *outLowPtr); + +extern "C" DLLEXPORT void DivDec128Dec128Dec128RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t xPrecision, + int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void DivDec64Dec128Dec128RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void DivDec128Dec64Dec128RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr); + +// Decimal ModOperation +extern "C" DLLEXPORT int64_t ModDec64Dec64Dec64RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale); + +extern "C" DLLEXPORT int64_t ModDec64Dec128Dec64RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t yHigh, int64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale); + +extern "C" DLLEXPORT int64_t ModDec128Dec64Dec64RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale); + +extern "C" DLLEXPORT void ModDec128Dec64Dec128RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t xPrecision, + int32_t xScale, int64_t y, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void ModDec128Dec128Dec128RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t xPrecision, + int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT int64_t ModDec128Dec128Dec64RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t xPrecision, + int32_t xScale, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, + int32_t outScale); + +extern "C" DLLEXPORT void ModDec64Dec128Dec128RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT int64_t UnscaledValue64(int64_t x, int32_t precision, int32_t scale, bool isNull); + +extern "C" DLLEXPORT int64_t MakeDecimal64(int64_t contextPtr, int64_t x, bool isNull, int32_t precision, + int32_t scale); + +extern "C" DLLEXPORT int64_t MakeDecimal64RetNull(bool *isNull, int64_t x, int32_t precision, int32_t scale); + +extern "C" DLLEXPORT void RoundDecimal128(int64_t contextPtr, int64_t xHigh, uint64_t xLow, int32_t xPrecision, + int32_t xScale, int32_t round, bool isNull, int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, + uint64_t *outLowPtr); + +extern "C" DLLEXPORT void RoundDecimal128WithoutRound(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t xPrecision, int32_t xScale, bool isNull, int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, + uint64_t *outLowPtr); + +extern "C" DLLEXPORT int64_t RoundDecimal64(int64_t contextPtr, int64_t x, int32_t xPrecision, int32_t xScale, + int32_t round, bool isNull, int32_t outPrecision, int32_t outScale); + +extern "C" DLLEXPORT int64_t RoundDecimal64WithoutRound(int64_t contextPtr, int64_t x, int32_t xPrecision, + int32_t xScale, bool isNull, int32_t outPrecision, int32_t outScale); + +extern "C" DLLEXPORT void RoundDecimal128RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t xPrecision, + int32_t xScale, int32_t round, int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT int64_t RoundDecimal64RetNull(bool *isNull, int64_t x, int32_t xPrecision, int32_t xScale, + int32_t round, int32_t outPrecision, int32_t outScale); + +extern "C" DLLEXPORT int64_t GreatestDecimal64(int64_t contextPtr, int64_t xValue, int32_t xPrecision, int32_t xScale, + bool xIsNull, int64_t yValue, int32_t yPrecision, int32_t yScale, bool yIsNull, bool *retIsNull, + int32_t newPrecision, int32_t newScale); + +extern "C" DLLEXPORT void GreatestDecimal128(int64_t contextPtr, int64_t xHigh, uint64_t xLow, int32_t xPrecision, + int32_t xScale, bool xIsNull, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, bool yIsNull, + bool *retIsNull, int32_t newPrecision, int32_t newScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT int64_t GreatestDecimal64RetNull(bool *isNull, int64_t xValue, int32_t xPrecision, int32_t xScale, + bool xIsNull, int64_t yValue, int32_t yPrecision, int32_t yScale, bool yIsNull, bool *retIsNull, + int32_t newPrecision, int32_t newScale); + +extern "C" DLLEXPORT void GreatestDecimal128RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t xPrecision, + int32_t xScale, bool xIsNull, int64_t yHigh, uint64_t yLow, int32_t yPrecision, int32_t yScale, bool yIsNull, + bool *retIsNull, int32_t newPrecision, int32_t newScale, int64_t *outHighPtr, uint64_t *outLowPtr); +} + +#endif // OMNI_RUNTIME_DECIMAL_ARITHMETIC_FUNCTIONS_H diff --git a/core/src/codegen/functions/decimal_cast_functions.cpp b/core/src/codegen/functions/decimal_cast_functions.cpp new file mode 100644 index 0000000..6701368 --- /dev/null +++ b/core/src/codegen/functions/decimal_cast_functions.cpp @@ -0,0 +1,728 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: batch decimal functions implementation + */ + +#include "decimal_cast_functions.h" + +namespace omniruntime::codegen::function { + +// Cast Function +extern "C" DLLEXPORT int64_t CastDecimal64To64(int64_t contextPtr, int64_t x, int32_t precision, int32_t scale, + bool isNull, int32_t newPrecision, int32_t newScale) +{ + if (isNull) { + return 0; + } + Decimal64 result(x); + result.SetScale(scale).ReScale(newScale); + if (result.IsOverflow(newPrecision) != OpStatus::SUCCESS) { + SetError(contextPtr, CastErrorMessage(OMNI_DECIMAL64, OMNI_DECIMAL64, x, OpStatus::SUCCESS, precision, scale, + newPrecision, newScale)); + return 0; + } + return result.GetValue(); +} + +extern "C" DLLEXPORT void CastDecimal128To128(int64_t contextPtr, int64_t xHigh, uint64_t xLow, int32_t precision, + int32_t scale, bool isNull, int32_t newPrecision, int32_t newScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + if (isNull) { + *outHighPtr = 0; + *outLowPtr = 0; + return; + } + Decimal128Wrapper result(xHigh, xLow); + result.SetScale(scale).ReScale(newScale); + if (result.IsOverflow(newPrecision) != OpStatus::SUCCESS) { + SetError(contextPtr, CastErrorMessage(OMNI_DECIMAL128, OMNI_DECIMAL128, ((int128_t(xHigh) << 64) + xLow), + OpStatus::SUCCESS, precision, scale, newPrecision, newScale)); + return; + } + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void CastDecimal64To128(int64_t contextPtr, int64_t x, int32_t precision, int32_t scale, + bool isNull, int32_t newPrecision, int32_t newScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + if (isNull) { + *outHighPtr = 0; + *outLowPtr = 0; + return; + } + Decimal128Wrapper result(x); + result.SetScale(scale).ReScale(newScale); + if (result.IsOverflow(newPrecision) != OpStatus::SUCCESS) { + SetError(contextPtr, CastErrorMessage(OMNI_DECIMAL64, OMNI_DECIMAL128, x, OpStatus::SUCCESS, precision, scale, + newPrecision, newScale)); + return; + } + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT int64_t CastDecimal128To64(int64_t contextPtr, int64_t xHigh, uint64_t xLow, int32_t precision, + int32_t scale, bool isNull, int32_t newPrecision, int32_t newScale) +{ + if (isNull) { + return 0; + } + Decimal64 result(Decimal128Wrapper(xHigh, xLow).SetScale(scale).ReScale(newScale)); + if (result.IsOverflow(newPrecision) != OpStatus::SUCCESS) { + SetError(contextPtr, CastErrorMessage(OMNI_DECIMAL128, OMNI_DECIMAL64, (int128_t(xHigh) << 64 | xLow), + OpStatus::SUCCESS, precision, scale, newPrecision, newScale)); + return 0; + } + return result.GetValue(); +} + +extern "C" DLLEXPORT int64_t CastIntToDecimal64(int64_t contextPtr, int32_t x, bool isNull, int32_t precision, + int32_t scale) +{ + if (isNull) { + return 0; + } + Decimal64 result(x); + result.ReScale(scale); + if (result.IsOverflow(precision) != OpStatus::SUCCESS) { + SetError(contextPtr, CastErrorMessage(OMNI_INT, OMNI_DECIMAL64, x, OpStatus::SUCCESS, precision, scale)); + return 0; + } + return result.GetValue(); +} + +extern "C" DLLEXPORT int64_t CastInt16ToDecimal64(int64_t contextPtr, int16_t x, bool isNull, int32_t precision, + int32_t scale) +{ + if (isNull) { + return 0; + } + Decimal64 result(x); + result.ReScale(scale); + if (result.IsOverflow(precision) != OpStatus::SUCCESS) { + SetError(contextPtr, CastErrorMessage(OMNI_SHORT, OMNI_DECIMAL64, x, OpStatus::SUCCESS, precision, scale)); + return 0; + } + return result.GetValue(); +} + +extern "C" DLLEXPORT int64_t CastInt8ToDecimal64(int64_t contextPtr, int8_t x, bool isNull, int32_t precision, + int32_t scale) +{ + if (isNull) { + return 0; + } + Decimal64 result(x); + result.ReScale(scale); + if (result.IsOverflow(precision) != OpStatus::SUCCESS) { + SetError(contextPtr, CastErrorMessage(OMNI_BYTE, OMNI_DECIMAL64, x, OpStatus::SUCCESS, precision, scale)); + return 0; + } + return result.GetValue(); +} + +extern "C" DLLEXPORT int64_t CastLongToDecimal64(int64_t contextPtr, int64_t x, bool isNull, int32_t outPrecision, + int32_t outScale) +{ + if (isNull) { + return 0; + } + Decimal64 result(x); + result.ReScale(outScale); + if (result.IsOverflow(outPrecision) != OpStatus::SUCCESS) { + SetError(contextPtr, CastErrorMessage(OMNI_LONG, OMNI_DECIMAL64, x, OpStatus::SUCCESS, outPrecision, outScale)); + return 0; + } + return result.GetValue(); +} + +extern "C" DLLEXPORT int64_t CastDoubleToDecimal64(int64_t contextPtr, double x, bool isNull, int32_t outPrecision, + int32_t outScale) +{ + if (isNull) { + return 0; + } + Decimal64 result(x); + result.ReScale(outScale); + if (result.IsOverflow(outPrecision) != OpStatus::SUCCESS) { + SetError(contextPtr, + CastErrorMessage(OMNI_DOUBLE, OMNI_DECIMAL64, x, OpStatus::SUCCESS, outPrecision, outScale)); + return 0; + } + return result.GetValue(); +} + +extern "C" DLLEXPORT void CastIntToDecimal128(int64_t contextPtr, int32_t x, bool isNull, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + if (isNull) { + *outHighPtr = 0; + *outLowPtr = 0; + return; + } + Decimal128Wrapper result(x); + result.ReScale(outScale); + if (result.IsOverflow(outPrecision) != OpStatus::SUCCESS) { + SetError(contextPtr, + CastErrorMessage(OMNI_INT, OMNI_DECIMAL128, x, OpStatus::OP_OVERFLOW, outPrecision, outScale)); + return; + } + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void CastInt16ToDecimal128(int64_t contextPtr, int16_t x, bool isNull, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + if (isNull) { + *outHighPtr = 0; + *outLowPtr = 0; + return; + } + Decimal128Wrapper result(x); + result.ReScale(outScale); + if (result.IsOverflow(outPrecision) != OpStatus::SUCCESS) { + SetError(contextPtr, + CastErrorMessage(OMNI_SHORT, OMNI_DECIMAL128, x, OpStatus::OP_OVERFLOW, outPrecision, outScale)); + return; + } + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void CastInt8ToDecimal128(int64_t contextPtr, int8_t x, bool isNull, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + if (isNull) { + *outHighPtr = 0; + *outLowPtr = 0; + return; + } + Decimal128Wrapper result(x); + result.ReScale(outScale); + if (result.IsOverflow(outPrecision) != OpStatus::SUCCESS) { + SetError(contextPtr, + CastErrorMessage(OMNI_BYTE, OMNI_DECIMAL128, x, OpStatus::OP_OVERFLOW, outPrecision, outScale)); + return; + } + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void CastLongToDecimal128(int64_t contextPtr, int64_t x, bool isNull, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + if (isNull) { + return; + } + Decimal128Wrapper result(x); + result.ReScale(outScale); + if (result.IsOverflow(outPrecision) != OpStatus::SUCCESS) { + SetError(contextPtr, + CastErrorMessage(OMNI_LONG, OMNI_DECIMAL128, x, OpStatus::OP_OVERFLOW, outPrecision, outScale)); + return; + } + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void CastDoubleToDecimal128(int64_t contextPtr, double x, bool isNull, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + if (isNull) { + *outHighPtr = 0; + *outLowPtr = 0; + return; + } + Decimal128Wrapper result(x); + result.ReScale(outScale); + if (result.IsOverflow(outPrecision) != OpStatus::SUCCESS) { + SetError(contextPtr, + CastErrorMessage(OMNI_DOUBLE, OMNI_DECIMAL128, x, OpStatus::OP_OVERFLOW, outPrecision, outScale)); + return; + } + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT int32_t CastDecimal64ToIntHalfUp(int64_t contextPtr, int64_t x, int32_t precision, int32_t scale, + bool isNull) +{ + if (isNull) { + return 0; + } + + int32_t result; + try { + result = static_cast(Decimal64(x).SetScale(scale)); + } catch (std::overflow_error &e) { + SetError(contextPtr, CastErrorMessage(OMNI_DECIMAL64, OMNI_INT, x, OpStatus::SUCCESS, precision, scale)); + return 0; + } + return result; +} + +extern "C" DLLEXPORT int16_t CastDecimal64ToInt16HalfUp(int64_t contextPtr, int64_t x, int32_t precision, int32_t scale, + bool isNull) +{ + if (isNull) { + return 0; + } + + int16_t result; + try { + result = static_cast(Decimal64(x).SetScale(scale)); + } catch (std::overflow_error &e) { + SetError(contextPtr, CastErrorMessage(OMNI_DECIMAL64, OMNI_SHORT, x, OpStatus::SUCCESS, precision, scale)); + return 0; + } + return result; +} + +extern "C" DLLEXPORT int8_t CastDecimal64ToInt8HalfUp(int64_t contextPtr, int64_t x, int32_t precision, int32_t scale, + bool isNull) +{ + if (isNull) { + return 0; + } + + int8_t result; + try { + result = static_cast(Decimal64(x).SetScale(scale)); + } catch (std::overflow_error &e) { + SetError(contextPtr, CastErrorMessage(OMNI_DECIMAL64, OMNI_BYTE, x, OpStatus::SUCCESS, precision, scale)); + return 0; + } + return result; +} + +extern "C" DLLEXPORT double CastDecimal64ToDoubleHalfUp(int64_t x, int32_t precision, int32_t scale, bool isNull) +{ + if (isNull) { + return 0; + } + double result = static_cast(Decimal64(x).SetScale(scale)); + return result; +} + +extern "C" DLLEXPORT int32_t CastDecimal128ToIntHalfUp(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t precision, int32_t scale, bool isNull) +{ + if (isNull) { + return 0; + } + int32_t result; + try { + result = static_cast(Decimal128Wrapper(xHigh, xLow).SetScale(scale)); + } catch (std::overflow_error &e) { + SetError(contextPtr, CastErrorMessage(OMNI_DECIMAL128, OMNI_INT, (int128_t(xHigh) << 64 | xLow), + OpStatus::OP_OVERFLOW, precision, scale)); + return 0; + } + return result; +} + +extern "C" DLLEXPORT int16_t CastDecimal128ToInt16HalfUp(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t precision, int32_t scale, bool isNull) +{ + if (isNull) { + return 0; + } + int16_t result; + try { + result = static_cast(Decimal128Wrapper(xHigh, xLow).SetScale(scale)); + } catch (std::overflow_error &e) { + SetError(contextPtr, CastErrorMessage(OMNI_DECIMAL128, OMNI_SHORT, (int128_t(xHigh) << 64 | xLow), + OpStatus::OP_OVERFLOW, precision, scale)); + return 0; + } + return result; +} + +extern "C" DLLEXPORT int8_t CastDecimal128ToInt8HalfUp(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t precision, int32_t scale, bool isNull) +{ + if (isNull) { + return 0; + } + int8_t result; + try { + result = static_cast(Decimal128Wrapper(xHigh, xLow).SetScale(scale)); + } catch (std::overflow_error &e) { + SetError(contextPtr, CastErrorMessage(OMNI_DECIMAL128, OMNI_BYTE, (int128_t(xHigh) << 64 | xLow), + OpStatus::OP_OVERFLOW, precision, scale)); + return 0; + } + return result; +} + +extern "C" DLLEXPORT int64_t CastDecimal128ToLongHalfUp(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t precision, int32_t scale, bool isNull) +{ + if (isNull) { + return 0; + } + int64_t result; + try { + result = static_cast(Decimal128Wrapper(xHigh, xLow).SetScale(scale)); + } catch (std::overflow_error &e) { + SetError(contextPtr, CastErrorMessage(OMNI_DECIMAL128, OMNI_LONG, (int128_t(xHigh) << 64 | xLow), + OpStatus::OP_OVERFLOW, precision, scale)); + return 0; + } + return result; +} + +extern "C" DLLEXPORT double CastDecimal128ToDoubleHalfUp(int64_t high, uint64_t low, int32_t precision, int32_t scale, + bool isNull) +{ + if (isNull) { + return 0.0; + } + Decimal128Wrapper input(high, low); + + return (double)input.SetScale(scale); +} + +extern "C" DLLEXPORT int32_t CastDecimal64ToIntDown(int64_t contextPtr, int64_t x, int32_t precision, int32_t scale, + bool isNull) +{ + if (isNull) { + return 0; + } + return static_cast(Decimal64(x).ReScale(-scale, RoundingMode::ROUND_FLOOR).GetValue()); +} + +extern "C" DLLEXPORT int16_t CastDecimal64ToInt16Down(int64_t contextPtr, int64_t x, int32_t precision, int32_t scale, + bool isNull) +{ + if (isNull) { + return 0; + } + return static_cast(Decimal64(x).ReScale(-scale, RoundingMode::ROUND_FLOOR).GetValue()); +} + +extern "C" DLLEXPORT int8_t CastDecimal64ToInt8Down(int64_t contextPtr, int64_t x, int32_t precision, int32_t scale, + bool isNull) +{ + if (isNull) { + return 0; + } + return static_cast(Decimal64(x).ReScale(-scale, RoundingMode::ROUND_FLOOR).GetValue()); +} + +extern "C" DLLEXPORT int64_t CastDecimal64ToLongDown(int64_t x, int32_t precision, int32_t scale, bool isNull) +{ + if (isNull) { + return 0; + } + int64_t scaledValue = Decimal64(x).SetScale(scale).ReScale(0, RoundingMode::ROUND_FLOOR).GetValue(); + return scaledValue; +} + + +extern "C" DLLEXPORT int64_t CastDecimal64ToLongHalfUp(int64_t x, int32_t precision, int32_t scale, bool isNull) +{ + if (isNull) { + return 0; + } + int64_t result = static_cast(Decimal64(x).SetScale(scale)); + return result; +} + +extern "C" DLLEXPORT double CastDecimal64ToDoubleDown(int64_t x, int32_t precision, int32_t scale, bool isNull) +{ + if (isNull) { + return 0; + } + std::string doubleString = Decimal64(x).SetScale(scale).ToString(); + return stod(doubleString); +} + +extern "C" DLLEXPORT int32_t CastDecimal128ToIntDown(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t precision, int32_t scale, bool isNull) +{ + if (isNull) { + return 0; + } + int32_t result; + try { + result = static_cast(Decimal128Wrapper(xHigh, xLow) + .ReScale(-scale, RoundingMode::ROUND_FLOOR) + .ToInt128()); + } catch (std::overflow_error &e) { + SetError(contextPtr, CastErrorMessage(OMNI_DECIMAL128, OMNI_INT, (int128_t(xHigh) << 64 | xLow), + OpStatus::OP_OVERFLOW, precision, scale)); + return 0; + } + return result; +} + +extern "C" DLLEXPORT int16_t CastDecimal128ToInt16Down(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t precision, int32_t scale, bool isNull) +{ + if (isNull) { + return 0; + } + int16_t result; + try { + result = static_cast(Decimal128Wrapper(xHigh, xLow) + .ReScale(-scale, RoundingMode::ROUND_FLOOR) + .ToInt128()); + } catch (std::overflow_error &e) { + SetError(contextPtr, CastErrorMessage(OMNI_DECIMAL128, OMNI_SHORT, (int128_t(xHigh) << 64 | xLow), + OpStatus::OP_OVERFLOW, precision, scale)); + return 0; + } + return result; +} + +extern "C" DLLEXPORT int8_t CastDecimal128ToInt8Down(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t precision, int32_t scale, bool isNull) +{ + if (isNull) { + return 0; + } + int8_t result; + try { + result = static_cast(Decimal128Wrapper(xHigh, xLow) + .ReScale(-scale, RoundingMode::ROUND_FLOOR) + .ToInt128()); + } catch (std::overflow_error &e) { + SetError(contextPtr, CastErrorMessage(OMNI_DECIMAL128, OMNI_BYTE, (int128_t(xHigh) << 64 | xLow), + OpStatus::OP_OVERFLOW, precision, scale)); + return 0; + } + return result; +} + +extern "C" DLLEXPORT int64_t CastDecimal128ToLongDown(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t precision, int32_t scale, bool isNull) +{ + if (isNull) { + return 0; + } + int64_t result; + try { + result = static_cast(Decimal128Wrapper(xHigh, xLow) + .ReScale(-scale, RoundingMode::ROUND_FLOOR) + .ToInt128()); + } catch (std::overflow_error &e) { + SetError(contextPtr, CastErrorMessage(OMNI_DECIMAL128, OMNI_LONG, (int128_t(xHigh) << 64 | xLow), + OpStatus::OP_OVERFLOW, precision, scale)); + return 0; + } + return result; +} + +extern "C" DLLEXPORT double CastDecimal128ToDoubleDown(int64_t high, uint64_t low, int32_t precision, int32_t scale, + bool isNull) +{ + if (isNull) { + return 0.0; + } + Decimal128Wrapper input(high, low); + return (double)input / DOUBLE_10_POW[scale]; +} + +// Cast Function +extern "C" DLLEXPORT int64_t CastDecimal64To64RetNull(bool *isNull, int64_t x, int32_t precision, int32_t scale, + int32_t newPrecision, int32_t newScale) +{ + Decimal64 result(x); + result.SetScale(scale).ReScale(newScale); + CHECK_OVERFLOW_RETURN_NULL(result, newPrecision); + return result.GetValue(); +} + +extern "C" DLLEXPORT void CastDecimal128To128RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t precision, + int32_t scale, int32_t newPrecision, int32_t newScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result(xHigh, xLow); + result.SetScale(scale).ReScale(newScale); + CHECK_OVERFLOW_VOID_RETURN_NULL(result, newPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void CastDecimal64To128RetNull(bool *isNull, int64_t x, int32_t precision, int32_t scale, + int32_t newPrecision, int32_t newScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result(x); + result.SetScale(scale).ReScale(newScale); + CHECK_OVERFLOW_VOID_RETURN_NULL(result, newPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT int64_t CastDecimal128To64RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t precision, + int32_t scale, int32_t newPrecision, int32_t newScale) +{ + Decimal64 result(Decimal128Wrapper(xHigh, xLow).SetScale(scale).ReScale(newScale)); + CHECK_OVERFLOW_RETURN_NULL(result, newPrecision); + return result.GetValue(); +} + +extern "C" DLLEXPORT int64_t CastIntToDecimal64RetNull(bool *isNull, int32_t x, int32_t outPrecision, int32_t outScale) +{ + Decimal64 result(x); + result.ReScale(outScale); + CHECK_OVERFLOW_RETURN_NULL(result, outPrecision); + return result.GetValue(); +} + +extern "C" DLLEXPORT int64_t CastInt16ToDecimal64RetNull(bool *isNull, int16_t x, int32_t outPrecision, int32_t outScale) +{ + Decimal64 result(x); + result.ReScale(outScale); + CHECK_OVERFLOW_RETURN_NULL(result, outPrecision); + return result.GetValue(); +} + +extern "C" DLLEXPORT int64_t CastInt8ToDecimal64RetNull(bool *isNull, int8_t x, int32_t outPrecision, int32_t outScale) +{ + Decimal64 result(x); + result.ReScale(outScale); + CHECK_OVERFLOW_RETURN_NULL(result, outPrecision); + return result.GetValue(); +} + +extern "C" DLLEXPORT int64_t CastLongToDecimal64RetNull(bool *isNull, int64_t x, int32_t outPrecision, int32_t outScale) +{ + Decimal64 result(x); + result.ReScale(outScale); + CHECK_OVERFLOW_RETURN_NULL(result, outPrecision); + return result.GetValue(); +} + +extern "C" DLLEXPORT int64_t CastDoubleToDecimal64RetNull(bool *isNull, double x, int32_t outPrecision, + int32_t outScale) +{ + Decimal64 result(x); + result.ReScale(outScale); + if (result.IsOverflow(outPrecision) != OpStatus::SUCCESS) { + *isNull = true; + return 0; + } + return result.GetValue(); +} + +extern "C" DLLEXPORT void CastIntToDecimal128RetNull(bool *isNull, int32_t x, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result(x); + result.ReScale(outScale); + CHECK_OVERFLOW_VOID_RETURN_NULL(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void CastInt16ToDecimal128RetNull(bool *isNull, int16_t x, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result(x); + result.ReScale(outScale); + CHECK_OVERFLOW_VOID_RETURN_NULL(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void CastInt8ToDecimal128RetNull(bool *isNull, int8_t x, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result(x); + result.ReScale(outScale); + CHECK_OVERFLOW_VOID_RETURN_NULL(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void CastLongToDecimal128RetNull(bool *isNull, int64_t x, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result(x); + result.ReScale(outScale); + CHECK_OVERFLOW_VOID_RETURN_NULL(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT void CastDoubleToDecimal128RetNull(bool *isNull, double x, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr) +{ + Decimal128Wrapper result(x); + result.ReScale(outScale); + CHECK_OVERFLOW_VOID_RETURN_NULL(result, outPrecision); + *outHighPtr = result.HighBits(); + *outLowPtr = result.LowBits(); +} + +extern "C" DLLEXPORT int32_t CastDecimal64ToIntRetNull(bool *isNull, int64_t x, int32_t precision, int32_t scale) +{ + return static_cast(Decimal64(x).ReScale(-scale, RoundingMode::ROUND_FLOOR).GetValue()); +} + +extern "C" DLLEXPORT int16_t CastDecimal64ToInt16RetNull(bool *isNull, int64_t x, int32_t precision, int32_t scale) +{ + return static_cast(Decimal64(x).ReScale(-scale, RoundingMode::ROUND_FLOOR).GetValue()); +} + +extern "C" DLLEXPORT int8_t CastDecimal64ToInt8RetNull(bool *isNull, int64_t x, int32_t precision, int32_t scale) +{ + return static_cast(Decimal64(x).ReScale(-scale, RoundingMode::ROUND_FLOOR).GetValue()); +} + +extern "C" DLLEXPORT int64_t CastDecimal64ToLongRetNull(bool *isNull, int64_t x, int32_t precision, int32_t scale) +{ + Decimal64 result(x); + result.SetScale(scale).ReScale(0, RoundingMode::ROUND_FLOOR); + return result.GetValue(); +} + +extern "C" DLLEXPORT double CastDecimal64ToDoubleRetNull(bool *isNull, int64_t x, int32_t precision, int32_t scale) +{ + std::string doubleString = Decimal64(x).SetScale(scale).ToString(); + double result; + ConvertStringToDouble(result, doubleString); + return result; +} + +extern "C" DLLEXPORT int32_t CastDecimal128ToIntRetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t precision, + int32_t scale) +{ + return static_cast( + Decimal128Wrapper(xHigh, xLow).ReScale(-scale, RoundingMode::ROUND_FLOOR).ToInt128()); +} + +extern "C" DLLEXPORT int16_t CastDecimal128ToInt16RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t precision, + int32_t scale) +{ + return static_cast( + Decimal128Wrapper(xHigh, xLow).ReScale(-scale, RoundingMode::ROUND_FLOOR).ToInt128()); +} + +extern "C" DLLEXPORT int8_t CastDecimal128ToInt8RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t precision, + int32_t scale) +{ + return static_cast( + Decimal128Wrapper(xHigh, xLow).ReScale(-scale, RoundingMode::ROUND_FLOOR).ToInt128()); +} + + +extern "C" DLLEXPORT int64_t CastDecimal128ToLongRetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t precision, + int32_t scale) +{ + return static_cast( + Decimal128Wrapper(xHigh, xLow).ReScale(-scale, RoundingMode::ROUND_FLOOR).ToInt128()); +} + +extern "C" DLLEXPORT double CastDecimal128ToDoubleRetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t precision, + int32_t scale) +{ + std::string doubleString = Decimal128Wrapper(xHigh, xLow).SetScale(scale).ToString(); + double result; + ConvertStringToDouble(result, doubleString); + return result; +} +} \ No newline at end of file diff --git a/core/src/codegen/functions/decimal_cast_functions.h b/core/src/codegen/functions/decimal_cast_functions.h new file mode 100644 index 0000000..e9ca059 --- /dev/null +++ b/core/src/codegen/functions/decimal_cast_functions.h @@ -0,0 +1,194 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: batch decimal functions implementation + */ + +#ifndef OMNI_RUNTIME_DECIMAL_CAST_FUNCTIONS_H +#define OMNI_RUNTIME_DECIMAL_CAST_FUNCTIONS_H + +#include +#include +#include +#include +#include "type/decimal128.h" +#include "codegen/context_helper.h" +#include "type/decimal_operations.h" +#include "util/config_util.h" +#include "type/data_type.h" + +#ifdef _WIN32 +#define DLLEXPORT __declspec(dllexport) +#else +#define DLLEXPORT +#endif + +namespace omniruntime::codegen::function { +using namespace omniruntime::type; + +// Cast +extern "C" DLLEXPORT int64_t CastDecimal64To64(int64_t contextPtr, int64_t x, int32_t precision, int32_t scale, + bool isNull, int32_t newPrecision, int32_t newScale); + +extern "C" DLLEXPORT void CastDecimal128To128(int64_t contextPtr, int64_t xHigh, uint64_t xLow, int32_t precision, + int32_t scale, bool isNull, int32_t newPrecision, int32_t newScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void CastDecimal64To128(int64_t contextPtr, int64_t x, int32_t precision, int32_t scale, + bool isNull, int32_t newPrecision, int32_t newScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT int64_t CastDecimal128To64(int64_t contextPtr, int64_t xHigh, uint64_t xLow, int32_t precision, + int32_t scale, bool isNull, int32_t newPrecision, int32_t newScale); + +extern "C" DLLEXPORT int64_t CastIntToDecimal64(int64_t contextPtr, int32_t x, bool isNull, int32_t precision, + int32_t scale); + +extern "C" DLLEXPORT int64_t CastInt16ToDecimal64(int64_t contextPtr, int16_t x, bool isNull, int32_t precision, + int32_t scale); + +extern "C" DLLEXPORT int64_t CastInt8ToDecimal64(int64_t contextPtr, int8_t x, bool isNull, int32_t precision, + int32_t scale); + +extern "C" DLLEXPORT int64_t CastLongToDecimal64(int64_t contextPtr, int64_t x, bool isNull, int32_t outPrecision, + int32_t outScale); + +extern "C" DLLEXPORT int64_t CastDoubleToDecimal64(int64_t contextPtr, double x, bool isNull, int32_t precision, + int32_t scale); + +extern "C" DLLEXPORT void CastIntToDecimal128(int64_t contextPtr, int32_t x, bool isNull, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void CastInt16ToDecimal128(int64_t contextPtr, int16_t x, bool isNull, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void CastInt8ToDecimal128(int64_t contextPtr, int8_t x, bool isNull, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void CastLongToDecimal128(int64_t contextPtr, int64_t x, bool isNull, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void CastDoubleToDecimal128(int64_t contextPtr, double x, bool isNull, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT int32_t CastDecimal64ToIntHalfUp(int64_t contextPtr, int64_t x, int32_t precision, int32_t scale, + bool isNull); + +extern "C" DLLEXPORT int16_t CastDecimal64ToInt16HalfUp(int64_t contextPtr, int64_t x, int32_t precision, int32_t scale, + bool isNull); + +extern "C" DLLEXPORT int8_t CastDecimal64ToInt8HalfUp(int64_t contextPtr, int64_t x, int32_t precision, int32_t scale, + bool isNull); + +extern "C" DLLEXPORT int64_t CastDecimal64ToLongHalfUp(int64_t x, int32_t precision, int32_t scale, bool isNull); + +extern "C" DLLEXPORT double CastDecimal64ToDoubleHalfUp(int64_t x, int32_t precision, int32_t scale, bool isNull); + +extern "C" DLLEXPORT int32_t CastDecimal128ToIntHalfUp(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t precision, int32_t scale, bool isNull); + +extern "C" DLLEXPORT int16_t CastDecimal128ToInt16HalfUp(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t precision, int32_t scale, bool isNull); + +extern "C" DLLEXPORT int8_t CastDecimal128ToInt8HalfUp(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t precision, int32_t scale, bool isNull); + +extern "C" DLLEXPORT int64_t CastDecimal128ToLongHalfUp(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t precision, int32_t scale, bool isNull); + +extern "C" DLLEXPORT double CastDecimal128ToDoubleHalfUp(int64_t xHigh, uint64_t xLow, int32_t precision, int32_t scale, + bool isNull); + +extern "C" DLLEXPORT int32_t CastDecimal64ToIntDown(int64_t contextPtr, int64_t x, int32_t precision, int32_t scale, + bool isNull); + +extern "C" DLLEXPORT int16_t CastDecimal64ToInt16Down(int64_t contextPtr, int64_t x, int32_t precision, int32_t scale, + bool isNull); + +extern "C" DLLEXPORT int8_t CastDecimal64ToInt8Down(int64_t contextPtr, int64_t x, int32_t precision, int32_t scale, + bool isNull); + +extern "C" DLLEXPORT int64_t CastDecimal64ToLongDown(int64_t x, int32_t precision, int32_t scale, bool isNull); + +extern "C" DLLEXPORT double CastDecimal64ToDoubleDown(int64_t x, int32_t precision, int32_t scale, bool isNull); + +extern "C" DLLEXPORT int32_t CastDecimal128ToIntDown(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t precision, int32_t scale, bool isNull); + +extern "C" DLLEXPORT int16_t CastDecimal128ToInt16Down(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t precision, int32_t scale, bool isNull); + +extern "C" DLLEXPORT int8_t CastDecimal128ToInt8Down(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t precision, int32_t scale, bool isNull); + +extern "C" DLLEXPORT int64_t CastDecimal128ToLongDown(int64_t contextPtr, int64_t xHigh, uint64_t xLow, + int32_t precision, int32_t scale, bool isNull); + +extern "C" DLLEXPORT double CastDecimal128ToDoubleDown(int64_t xHigh, uint64_t xLow, int32_t precision, int32_t scale, + bool isNull); + +// Cast +extern "C" DLLEXPORT int64_t CastDecimal64To64RetNull(bool *isNull, int64_t x, int32_t precision, int32_t scale, + int32_t newPrecision, int32_t newScale); + +extern "C" DLLEXPORT void CastDecimal128To128RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t precision, + int32_t scale, int32_t newPrecision, int32_t newScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void CastDecimal64To128RetNull(bool *isNull, int64_t x, int32_t precision, int32_t scale, + int32_t newPrecision, int32_t newScale, int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT int64_t CastDecimal128To64RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t precision, + int32_t scale, int32_t newPrecision, int32_t newScale); + +extern "C" DLLEXPORT int64_t CastIntToDecimal64RetNull(bool *isNull, int32_t x, int32_t precision, int32_t scale); + +extern "C" DLLEXPORT int64_t CastInt16ToDecimal64RetNull(bool *isNull, int16_t x, int32_t outPrecision, int32_t outScale); + +extern "C" DLLEXPORT int64_t CastInt8ToDecimal64RetNull(bool *isNull, int8_t x, int32_t outPrecision, int32_t outScale); + +extern "C" DLLEXPORT int64_t CastLongToDecimal64RetNull(bool *isNull, int64_t x, int32_t outPrecision, + int32_t outScale); + +extern "C" DLLEXPORT int64_t CastDoubleToDecimal64RetNull(bool *isNull, double x, int32_t outPrecision, + int32_t outScale); + +extern "C" DLLEXPORT void CastIntToDecimal128RetNull(bool *isNull, int32_t x, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void CastInt16ToDecimal128RetNull(bool *isNull, int16_t x, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void CastInt8ToDecimal128RetNull(bool *isNull, int8_t x, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void CastLongToDecimal128RetNull(bool *isNull, int64_t x, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT void CastDoubleToDecimal128RetNull(bool *isNull, double x, int32_t outPrecision, int32_t outScale, + int64_t *outHighPtr, uint64_t *outLowPtr); + +extern "C" DLLEXPORT int32_t CastDecimal64ToIntRetNull(bool *isNull, int64_t x, int32_t precision, int32_t scale); + +extern "C" DLLEXPORT int16_t CastDecimal64ToInt16RetNull(bool *isNull, int64_t x, int32_t precision, int32_t scale); + +extern "C" DLLEXPORT int8_t CastDecimal64ToInt8RetNull(bool *isNull, int64_t x, int32_t precision, int32_t scale); + +extern "C" DLLEXPORT int64_t CastDecimal64ToLongRetNull(bool *isNull, int64_t x, int32_t precision, int32_t scale); + +extern "C" DLLEXPORT double CastDecimal64ToDoubleRetNull(bool *isNull, int64_t x, int32_t precision, int32_t scale); + +extern "C" DLLEXPORT int32_t CastDecimal128ToIntRetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t precision, + int32_t scale); + +extern "C" DLLEXPORT int16_t CastDecimal128ToInt16RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t precision, + int32_t scale); + +extern "C" DLLEXPORT int8_t CastDecimal128ToInt8RetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t precision, + int32_t scale); + +extern "C" DLLEXPORT int64_t CastDecimal128ToLongRetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t precision, + int32_t scale); + +extern "C" DLLEXPORT double CastDecimal128ToDoubleRetNull(bool *isNull, int64_t xHigh, uint64_t xLow, int32_t precision, + int32_t scale); +} + + +#endif // OMNI_RUNTIME_DECIMAL_CAST_FUNCTIONS_H diff --git a/core/src/codegen/functions/dictionaryfunctions.cpp b/core/src/codegen/functions/dictionaryfunctions.cpp new file mode 100644 index 0000000..184dbb6 --- /dev/null +++ b/core/src/codegen/functions/dictionaryfunctions.cpp @@ -0,0 +1,76 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description: registry dictionary functions + */ + +#include "dictionaryfunctions.h" +#include "vector/vector.h" +#include "codegen/context_helper.h" + +using namespace omniruntime::vec; +using namespace std; + +namespace omniruntime::codegen::function { +extern DLLEXPORT int32_t GetIntFromDictionaryVector(int64_t dictionaryVectorAddr, int32_t index) +{ + auto dictionaryVectorPtr = + reinterpret_cast> *>(dictionaryVectorAddr); + return dictionaryVectorPtr->GetValue(index); +} + +extern DLLEXPORT int8_t GetByteFromDictionaryVector(int64_t dictionaryVectorAddr, int32_t index) +{ + auto dictionaryVectorPtr = + reinterpret_cast> *>(dictionaryVectorAddr); + return dictionaryVectorPtr->GetValue(index); +} + +extern DLLEXPORT int16_t GetShortFromDictionaryVector(int64_t dictionaryVectorAddr, int32_t index) +{ + auto dictionaryVectorPtr = + reinterpret_cast> *>(dictionaryVectorAddr); + return dictionaryVectorPtr->GetValue(index); +} + +extern DLLEXPORT int64_t GetLongFromDictionaryVector(int64_t dictionaryVectorAddr, int32_t index) +{ + auto dictionaryVectorPtr = + reinterpret_cast> *>(dictionaryVectorAddr); + return dictionaryVectorPtr->GetValue(index); +} + +extern DLLEXPORT double GetDoubleFromDictionaryVector(int64_t dictionaryVectorAddr, int32_t index) +{ + auto dictionaryVectorPtr = + reinterpret_cast> *>(dictionaryVectorAddr); + return dictionaryVectorPtr->GetValue(index); +} + +extern DLLEXPORT bool GetBooleanFromDictionaryVector(int64_t dictionaryVectorAddr, int32_t index) +{ + auto dictionaryVectorPtr = + reinterpret_cast> *>(dictionaryVectorAddr); + return dictionaryVectorPtr->GetValue(index); +} + +extern DLLEXPORT uint8_t *GetVarcharFromDictionaryVector(int64_t dictionaryVectorAddr, int32_t index, + int32_t *lengthPtr) +{ + auto dictionaryVectorPtr = + reinterpret_cast> *>(dictionaryVectorAddr); + auto stringView = dictionaryVectorPtr->GetValue(index); + int32_t length = stringView.length(); + *lengthPtr = length; + return (uint8_t *)stringView.data(); +} + +extern DLLEXPORT void GetDecimalFromDictionaryVector(int64_t dictionaryVectorAddr, int32_t index, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr) +{ + auto dictionaryVectorPtr = + reinterpret_cast> *>(dictionaryVectorAddr); + auto value = dictionaryVectorPtr->GetValue(index); + *outLowPtr = value.LowBits(); + *outHighPtr = value.HighBits(); +} +} \ No newline at end of file diff --git a/core/src/codegen/functions/dictionaryfunctions.h b/core/src/codegen/functions/dictionaryfunctions.h new file mode 100644 index 0000000..a755ddd --- /dev/null +++ b/core/src/codegen/functions/dictionaryfunctions.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description: registry dictionary functions + */ + +#ifndef OMNI_RUNTIME_DICTIONARYFUNCTIONS_H +#define OMNI_RUNTIME_DICTIONARYFUNCTIONS_H + +#include + +namespace omniruntime::codegen::function { +#ifdef _WIN32 +#define DLLEXPORT __declspec(dllexport) +#else +#define DLLEXPORT +#endif + +extern DLLEXPORT int32_t GetIntFromDictionaryVector(int64_t dictionaryVectorAddr, int32_t index); + +extern DLLEXPORT int8_t GetByteFromDictionaryVector(int64_t dictionaryVectorAddr, int32_t index); + +extern DLLEXPORT int16_t GetShortFromDictionaryVector(int64_t dictionaryVectorAddr, int32_t index); + +extern DLLEXPORT int64_t GetLongFromDictionaryVector(int64_t dictionaryVectorAddr, int32_t index); + +extern DLLEXPORT double GetDoubleFromDictionaryVector(int64_t dictionaryVectorAddr, int32_t index); + +extern DLLEXPORT bool GetBooleanFromDictionaryVector(int64_t dictionaryVectorAddr, int32_t index); + +extern DLLEXPORT uint8_t *GetVarcharFromDictionaryVector(int64_t dictionaryVectorAddr, int32_t index, + int32_t *lengthPtr); + +extern DLLEXPORT void GetDecimalFromDictionaryVector(int64_t dictionaryVectorAddr, int32_t index, int32_t outPrecision, + int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr); +extern "C" DLLEXPORT uint8_t *GetStringViewValueAndLength(int64_t stringValueAddr, int32_t index, int32_t *length); +} +#endif // OMNI_RUNTIME_DICTIONARYFUNCTIONS_H diff --git a/core/src/codegen/functions/dtoa.cpp b/core/src/codegen/functions/dtoa.cpp new file mode 100644 index 0000000..fb73196 --- /dev/null +++ b/core/src/codegen/functions/dtoa.cpp @@ -0,0 +1,1218 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * Description: registry function implementation + */ + +#include +#include "dtoa.h" + +namespace omniruntime::codegen::function { +template +static ALWAYS_INLINE Dest BitCast(const Source &source) +{ + Dest dest; + memmove_s(&dest, sizeof(dest), &source, sizeof(dest)); + return dest; +} + +static ALWAYS_INLINE int64_t Unsigned64RightShift(int64_t input, int32_t shift) +{ + return static_cast(static_cast(input) >> shift); +} + +static ALWAYS_INLINE int32_t Unsigned32RightShift(int32_t input, int32_t shift) +{ + return static_cast(static_cast(input) >> shift); +} + +template +static ALWAYS_INLINE int32_t NumberOfLeadingZeros(T input) +{ + if constexpr (std::is_same_v) { + return __builtin_clzl(input); + } else { + return __builtin_clz(input); + } +} + +template +static ALWAYS_INLINE int32_t NumberOfTrailingZeros(T input) +{ + if constexpr (std::is_same_v) { + return __builtin_ctzl(input); + } else { + return __builtin_ctz(input); + } +} + +static FDBigInteger GetZero() noexcept +{ + int temp[] = {0}; + auto zero = FDBigInteger(temp, 0, 1); + zero.MakeImmutable(); + return zero; +} + +static std::vector GetPow5Cache() noexcept +{ + static std::vector pow5Cache(FDBigInteger::MAX_FIVE_POW); + int i = 0; + while (i < FDBigInteger::SMALL_5_POW_SIZE) { + int temp[] = {FDBigInteger::SMALL_5_POW[i]}; + FDBigInteger pow5 = FDBigInteger(temp, 0, 1); + pow5.MakeImmutable(); + pow5Cache[i] = pow5; + i++; + } + FDBigInteger prev = pow5Cache[i - 1]; + while (i < FDBigInteger::MAX_FIVE_POW) { + prev = prev.Mul(5); + prev.MakeImmutable(); + pow5Cache[i] = prev; + i++; + } + return pow5Cache; +} + +FDBigInteger::FDBigInteger(long lValue, const char *digits, int kDigits, int nDigits) +{ + data[0] = static_cast(lValue); // starting value + data[1] = static_cast(Unsigned64RightShift(lValue, 32)); + offset = 0; + nWords = 2; + int i = kDigits; + int limit = nDigits - 5; // slurp digits 5 at a time. + int v; + while (i < limit) { + int iLim = i + 5; + v = static_cast(digits[i++]) - static_cast('0'); + while (i < iLim) { + v = 10 * v + static_cast(digits[i++]) - static_cast('0'); + } + MulAddMe(100000, v); // ... where 100000 is 10^5. + } + int factor = 1; + v = 0; + while (i < nDigits) { + v = 10 * v + static_cast(digits[i++]) - static_cast('0'); + factor *= 10; + } + if (factor != 1) { + MulAddMe(factor, v); + } + TrimLeadingZeros(); +} + +FDBigInteger FDBigInteger::ZERO = GetZero(); +std::vector FDBigInteger::POW_5_CACHE = GetPow5Cache(); + +void FDBigInteger::TrimLeadingZeros() +{ + int i = nWords; + if (i > 0 && (data[--i] == 0)) { + while (i > 0 && data[i - 1] == 0) { + i--; + } + nWords = i; + if (i == 0) { // all words are zero + offset = 0; + } + } +} + +void FDBigInteger::MakeImmutable() +{ + isImmutable = true; +} + +int FDBigInteger::Cmp(const FDBigInteger &other) const +{ + int aSize = nWords + offset; + int bSize = other.nWords + other.offset; + if (aSize > bSize) { + return 1; + } else if (aSize < bSize) { + return -1; + } + int aLen = nWords; + int bLen = other.nWords; + while (aLen > 0 && bLen > 0) { + int a = data[--aLen]; + int b = other.data[--bLen]; + if (a != b) { + return ((a & LONG_MASK) < (b & LONG_MASK)) ? -1 : 1; + } + } + if (aLen > 0) { + return CheckZeroTail(data, aLen); + } + if (bLen > 0) { + return -CheckZeroTail(other.data, bLen); + } + return 0; +} + +FDBigInteger FDBigInteger::Mul(int i) +{ + if (nWords == 0) { + return *this; + } + int r[MAX_DATA_LENGTH]{0}; + Mul(data, nWords, i, r); + return {r, offset, nWords + 1}; +} + +FDBigInteger FDBigInteger::LeftShift(int shift) +{ + if (shift == 0 || nWords == 0) { + return *this; + } + int wordCount = shift >> 5; + int bitCount = shift & 0x1f; + if (isImmutable) { + if (bitCount == 0) { + int newData[MAX_DATA_LENGTH]{0}; + int count = nWords * static_cast(sizeof(int)); + memcpy_s(newData, count, data, count); + return {newData, offset + wordCount, nWords}; + } else { + int antiCount = 32 - bitCount; + int idx = nWords - 1; + int prev = data[idx]; + int hi = Unsigned32RightShift(prev, antiCount); + int result[MAX_DATA_LENGTH]{0}; + int tmpLen = 0; + if (hi != 0) { + tmpLen = nWords + 1; + result[nWords] = hi; + } else { + tmpLen = nWords; + } + LeftShift(data, idx, result, bitCount, antiCount, prev); + return {result, offset + wordCount, tmpLen}; + } + } else { + if (bitCount != 0) { + int antiCount = 32 - bitCount; + if ((data[0] << bitCount) == 0) { + int idx = 0; + int prev = data[idx]; + for (; idx < nWords - 1; idx++) { + int v = Unsigned32RightShift(prev, antiCount); + prev = data[idx + 1]; + v |= (prev << bitCount); + data[idx] = v; + } + int v = Unsigned32RightShift(prev, antiCount); + data[idx] = v; + if (v == 0) { + nWords--; + } + offset++; + } else { + int idx = nWords - 1; + int prev = data[idx]; + int hi = Unsigned32RightShift(prev, antiCount); + int src[MAX_DATA_LENGTH]{0}; + memcpy_s(src, nWords * sizeof(int), data, nWords * sizeof(int)); + if (hi != 0) { + if (nWords == DataSize()) { + Resize(nWords + 1); + } + data[nWords++] = hi; + } + LeftShift(src, idx, data, bitCount, antiCount, prev); + } + } + offset += wordCount; + return *this; + } +} + +int FDBigInteger::AddAndCmp(const FDBigInteger &x, const FDBigInteger &y) +{ + FDBigInteger big; + FDBigInteger small; + int xSize = x.Size(); + int ySize = y.Size(); + int bSize; + int sSize; + if (xSize >= ySize) { + big = x; + small = y; + bSize = xSize; + sSize = ySize; + } else { + big = y; + small = x; + bSize = ySize; + sSize = xSize; + } + int thSize = Size(); + if (bSize == 0) { + return thSize == 0 ? 0 : 1; + } + if (sSize == 0) { + return Cmp(big); + } + if (bSize > thSize) { + return -1; + } + if (bSize + 1 < thSize) { + return 1; + } + long top = (big.data[big.nWords - 1] & LONG_MASK); + if (sSize == bSize) { + top += (small.data[small.nWords - 1] & LONG_MASK); + } + if (Unsigned64RightShift(top, 32) == 0) { + if (Unsigned64RightShift(top + 1, 32) == 0) { + // good case - no carry extension + if (bSize < thSize) { + return 1; + } + // here sum.nWords == nWords + long v = (data[nWords - 1] & LONG_MASK); + if (v < top) { + return -1; + } + if (v > top + 1) { + return 1; + } + } + } else { // (top>>>32)!=0 guaranteed carry extension + if (bSize + 1 > thSize) { + return -1; + } + // here sum.nWords == nWords + Unsigned64RightShift(top, 32); + long v = (data[nWords - 1] & LONG_MASK); + if (v < top) { + return -1; + } + if (v > top + 1) { + return 1; + } + } + return Cmp(big.Add(small)); +} + +FDBigInteger FDBigInteger::MulBy10() +{ + if (nWords == 0) { + return *this; + } + if (isImmutable) { + int res[MAX_DATA_LENGTH]{0}; + res[nWords] = MulAndCarryBy10(data, nWords, res); + return {res, offset, nWords + 1}; + } else { + int p = MulAndCarryBy10(data, nWords, data); + if (p != 0) { + if (nWords == DataSize()) { + if (data[0] == 0) { + auto l = --nWords; + memcpy_s(data, l * sizeof(int), data + 1, l * sizeof(int)); + offset++; + } else { + Resize(DataSize() + 1); + } + } + data[nWords++] = p; + } else { + TrimLeadingZeros(); + } + return *this; + } +} + +void FDBigInteger::Mul(const int *src, int srcLen, int value, int *dst) +{ + long val = value & LONG_MASK; + long carry = 0; + for (int i = 0; i < srcLen; i++) { + long product = (src[i] & LONG_MASK) * val + carry; + dst[i] = static_cast(product); + carry = Unsigned64RightShift(product, 32); + } + dst[srcLen] = static_cast(carry); +} + +FDBigInteger FDBigInteger::Big5PowRec(int p) +{ + if (p < MAX_FIVE_POW) { + return POW_5_CACHE[p]; + } + // construct the value. + // recursively. + int q, r; + // in order to compute 5^p, + // compute its square root, 5^(p/2) and square. + // or, let q = p / 2, r = p -q, then + // 5^p = 5^(q+r) = 5^q * 5^r + q = p >> 1; + r = p - q; + FDBigInteger bigQ = Big5PowRec(q); + if (r < SMALL_5_POW_SIZE) { + return bigQ.Mul(SMALL_5_POW[r]); + } else { + return bigQ.Mul(Big5PowRec(r)); + } +} + +FDBigInteger FDBigInteger::Big5Pow(int p) +{ + if (p < MAX_FIVE_POW) { + return POW_5_CACHE[p]; + } + return Big5PowRec(p); +} + +FDBigInteger FDBigInteger::ValueOfPow52(int p5, int p2) +{ + if (p5 != 0) { + if (p2 == 0) { + return Big5Pow(p5); + } else if (p5 < SMALL_5_POW_SIZE) { + int pow5 = SMALL_5_POW[p5]; + int wordcount = p2 >> 5; + int bitCount = p2 & 0x1f; + if (bitCount == 0) { + int temp[] = {pow5}; + return {temp, wordcount, 1}; + } else { + int temp[] = {pow5 << bitCount, Unsigned32RightShift(pow5, 32 - bitCount)}; + return {temp, wordcount, 2}; + } + } else { + return Big5Pow(p5).LeftShift(p2); + } + } else { + return ValueOfPow2(p2); + } +} + +FDBigInteger FDBigInteger::ValueOfMulPow52(long value, int p5, int p2) +{ + int v0 = static_cast(value); + int v1 = static_cast(Unsigned64RightShift(value, 32)); + int wordcount = p2 >> 5; + int bitCount = p2 & 0x1f; + if (p5 != 0) { + if (p5 < SMALL_5_POW_SIZE) { + long pow5 = SMALL_5_POW[p5] & LONG_MASK; + long carry = (v0 & LONG_MASK) * pow5; + v0 = static_cast(carry); + carry = Unsigned64RightShift(carry, 32); + carry = (v1 & LONG_MASK) * pow5 + carry; + v1 = static_cast(carry); + int v2 = static_cast(Unsigned64RightShift(carry, 32)); + if (bitCount == 0) { + int temp[] = {v0, v1, v2}; + return {temp, wordcount, 3}; + } else { + int temp[] = {v0 << bitCount, (v1 << bitCount) | Unsigned32RightShift(v0, 32 - bitCount), + (v2 << bitCount) | Unsigned32RightShift(v1, 32 - bitCount), + Unsigned32RightShift(v2, 32 - bitCount)}; + return {temp, wordcount, 4}; + } + } else { + FDBigInteger pow5 = Big5Pow(p5); + int r[MAX_DATA_LENGTH]{0}; + int len = 0; + if (v1 == 0) { + len = pow5.nWords + 1 + ((p2 != 0) ? 1 : 0); + Mul(pow5.data, pow5.nWords, v0, r); + } else { + len = pow5.nWords + 2 + ((p2 != 0) ? 1 : 0); + Mul(pow5.data, pow5.nWords, v0, v1, r); + } + return (FDBigInteger(r, pow5.offset, len)).LeftShift(p2); + } + } else if (p2 != 0) { + if (bitCount == 0) { + int temp[] = {v0, v1}; + return {temp, wordcount, 2}; + } else { + int temp[] = {v0 << bitCount, (v1 << bitCount) | Unsigned32RightShift(v0, 32 - bitCount), + Unsigned32RightShift(v1, 32 - bitCount)}; + return {temp, wordcount, 3}; + } + } + int temp[] = {v0, v1}; + return {temp, 0, 2}; +} + +int FDBigInteger::MulAndCarryBy10(const int *src, int srcLen, int *dst) +{ + long carry = 0; + for (int i = 0; i < srcLen; i++) { + long product = (src[i] & LONG_MASK) * 10L + carry; + dst[i] = static_cast(product); + carry = Unsigned64RightShift(product, 32); + } + return static_cast(carry); +} + +void FDBigInteger::Mul(const int *src, int srcLen, int v0, int v1, int *dst) +{ + long v = v0 & LONG_MASK; + long carry = 0; + for (int j = 0; j < srcLen; j++) { + long product = v * (src[j] & LONG_MASK) + carry; + dst[j] = static_cast(product); + carry = Unsigned64RightShift(product, 32); + } + dst[srcLen] = static_cast(carry); + v = v1 & LONG_MASK; + carry = 0; + for (int j = 0; j < srcLen; j++) { + long product = (dst[j + 1] & LONG_MASK) + v * (src[j] & LONG_MASK) + carry; + dst[j + 1] = static_cast(product); + carry = Unsigned64RightShift(product, 32); + } + dst[srcLen + 1] = static_cast(carry); +} + +void FDBigInteger::Mul(const int *s1, int s1Len, const int *s2, int s2Len, int *dst) +{ + for (int i = 0; i < s1Len; i++) { + long v = s1[i] & LONG_MASK; + long p = 0L; + for (int j = 0; j < s2Len; j++) { + p += (dst[i + j] & LONG_MASK) + v * (s2[j] & LONG_MASK); + dst[i + j] = static_cast(p); + p = Unsigned64RightShift(p, 32); + } + dst[i + s2Len] = static_cast(p); + } +} + +int FDBigInteger::CheckZeroTail(const int *a, int from) +{ + while (from > 0) { + if (a[--from] != 0) { + return 1; + } + } + return 0; +} + +void FDBigInteger::LeftShift(const int *src, int idx, int *result, int bitCount, int antiCount, int prev) +{ + for (; idx > 0; idx--) { + int v = (prev << bitCount); + prev = src[idx - 1]; + v |= Unsigned32RightShift(prev, antiCount); + result[idx] = v; + } + int v = prev << bitCount; + result[0] = v; +} + +int FDBigInteger::GetNormalizationBias() const +{ + if (nWords == 0) { + throw omniruntime::exception::OmniException("OPERATOR_RUNTIME_ERROR", "Zero value cannot be normalized"); + } + int zeros = NumberOfLeadingZeros(data[nWords - 1]); + return (zeros < 4) ? 28 + zeros : zeros - 4; +} + +int FDBigInteger::QuoRemIteration(FDBigInteger &s) +{ + // ensure that this and S have the same number of + // digits. If S is properly normalized and q < 10 then + // this must be so. + int thSize = Size(); + int sSize = s.Size(); + if (thSize < sSize) { + // this value is significantly less than S, result of division is zero. + // just mul this by 10. + int p = MulAndCarryBy10(data, nWords, data); + if (p != 0) { + data[nWords++] = p; + } else { + TrimLeadingZeros(); + } + return 0; + } else if (thSize > sSize) { + throw omniruntime::exception::OmniException("OPERATOR_RUNTIME_ERROR", "disparate values"); + } + // estimate q the obvious way. We will usually be + // right. If not, then we're only off by a little and + // will re-add. + long q = (data[nWords - 1] & LONG_MASK) / (s.data[s.nWords - 1] & LONG_MASK); + long diff = MulDiffMe(q, s); + if (diff != 0L) { + //@ assert q != 0; + //@ assert offset == \old(Math.min(offset, S.offset)); + //@ assert offset <= S.offset; + + // q is too big. + // add S back in until this turns +. This should + // not be very many times! + long sum = 0L; + int tStart = s.offset - offset; + //@ assert tStart >= 0; + int *sd = s.data; + int *td = data; + while (sum == 0L) { + for (int sIndex = 0, tIndex = tStart; tIndex < nWords; sIndex++, tIndex++) { + sum += (td[tIndex] & LONG_MASK) + (sd[sIndex] & LONG_MASK); + td[tIndex] = static_cast(sum); + sum = Unsigned64RightShift(sum, 32); // Signed or unsigned, answer is 0 or 1 + } + // + // Originally the following line read + // "if ( sum !=0 && sum != -1 )" + // but that would be wrong, because of the + // treatment of the two values as entirely unsigned, + // it would be impossible for a carry-out to be interpreted + // as -1 -- it would have to be a single-bit carry-out, or +1. + // + q -= 1; + } + } + // finally, we can multiply this by 10. + // it cannot overflow, right, as the high-order word has + // at least 4 high-order zeros! + MulAndCarryBy10(data, nWords, data); + TrimLeadingZeros(); + return static_cast(q); +} + +FDBigInteger FDBigInteger::Add(const FDBigInteger &other) +{ + FDBigInteger big, small; + int bigLen, smallLen; + int tSize = Size(); + int oSize = other.Size(); + if (tSize >= oSize) { + big = *this; + bigLen = tSize; + small = other; + smallLen = oSize; + } else { + big = other; + bigLen = oSize; + small = *this; + smallLen = tSize; + } + int r[MAX_DATA_LENGTH]{0}; + int i = 0; + long carry = 0L; + for (; i < smallLen; i++) { + carry += (i < big.offset ? 0L : (big.data[i - big.offset] & LONG_MASK)) + + ((i < small.offset ? 0L : (small.data[i - small.offset] & LONG_MASK))); + r[i] = static_cast(carry); + carry >>= 32; // signed shift. + } + for (; i < bigLen; i++) { + carry += (i < big.offset ? 0L : (big.data[i - big.offset] & LONG_MASK)); + r[i] = static_cast(carry); + carry >>= 32; // signed shift. + } + r[bigLen] = static_cast(carry); + return {r, 0, bigLen + 1}; +} + +long FDBigInteger::MulDiffMe(long q, FDBigInteger &s) +{ + long diff = 0L; + if (q != 0) { + int deltaSize = s.offset - offset; + if (deltaSize >= 0) { + int *sd = s.data; + int *td = data; + for (int sIndex = 0, tIndex = deltaSize; sIndex < s.nWords; sIndex++, tIndex++) { + diff += (td[tIndex] & LONG_MASK) - q * (sd[sIndex] & LONG_MASK); + td[tIndex] = static_cast(diff); + diff >>= 32; // N.B. SIGNED shift. + } + } else { + deltaSize = -deltaSize; + int rd[MAX_DATA_LENGTH]{0}; + int sIndex = 0; + int rIndex = 0; + int *sd = s.data; + for (; rIndex < deltaSize && sIndex < s.nWords; sIndex++, rIndex++) { + diff -= q * (sd[sIndex] & LONG_MASK); + rd[rIndex] = static_cast(diff); + diff >>= 32; // N.B. SIGNED shift. + } + int tIndex = 0; + int *td = data; + for (; sIndex < s.nWords; sIndex++, tIndex++, rIndex++) { + diff += (td[tIndex] & LONG_MASK) - q * (sd[sIndex] & LONG_MASK); + rd[rIndex] = static_cast(diff); + diff >>= 32; // N.B. SIGNED shift. + } + nWords += deltaSize; + offset -= deltaSize; + UpdateDataVector(rd); + } + } + return diff; +} + +FDBigInteger FDBigInteger::Mul(FDBigInteger other) +{ + if (nWords == 0) { + return *this; + } + if (Size() == 1) { + return other.Mul(data[0]); + } + if (other.nWords == 0) { + return other; + } + if (other.Size() == 1) { + return Mul(other.data[0]); + } + int r[MAX_DATA_LENGTH]{0}; + Mul(data, nWords, other.data, other.nWords, r); + return {r, offset + other.offset, nWords + other.nWords}; +} + +void FDBigInteger::MulAddMe(int iv, int addend) +{ + long v = iv & LONG_MASK; + // unroll 0th iteration, doing addition. + long p = v * (data[0] & LONG_MASK) + (addend & LONG_MASK); + data[0] = static_cast(p); + p = Unsigned64RightShift(p, 32); + for (int i = 1; i < nWords; i++) { + p += v * (data[i] & LONG_MASK); + data[i] = static_cast(p); + p = Unsigned64RightShift(p, 32); + } + if (p != 0L) { + data[nWords++] = static_cast(p); // will fail noisily if illegal! + } +} + +void DoubleToString::DevelopLongDigits(int exponent, long leftValue, int insignificantDigits) +{ + if (insignificantDigits != 0) { + // Discard non-significant low-order bits, while rounding, + // up to insignificant value. + long pow10 = FDBigInteger::LONG_5_POW[insignificantDigits] << insignificantDigits; // 10^i == 5^i * 2^i; + long residue = leftValue % pow10; + leftValue /= pow10; + exponent += insignificantDigits; + if (residue >= (pow10 >> 1)) { + // round up based on the low-order bits we're discarding + leftValue++; + } + } + int digitNo = 20 - 1; + int c; + if (leftValue <= INT32_MAX) { + // even easier subcase! + // can do int arithmetic rather than long! + int iValue = static_cast(leftValue); + c = iValue % 10; + iValue /= 10; + while (c == 0) { + exponent++; + c = iValue % 10; + iValue /= 10; + } + while (iValue != 0) { + digits[digitNo--] = static_cast(c + '0'); + exponent++; + c = iValue % 10; + iValue /= 10; + } + digits[digitNo] = static_cast(c + '0'); + } else { + // same algorithm as above (same bugs, too ) + // but using long arithmetic. + c = static_cast(leftValue % 10L); + leftValue /= 10L; + while (c == 0) { + exponent++; + c = static_cast(leftValue % 10L); + leftValue /= 10L; + } + while (leftValue != 0L) { + digits[digitNo--] = static_cast(c + '0'); + exponent++; + c = static_cast(leftValue % 10L); + leftValue /= 10; + } + digits[digitNo] = static_cast(c + '0'); + } + this->decExponent = exponent + 1; + this->firstDigitIndex = digitNo; + this->nDigits = 20 - digitNo; +} + +void DoubleToString::Dtoa(int binExp, long fractBits, int nSignificantBits, bool isCompatibleFormat) +{ + // Examine number. Determine if it is an easy case, + // which we can do pretty trivially using float/long conversion, + // or whether we must do real work. + int tailZeros = NumberOfTrailingZeros(fractBits); + + // number of significant bits of fractBits; + int nFractBits = EXP_SHIFT + 1 - tailZeros; + + // number of significant bits to the right of the point. + int nTinyBits = std::max(0, nFractBits - binExp - 1); + if (binExp <= MAX_SMALL_BIN_EXP && binExp >= MIN_SMALL_BIN_EXP) { + // Look more closely at the number to decide if, + // with scaling by 10^nTinyBits, the result will fit in + // a long. + if ((nTinyBits < FDBigInteger::LONG_5_POW_SIZE) && ((nFractBits + N_5_BITS[nTinyBits]) < 64)) { + // + // We can do this: + // take the fraction bits, which are normalized. + // (1) nTinyBits == 0: Shift left or right appropriately to align the binary point + // at the extreme right, i.e.where a long int point is expected to be. + // The integer result is easily converted to a string. + // (2) nTinyBits > 0: Shift right by EXP_SHIFT-n FractBits, which effectively converts to + // long and scales by 2^nTinyBits. Then multiply by 5^nTinyBits to complete the scaling. + // We know this won't overflow because we just counted the number of bits necessary in the result. + // The integer you get from this can then be converted to a string pretty easily. + // + if (nTinyBits == 0) { + int insignificant; + if (binExp > nSignificantBits) { + insignificant = InsignificantDigitsForPow2(binExp - nSignificantBits - 1); + } else { + insignificant = 0; + } + if (binExp >= EXP_SHIFT) { + fractBits <<= (binExp - EXP_SHIFT); + } else { + fractBits = Unsigned64RightShift(fractBits, EXP_SHIFT - binExp); + } + DevelopLongDigits(0, fractBits, insignificant); + return; + } + } + } + // + // This is the hard case. We are going to compute large positive integers B and S and integer decExp, s.t. + // d = ( B / S )// 10^decExp + // 1 <= B / S < 10 + // Obvious choices are: + // decExp = floor( log10(d) ) + // B = d// 2^nTinyBits// 10^max( 0, -decExp ) + // S = 10^max( 0, decExp)// 2^nTinyBits + // (noting that nTinyBits has already been forced to non-negative) + // I am also going to compute a large positive integer + // M = (1/2^nSignificantBits)// 2^nTinyBits// 10^max( 0, -decExp ) + // i.e. M is (1/2) of the ULP of d, scaled like B. + // When we iterate through dividing B/S and picking off the quotient bits, + // we will know when to stop when the remainder + // is <= M. + // + // We keep track of powers of 2 and powers of 5. + int decExp = EstimateDecExp(fractBits, binExp); + int d2, d5; // powers of 2 and powers of 5, respectively, in D + int s2, s5; // powers of 2 and powers of 5, respectively, in S + int b2, b5; // powers of 2 and powers of 5, respectively, in B + + d5 = std::max(0, -decExp); + d2 = d5 + nTinyBits + binExp; + + s5 = std::max(0, decExp); + s2 = s5 + nTinyBits; + + b5 = d5; + b2 = d2 - nSignificantBits; + + // the long integer fractBits contains the (nFractBits) interesting bits from the + // mantissa of d ( hidden 1 added if necessary) followed by (EXP_SHIFT+1-nFractBits) zeros. + // In the interest of compactness, I will shift out those zeros before turning fractBits + // into a FDBigInteger. + // The resulting whole number will be + // d * 2^(nFractBits-1-binExp). + fractBits = Unsigned64RightShift(fractBits, tailZeros); + d2 -= nFractBits - 1; + int common2factor = std::min(d2, s2); + d2 -= common2factor; + s2 -= common2factor; + b2 -= common2factor; + + // HACK!! For exact powers of two, the next smallest number is only half as far away + // as we think (because the meaning of ULP changes at power-of-two bounds) + // for this reason, we hack M2. Hope this works. + if (nFractBits == 1) { + b2 -= 1; + } + + if (b2 < 0) { + // oops. since we cannot scale M down far enough, we must scale the other values up. + d2 -= b2; + s2 -= b2; + b2 = 0; + } + + // Construct, Scale, iterate. Some day, we'll write a stopping test that + // takes account of the asymmetry of the spacing of floating-point numbers + // below perfect powers of 2 26 Sept 96 is not that day. So we use a symmetric test. + int nDigit = 0; + bool low, high; + long lowDigitDifference; + int q; + + // Detect the special cases where all the numbers we are about to compute will fit in int or long integers. + // In these cases, we will avoid doing FDBigInteger arithmetic. We use the same algorithms, + // except that we "normalize" our FDBigIntegers before iterating. This is to make division easier, + // as it makes our fist guess (quotient of high-order words) more accurate! + // + // Some day, we'll write a stopping test that takes account of the asymmetry of + // the spacing of floating-point numbers below perfect powers of 2 26 Sept 96 is not that day. + // So we use a symmetric test. + // + // binary digits needed to represent B, approx. + int bBits = nFractBits + d2 + ((d5 < N_5_BITS_SIZE) ? N_5_BITS[d5] : (d5 * 3)); + + // binary digits needed to represent 10*S, approx. + int tensBits = s2 + 1 + (((s5 + 1) < N_5_BITS_SIZE) ? N_5_BITS[(s5 + 1)] : ((s5 + 1) * 3)); + if (bBits < 64 && tensBits < 64) { + if (bBits < 32 && tensBits < 32) { + // wa-hoo! They're all ints! + int b = (static_cast(fractBits) * FDBigInteger::SMALL_5_POW[d5]) << d2; + int s = FDBigInteger::SMALL_5_POW[s5] << s2; + int m = FDBigInteger::SMALL_5_POW[b5] << b2; + int tens = s * 10; + // + // Unroll the first iteration. If our decExp estimate + // was too high, our first quotient will be zero. In this + // case, we discard it and decrement decExp. + nDigit = 0; + q = b / s; + b = 10 * (b % s); + m *= 10; + low = (b < m); + high = (b + m > tens); + if ((q == 0) && !high) { + // oops. Usually ignore leading zero. + decExp--; + } else { + digits[nDigit++] = static_cast('0' + q); + } + + // HACK! Java spec sez that we always have at least one digit + // after the . in either F- or E-form output. + // Thus, we will need more than one digit if we're using E-form + if (!isCompatibleFormat || decExp < -3 || decExp >= 8) { + high = false; + low = false; + } + while (!low && !high) { + b = 10 * (b % s); + q = b / s; + bool overflow = __builtin_mul_overflow(m, 10, &m); + if (!overflow) { + low = (b < m); + high = (b + m > tens); + } else { + // hack -- m might overflow! in this case, it is certainly > b, + // which won't and b+m > tens, too, since that has overflowed either! + low = true; + high = true; + } + digits[nDigit++] = static_cast('0' + q); + } + lowDigitDifference = (b << 1) - tens; + } else { + // still good! they're all longs! + long b = (fractBits * FDBigInteger::LONG_5_POW[d5]) << d2; + long s = FDBigInteger::LONG_5_POW[s5] << s2; + long m = FDBigInteger::LONG_5_POW[b5] << b2; + long tens = s * 10L; + // + // Unroll the first iteration. If our decExp estimate + // was too high, our first quotient will be zero. In this + // case, we discard it and decrement decExp. + // + nDigit = 0; + q = static_cast(b / s); + b = 10L * (b % s); + m *= 10L; + low = (b < m); + high = (b + m > tens); + if ((q == 0) && !high) { + // oops. Usually ignore leading zero. + decExp--; + } else { + digits[nDigit++] = static_cast('0' + q); + } + + // HACK! Java spec sez that we always have at least one digit + // after the . in either F- or E-form output. + // Thus, we will need more than one digit if we're using E-form + if (!isCompatibleFormat || decExp < -3 || decExp >= 8) { + high = low = false; + } + while (!low && !high) { + q = static_cast(b / s); + b = 10 * (b % s); + bool overflow = __builtin_mul_overflow(m, 10, &m); + if (!overflow) { + high = (b + m > tens); + low = (b < m); + } else { + // m might overflow! in this case, it is certainly > b, + // which won't and b+m > tens, too, since that has overflowed either! + low = true; + high = true; + } + digits[nDigit++] = static_cast('0' + q); + } + lowDigitDifference = (b << 1) - tens; + } + } else { + // We really must do FDBigInteger arithmetic. + // Fist, construct our FDBigInteger initial values. + FDBigInteger sval = FDBigInteger::ValueOfPow52(s5, s2); + int shiftBias = sval.GetNormalizationBias(); + sval = sval.LeftShift(shiftBias); // normalize so that division works better + + FDBigInteger bval = FDBigInteger::ValueOfMulPow52(fractBits, d5, d2 + shiftBias); + FDBigInteger mval = FDBigInteger::ValueOfPow52(b5 + 1, b2 + shiftBias + 1); + FDBigInteger tenSval = FDBigInteger::ValueOfPow52(s5 + 1, s2 + shiftBias + 1); // Sval.mul( 10 ); + + // Unroll the first iteration. If our decExp estimate + // was too high, our first quotient will be zero. In this + // case, we discard it and decrement decExp. + nDigit = 0; + q = bval.QuoRemIteration(sval); + low = (bval.Cmp(mval) < 0); + high = tenSval.AddAndCmp(bval, mval) <= 0; + + if ((q == 0) && !high) { + // oops. Usually ignore leading zero. + decExp--; + } else { + digits[nDigit++] = static_cast('0' + q); + } + + // HACK! Java spec sez that we always have at least one digit + // after the . in either F- or E-form output. Thus, we will need more than + // one digit if we're using E-form + if (!isCompatibleFormat || decExp < -3 || decExp >= 8) { + high = low = false; + } + while (!low && !high) { + q = bval.QuoRemIteration(sval); + mval = mval.MulBy10(); // Mval = Mval.mul( 10 ); + low = (bval.Cmp(mval) < 0); + high = tenSval.AddAndCmp(bval, mval) <= 0; + digits[nDigit++] = static_cast('0' + q); + } + if (high && low) { + bval = bval.LeftShift(1); + lowDigitDifference = bval.Cmp(tenSval); + } else { + lowDigitDifference = 0L; // this here only for flow analysis! + } + } + decExponent = decExp + 1; + firstDigitIndex = 0; + nDigits = nDigit; + + // Last digit gets rounded based on stopping condition. + if (high) { + if (low) { + if (lowDigitDifference == 0L) { + // it's a tie! choose based on which digits we like. + auto index = firstDigitIndex + nDigits - 1; + if (index >= MAX_DIGIT_INDEX) { + throw std::runtime_error("digits index overflow! index value:" + std::to_string(index)); + } + if ((digits[index] & 1) != 0) { + Roundup(); + } + } else if (lowDigitDifference > 0) { + Roundup(); + } + } else { + Roundup(); + } + } +} + +void DoubleToString::Roundup() +{ + int index = (firstDigitIndex + nDigits - 1); + int q = digits[index]; + if (q == '9') { + while (q == '9' && index > firstDigitIndex) { + digits[index] = '0'; + q = digits[--index]; + } + if (q == '9') { + // carryout! High-order 1, rest 0s, larger exp. + decExponent += 1; + digits[firstDigitIndex] = '1'; + return; + } + // else fall through. + } + digits[index] = static_cast(q + 1); +} + +int DoubleToString::GetChars(char *result) const +{ + int index = 0; + if (isNegative) { + result[0] = '-'; + index = 1; + } + if (decExponent > 0 && decExponent < 8) { + // print digits.digits. + int charLength = std::min(nDigits, decExponent); + memcpy_s(result + index, charLength, digits + firstDigitIndex, charLength); + index += charLength; + if (charLength < decExponent) { + charLength = decExponent - charLength; + std::fill(result + index, result + (index + charLength), '0'); + index += charLength; + result[index++] = '.'; + result[index++] = '0'; + } else { + result[index++] = '.'; + if (charLength < nDigits) { + int t = nDigits - charLength; + memcpy_s(result + index, t, digits + (firstDigitIndex + charLength), t); + index += t; + } else { + result[index++] = '0'; + } + } + } else if (decExponent <= 0 && decExponent > -3) { + result[index++] = '0'; + result[index++] = '.'; + if (decExponent != 0) { + std::fill(result + index, result + (index - decExponent), '0'); + index -= decExponent; + } + memcpy_s(result + index, nDigits, digits + firstDigitIndex, nDigits); + index += nDigits; + } else { + result[index++] = digits[firstDigitIndex]; + result[index++] = '.'; + if (nDigits > 1) { + memcpy_s(result + index, nDigits - 1, digits + (firstDigitIndex + 1), nDigits - 1); + index += nDigits - 1; + } else { + result[index++] = '0'; + } + result[index++] = 'E'; + int e; + if (decExponent <= 0) { + result[index++] = '-'; + e = -decExponent + 1; + } else { + e = decExponent - 1; + } + // decExponent has 1, 2, or 3, digits + if (e <= 9) { + result[index++] = static_cast(e + '0'); + } else if (e <= 99) { + result[index++] = static_cast(e / 10 + '0'); + result[index++] = static_cast(e % 10 + '0'); + } else { + result[index++] = static_cast(e / 100 + '0'); + e %= 100; + result[index++] = static_cast(e / 10 + '0'); + result[index++] = static_cast(e % 10 + '0'); + } + } + return index; +} + +int DoubleToString::InsignificantDigitsForPow2(int p2) +{ + if (p2 > 1 && p2 < INSIGNIFICANT_DIGITS_NUMBER_SIZE) { + return INSIGNIFICANT_DIGITS_NUMBER[p2]; + } + return 0; +} + +int DoubleToString::EstimateDecExp(long fractBits, int binExp) +{ + auto d2 = BitCast(EXP_ONE | (fractBits & DoubleConsts::SIGNIF_BIT_MASK)); + double d = (d2 - 1.5) * 0.289529654 + 0.176091259 + static_cast(binExp) * 0.301029995663981; + long dBits = BitCast(d); //can't be NaN here so use raw + int exponent = static_cast((dBits & DoubleConsts::EXP_BIT_MASK) >> EXP_SHIFT) - DoubleConsts::EXP_BIAS; + bool isNeg = (dBits & DoubleConsts::SIGN_BIT_MASK) != 0; // discover sign + if (exponent >= 0 && exponent < 52) { // hot path + long mask = DoubleConsts::SIGNIF_BIT_MASK >> exponent; + int r = static_cast(((dBits & DoubleConsts::SIGNIF_BIT_MASK) | FRACT_HOB) >> (EXP_SHIFT - exponent)); + return isNeg ? (((mask & dBits) == 0L) ? -r : -r - 1) : r; + } else if (exponent < 0) { + return (((dBits & ~DoubleConsts::SIGN_BIT_MASK) == 0) ? 0 : + ((isNeg) ? -1 : 0)); + } else { + return static_cast(d); + } +} + +std::size_t DoubleToString::DoubleToStringConverter(double d, char *result) +{ + long dBits = BitCast(d); + bool isNeg = (dBits & DoubleConsts::SIGN_BIT_MASK) != 0; // discover sign + long fractBits = dBits & DoubleConsts::SIGNIF_BIT_MASK; + int binExp = static_cast((dBits & DoubleConsts::EXP_BIT_MASK) >> EXP_SHIFT); + + auto setValueAndGetSize = [&result](const std::string &inputString) -> std::size_t { + auto size = inputString.size(); + memcpy_s(result, size, inputString.c_str(), size); + return size; + }; + // Discover obvious special cases of NaN and Infinity. + if (binExp == static_cast(DoubleConsts::EXP_BIT_MASK >> EXP_SHIFT)) { + if (fractBits == 0L) { + if (isNeg) { + return setValueAndGetSize("-Infinity"); + } + return setValueAndGetSize("Infinity"); + } else { + return setValueAndGetSize("NaN"); + } + } + // Finish unpacking + // Normalize denormalized numbers. Insert assumed high-order bit + // for normalized numbers. Subtract exponent bias. + int nSignificantBits; + if (binExp == 0) { + if (fractBits == 0L) { + if (isNeg) { + return setValueAndGetSize("-0.0"); + } + return setValueAndGetSize("0.0"); + } + int leadingZeros = NumberOfLeadingZeros(fractBits); + int shift = leadingZeros - (63 - EXP_SHIFT); + fractBits <<= shift; + binExp = 1 - shift; + // recall binExp is - shift count. + nSignificantBits = 64 - leadingZeros; + } else { + fractBits |= FRACT_HOB; + nSignificantBits = EXP_SHIFT + 1; + } + binExp -= DoubleConsts::EXP_BIAS; + DoubleToString buf = DoubleToString(); + buf.setSign(isNeg); + // call the routine that actually does all the hard work. + buf.Dtoa(binExp, fractBits, nSignificantBits, true); + return buf.ToString(result); +} + +// just for ut test +std::string DoubleToString::DoubleToStringConverter(double d) +{ + char result[MAX_DATA_LENGTH]; + auto length = DoubleToStringConverter(d, result); + return std::string{result, length}; +} +} \ No newline at end of file diff --git a/core/src/codegen/functions/dtoa.h b/core/src/codegen/functions/dtoa.h new file mode 100644 index 0000000..b8bc900 --- /dev/null +++ b/core/src/codegen/functions/dtoa.h @@ -0,0 +1,410 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * Description: registry function implementation + */ + +#ifndef OMNI_RUNTIME_DTOA_H +#define OMNI_RUNTIME_DTOA_H + +#include +#include +#include +#include +#include +#include +#include +#include "util/omni_exception.h" +#include "util/compiler_util.h" + +namespace omniruntime::codegen::function { +constexpr std::size_t MAX_DATA_LENGTH = 27; + +class FDBigInteger { +public: + FDBigInteger() = default; + + FDBigInteger(int *data, int offset, int length) + { + this->offset = offset; + this->nWords = length; + this->length = length; + memcpy_s(this->data, nWords * sizeof(int), data, nWords * sizeof(int)); + this->TrimLeadingZeros(); + } + + FDBigInteger(long lValue, const char *digits, int kDigits, int nDigits); + + ~FDBigInteger() = default; + + void Resize(int size) + { + nWords = size; + } + + /** + * Removes all leading zeros from this FDBigInteger adjusting + * the offset and number of non-zero leading words accordingly. + */ + void TrimLeadingZeros(); + + void MakeImmutable(); + + /** + * Compares the parameter with this FDBigInteger. Returns an + * integer accordingly as: + *
+     * >0: this > other
+     *  0: this == other
+     * <0: this < other
+     * 
+ * + * @param other The FDBigInteger to compare. + * @return A negative value, zero, or a positive value according to the + * result of the comparison. + */ + int Cmp(const FDBigInteger &other) const; + + /** + * Multiplies by a constant value a big integer represented as an array. + * The constant factor is an int. + * + * @param src The array representation of the big integer. + * @param srcLen The number of elements of src to use. + * @param value The constant factor by which to multiply. + * @param dst The product array. + */ + ALWAYS_INLINE static void Mul(const int *src, int srcLen, int value, int *dst); + + /** + * Multiplies this FDBigInteger by an integer. + * + * @param i The factor by which to multiply this FDBigInteger. + * @return This FDBigInteger multiplied by an integer. + */ + FDBigInteger Mul(int i); + + /** + * Shifts this FDBigInteger to the left. The shift is performed + * in-place unless the FDBigInteger is immutable in which case + * a new instance of FDBigInteger is returned. + * + * @param shift The number of bits to shift left. + * @return The shifted FDBigInteger. + */ + FDBigInteger LeftShift(int shift); + + /** + * Computes 5 raised to a given power. + * + * @param p The exponent of 5. + * @return 5p. + */ + static FDBigInteger Big5PowRec(int p); + + /** + * Computes 5 raised to a given power. + * + * @param p The exponent of 5. + * @return 5p. + */ + ALWAYS_INLINE static FDBigInteger Big5Pow(int p); + + /** + * Returns an FDBigInteger with the numerical value + * 5p5 * 2p2. + * + * @param p5 The exponent of the power-of-five factor. + * @param p2 The exponent of the power-of-two factor. + * @return 5p5 * 2p2 + */ + ALWAYS_INLINE static FDBigInteger ValueOfPow52(int p5, int p2); + + /** + * Compares this FDBigInteger with x + y. Returns a + * value according to the comparison as: + *
+     * -1: this <  x + y
+     *  0: this == x + y
+     *  1: this >  x + y
+     * 
+ * @param x The first addend of the sum to compare. + * @param y The second addend of the sum to compare. + * @return -1, 0, or 1 according to the result of the comparison. + */ + int AddAndCmp(const FDBigInteger &x, const FDBigInteger &y); + + /** + * Multiplies this FDBigInteger by 10. The operation will be + * performed in place unless the FDBigInteger is immutable in + * which case a new FDBigInteger will be returned. + * + * @return The FDBigInteger multiplied by 10. + */ + FDBigInteger MulBy10(); + + /** + * Retrieves the normalization bias of the FDBigInteger. The + * normalization bias is a left shift such that after it the highest word + * of the value will have the 4 highest bits equal to zero: + * (highestWord & 0xf0000000) == 0, but the next bit should be 1 + * (highestWord & 0x08000000) != 0. + * + * @return The normalization bias. + */ + int GetNormalizationBias() const; + + /** + * Computes + *
+     * q = (int)( this / S )
+     * this = 10 * ( this mod S )
+     * Return q.
+     * 
+ * This is the iteration step of digit development for output. + * We assume that S has been normalized, as above, and that + * "this" has been left-shifted accordingly. + * Also assumed, of course, is that the result, q, can be expressed + * as an integer, 0 <= q < 10. + * + * @param The divisor of this FDBigInteger. + * @return q = (int)(this / S). + */ + int QuoRemIteration(FDBigInteger &s); + + /** + * Returns an FDBigInteger with the numerical value + * value * 5p5 * 2p2. + * + * @param value The constant factor. + * @param p5 The exponent of the power-of-five factor. + * @param p2 The exponent of the power-of-two factor. + * @return value * 5p5 * 2p2 + */ + ALWAYS_INLINE static FDBigInteger ValueOfMulPow52(long value, int p5, int p2); + + int *GetData() + { + return data; + } + + constexpr static int SMALL_5_POW[] = { + 1, 5, 25, 125, 625, 3125, 15625, 78125, 390625, 1953125, 9765625, 48828125, 244140625, 1220703125}; + + constexpr static int SMALL_5_POW_SIZE = 14; + + constexpr static long LONG_5_POW[] = { + 1L, 5L, 25L, 125L, 625L, 3125L, 15625L, 78125L, 390625L, 1953125L, 9765625L, 48828125L, 244140625L, 1220703125L, + 6103515625L, 30517578125L, 152587890625L, 762939453125L, 3814697265625L, 19073486328125L, 95367431640625L, + 476837158203125L, 2384185791015625L, 11920928955078125L, 59604644775390625L, 298023223876953125L, + 1490116119384765625L}; + + constexpr static long LONG_5_POW_SIZE = 27; + + constexpr static int MAX_FIVE_POW = 340; + static std::vector POW_5_CACHE; + static FDBigInteger ZERO; + constexpr static long LONG_MASK = 4294967295L; +private: + int Size() const + { + return nWords + offset; + } + + int DataSize() const + { + return static_cast(this->length); + } + + /** + * Determines whether all elements of an array are zero for all indices less + * than a given index. + * + * @param a The array to be examined. + * @param from The index strictly below which elements are to be examined. + * @return Zero if all elements in range are zero, 1 otherwise. + */ + static int CheckZeroTail(const int *a, int from); + + /** + * Returns an FDBigInteger with the numerical value + * 2p2. + * + * @param p2 The exponent of 2. + * @return 2p2 + */ + ALWAYS_INLINE static FDBigInteger ValueOfPow2(int p2) + { + int wordcount = p2 >> 5; + int bitCount = p2 & 0x1f; + int temp[] = {1 << bitCount}; + return {temp, wordcount, 1}; + } + + FDBigInteger Add(const FDBigInteger &other); + + ALWAYS_INLINE static int MulAndCarryBy10(const int *src, int srcLen, int *dst); + + void UpdateDataVector(int *newData) + { + memcpy_s(this->data, nWords * sizeof(int), newData, nWords * sizeof(int)); + } + + /** + * Multiplies the parameters and subtracts them from this + * FDBigInteger. + * + * @param q The integer parameter. + * @param s The FDBigInteger parameter. + * @return this - q*S. + */ + long MulDiffMe(long q, FDBigInteger &s); + + /** + * Multiplies by a constant value a big integer represented as an array. + * The constant factor is a long represent as two ints. + * + * @param src The array representation of the big integer. + * @param srcLen The number of elements of src to use. + * @param v0 The lower 32 bits of the long factor. + * @param v1 The upper 32 bits of the long factor. + * @param dst The product array. + */ + static void Mul(const int *src, int srcLen, int v0, int v1, int *dst); + + /** + * Left shifts the contents of one int array into another. + * + * @param src The source array. + * @param idx The initial index of the source array. + * @param result The destination array. + * @param bitCount The left shift. + * @param antiCount The left anti-shift, e.g., 32-bitCount. + * @param prev The prior source value. + */ + static void LeftShift(const int *src, int idx, int *result, int bitCount, int antiCount, int prev); + + /** + * Multiplies two big integers represented as int arrays. + * + * @param s1 The first array factor. + * @param s1Len The number of elements of s1 to use. + * @param s2 The second array factor. + * @param s2Len The number of elements of s2 to use. + * @param dst The product array. + */ + static void Mul(const int *s1, int s1Len, const int *s2, int s2Len, int *dst); + + /** + * Multiplies this FDBigInteger by another FDBigInteger. + * + * @param other The FDBigInteger factor by which to multiply. + * @return The product of this and the parameter FDBigIntegers. + */ + FDBigInteger Mul(FDBigInteger other); + + void MulAddMe(int iv, int addend); + + int data[MAX_DATA_LENGTH]{0}; + int offset = 0; + int nWords = 0; + int length = 0; + bool isImmutable = false; +}; + +class DoubleConsts { +public: + static constexpr int EXP_BIAS = 1023; + static constexpr long SIGNIF_BIT_MASK = 0x000FFFFFFFFFFFFFL; + static constexpr long EXP_BIT_MASK = 0x7FF0000000000000L; + static constexpr long SIGN_BIT_MASK = 0x8000000000000000L; +}; + +class DoubleToString { +public: + DoubleToString() = default; + + void setSign(bool value) + { + this->isNegative = value; + } + + /** + * This is the easy subcase + * all the significant bits, after scaling, are held in lvalue.negSign and decExponent tell us + * what processing and scaling has already been done. Exceptional cases have already been stripped out. + * In particular: + * lvalue is a finite number (not Inf, nor NaN) + * lvalue > 0L (not zero, nor negative). + * + * The only reason that we develop the digits here, rather than calling on Long.toString() is + * that we can do it a little faster,and besides want to treat trailing 0s specially. + * If Long.toString changes, we should re-evaluate this strategy! + */ + void DevelopLongDigits(int exponent, long leftValue, int insignificantDigits); + + /** + * Calculates + *
+    * InsignificantDigitsForPow2(v) == insignificantDigits(1L<
+    */
+    static int InsignificantDigitsForPow2(int p2);
+
+    /**
+     * Estimate decimal exponent. (If it is small-ish, we could double-check.)
+     *
+     * First, scale the mantissa bits such that 1 <= d2 < 2. We are then going to estimate
+     *     log10(d2) ~=~  (d2-1.5)/1.5 + log(1.5)
+     * and so we can estimate
+     *     log10(d) ~=~ log10(d2) + binExp * log10(2)
+     * take the floor and call it decExp.
+     */
+    static int EstimateDecExp(long fractBits, int binExp);
+
+    void Dtoa(int binExp, long fractBits, int nSignificantBits, bool isCompatibleFormat);
+
+    void Roundup();
+
+    int GetChars(char *result) const;
+
+    // just for ut test
+    static std::string DoubleToStringConverter(double d);
+
+    static std::size_t DoubleToStringConverter(double d, char *result);
+
+    std::string ToString()
+    {
+        int len = GetChars(buffer);
+        auto res = std::string(buffer, len);
+        return res;
+    }
+
+    std::size_t ToString(char *result)
+    {
+        return GetChars(result);
+    }
+
+    static constexpr int SIGNIFICAND_WIDTH = 53;
+    static constexpr int MAX_DIGIT_INDEX = 20;
+    static constexpr int EXP_SHIFT = SIGNIFICAND_WIDTH - 1;
+    static constexpr long FRACT_HOB = (1L << EXP_SHIFT);
+    static constexpr int MAX_SMALL_BIN_EXP = 62;
+    static constexpr int MIN_SMALL_BIN_EXP = -(63 / 3);
+    static constexpr long EXP_ONE = ((long)DoubleConsts::EXP_BIAS) << EXP_SHIFT; // exponent of 1.0
+    static constexpr int INSIGNIFICANT_DIGITS_NUMBER[] = {
+        0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 9, 10, 10,
+        10, 11, 11, 11, 12, 12, 12, 12, 13, 13, 13, 14, 14, 14, 15, 15, 15, 15, 16, 16, 16, 17, 17, 17, 18, 18, 18, 19};
+    static constexpr int INSIGNIFICANT_DIGITS_NUMBER_SIZE = 64;
+    static constexpr int N_5_BITS[] = {
+        0, 3, 5, 7, 10, 12, 14, 17, 19, 21, 24, 26, 28, 31, 33, 35, 38, 40, 42, 45, 47, 49, 52, 54, 56, 59, 61};
+    static constexpr int N_5_BITS_SIZE = 27;
+private:
+    bool isNegative;
+    int decExponent;
+    int firstDigitIndex;
+    int nDigits;
+    char digits[MAX_DIGIT_INDEX];
+    char buffer[26];
+};
+}
+#endif // OMNI_RUNTIME_DTOA_H
diff --git a/core/src/codegen/functions/mathfunctions.cpp b/core/src/codegen/functions/mathfunctions.cpp
new file mode 100644
index 0000000..c1bbcb9
--- /dev/null
+++ b/core/src/codegen/functions/mathfunctions.cpp
@@ -0,0 +1,584 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
+ * Description: registry math function name
+ */
+#include "mathfunctions.h"
+#include 
+#include 
+#include "codegen/context_helper.h"
+#include "codegen/common_util.h"
+
+
+#ifdef _WIN32
+#define DLLEXPORT __declspec(dllexport)
+#else
+#define DLLEXPORT
+#endif
+
+const double DOUBLE_NAN = (0.0 / 0.0);
+const uint64_t DOUBLE_BIT_MASK = ((static_cast(1) << (sizeof(double) * 8 - 1)) - 1);
+
+namespace omniruntime::codegen::function {
+static constexpr char DIVIDE_ZERO_EROR[] = "Divided by zero error!";
+
+extern "C" DLLEXPORT int16_t CastInt32ToInt16(int32_t x)
+{
+    return static_cast(x);
+}
+
+extern "C" DLLEXPORT int8_t CastInt32ToInt8(int32_t x)
+{
+    return static_cast(x);
+}
+
+extern "C" DLLEXPORT int16_t CastInt64ToInt16(int64_t x)
+{
+    return static_cast(x);
+}
+
+extern "C" DLLEXPORT int8_t CastInt64ToInt8(int64_t x)
+{
+    return static_cast(x);
+}
+
+extern "C" DLLEXPORT int64_t CastInt32ToInt64(int32_t x)
+{
+    return static_cast(x);
+}
+
+extern "C" DLLEXPORT int32_t CastInt64ToInt32(int64_t x)
+{
+    return static_cast(x);
+}
+
+
+extern "C" DLLEXPORT int32_t CastInt16ToInt32(int16_t x)
+{
+    return static_cast(x);
+}
+
+extern "C" DLLEXPORT int32_t CastInt8ToInt32(int8_t x)
+{
+    return static_cast(x);
+}
+
+extern "C" DLLEXPORT int64_t CastInt16ToInt64(int16_t x)
+{
+    return static_cast(x);
+}
+
+extern "C" DLLEXPORT int64_t CastInt8ToInt64(int8_t x)
+{
+    return static_cast(x);
+}
+
+extern "C" DLLEXPORT double CastInt16ToDouble(int16_t x)
+{
+    return static_cast(x);
+}
+
+extern "C" DLLEXPORT double CastInt8ToDouble(int8_t x)
+{
+    return static_cast(x);
+}
+
+extern "C" DLLEXPORT double CastInt32ToDouble(int32_t x)
+{
+    return static_cast(x);
+}
+
+extern "C" DLLEXPORT double CastInt64ToDouble(int64_t x)
+{
+    return static_cast(x);
+}
+
+extern "C" DLLEXPORT int32_t CastDoubleToInt32Down(double x)
+{
+    return static_cast(x);
+}
+
+extern "C" DLLEXPORT int16_t CastDoubleToInt16Down(double x)
+{
+    return static_cast(x);
+}
+
+extern "C" DLLEXPORT int8_t CastDoubleToInt8Down(double x)
+{
+    return static_cast(x);
+}
+
+extern "C" DLLEXPORT int64_t CastDoubleToInt64Down(double x)
+{
+    return static_cast(x);
+}
+
+extern "C" DLLEXPORT int32_t CastDoubleToInt32HalfUp(double x)
+{
+    return static_cast(Round(x, 0));
+}
+
+extern "C" DLLEXPORT int16_t CastDoubleToInt16HalfUp(double x)
+{
+    return static_cast(Round(x, 0));
+}
+
+extern "C" DLLEXPORT int8_t CastDoubleToInt8HalfUp(double x)
+{
+    return static_cast(Round(x, 0));
+}
+
+extern "C" DLLEXPORT int64_t CastDoubleToInt64HalfUp(double x)
+{
+    return static_cast(Round(x, 0));
+}
+
+// double functions
+
+extern "C" DLLEXPORT double AddDouble(double left, double right)
+{
+    return left + right;
+}
+
+extern "C" DLLEXPORT double SubtractDouble(double left, double right)
+{
+    return left - right;
+}
+
+extern "C" DLLEXPORT double MultiplyDouble(double left, double right)
+{
+    return left * right;
+}
+
+extern "C" DLLEXPORT double DivideDouble(bool *isNull, double divident, double divisor)
+{
+    if (divisor == 0) {
+        *isNull = true;
+        return 0;
+    }
+    return divident / divisor;
+}
+
+extern "C" DLLEXPORT double ModulusDouble(bool *isNull, double divident, double divisor)
+{
+    if (divisor == 0) {
+        *isNull = true;
+        return 0;
+    }
+    return std::fmod(divident, divisor);
+}
+
+extern "C" DLLEXPORT bool LessThanDouble(double left, double right)
+{
+    return left < right;
+}
+
+extern "C" DLLEXPORT bool LessThanEqualDouble(double left, double right)
+{
+    return left <= right;
+}
+
+extern "C" DLLEXPORT bool GreaterThanDouble(double left, double right)
+{
+    return left > right;
+}
+
+extern "C" DLLEXPORT bool GreaterThanEqualDouble(double left, double right)
+{
+    return left >= right;
+}
+
+extern "C" DLLEXPORT bool EqualDouble(double left, double right)
+{
+    return std::fabs(left - right) < DBL_EPSILON;
+}
+
+extern "C" DLLEXPORT bool NotEqualDouble(double left, double right)
+{
+    return std::fabs(left - right) >= DBL_EPSILON;
+}
+
+extern "C" DLLEXPORT double NormalizeNaNAndZero(double value)
+{
+    if (std::isnan(value)) {
+        return DOUBLE_NAN;
+    }
+    union {
+        uint64_t l;
+        double d;
+    } u;
+    u.d = value;
+    if (u.l & DOUBLE_BIT_MASK) {
+        return value;
+    }
+    return 0.0;
+}
+
+extern "C" DLLEXPORT double PowerDouble(double base, double exponent)
+{
+    return pow(base, exponent);
+}
+
+// long functions
+
+extern "C" DLLEXPORT int64_t AddInt64(int64_t left, int64_t right)
+{
+    return left + right;
+}
+
+extern "C" DLLEXPORT int64_t SubtractInt64(int64_t left, int64_t right)
+{
+    return left - right;
+}
+
+extern "C" DLLEXPORT int64_t MultiplyInt64(int64_t left, int64_t right)
+{
+    return left * right;
+}
+
+extern "C" DLLEXPORT int64_t DivideInt64(bool *isNull, int64_t divident, int64_t divisor)
+{
+    if (divisor == 0) {
+        *isNull = true;
+        return 0;
+    }
+    return divident / divisor;
+}
+
+extern "C" DLLEXPORT int64_t ModulusInt64(bool *isNull, int64_t divident, int64_t divisor)
+{
+    if (divisor == 0) {
+        *isNull = true;
+        return 0;
+    }
+    return divident % divisor;
+}
+
+extern "C" DLLEXPORT int64_t AddInt64RetNull(bool *isNull, int64_t left, int64_t right)
+{
+    int64_t result;
+    *isNull = __builtin_add_overflow(left, right, &result);
+    return result;
+}
+
+extern "C" DLLEXPORT int64_t SubtractInt64RetNull(bool *isNull, int64_t left, int64_t right)
+{
+    int64_t result;
+    *isNull = __builtin_sub_overflow(left, right, &result);
+    return result;
+}
+
+extern "C" DLLEXPORT int64_t MultiplyInt64RetNull(bool *isNull, int64_t left, int64_t right)
+{
+    int64_t result;
+    *isNull = __builtin_mul_overflow(left, right, &result);
+    return result;
+}
+
+
+extern "C" DLLEXPORT bool LessThanInt64(int64_t left, int64_t right)
+{
+    return left < right;
+}
+
+extern "C" DLLEXPORT bool LessThanEqualInt64(int64_t left, int64_t right)
+{
+    return left <= right;
+}
+
+extern "C" DLLEXPORT bool GreaterThanInt64(int64_t left, int64_t right)
+{
+    return left > right;
+}
+
+extern "C" DLLEXPORT bool GreaterThanEqualInt64(int64_t left, int64_t right)
+{
+    return left >= right;
+}
+
+extern "C" DLLEXPORT bool EqualInt64(int64_t left, int64_t right)
+{
+    return left == right;
+}
+
+extern "C" DLLEXPORT bool NotEqualInt64(int64_t left, int64_t right)
+{
+    return left != right;
+}
+
+// int functions
+
+extern "C" DLLEXPORT int32_t AddInt32(int32_t left, int32_t right)
+{
+    return left + right;
+}
+
+extern "C" DLLEXPORT int32_t SubtractInt32(int32_t left, int32_t right)
+{
+    return left - right;
+}
+
+extern "C" DLLEXPORT int32_t MultiplyInt32(int32_t left, int32_t right)
+{
+    return left * right;
+}
+
+extern "C" DLLEXPORT int32_t DivideInt32(bool *isNull, int32_t divident, int32_t divisor)
+{
+    if (divisor == 0) {
+        *isNull = true;
+        return 0;
+    }
+    return divident / divisor;
+}
+
+extern "C" DLLEXPORT int32_t ModulusInt32(bool *isNull, int32_t divident, int32_t divisor)
+{
+    if (divisor == 0) {
+        *isNull = true;
+        return 0;
+    }
+    return divident % divisor;
+}
+
+extern "C" DLLEXPORT int32_t AddInt32RetNull(bool *isNull, int32_t left, int32_t right)
+{
+    int32_t result;
+    *isNull = __builtin_add_overflow(left, right, &result);
+    return result;
+}
+
+extern "C" DLLEXPORT int32_t SubtractInt32RetNull(bool *isNull, int32_t left, int32_t right)
+{
+    int32_t result;
+    *isNull = __builtin_sub_overflow(left, right, &result);
+    return result;
+}
+
+extern "C" DLLEXPORT int32_t MultiplyInt32RetNull(bool *isNull, int32_t left, int32_t right)
+{
+    int32_t result;
+    *isNull = __builtin_mul_overflow(left, right, &result);
+    return result;
+}
+
+
+extern "C" DLLEXPORT bool LessThanInt32(int32_t left, int32_t right)
+{
+    return left < right;
+}
+
+extern "C" DLLEXPORT bool LessThanEqualInt32(int32_t left, int32_t right)
+{
+    return left <= right;
+}
+
+extern "C" DLLEXPORT bool GreaterThanInt32(int32_t left, int32_t right)
+{
+    return left > right;
+}
+
+extern "C" DLLEXPORT bool GreaterThanEqualInt32(int32_t left, int32_t right)
+{
+    return left >= right;
+}
+
+extern "C" DLLEXPORT bool EqualInt32(int32_t left, int32_t right)
+{
+    return left == right;
+}
+
+extern "C" DLLEXPORT bool NotEqualInt32(int32_t left, int32_t right)
+{
+    return left != right;
+}
+
+extern "C" DLLEXPORT int32_t Pmod(int32_t x, int32_t y)
+{
+    if (y == 0) {
+        return 0;
+    }
+    int32_t r = x % y;
+    if (r < 0) {
+        return (r + y) % y;
+    } else {
+        return r;
+    }
+}
+
+extern "C" DLLEXPORT int64_t RoundLong(int64_t num, int32_t decimals)
+{
+    return RoundOperator(num, decimals);
+}
+}
+
+// short functions
+
+extern "C" DLLEXPORT int16_t AddInt16(int16_t left, int16_t right)
+{
+    return left + right;
+}
+
+extern "C" DLLEXPORT int16_t SubtractInt16(int16_t left, int16_t right)
+{
+    return left - right;
+}
+
+extern "C" DLLEXPORT int16_t MultiplyInt16(int16_t left, int16_t right)
+{
+    return left * right;
+}
+
+extern "C" DLLEXPORT int16_t DivideInt16(bool *isNull, int16_t divident, int16_t divisor)
+{
+    if (divisor == 0) {
+        *isNull = true;
+        return 0;
+    }
+    return divident / divisor;
+}
+
+extern "C" DLLEXPORT int16_t ModulusInt16(bool *isNull, int16_t divident, int16_t divisor)
+{
+    if (divisor == 0) {
+        *isNull = true;
+        return 0;
+    }
+    return divident % divisor;
+}
+
+extern "C" DLLEXPORT int16_t AddInt16RetNull(bool *isNull, int16_t left, int16_t right)
+{
+    int16_t result;
+    *isNull = __builtin_add_overflow(left, right, &result);
+    return result;
+}
+
+extern "C" DLLEXPORT int16_t SubtractInt16RetNull(bool *isNull, int16_t left, int16_t right)
+{
+    int16_t result;
+    *isNull = __builtin_sub_overflow(left, right, &result);
+    return result;
+}
+
+extern "C" DLLEXPORT int16_t MultiplyInt16RetNull(bool *isNull, int16_t left, int16_t right)
+{
+    int16_t result;
+    *isNull = __builtin_mul_overflow(left, right, &result);
+    return result;
+}
+
+extern "C" DLLEXPORT bool LessThanInt16(int16_t left, int16_t right)
+{
+    return left < right;
+}
+
+extern "C" DLLEXPORT bool LessThanEqualInt16(int16_t left, int16_t right)
+{
+    return left <= right;
+}
+
+extern "C" DLLEXPORT bool GreaterThanInt16(int16_t left, int16_t right)
+{
+    return left > right;
+}
+
+extern "C" DLLEXPORT bool GreaterThanEqualInt16(int16_t left, int16_t right)
+{
+    return left >= right;
+}
+
+extern "C" DLLEXPORT bool EqualInt16(int16_t left, int16_t right)
+{
+    return left == right;
+}
+
+extern "C" DLLEXPORT bool NotEqualInt16(int16_t left, int16_t right)
+{
+    return left != right;
+}
+
+// byte functions
+
+extern "C" DLLEXPORT int8_t AddInt8(int8_t left, int8_t right)
+{
+    return left + right;
+}
+
+extern "C" DLLEXPORT int8_t SubtractInt8(int8_t left, int8_t right)
+{
+    return left - right;
+}
+
+extern "C" DLLEXPORT int8_t MultiplyInt8(int8_t left, int8_t right)
+{
+    return left * right;
+}
+
+extern "C" DLLEXPORT int8_t DivideInt8(bool *isNull, int8_t divident, int8_t divisor)
+{
+    if (divisor == 0) {
+        *isNull = true;
+        return 0;
+    }
+    return divident / divisor;
+}
+
+extern "C" DLLEXPORT int8_t ModulusInt8(bool *isNull, int8_t divident, int8_t divisor)
+{
+    if (divisor == 0) {
+        *isNull = true;
+        return 0;
+    }
+    return divident % divisor;
+}
+
+extern "C" DLLEXPORT int8_t AddInt8RetNull(bool *isNull, int8_t left, int8_t right)
+{
+    int8_t result;
+    *isNull = __builtin_add_overflow(left, right, &result);
+    return result;
+}
+
+extern "C" DLLEXPORT int8_t SubtractInt8RetNull(bool *isNull, int8_t left, int8_t right)
+{
+    int8_t result;
+    *isNull = __builtin_sub_overflow(left, right, &result);
+    return result;
+}
+
+extern "C" DLLEXPORT int8_t MultiplyInt8RetNull(bool *isNull, int8_t left, int8_t right)
+{
+    int8_t result;
+    *isNull = __builtin_mul_overflow(left, right, &result);
+    return result;
+}
+
+extern "C" DLLEXPORT bool LessThanInt8(int8_t left, int8_t right)
+{
+    return left < right;
+}
+
+extern "C" DLLEXPORT bool LessThanEqualInt8(int8_t left, int8_t right)
+{
+    return left <= right;
+}
+
+extern "C" DLLEXPORT bool GreaterThanInt8(int8_t left, int8_t right)
+{
+    return left > right;
+}
+
+extern "C" DLLEXPORT bool GreaterThanEqualInt8(int8_t left, int8_t right)
+{
+    return left >= right;
+}
+
+extern "C" DLLEXPORT bool EqualInt8(int8_t left, int8_t right)
+{
+    return left == right;
+}
+
+extern "C" DLLEXPORT bool NotEqualInt8(int8_t left, int8_t right)
+{
+    return left != right;
+}
\ No newline at end of file
diff --git a/core/src/codegen/functions/mathfunctions.h b/core/src/codegen/functions/mathfunctions.h
new file mode 100644
index 0000000..13f6c23
--- /dev/null
+++ b/core/src/codegen/functions/mathfunctions.h
@@ -0,0 +1,270 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
+ * Description: registry math function name
+ */
+#ifndef __MATHFUNCTIONS_H__
+#define __MATHFUNCTIONS_H__
+
+#include 
+#include 
+
+// All extern functions go here temporarily
+#ifdef _WIN32
+#define DLLEXPORT __declspec(dllexport)
+#else
+#define DLLEXPORT
+#endif
+
+namespace omniruntime::codegen::function {
+// Absolute value
+template  extern DLLEXPORT T Abs(T x)
+{
+    return std::abs(x);
+}
+extern "C" DLLEXPORT int16_t CastInt32ToInt16(int32_t x);
+
+extern "C" DLLEXPORT int8_t CastInt32ToInt8(int32_t x);
+
+extern "C" DLLEXPORT int16_t CastInt64ToInt16(int64_t x);
+
+extern "C" DLLEXPORT int8_t CastInt64ToInt8(int64_t x);
+
+extern "C" DLLEXPORT double CastInt32ToDouble(int32_t x);
+
+extern "C" DLLEXPORT double CastInt64ToDouble(int64_t x);
+
+extern "C" DLLEXPORT int64_t CastInt32ToInt64(int32_t x);
+
+extern "C" DLLEXPORT int32_t CastInt64ToInt32(int64_t x);
+
+extern "C" DLLEXPORT int32_t CastInt16ToInt32(int16_t x);
+
+extern "C" DLLEXPORT int32_t CastInt8ToInt32(int8_t x);
+
+extern "C" DLLEXPORT int64_t CastInt16ToInt64(int16_t x);
+
+extern "C" DLLEXPORT int64_t CastInt8ToInt64(int8_t x);
+
+extern "C" DLLEXPORT double CastInt16ToDouble(int16_t x);
+
+extern "C" DLLEXPORT double CastInt8ToDouble(int8_t x);
+
+extern "C" DLLEXPORT int32_t CastDoubleToInt32HalfUp(double x);
+
+extern "C" DLLEXPORT int16_t CastDoubleToInt16HalfUp(double x);
+
+extern "C" DLLEXPORT int8_t CastDoubleToInt8HalfUp(double x);
+
+extern "C" DLLEXPORT int64_t CastDoubleToInt64HalfUp(double x);
+
+extern "C" DLLEXPORT int32_t CastDoubleToInt32Down(double x);
+
+extern "C" DLLEXPORT int16_t CastDoubleToInt16Down(double x);
+
+extern "C" DLLEXPORT int8_t CastDoubleToInt8Down(double x);
+
+extern "C" DLLEXPORT int64_t CastDoubleToInt64Down(double x);
+
+extern "C" DLLEXPORT double CastInt16ToDouble(int16_t x);
+
+extern "C" DLLEXPORT double CastInt8ToDouble(int8_t x);
+
+extern "C" DLLEXPORT int32_t CastInt16ToInt32(int16_t x);
+
+extern "C" DLLEXPORT int32_t CastInt8ToInt32(int8_t x);
+
+extern "C" DLLEXPORT int64_t CastInt16ToInt64(int16_t x);
+
+extern "C" DLLEXPORT int64_t CastInt8ToInt64(int8_t x);
+
+extern "C" DLLEXPORT int32_t CastDoubleToInt32HalfUp(double x);
+
+extern "C" DLLEXPORT int16_t CastDoubleToInt16HalfUp(double x);
+
+extern "C" DLLEXPORT int8_t CastDoubleToInt8HalfUp(double x);
+
+extern "C" DLLEXPORT int64_t CastDoubleToInt64HalfUp(double x);
+
+extern "C" DLLEXPORT int32_t CastDoubleToInt32Down(double x);
+
+extern "C" DLLEXPORT int16_t CastDoubleToInt16Down(double x);
+
+extern "C" DLLEXPORT int8_t CastDoubleToInt8Down(double x);
+
+extern "C" DLLEXPORT int64_t CastDoubleToInt64Down(double x);
+
+// double binary operations
+extern "C" DLLEXPORT double AddDouble(double left, double right);
+
+extern "C" DLLEXPORT double SubtractDouble(double left, double right);
+
+extern "C" DLLEXPORT double MultiplyDouble(double left, double right);
+
+extern "C" DLLEXPORT double DivideDouble(bool *isNull, double divident, double divisor);
+
+extern "C" DLLEXPORT double ModulusDouble(bool *isNull, double divident, double divisor);
+
+extern "C" DLLEXPORT bool LessThanDouble(double left, double right);
+
+extern "C" DLLEXPORT bool LessThanEqualDouble(double left, double right);
+
+extern "C" DLLEXPORT bool GreaterThanDouble(double left, double right);
+
+extern "C" DLLEXPORT bool GreaterThanEqualDouble(double left, double right);
+
+extern "C" DLLEXPORT bool EqualDouble(double left, double right);
+
+extern "C" DLLEXPORT bool NotEqualDouble(double left, double right);
+
+extern "C" DLLEXPORT double NormalizeNaNAndZero(double value);
+
+extern "C" DLLEXPORT double PowerDouble(double base, double exponent);
+
+// long binary operations
+extern "C" DLLEXPORT int64_t AddInt64(int64_t left, int64_t right);
+
+extern "C" DLLEXPORT int64_t SubtractInt64(int64_t left, int64_t right);
+
+extern "C" DLLEXPORT int64_t MultiplyInt64(int64_t left, int64_t right);
+
+extern "C" DLLEXPORT int64_t DivideInt64(bool *isNull, int64_t divident, int64_t divisor);
+
+extern "C" DLLEXPORT int64_t ModulusInt64(bool *isNull, int64_t divident, int64_t divisor);
+
+extern "C" DLLEXPORT int64_t AddInt64RetNull(bool *isNull, int64_t left, int64_t right);
+
+extern "C" DLLEXPORT int64_t SubtractInt64RetNull(bool *isNull, int64_t left, int64_t right);
+
+extern "C" DLLEXPORT int64_t MultiplyInt64RetNull(bool *isNull, int64_t left, int64_t right);
+
+extern "C" DLLEXPORT bool LessThanInt64(int64_t left, int64_t right);
+
+extern "C" DLLEXPORT bool LessThanEqualInt64(int64_t left, int64_t right);
+
+extern "C" DLLEXPORT bool GreaterThanInt64(int64_t left, int64_t right);
+
+extern "C" DLLEXPORT bool GreaterThanEqualInt64(int64_t left, int64_t right);
+
+extern "C" DLLEXPORT bool EqualInt64(int64_t left, int64_t right);
+
+extern "C" DLLEXPORT bool NotEqualInt64(int64_t left, int64_t right);
+
+// int binary operations
+extern "C" DLLEXPORT int32_t AddInt32(int32_t left, int32_t right);
+
+extern "C" DLLEXPORT int32_t SubtractInt32(int32_t left, int32_t right);
+
+extern "C" DLLEXPORT int32_t MultiplyInt32(int32_t left, int32_t right);
+
+extern "C" DLLEXPORT int32_t DivideInt32(bool *isNull, int32_t divident, int32_t divisor);
+
+extern "C" DLLEXPORT int32_t ModulusInt32(bool *isNull, int32_t divident, int32_t divisor);
+
+extern "C" DLLEXPORT int32_t AddInt32RetNull(bool *isNull, int32_t left, int32_t right);
+
+extern "C" DLLEXPORT int32_t SubtractInt32RetNull(bool *isNull, int32_t left, int32_t right);
+
+extern "C" DLLEXPORT int32_t MultiplyInt32RetNull(bool *isNull, int32_t left, int32_t right);
+
+extern "C" DLLEXPORT bool LessThanInt32(int32_t left, int32_t right);
+
+extern "C" DLLEXPORT bool LessThanEqualInt32(int32_t left, int32_t right);
+
+extern "C" DLLEXPORT bool GreaterThanInt32(int32_t left, int32_t right);
+
+extern "C" DLLEXPORT bool GreaterThanEqualInt32(int32_t left, int32_t right);
+
+extern "C" DLLEXPORT bool EqualInt32(int32_t left, int32_t right);
+
+extern "C" DLLEXPORT bool NotEqualInt32(int32_t left, int32_t right);
+
+extern "C" DLLEXPORT int32_t Pmod(int32_t x, int32_t y);
+
+extern "C" DLLEXPORT int64_t RoundLong(int64_t num, int32_t decimals);
+
+// short binary operations
+extern "C" DLLEXPORT int16_t AddInt16(int16_t left, int16_t right);
+
+extern "C" DLLEXPORT int16_t SubtractInt16(int16_t left, int16_t right);
+
+extern "C" DLLEXPORT int16_t MultiplyInt16(int16_t left, int16_t right);
+
+extern "C" DLLEXPORT int16_t DivideInt16(bool *isNull, int16_t divident, int16_t divisor);
+
+extern "C" DLLEXPORT int16_t ModulusInt16(bool *isNull, int16_t divident, int16_t divisor);
+
+extern "C" DLLEXPORT int16_t AddInt16RetNull(bool *isNull, int16_t left, int16_t right);
+
+extern "C" DLLEXPORT int16_t SubtractInt16RetNull(bool *isNull, int16_t left, int16_t right);
+
+extern "C" DLLEXPORT int16_t MultiplyInt16RetNull(bool *isNull, int16_t left, int16_t right);
+
+extern "C" DLLEXPORT bool LessThanInt16(int16_t left, int16_t right);
+
+extern "C" DLLEXPORT bool LessThanEqualInt16(int16_t left, int16_t right);
+
+extern "C" DLLEXPORT bool GreaterThanInt16(int16_t left, int16_t right);
+
+extern "C" DLLEXPORT bool GreaterThanEqualInt16(int16_t left, int16_t right);
+
+extern "C" DLLEXPORT bool EqualInt16(int16_t left, int16_t right);
+
+extern "C" DLLEXPORT bool NotEqualInt16(int16_t left, int16_t right);
+
+// byte binary operations
+extern "C" DLLEXPORT int8_t AddInt8(int8_t left, int8_t right);
+
+extern "C" DLLEXPORT int8_t SubtractInt8(int8_t left, int8_t right);
+
+extern "C" DLLEXPORT int8_t MultiplyInt8(int8_t left, int8_t right);
+
+extern "C" DLLEXPORT int8_t DivideInt8(bool *isNull, int8_t divident, int8_t divisor);
+
+extern "C" DLLEXPORT int8_t ModulusInt8(bool *isNull, int8_t divident, int8_t divisor);
+
+extern "C" DLLEXPORT int8_t AddInt8RetNull(bool *isNull, int8_t left, int8_t right);
+
+extern "C" DLLEXPORT int8_t SubtractInt8RetNull(bool *isNull, int8_t left, int8_t right);
+
+extern "C" DLLEXPORT int8_t MultiplyInt8RetNull(bool *isNull, int8_t left, int8_t right);
+
+extern "C" DLLEXPORT bool LessThanInt8(int8_t left, int8_t right);
+
+extern "C" DLLEXPORT bool LessThanEqualInt8(int8_t left, int8_t right);
+
+extern "C" DLLEXPORT bool GreaterThanInt8(int8_t left, int8_t right);
+
+extern "C" DLLEXPORT bool GreaterThanEqualInt8(int8_t left, int8_t right);
+
+extern "C" DLLEXPORT bool EqualInt8(int8_t left, int8_t right);
+
+extern "C" DLLEXPORT bool NotEqualInt8(int8_t left, int8_t right);
+
+template  extern DLLEXPORT T Round(T num, int32_t decimals)
+{
+    if (std::isnan(num) || std::isinf(num)) {
+        return num;
+    }
+    int32_t tenthPower = 10;
+    double factor = std::pow(tenthPower, decimals);
+    if (num < 0) {
+        return -(std::round(-num * factor) / factor);
+    }
+
+    return std::round(num * factor) / factor;
+}
+
+template  extern DLLEXPORT T Greatest(T lValue, bool lIsNull, T rValue, bool rIsNull, bool *retIsNull)
+{
+    if (lIsNull && rIsNull) {
+        *retIsNull = true;
+        return lValue;
+    }
+    if (lIsNull || (!rIsNull && rValue > lValue)) {
+        return rValue;
+    }
+    return lValue;
+}
+}
+
+#endif
\ No newline at end of file
diff --git a/core/src/codegen/functions/md5.cpp b/core/src/codegen/functions/md5.cpp
new file mode 100644
index 0000000..5318ae1
--- /dev/null
+++ b/core/src/codegen/functions/md5.cpp
@@ -0,0 +1,229 @@
+/*
+ * @Copyright: Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
+ * @Description: md5 function implementations
+*/
+#include "md5.h"
+
+namespace omniruntime::codegen::function {
+using Fun = unsigned int (*)(unsigned int, unsigned int, unsigned int);
+
+static inline unsigned int Fun1(unsigned int x, unsigned int y, unsigned int z)
+{
+    return z ^ (x & (y ^ z));
+}
+
+static inline unsigned int Fun2(unsigned int x, unsigned int y, unsigned int z)
+{
+    return Fun1(z, x, y);
+}
+
+static inline unsigned int Fun3(unsigned int x, unsigned int y, unsigned int z)
+{
+    return x ^ y ^ z;
+}
+
+static inline unsigned int Fun4(unsigned int x, unsigned int y, unsigned int z)
+{
+    return y ^ (x | ~z);
+}
+
+static inline void Md5Fun(Fun f, unsigned int &w, unsigned int x, unsigned int y, unsigned int z,
+    unsigned int data, unsigned int s)
+{
+    w += f(x, y, z) + data;
+    w = w << s | w >> (32 - s);
+    w += x;
+}
+
+/**
+ * The core of the MD5 algorithm, this alters an existing MD5 hash to
+ * reflect the addition of 16 long words of new data.  MD5Update blocks
+ * the data and converts bytes into long words for this routine.
+*/
+static void MD5Transform(unsigned int buf[4], const unsigned int in[16])
+{
+    unsigned int b1 = buf[0];
+    unsigned int b2 = buf[1];
+    unsigned int b3 = buf[2];
+    unsigned int b4 = buf[3];
+
+    Md5Fun(Fun1, b1, b2, b3, b4, in[0] + 0xd76aa478, 7);
+    Md5Fun(Fun1, b4, b1, b2, b3, in[1] + 0xe8c7b756, 12);
+    Md5Fun(Fun1, b3, b4, b1, b2, in[2] + 0x242070db, 17);
+    Md5Fun(Fun1, b2, b3, b4, b1, in[3] + 0xc1bdceee, 22);
+    Md5Fun(Fun1, b1, b2, b3, b4, in[4] + 0xf57c0faf, 7);
+    Md5Fun(Fun1, b4, b1, b2, b3, in[5] + 0x4787c62a, 12);
+    Md5Fun(Fun1, b3, b4, b1, b2, in[6] + 0xa8304613, 17);
+    Md5Fun(Fun1, b2, b3, b4, b1, in[7] + 0xfd469501, 22);
+    Md5Fun(Fun1, b1, b2, b3, b4, in[8] + 0x698098d8, 7);
+    Md5Fun(Fun1, b4, b1, b2, b3, in[9] + 0x8b44f7af, 12);
+    Md5Fun(Fun1, b3, b4, b1, b2, in[10] + 0xffff5bb1, 17);
+    Md5Fun(Fun1, b2, b3, b4, b1, in[11] + 0x895cd7be, 22);
+    Md5Fun(Fun1, b1, b2, b3, b4, in[12] + 0x6b901122, 7);
+    Md5Fun(Fun1, b4, b1, b2, b3, in[13] + 0xfd987193, 12);
+    Md5Fun(Fun1, b3, b4, b1, b2, in[14] + 0xa679438e, 17);
+    Md5Fun(Fun1, b2, b3, b4, b1, in[15] + 0x49b40821, 22);
+
+    Md5Fun(Fun2, b1, b2, b3, b4, in[1] + 0xf61e2562, 5);
+    Md5Fun(Fun2, b4, b1, b2, b3, in[6] + 0xc040b340, 9);
+    Md5Fun(Fun2, b3, b4, b1, b2, in[11] + 0x265e5a51, 14);
+    Md5Fun(Fun2, b2, b3, b4, b1, in[0] + 0xe9b6c7aa, 20);
+    Md5Fun(Fun2, b1, b2, b3, b4, in[5] + 0xd62f105d, 5);
+    Md5Fun(Fun2, b4, b1, b2, b3, in[10] + 0x02441453, 9);
+    Md5Fun(Fun2, b3, b4, b1, b2, in[15] + 0xd8a1e681, 14);
+    Md5Fun(Fun2, b2, b3, b4, b1, in[4] + 0xe7d3fbc8, 20);
+    Md5Fun(Fun2, b1, b2, b3, b4, in[9] + 0x21e1cde6, 5);
+    Md5Fun(Fun2, b4, b1, b2, b3, in[14] + 0xc33707d6, 9);
+    Md5Fun(Fun2, b3, b4, b1, b2, in[3] + 0xf4d50d87, 14);
+    Md5Fun(Fun2, b2, b3, b4, b1, in[8] + 0x455a14ed, 20);
+    Md5Fun(Fun2, b1, b2, b3, b4, in[13] + 0xa9e3e905, 5);
+    Md5Fun(Fun2, b4, b1, b2, b3, in[2] + 0xfcefa3f8, 9);
+    Md5Fun(Fun2, b3, b4, b1, b2, in[7] + 0x676f02d9, 14);
+    Md5Fun(Fun2, b2, b3, b4, b1, in[12] + 0x8d2a4c8a, 20);
+
+    Md5Fun(Fun3, b1, b2, b3, b4, in[5] + 0xfffa3942, 4);
+    Md5Fun(Fun3, b4, b1, b2, b3, in[8] + 0x8771f681, 11);
+    Md5Fun(Fun3, b3, b4, b1, b2, in[11] + 0x6d9d6122, 16);
+    Md5Fun(Fun3, b2, b3, b4, b1, in[14] + 0xfde5380c, 23);
+    Md5Fun(Fun3, b1, b2, b3, b4, in[1] + 0xa4beea44, 4);
+    Md5Fun(Fun3, b4, b1, b2, b3, in[4] + 0x4bdecfa9, 11);
+    Md5Fun(Fun3, b3, b4, b1, b2, in[7] + 0xf6bb4b60, 16);
+    Md5Fun(Fun3, b2, b3, b4, b1, in[10] + 0xbebfbc70, 23);
+    Md5Fun(Fun3, b1, b2, b3, b4, in[13] + 0x289b7ec6, 4);
+    Md5Fun(Fun3, b4, b1, b2, b3, in[0] + 0xeaa127fa, 11);
+    Md5Fun(Fun3, b3, b4, b1, b2, in[3] + 0xd4ef3085, 16);
+    Md5Fun(Fun3, b2, b3, b4, b1, in[6] + 0x04881d05, 23);
+    Md5Fun(Fun3, b1, b2, b3, b4, in[9] + 0xd9d4d039, 4);
+    Md5Fun(Fun3, b4, b1, b2, b3, in[12] + 0xe6db99e5, 11);
+    Md5Fun(Fun3, b3, b4, b1, b2, in[15] + 0x1fa27cf8, 16);
+    Md5Fun(Fun3, b2, b3, b4, b1, in[2] + 0xc4ac5665, 23);
+
+    Md5Fun(Fun4, b1, b2, b3, b4, in[0] + 0xf4292244, 6);
+    Md5Fun(Fun4, b4, b1, b2, b3, in[7] + 0x432aff97, 10);
+    Md5Fun(Fun4, b3, b4, b1, b2, in[14] + 0xab9423a7, 15);
+    Md5Fun(Fun4, b2, b3, b4, b1, in[5] + 0xfc93a039, 21);
+    Md5Fun(Fun4, b1, b2, b3, b4, in[12] + 0x655b59c3, 6);
+    Md5Fun(Fun4, b4, b1, b2, b3, in[3] + 0x8f0ccc92, 10);
+    Md5Fun(Fun4, b3, b4, b1, b2, in[10] + 0xffeff47d, 15);
+    Md5Fun(Fun4, b2, b3, b4, b1, in[1] + 0x85845dd1, 21);
+    Md5Fun(Fun4, b1, b2, b3, b4, in[8] + 0x6fa87e4f, 6);
+    Md5Fun(Fun4, b4, b1, b2, b3, in[15] + 0xfe2ce6e0, 10);
+    Md5Fun(Fun4, b3, b4, b1, b2, in[6] + 0xa3014314, 15);
+    Md5Fun(Fun4, b2, b3, b4, b1, in[13] + 0x4e0811a1, 21);
+    Md5Fun(Fun4, b1, b2, b3, b4, in[4] + 0xf7537e82, 6);
+    Md5Fun(Fun4, b4, b1, b2, b3, in[11] + 0xbd3af235, 10);
+    Md5Fun(Fun4, b3, b4, b1, b2, in[2] + 0x2ad7d2bb, 15);
+    Md5Fun(Fun4, b2, b3, b4, b1, in[9] + 0xeb86d391, 21);
+
+    buf[0] += b1;
+    buf[1] += b2;
+    buf[2] += b3;
+    buf[3] += b4;
+}
+
+static inline void ByteReverse(unsigned char *buf, unsigned longs)
+{
+    unsigned int t;
+    do {
+        t = (static_cast(static_cast(buf[3]) << 8 | buf[2]) << 16) |
+            (static_cast(buf[1]) << 8 | buf[0]);
+        *reinterpret_cast(buf) = t;
+        buf += 4;
+    } while (--longs);
+}
+
+void Md5Function::Finish(unsigned char *outDigest)
+{
+    unsigned char *ptr;
+    unsigned bitsCount;
+
+    // Compute number of bytes mod 64
+    bitsCount = (bits[0] >> 3) & 0x3F;
+
+    // Set the first char of padding to 0x80.  This is safe since there is
+    // always at least one byte free
+    ptr = in + bitsCount;
+    *ptr++ = 0x80;
+
+    // Bytes of padding needed to make 64 bytes
+    bitsCount = 64 - 1 - bitsCount;
+
+    // Pad out to 56 mod 64
+    if (bitsCount < 8) {
+        // Two lots of padding:  Pad the first block to 64 bytes
+        memset_s(ptr, bitsCount, 0, bitsCount);
+        ByteReverse(in, 16);
+        MD5Transform(buf, reinterpret_cast(in));
+
+        // Now fill the next block with 56 bytes
+        memset_s(in, 56, 0, 56);
+    } else {
+        // Pad block to 56 bytes
+        memset_s(ptr, bitsCount - 8, 0, bitsCount - 8);
+    }
+    ByteReverse(in, 14);
+
+    // Append length in bits and transform
+    (reinterpret_cast(in))[14] = bits[0];
+    (reinterpret_cast(in))[15] = bits[1];
+    MD5Transform(buf, reinterpret_cast(in));
+    ByteReverse(reinterpret_cast(buf), 4);
+    memcpy_s(outDigest, 16, buf, 16);
+}
+
+void Md5Function::FinishHex(char *outDigest)
+{
+    unsigned char digest[MD5_HASH_LENGTH_BINARY];
+    Finish(digest);
+    DigestToBase16(digest, outDigest);
+}
+
+void Md5Function::MD5Update(const char *data, uint64_t len)
+{
+    unsigned int temp = bits[0];
+    if ((bits[0] = temp + (static_cast(len) << 3)) < temp) {
+        bits[1]++; // Carry from low to high
+    }
+    bits[1] += len >> 29;
+    temp = (temp >> 3) & 0x3f; // Bytes already in shsInfo->data
+
+    // Handle any leading odd-sized chunks
+    if (temp) {
+        unsigned char *p = in + temp;
+
+        temp = 64 - temp;
+        if (len < temp) {
+            memcpy_s(p, len, data, len);
+            return;
+        }
+        memcpy_s(p, temp, data, temp);
+        ByteReverse(in, 16);
+        MD5Transform(buf, reinterpret_cast(in));
+        data += temp;
+        len -= temp;
+    }
+
+    // Process data in 64-byte chunks
+    while (len >= 64) {
+        memcpy_s(in, 64, data, 64);
+        ByteReverse(in, 16);
+        MD5Transform(buf, reinterpret_cast(in));
+        data += 64;
+        len -= 64;
+    }
+
+    // Handle any remaining bytes of data.
+    memcpy_s(in, len, data, len);
+}
+
+void Md5Function::DigestToBase16(const unsigned char *digest, char *zBuf)
+{
+    static char const HEX_CODES[] = "0123456789abcdef";
+    int i, j;
+    for (j = i = 0; i < MD5_HASH_LENGTH_BINARY; i++) {
+        int a = digest[i];
+        zBuf[j++] = HEX_CODES[(a >> 4) & 0xf];
+        zBuf[j++] = HEX_CODES[a & 0xf];
+    }
+}
+}
diff --git a/core/src/codegen/functions/md5.h b/core/src/codegen/functions/md5.h
new file mode 100644
index 0000000..fe81dd5
--- /dev/null
+++ b/core/src/codegen/functions/md5.h
@@ -0,0 +1,39 @@
+/*
+ * @Copyright: Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
+ * @Description: md5 function implementations
+*/
+
+#ifndef OMNI_RUNTIME_MD5_H
+#define OMNI_RUNTIME_MD5_H
+
+#include 
+#include 
+#include 
+
+namespace omniruntime::codegen::function {
+class Md5Function {
+public:
+    static constexpr int MD5_HASH_LENGTH_BINARY = 16;
+
+    Md5Function(const char *data, uint64_t len)
+    {
+        MD5Update(data, len);
+    }
+
+    //! Write the 16-byte (binary) digest to the specified location
+    void Finish(unsigned char *outDigest);
+
+    // Write the 32-character digest (in hexadecimal format) to the specified location
+    void FinishHex(char *outDigest);
+
+private:
+    void MD5Update(const char *data, uint64_t len);
+
+    static void DigestToBase16(const unsigned char *digest, char *zBuf);
+
+    unsigned int buf[4] = {0x67452301, 0xefcdab89, 0x98badcfe, 0x10325476};
+    unsigned int bits[2] = {0, 0};
+    unsigned char in[64] = {0};
+};
+}
+#endif // OMNI_RUNTIME_MD5_H
diff --git a/core/src/codegen/functions/mightcontain.cpp b/core/src/codegen/functions/mightcontain.cpp
new file mode 100644
index 0000000..abcb744
--- /dev/null
+++ b/core/src/codegen/functions/mightcontain.cpp
@@ -0,0 +1,21 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved.
+ * Description: registry  function  implementation
+ */
+#include "mightcontain.h"
+#include "codegen/bloom_filter.h"
+
+namespace omniruntime::codegen::function {
+extern "C" DLLEXPORT bool MightContain(int64_t bloomFilterAddr, int64_t hashValue, bool isNull)
+{
+    /*
+     * Limited by the current processing framework, bloomFilterAddr is set to null by the engine when the value of
+     * bloomFilterAddr is 0.
+     */
+    if (!isNull) {
+        auto bloomFilter = reinterpret_cast(bloomFilterAddr);
+        return bloomFilter->MightContainLong(hashValue);
+    }
+    return false;
+}
+}
\ No newline at end of file
diff --git a/core/src/codegen/functions/mightcontain.h b/core/src/codegen/functions/mightcontain.h
new file mode 100644
index 0000000..16874ca
--- /dev/null
+++ b/core/src/codegen/functions/mightcontain.h
@@ -0,0 +1,21 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved.
+ * Description: registry  function  implementation
+ */
+#ifndef OMNI_RUNTIME_MIGHTCONTAIN_H
+#define OMNI_RUNTIME_MIGHTCONTAIN_H
+
+#include 
+#include "xxhash64_hash.h"
+
+namespace omniruntime::codegen::function {
+// All extern functions go here temporarily
+#ifdef _WIN32
+#define DLLEXPORT __declspec(dllexport)
+#else
+#define DLLEXPORT
+#endif
+
+extern "C" DLLEXPORT bool MightContain(int64_t bloomFilterAddr, int64_t hashValue, bool isNull);
+}
+#endif // OMNI_RUNTIME_MIGHTCONTAIN_H
\ No newline at end of file
diff --git a/core/src/codegen/functions/murmur3_hash.cpp b/core/src/codegen/functions/murmur3_hash.cpp
new file mode 100644
index 0000000..e7b819c
--- /dev/null
+++ b/core/src/codegen/functions/murmur3_hash.cpp
@@ -0,0 +1,269 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
+ * Description: Murmur3 Hash function
+ */
+#include "util/compiler_util.h"
+#include "type/decimal128_utils.h"
+#include "murmur3_hash.h"
+
+namespace omniruntime::codegen::function {
+static const int COMBINE_HASH_VALUE = 31;
+static const uint32_t MM3_C1 = 0xcc9e2d51;
+static const uint32_t MM3_C2 = 0x1b873593;
+
+static const uint32_t MM3_BITS_INT = 32;
+
+static const uint32_t MIXK1_ROTATE_LEFT_NUM = 15;
+
+static const uint32_t MIXH1_ROTATE_LEFT_NUM = 13;
+static const uint32_t MIXH1_MULTIPLY_M = 5;
+static const uint32_t MIXH1_ADD_N = 0xe6546b64;
+
+static const uint32_t FMIX_RIGHT_SHIFT_M = 16;
+static const uint32_t FMIX_RIGHT_SHIFT_N = 13;
+static const uint32_t FMIX_MULTIPLY_M = 0x85ebca6b;
+static const uint32_t FMIX_MULTIPLY_N = 0xc2b2ae35;
+
+static const uint32_t HASH_LONG_RIGHT_SHIFT = 32;
+
+static const uint32_t MM3_SIZE_BYTE = 1;
+static const uint32_t MM3_SIZE_SHORT = 2;
+static const uint32_t MM3_SIZE_INT = 4;
+static const uint32_t MM3_SIZE_LONG = 8;
+
+static const uint32_t MM3_INT_ONE = 1;
+
+static const uint32_t REVERSE_SHIFT_M = 24;
+static const uint32_t REVERSE_SHIFT_N = 8;
+static const uint32_t REVERSE_AND_A = 0xff;
+static const uint32_t REVERSE_AND_B = 0xff0000;
+static const uint32_t REVERSE_AND_C = 0xff00;
+static const uint32_t REVERSE_AND_D = 0xff000000;
+
+uint32_t ALWAYS_INLINE RotateLeft(uint32_t i, uint32_t distance)
+{
+    return (i << distance) | (i >> (MM3_BITS_INT - distance));
+}
+
+uint32_t ALWAYS_INLINE MixK1(uint32_t k1)
+{
+    k1 *= MM3_C1;
+    k1 = RotateLeft(k1, MIXK1_ROTATE_LEFT_NUM);
+    k1 *= MM3_C2;
+    return k1;
+}
+
+uint32_t ALWAYS_INLINE MixH1(uint32_t h1, uint32_t k1)
+{
+    h1 ^= k1;
+    h1 = RotateLeft(h1, MIXH1_ROTATE_LEFT_NUM);
+    h1 = h1 * MIXH1_MULTIPLY_M + MIXH1_ADD_N;
+    return h1;
+}
+
+uint32_t ALWAYS_INLINE Fmix(uint32_t h1, uint32_t length)
+{
+    h1 ^= length;
+    h1 ^= h1 >> FMIX_RIGHT_SHIFT_M;
+    h1 *= FMIX_MULTIPLY_M;
+    h1 ^= h1 >> FMIX_RIGHT_SHIFT_N;
+    h1 *= FMIX_MULTIPLY_N;
+    h1 ^= h1 >> FMIX_RIGHT_SHIFT_M;
+    return h1;
+}
+
+bool ALWAYS_INLINE IsBigEndian()
+{
+    union {
+        uint32_t m;
+        char n;
+    } uval = { 0 };
+    uval.m = MM3_INT_ONE;
+    if (uval.n == MM3_INT_ONE) {
+        return false;
+    } else {
+        return true;
+    }
+}
+
+uint32_t ALWAYS_INLINE ReverseBytes(uint32_t x)
+{
+    return ((x >> REVERSE_SHIFT_M) & REVERSE_AND_A) | ((x << REVERSE_SHIFT_N) & REVERSE_AND_B) |
+        ((x >> REVERSE_SHIFT_N) & REVERSE_AND_C) | ((x << REVERSE_SHIFT_M) & REVERSE_AND_D);
+}
+
+uint32_t ALWAYS_INLINE HashBytesByInt(char *base, uint32_t lengthInBytes, uint32_t seed)
+{
+    uint32_t h1 = seed;
+    for (uint i = 0; i < lengthInBytes; i += MM3_SIZE_INT) {
+        uint32_t halfWord = *reinterpret_cast(base + i);
+        if (IsBigEndian()) {
+            halfWord = ReverseBytes(halfWord);
+        }
+        h1 = MixH1(h1, MixK1(halfWord));
+    }
+    return h1;
+}
+
+uint32_t HashShort(uint16_t input, uint32_t seed)
+{
+    uint32_t k1 = static_cast(input);
+    k1 = MixK1(k1);
+    uint32_t h1 = MixH1(seed, k1);
+
+    return Fmix(h1, MM3_SIZE_SHORT);
+}
+
+uint32_t HashByte(uint8_t input, uint32_t seed)
+{
+    uint32_t k1 = static_cast(input);
+    k1 = MixK1(k1);
+    uint32_t h1 = MixH1(seed, k1);
+
+    return Fmix(h1, MM3_SIZE_BYTE);
+}
+
+uint32_t HashInt(uint32_t input, uint32_t seed)
+{
+    uint32_t k1 = MixK1(input);
+    uint32_t h1 = MixH1(seed, k1);
+
+    return Fmix(h1, MM3_SIZE_INT);
+}
+
+uint32_t HashLong(uint64_t input, uint32_t seed)
+{
+    auto low = static_cast(input);
+    auto high = static_cast(input >> HASH_LONG_RIGHT_SHIFT);
+
+    uint32_t k1 = MixK1(low);
+    uint32_t h1 = MixH1(seed, k1);
+
+    k1 = MixK1(high);
+    h1 = MixH1(h1, k1);
+
+    return Fmix(h1, MM3_SIZE_LONG);
+}
+
+uint32_t HashUnsafeBytes(char *base, uint32_t lengthInBytes, uint32_t seed)
+{
+    uint32_t lengthAligned = lengthInBytes - lengthInBytes % MM3_SIZE_INT;
+    uint32_t h1 = HashBytesByInt(base, lengthAligned, seed);
+    for (uint i = lengthAligned; i < lengthInBytes; i++) {
+        auto charVal = *(base + i);
+        auto halfWord = static_cast(charVal);
+        halfWord &= 0x000000FF; // get the lower eight bits
+        uint32_t k1 = MixK1(halfWord);
+        h1 = MixH1(h1, k1);
+    }
+    return Fmix(h1, lengthInBytes);
+}
+
+extern "C" DLLEXPORT int32_t Mm3Int32(int32_t val, bool isValNull, int32_t seed, bool isSeedNull)
+{
+    if (isSeedNull) {
+        seed = 0;
+    }
+    if (isValNull) {
+        return seed;
+    }
+
+    return static_cast(HashInt(static_cast(val * !isValNull), static_cast(seed)));
+}
+
+extern "C" DLLEXPORT int32_t Mm3Int64(int64_t val, bool isValNull, int32_t seed, bool isSeedNull)
+{
+    if (isSeedNull) {
+        seed = 0;
+    }
+    if (isValNull) {
+        return seed;
+    }
+
+    return static_cast(HashLong(static_cast(val * !isValNull), static_cast(seed)));
+}
+
+extern "C" DLLEXPORT int32_t Mm3String(char *val, int32_t valLen, bool isValNull, int32_t seed, bool isSeedNull)
+{
+    if (isSeedNull) {
+        seed = 0;
+    }
+    if (isValNull) {
+        return seed;
+    }
+
+    valLen = valLen * !isValNull;
+    return static_cast(HashUnsafeBytes(val, static_cast(valLen), static_cast(seed)));
+}
+
+extern "C" DLLEXPORT int32_t Mm3Double(double val, bool isValNull, int32_t seed, bool isSeedNull)
+{
+    union {
+        uint64_t lVal;
+        double dVal;
+    } uVal = { 0 };
+    uVal.dVal = val * !isValNull;
+    if (isSeedNull) {
+        seed = 0;
+    }
+    if (isValNull) {
+        return seed;
+    }
+
+    return static_cast(HashLong(uVal.lVal, static_cast(seed)));
+}
+
+extern "C" DLLEXPORT int32_t Mm3Decimal64(int64_t val, int32_t precision, int32_t scale, bool isValNull, int32_t seed,
+    bool isSeedNull)
+{
+    if (isSeedNull) {
+        seed = 0;
+    }
+    if (isValNull) {
+        return seed;
+    }
+
+    return static_cast(HashLong(val * !isValNull, seed));
+}
+
+extern "C" DLLEXPORT int32_t Mm3Decimal128(int64_t xHigh, uint64_t xLow, int32_t precision, int32_t scale,
+    bool isValNull, int32_t seed, bool isSeedNull)
+{
+    if (isSeedNull) {
+        seed = 0;
+    }
+    if (isValNull) {
+        return seed;
+    }
+
+    int32_t byteLen = 0;
+    auto bytes = omniruntime::type::Decimal128Utils::Decimal128ToBytes(xHigh, xLow, byteLen);
+    auto result = static_cast(HashUnsafeBytes(reinterpret_cast(bytes), byteLen, seed));
+    delete[] bytes;
+    return result;
+}
+
+extern "C" DLLEXPORT int32_t Mm3Boolean(bool val, bool isValNull, int32_t seed, bool isSeedNull)
+{
+    if (isSeedNull) {
+        seed = 0;
+    }
+    if (isValNull) {
+        return seed;
+    }
+
+    uint32_t intVal = val ? 1 : 0;
+    return static_cast(HashInt(static_cast(intVal * !isValNull), static_cast(seed)));
+}
+
+extern "C" DLLEXPORT int64_t CombineHash(int64_t prevHashVal, bool isPrevHashValNull, int64_t val, bool isValNull)
+{
+    if (isPrevHashValNull) {
+        prevHashVal = 0;
+    }
+    if (isValNull) {
+        val = 0;
+    }
+    return COMBINE_HASH_VALUE * prevHashVal + val;
+}
+}
\ No newline at end of file
diff --git a/core/src/codegen/functions/murmur3_hash.h b/core/src/codegen/functions/murmur3_hash.h
new file mode 100644
index 0000000..ea8e682
--- /dev/null
+++ b/core/src/codegen/functions/murmur3_hash.h
@@ -0,0 +1,43 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
+ * Description: Murmur3 Hash function
+ */
+#ifndef __MURMUR3HASH_H__
+#define __MURMUR3HASH_H__
+
+#include 
+
+namespace omniruntime::codegen::function {
+// All extern functions go here temporarily
+#ifdef _WIN32
+#define DLLEXPORT __declspec(dllexport)
+#else
+#define DLLEXPORT
+#endif
+
+extern "C" DLLEXPORT int64_t CombineHash(int64_t prevHashVal, bool isPrevHashValNull, int64_t val, bool isValNull);
+
+extern "C" DLLEXPORT int32_t Mm3Int32(int32_t val, bool isValNull, int32_t seed, bool isSeedNull);
+
+extern "C" DLLEXPORT int32_t Mm3Int64(int64_t val, bool isValNull, int32_t seed, bool isSeedNull);
+
+extern "C" DLLEXPORT int32_t Mm3String(char *val, int32_t valLen, bool isValNull, int32_t seed, bool isSeedNull);
+
+extern "C" DLLEXPORT int32_t Mm3Double(double val, bool isValNull, int32_t seed, bool isSeedNull);
+
+extern "C" DLLEXPORT int32_t Mm3Decimal64(int64_t val, int32_t precision, int32_t scale, bool isValNull, int32_t seed,
+    bool isSeedNull);
+
+extern "C" DLLEXPORT int32_t Mm3Decimal128(int64_t xHigh, uint64_t xLow, int32_t precision, int32_t scale,
+    bool isValNull, int32_t seed, bool isSeedNull);
+
+extern "C" DLLEXPORT int32_t Mm3Boolean(bool val, bool isValNull, int32_t seed, bool isSeedNull);
+
+uint32_t HashUnsafeBytes(char *base, uint32_t lengthInBytes, uint32_t seed);
+uint32_t HashLong(uint64_t input, uint32_t seed);
+uint32_t HashInt(uint32_t input, uint32_t seed);
+uint32_t HashShort(uint16_t input, uint32_t seed);
+uint32_t HashByte(uint8_t input, uint32_t seed);
+}
+// OMNI_RUNTIME_MURMUR3_HASH_H
+#endif
\ No newline at end of file
diff --git a/core/src/codegen/functions/stringfunctions.cpp b/core/src/codegen/functions/stringfunctions.cpp
new file mode 100644
index 0000000..aceccf6
--- /dev/null
+++ b/core/src/codegen/functions/stringfunctions.cpp
@@ -0,0 +1,1203 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2024. All rights reserved.
+ * Description: registry  function  implementation
+ */
+#include "stringfunctions.h"
+#include "md5.h"
+#include "dtoa.h"
+#include "type/string_Impl.h"
+
+namespace omniruntime::codegen::function {
+/**
+ * This function is only called when apLen is equal to bpLen. When apLen and bpLen are different,
+ * it will directly return false instead of calling StrEquals.
+ */
+extern "C" DLLEXPORT bool StrEquals(const char *ap, int32_t apLen, const char *bp, int32_t bpLen)
+{
+    for (int i = 0; i < apLen; ++i) {
+        if (ap[i] != bp[i]) {
+            return false;
+        }
+    }
+    return true;
+}
+
+extern "C" DLLEXPORT int32_t StrCompare(const char *ap, int32_t apLen, const char *bp, int32_t bpLen)
+{
+    int min = bpLen;
+    if (apLen < min) {
+        min = apLen;
+    }
+
+    int32_t result = memcmp(ap, bp, min);
+    if (result != 0) {
+        return result;
+    } else {
+        return apLen - bpLen;
+    }
+}
+
+extern "C" DLLEXPORT bool LikeStr(const char *str, int32_t strLen, const char *regexToMatch, int32_t regexLen,
+    bool isNull)
+{
+    if (isNull) {
+        return false;
+    }
+    std::string s = std::string(str, strLen);
+    std::string r = std::string(regexToMatch, regexLen);
+
+    std::wregex re(StringUtil::ToWideString(r));
+    return regex_match(StringUtil::ToWideString(s), re);
+}
+
+extern "C" DLLEXPORT bool LikeChar(const char *str, int32_t strWidth, int32_t strLen, const char *regexToMatch,
+    int32_t regexLen, bool isNull)
+{
+    int32_t paddingCount = strWidth - omniruntime::Utf8Util::CountCodePoints(str, strLen);
+    std::string originalStr;
+    originalStr.reserve(strLen + paddingCount);
+    originalStr.append(str, strLen);
+    for (int i = 0; i < paddingCount; i++) {
+        originalStr.append(" ");
+    }
+    std::string r = std::string(regexToMatch, regexLen);
+    std::wregex re(StringUtil::ToWideString(r));
+    return regex_match(StringUtil::ToWideString(originalStr), re);
+}
+
+extern "C" DLLEXPORT const char *ConcatStrStr(int64_t contextPtr, const char *ap, int32_t apLen, const char *bp,
+    int32_t bpLen, bool isNull, int32_t *outLen)
+{
+    if (isNull) {
+        return nullptr;
+    }
+
+    bool hasErr = false;
+    const char *ret = StringUtil::ConcatStrDiffWidths(contextPtr, ap, apLen, bp, bpLen, &hasErr, outLen);
+    if (hasErr) {
+        SetError(contextPtr, CONCAT_ERR_MSG);
+        return nullptr;
+    }
+    return ret;
+}
+
+extern "C" DLLEXPORT const char *ConcatCharChar(int64_t contextPtr, const char *ap, int32_t aWidth, int32_t apLen,
+    const char *bp, int32_t bWidth, int32_t bpLen, bool isNull, int32_t *outLen)
+{
+    if (isNull) {
+        return nullptr;
+    }
+    bool hasErr = false;
+    const char *ret = StringUtil::ConcatCharDiffWidths(contextPtr, ap, aWidth, apLen, bp, bpLen, &hasErr, outLen);
+    if (hasErr) {
+        SetError(contextPtr, CONCAT_ERR_MSG);
+        return nullptr;
+    }
+    return ret;
+}
+
+extern "C" DLLEXPORT const char *ConcatCharStr(int64_t contextPtr, const char *ap, int32_t aWidth, int32_t apLen,
+    const char *bp, int32_t bpLen, bool isNull, int32_t *outLen)
+{
+    if (isNull) {
+        return nullptr;
+    }
+    bool hasErr = false;
+    const char *ret = StringUtil::ConcatCharDiffWidths(contextPtr, ap, aWidth, apLen, bp, bpLen, &hasErr, outLen);
+    if (hasErr) {
+        SetError(contextPtr, CONCAT_ERR_MSG);
+        return nullptr;
+    }
+    return ret;
+}
+
+extern "C" DLLEXPORT const char *ConcatStrChar(int64_t contextPtr, const char *ap, int32_t apLen, const char *bp,
+    int32_t bWidth, int32_t bpLen, bool isNull, int32_t *outLen)
+{
+    if (isNull) {
+        return nullptr;
+    }
+
+    bool hasErr = false;
+    const char *ret = StringUtil::ConcatStrDiffWidths(contextPtr, ap, apLen, bp, bpLen, &hasErr, outLen);
+    if (hasErr) {
+        SetError(contextPtr, CONCAT_ERR_MSG);
+        return nullptr;
+    }
+    return ret;
+}
+
+extern "C" DLLEXPORT const char *ConcatWsStr(int64_t contextPtr, const char *separator, int32_t separatorLen,
+    bool separatorIsNull, const char *ap, int32_t apLen, const char *bp, int32_t bpLen, bool isNull, int32_t *outLen)
+{
+    if (isNull) {
+        return nullptr;
+    }
+    if (separatorIsNull) {
+        *outLen = 0;
+        isNull = true;
+        return nullptr;
+    }
+
+    bool hasErr = false;
+    const char *ret = StringUtil::ConcatWsStrDiffWidths(contextPtr, separator, separatorLen, ap, apLen, bp, bpLen, &hasErr, outLen);
+    if (hasErr) {
+        SetError(contextPtr, CONCAT_ERR_MSG);
+        return nullptr;
+    }
+    return ret;
+}
+
+extern "C" DLLEXPORT int32_t CastStringToDateNotAllowReducePrecison(int64_t contextPtr, const char *str, int32_t strLen,
+    bool isNull)
+{
+    if (isNull) {
+        return 0;
+    }
+    // Date is in the format 1996-02-28
+    // Doesn't account for leap seconds or daylight savings
+    // Should be ok just for dates
+    int64_t result = 0;
+    std::string s(str, strLen);
+    StringUtil::TrimString(s);
+    if (!regex_match(s, g_dateRegex)) {
+        SetError(contextPtr, "Only support cast date\'YYYY-MM-DD\' to integer");
+        return -1;
+    }
+    if (Date32::StringToDate32(str, strLen, result) != Status::CONVERT_SUCCESS) {
+        SetError(contextPtr, "Value cannot be cast to date: " + std::string(str, strLen));
+        return -1;
+    }
+    return static_cast(result);
+}
+
+extern "C" DLLEXPORT int32_t CastStringToDateAllowReducePrecison(int64_t contextPtr, const char *str, int32_t strLen,
+    bool isNull)
+{
+    if (isNull) {
+        return 0;
+    }
+    // Date is in the format 1996-02-28
+    // Doesn't account for leap seconds or daylight savings
+    // Should be ok just for dates
+    int64_t result = 0;
+    if (Date32::StringToDate32(str, strLen, result) != Status::CONVERT_SUCCESS) {
+        SetError(contextPtr, "Value cannot be cast to date: " + std::string(str, strLen));
+        return -1;
+    }
+    return static_cast(result);
+}
+
+extern "C" DLLEXPORT const char *ToUpperStr(int64_t contextPtr, const char *str, int32_t strLen, bool isNull,
+    int32_t *outLen)
+{
+    if (isNull) {
+        return nullptr;
+    }
+    auto ret = ArenaAllocatorMalloc(contextPtr, strLen);
+    for (int32_t i = 0; i < strLen; i++) {
+        auto currItem = *(str + i);
+        if (currItem >= static_cast('a') && currItem <= static_cast('z')) {
+            *(ret + i) = static_cast(currItem - STEP);
+        } else {
+            *(ret + i) = currItem;
+        }
+    }
+    *outLen = strLen;
+    return ret;
+}
+
+extern "C" DLLEXPORT const char *ToUpperChar(int64_t contextPtr, const char *str, int32_t width, int32_t strLen,
+    bool isNull, int32_t *outLen)
+{
+    if (isNull) {
+        return nullptr;
+    }
+    return ToUpperStr(contextPtr, str, strLen, isNull, outLen);
+}
+
+extern "C" DLLEXPORT const char *ToLowerStr(int64_t contextPtr, const char *str, int32_t strLen, bool isNull,
+    int32_t *outLen)
+{
+    if (isNull) {
+        return nullptr;
+    }
+    auto ret = ArenaAllocatorMalloc(contextPtr, strLen);
+    for (int32_t i = 0; i < strLen; i++) {
+        auto currItem = *(str + i);
+        if (currItem >= static_cast('A') && currItem <= static_cast('Z')) {
+            *(ret + i) = static_cast(currItem + STEP);
+        } else {
+            *(ret + i) = currItem;
+        }
+    }
+    *outLen = strLen;
+    return ret;
+}
+
+extern "C" DLLEXPORT const char *ToLowerChar(int64_t contextPtr, const char *str, int32_t width, int32_t strLen,
+    bool isNull, int32_t *outLen)
+{
+    if (isNull) {
+        return nullptr;
+    }
+    return ToLowerStr(contextPtr, str, strLen, isNull, outLen);
+}
+
+extern "C" DLLEXPORT int64_t LengthChar(const char *str, int32_t width, int32_t strLen, bool isNull)
+{
+    return isNull ? 0 : width;
+}
+
+extern "C" DLLEXPORT int32_t LengthCharReturnInt32(const char *str, int32_t width, int32_t strLen, bool isNull)
+{
+    return isNull ? 0 : width;
+}
+
+extern "C" DLLEXPORT int32_t LengthStrReturnInt32(const char *str, int32_t strLen, bool isNull)
+{
+    return isNull ? 0 : omniruntime::Utf8Util::CountCodePoints(str, strLen);
+}
+
+extern "C" DLLEXPORT int64_t LengthStr(const char *str, int32_t strLen, bool isNull)
+{
+    return isNull ? 0 : omniruntime::Utf8Util::CountCodePoints(str, strLen);
+}
+
+extern "C" DLLEXPORT const char *ReplaceStrStrStrWithRepNotReplace(int64_t contextPtr, const char *str, int32_t strLen,
+    const char *searchStr, int32_t searchLen, const char *replaceStr, int32_t replaceLen, bool isNull, int32_t *outLen)
+{
+    if (isNull) {
+        return nullptr;
+    }
+
+    bool hasErr = false;
+    char *ret;
+    if (searchLen == 0) {
+        *outLen = strLen;
+        ret = const_cast(str);
+    } else {
+        auto result = StringUtil::ReplaceWithSearchNotEmpty(contextPtr, str, strLen, searchStr, searchLen, replaceStr,
+            replaceLen, &hasErr, outLen);
+        ret = const_cast(result);
+    }
+
+    if (hasErr) {
+        SetError(contextPtr, REPLACE_ERR_MSG);
+    }
+    return ret;
+}
+
+extern "C" DLLEXPORT const char *ReplaceStrStrStrWithRepReplace(int64_t contextPtr, const char *str, int32_t strLen,
+    const char *searchStr, int32_t searchLen, const char *replaceStr, int32_t replaceLen, bool isNull, int32_t *outLen)
+{
+    if (isNull) {
+        return nullptr;
+    }
+
+    bool hasErr = false;
+    char *ret;
+    if (searchLen == 0) {
+        auto result =
+            StringUtil::ReplaceWithSearchEmpty(contextPtr, str, strLen, replaceStr, replaceLen, &hasErr, outLen);
+        ret = (const_cast(result));
+    } else {
+        auto result = StringUtil::ReplaceWithSearchNotEmpty(contextPtr, str, strLen, searchStr, searchLen, replaceStr,
+            replaceLen, &hasErr, outLen);
+        ret = const_cast(result);
+    }
+
+    if (hasErr) {
+        SetError(contextPtr, REPLACE_ERR_MSG);
+    }
+    return ret;
+}
+
+extern "C" DLLEXPORT const char *ReplaceStrStrWithoutRepNotReplace(int64_t contextPtr, const char *str, int32_t strLen,
+    const char *searchStr, int32_t searchLen, bool isNull, int32_t *outLen)
+{
+    if (isNull) {
+        return nullptr;
+    }
+    return ReplaceStrStrStrWithRepNotReplace(contextPtr, str, strLen, searchStr, searchLen, "", 0, isNull, outLen);
+}
+
+extern "C" DLLEXPORT const char *ReplaceStrStrWithoutRepReplace(int64_t contextPtr, const char *str, int32_t strLen,
+    const char *searchStr, int32_t searchLen, bool isNull, int32_t *outLen)
+{
+    if (isNull) {
+        return nullptr;
+    }
+    return ReplaceStrStrStrWithRepReplace(contextPtr, str, strLen, searchStr, searchLen, "", 0, isNull, outLen);
+}
+
+// Cast numeric type to std::string
+extern "C" DLLEXPORT const char *CastIntToString(int64_t contextPtr, int32_t value, bool isNull, int32_t *outLen)
+{
+    if (isNull) {
+        return nullptr;
+    }
+    std::string str = std::to_string(value);
+    *outLen = static_cast(str.size());
+    auto ret = ArenaAllocatorMalloc(contextPtr, *outLen);
+    errno_t res = memcpy_s(ret, *outLen, str.c_str(), *outLen);
+    if (res != EOK) {
+        SetError(contextPtr, "cast failed");
+        *outLen = 0;
+        return nullptr;
+    }
+    return ret;
+}
+
+extern "C" DLLEXPORT const char *CastInt16ToString(int64_t contextPtr, int16_t value, bool isNull, int32_t *outLen)
+{
+    return CastIntToString(contextPtr, static_cast(value), isNull, outLen);
+}
+
+extern "C" DLLEXPORT const char *CastInt8ToString(int64_t contextPtr, int8_t value, bool isNull, int32_t *outLen)
+{
+    return CastIntToString(contextPtr, static_cast(value), isNull, outLen);
+}
+
+extern "C" DLLEXPORT const char *CastLongToString(int64_t contextPtr, int64_t value, bool isNull, int32_t *outLen)
+{
+    if (isNull) {
+        return nullptr;
+    }
+    std::string str = std::to_string(value);
+    *outLen = static_cast(strlen(str.c_str()));
+    auto ret = ArenaAllocatorMalloc(contextPtr, *outLen);
+    errno_t res = memcpy_s(ret, *outLen, str.c_str(), *outLen);
+    if (res != EOK) {
+        SetError(contextPtr, "cast failed");
+        *outLen = 0;
+        return nullptr;
+    }
+    return ret;
+}
+
+extern "C" DLLEXPORT const char *CastDoubleToString(int64_t contextPtr, double value, bool isNull, int32_t *outLen)
+{
+    if (isNull) {
+        return nullptr;
+    }
+    auto ret = ArenaAllocatorMalloc(contextPtr, MAX_DATA_LENGTH);
+    *outLen = static_cast(DoubleToString::DoubleToStringConverter(value, ret));
+    return ret;
+}
+
+extern "C" DLLEXPORT const char *CastDecimal64ToString(int64_t contextPtr, int64_t x, int32_t precision, int32_t scale,
+    bool isNull, int32_t *outLen)
+{
+    if (isNull) {
+        return nullptr;
+    }
+    std::string str = Decimal64(x).SetScale(scale).ToString();
+    *outLen = static_cast(str.size());
+    auto ret = ArenaAllocatorMalloc(contextPtr, *outLen);
+    errno_t res = memcpy_s(ret, *outLen, str.c_str(), *outLen);
+    if (res != EOK) {
+        SetError(contextPtr, "cast failed");
+        *outLen = 0;
+        return nullptr;
+    }
+    return ret;
+}
+
+extern "C" DLLEXPORT const char *CastDecimal128ToString(int64_t contextPtr, int64_t high, uint64_t low,
+    int32_t precision, int32_t scale, bool isNull, int32_t *outLen)
+{
+    if (isNull) {
+        return nullptr;
+    }
+    std::string stringDecimal = Decimal128Wrapper(high, low).SetScale(scale).ToString();
+    *outLen = static_cast(stringDecimal.length());
+    auto ret = ArenaAllocatorMalloc(contextPtr, *outLen);
+    errno_t res = memcpy_s(ret, *outLen, stringDecimal.c_str(), *outLen);
+    if (res != EOK) {
+        SetError(contextPtr, "cast failed");
+        *outLen = 0;
+        return nullptr;
+    }
+    return ret;
+}
+
+extern "C" DLLEXPORT const char *CastStrWithDiffWidths(int64_t contextPtr, const char *srcStr, int32_t srcLen,
+    int32_t srcWidth, bool isNull, int32_t dstWidth, int32_t *outLen)
+{
+    if (isNull) {
+        return nullptr;
+    }
+    bool hasErr = false;
+    const char *ret = StringUtil::CastStrStr(&hasErr, srcStr, srcWidth, srcLen, outLen, dstWidth);
+    if (hasErr) {
+        std::ostringstream errorMessage;
+        errorMessage << "cast varchar[" << srcWidth << "] to varchar[" << dstWidth << "] failed.";
+        SetError(contextPtr, errorMessage.str());
+    }
+    return ret;
+}
+
+// Cast std::string to numeric type
+extern "C" DLLEXPORT int16_t CastStringToShort(int64_t contextPtr, const char *str, int32_t strLen, bool isNull)
+{
+    if (isNull) {
+        return 0;
+    }
+    int16_t result;
+    Status status = ConvertStringToInteger(result, str, strLen);
+    if (status == Status::IS_NOT_A_NUMBER) {
+        std::string s(str, strLen);
+        std::ostringstream errorMessage;
+        errorMessage << "Cannot cast '" << s << "' to INTEGER. Value is not a number.";
+        SetError(contextPtr, errorMessage.str());
+        return 0;
+    }
+
+    if (status == Status::CONVERT_OVERFLOW) {
+        std::string s(str, strLen);
+        std::ostringstream errorMessage;
+        errorMessage << "Cannot cast '" << s << "' to INTEGER. Value too large or too small.";
+        SetError(contextPtr, errorMessage.str());
+        return 0;
+    }
+    return result;
+}
+
+extern "C" DLLEXPORT int8_t CastStringToByte(int64_t contextPtr, const char *str, int32_t strLen, bool isNull)
+{
+    if (isNull) {
+        return 0;
+    }
+    int8_t result = 0;
+    Status status = ConvertStringToInteger(result, str, strLen);
+    std::ostringstream errorMessage;
+    SetError(contextPtr, errorMessage.str());
+    if (status == Status::IS_NOT_A_NUMBER) {
+        std::string s(str, strLen);
+        std::ostringstream errorMessage;
+        errorMessage << "Cannot cast '" << s << "' to BYTE. Value is not a number.";
+        SetError(contextPtr, errorMessage.str());
+        return 0;
+    }
+
+    if (status == Status::CONVERT_OVERFLOW) {
+        std::string s(str, strLen);
+        std::ostringstream errorMessage;
+        errorMessage << "Cannot cast '" << s << "' to BYTE. Value too large or too small.";
+        SetError(contextPtr, errorMessage.str());
+        return 0;
+    }
+
+    return result;
+}
+
+extern "C" DLLEXPORT int32_t CastStringToInt(int64_t contextPtr, const char *str, int32_t strLen, bool isNull)
+{
+    if (isNull) {
+        return 0;
+    }
+    int32_t result;
+    Status status = ConvertStringToInteger(result, str, strLen);
+    if (status == Status::IS_NOT_A_NUMBER) {
+        std::string s(str, strLen);
+        std::ostringstream errorMessage;
+        errorMessage << "Cannot cast '" << s << "' to INTEGER. Value is not a number.";
+        SetError(contextPtr, errorMessage.str());
+        return 0;
+    }
+
+    if (status == Status::CONVERT_OVERFLOW) {
+        std::string s(str, strLen);
+        std::ostringstream errorMessage;
+        errorMessage << "Cannot cast '" << s << "' to INTEGER. Value too large or too small.";
+        SetError(contextPtr, errorMessage.str());
+        return 0;
+    }
+    return result;
+}
+
+extern "C" DLLEXPORT int64_t CastStringToLong(int64_t contextPtr, const char *str, int32_t strLen, bool isNull)
+{
+    if (isNull) {
+        return 0;
+    }
+    int64_t result;
+    Status status = ConvertStringToInteger(result, str, strLen);
+    if (status == Status::IS_NOT_A_NUMBER) {
+        std::string s = std::string(str, strLen);
+        std::ostringstream errorMessage;
+        errorMessage << "Cannot cast '" << s << "' to BIGINT. Value is not a number.";
+        SetError(contextPtr, errorMessage.str());
+        return 0;
+    }
+
+    if (status == Status::CONVERT_OVERFLOW) {
+        std::string s = std::string(str, strLen);
+        std::ostringstream errorMessage;
+        errorMessage << "Cannot cast '" << s << "' to BIGINT. Value too large or too small.";
+        SetError(contextPtr, errorMessage.str());
+        return 0;
+    }
+
+    return result;
+}
+
+extern "C" DLLEXPORT double CastStringToDouble(int64_t contextPtr, const char *str, int32_t strLen, bool isNull)
+{
+    if (isNull) {
+        return 0;
+    }
+
+    double result;
+    Status status = ConvertStringToDouble(result, str, strLen);
+    if (status == Status::IS_NOT_A_NUMBER) {
+        std::ostringstream errorMessage;
+        errorMessage << "Cannot cast '" << std::string(str, strLen) << "' to DOUBLE. Value is not a number.";
+        SetError(contextPtr, errorMessage.str());
+        return 0;
+    }
+    if (status == Status::CONVERT_OVERFLOW) {
+        std::ostringstream errorMessage;
+        errorMessage << "Cannot cast '" << std::string(str, strLen) << "' to DOUBLE. Value is not a number.";
+        SetError(contextPtr, errorMessage.str());
+        return 0;
+    }
+    return result;
+}
+
+extern "C" DLLEXPORT int64_t CastStringToDecimal64(int64_t contextPtr, const char *str, int32_t strLen, bool isNull,
+    int32_t outPrecision, int32_t outScale)
+{
+    if (isNull) {
+        return 0;
+    }
+    std::string s = std::string(str, strLen);
+    StringUtil::TrimString(s);
+    if (!regex_match(s, g_decimalRegex)) {
+        std::ostringstream errorMessage;
+        errorMessage << "Cannot cast VARCHAR '" << s << "' to DECIMAL(" << outPrecision << ", " << outScale <<
+            "). Value is not a number.";
+        SetError(contextPtr, errorMessage.str());
+        return 0;
+    }
+    Decimal64 result(s);
+    result.ReScale(outScale);
+    if (result.IsOverflow(outPrecision) != OpStatus::SUCCESS) {
+        std::ostringstream errorMessage;
+        errorMessage << "Cannot cast VARCHAR '" << std::string(str, strLen) << "' to DECIMAL(" << outPrecision <<
+            ", " << outScale << "). Value too large.";
+        SetError(contextPtr, errorMessage.str());
+        return 0;
+    }
+    return result.GetValue();
+}
+
+extern "C" DLLEXPORT int64_t CastStringToDecimal64RoundUp(int64_t contextPtr, const char *str, int32_t strLen,
+    bool isNull, int32_t outPrecision, int32_t outScale)
+{
+    if (isNull) {
+        return 0;
+    }
+    std::string s = std::string(str, strLen);
+    Decimal64 result(s);
+    result.ReScale(outScale);
+    if (result.IsOverflow(outPrecision) == OpStatus::OP_OVERFLOW) {
+        std::ostringstream errorMessage;
+        errorMessage << "Cannot cast VARCHAR '" << std::string(str, strLen) << "' to DECIMAL(" << outPrecision <<
+            ", " << outScale << "). Value too large.";
+        SetError(contextPtr, errorMessage.str());
+        return 0;
+    }
+    if (result.IsOverflow(outPrecision) == OpStatus::FAIL) {
+        std::ostringstream errorMessage;
+        errorMessage << "Cannot cast VARCHAR '" << s << "' to DECIMAL(" << outPrecision << ", " << outScale <<
+            "). Value is not a number.";
+        SetError(contextPtr, errorMessage.str());
+        return 0;
+    }
+    return result.GetValue();
+}
+
+extern "C" DLLEXPORT void CastStringToDecimal128(int64_t contextPtr, const char *str, int32_t strLen, bool isNull,
+    int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr)
+{
+    if (isNull) {
+        return;
+    }
+    std::string s = std::string(str, strLen);
+    StringUtil::TrimString(s);
+    if (!regex_match(s, g_decimalRegex)) {
+        std::ostringstream errorMessage;
+        errorMessage << "Cannot cast VARCHAR '" << s << "' to DECIMAL(" << outPrecision << ", " << outScale <<
+            "). Value is not a number.";
+        SetError(contextPtr, errorMessage.str());
+        return;
+    }
+    Decimal128Wrapper result(s.c_str());
+    result.ReScale(outScale);
+    OpStatus status = result.IsOverflow(outPrecision);
+    if (status != OpStatus::SUCCESS) {
+        SetError(contextPtr, CastErrorMessage(OMNI_VARCHAR, OMNI_DECIMAL128, std::string(str, strLen).c_str(), status,
+            outPrecision, outScale));
+        return;
+    }
+    *outHighPtr = result.HighBits();
+    *outLowPtr = result.LowBits();
+}
+
+extern "C" DLLEXPORT void CastStringToDecimal128RoundUp(int64_t contextPtr, const char *str, int32_t strLen,
+    bool isNull, int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr)
+{
+    if (isNull) {
+        return;
+    }
+    std::string s = std::string(str, strLen);
+    StringUtil::TrimString(s);
+    if (!regex_match(s, g_decimalRegex)) {
+        std::ostringstream errorMessage;
+        errorMessage << "Cannot cast VARCHAR '" << s << "' to DECIMAL(" << outPrecision << ", " << outScale <<
+            "). Value is not a number.";
+        SetError(contextPtr, errorMessage.str());
+        return;
+    }
+    Decimal128Wrapper result(s.c_str());
+    result.ReScale(outScale);
+    OpStatus status = result.IsOverflow(outPrecision);
+    if (status != OpStatus::SUCCESS) {
+        SetError(contextPtr, CastErrorMessage(OMNI_VARCHAR, OMNI_DECIMAL128, std::string(str, strLen).c_str(), status,
+            outPrecision, outScale));
+        return;
+    }
+    *outHighPtr = result.HighBits();
+    *outLowPtr = result.LowBits();
+}
+
+extern "C" DLLEXPORT const char *ConcatStrStrRetNull(int64_t contextPtr, bool *isNull, const char *ap, int32_t apLen,
+    const char *bp, int32_t bpLen, int32_t *outLen)
+{
+    return StringUtil::ConcatStrDiffWidths(contextPtr, ap, apLen, bp, bpLen, isNull, outLen);
+}
+
+extern "C" DLLEXPORT const char *ConcatCharCharRetNull(int64_t contextPtr, bool *isNull, const char *ap, int32_t aWidth,
+    int32_t apLen, const char *bp, int32_t bWidth, int32_t bpLen, int32_t *outLen)
+{
+    return StringUtil::ConcatCharDiffWidths(contextPtr, ap, aWidth, apLen, bp, bpLen, isNull, outLen);
+}
+
+extern "C" DLLEXPORT const char *ConcatCharStrRetNull(int64_t contextPtr, bool *isNull, const char *ap, int32_t aWidth,
+    int32_t apLen, const char *bp, int32_t bpLen, int32_t *outLen)
+{
+    return StringUtil::ConcatCharDiffWidths(contextPtr, ap, aWidth, apLen, bp, bpLen, isNull, outLen);
+}
+
+extern "C" DLLEXPORT const char *ConcatStrCharRetNull(int64_t contextPtr, bool *isNull, const char *ap, int32_t apLen,
+    const char *bp, int32_t bWidth, int32_t bpLen, int32_t *outLen)
+{
+    return StringUtil::ConcatStrDiffWidths(contextPtr, ap, apLen, bp, bpLen, isNull, outLen);
+}
+
+extern "C" DLLEXPORT int32_t CastStringToDateRetNullNotAllowReducePrecison(bool *isNull, const char *str,
+    int32_t strLen)
+{
+    // Date is in the format 1996-02-28
+    // Doesn't account for leap seconds or daylight savings
+    // Should be ok just for dates
+    int64_t result = 0;
+    std::string s(str, strLen);
+    StringUtil::TrimString(s);
+    if (!regex_match(s, g_dateRegex)) {
+        *isNull = true;
+        return -1;
+    }
+    if (Date32::StringToDate32(str, strLen, result) != Status::CONVERT_SUCCESS) {
+        *isNull = true;
+        return -1;
+    }
+    return static_cast(result);
+}
+
+extern "C" DLLEXPORT int32_t CastStringToDateRetNullAllowReducePrecison(bool *isNull, const char *str, int32_t strLen)
+{
+    // Date is in the format 1996-02-28
+    // Doesn't account for leap seconds or daylight savings
+    // Should be ok just for dates
+    int64_t result = 0;
+    if (Date32::StringToDate32(str, strLen, result) != Status::CONVERT_SUCCESS) {
+        *isNull = true;
+        return -1;
+    }
+    return static_cast(result);
+}
+
+extern "C" DLLEXPORT const char *CastIntToStringRetNull(int64_t contextPtr, bool *isNull, int32_t value,
+    int32_t *outLen)
+{
+    std::string str = std::to_string(value);
+    *outLen = static_cast(str.size());
+    auto ret = ArenaAllocatorMalloc(contextPtr, *outLen);
+    errno_t res = memcpy_s(ret, *outLen, str.c_str(), *outLen);
+    if (res != EOK) {
+        *isNull = true;
+        *outLen = 0;
+        return nullptr;
+    }
+    return ret;
+}
+
+extern "C" DLLEXPORT const char *CastInt16ToStringRetNull(int64_t contextPtr, bool *isNull, int16_t value,
+    int32_t *outLen)
+{
+    return CastIntToStringRetNull(contextPtr, isNull, static_cast(value), outLen);
+}
+
+extern "C" DLLEXPORT const char *CastInt8ToStringRetNull(int64_t contextPtr, bool *isNull, int8_t value,
+    int32_t *outLen)
+{
+    return CastIntToStringRetNull(contextPtr, isNull, static_cast(value), outLen);
+}
+
+extern "C" DLLEXPORT const char *CastLongToStringRetNull(int64_t contextPtr, bool *isNull, int64_t value,
+    int32_t *outLen)
+{
+    std::string str = std::to_string(value);
+    *outLen = static_cast(strlen(str.c_str()));
+    auto ret = ArenaAllocatorMalloc(contextPtr, *outLen);
+    errno_t res = memcpy_s(ret, *outLen, str.c_str(), *outLen);
+    if (res != EOK) {
+        *isNull = true;
+        *outLen = 0;
+        return nullptr;
+    }
+    return ret;
+}
+
+extern "C" DLLEXPORT const char *CastDoubleToStringRetNull(int64_t contextPtr, bool *isNull, double value,
+    int32_t *outLen)
+{
+    auto ret = ArenaAllocatorMalloc(contextPtr, MAX_DATA_LENGTH);
+    *outLen = static_cast(DoubleToString::DoubleToStringConverter(value, ret));
+    return ret;
+}
+
+extern "C" DLLEXPORT const char *CastDecimal64ToStringRetNull(int64_t contextPtr, bool *isNull, int64_t x,
+    int32_t precision, int32_t scale, int32_t *outLen)
+{
+    std::string str = Decimal64(x).SetScale(scale).ToString();
+    *outLen = static_cast(str.size());
+    auto ret = ArenaAllocatorMalloc(contextPtr, *outLen);
+    errno_t res = memcpy_s(ret, *outLen, str.c_str(), *outLen);
+    if (res != EOK) {
+        *isNull = true;
+        *outLen = 0;
+        return nullptr;
+    }
+    return ret;
+}
+
+extern "C" DLLEXPORT const char *CastDecimal128ToStringRetNull(int64_t contextPtr, bool *isNull, int64_t high,
+    uint64_t low, int32_t precision, int32_t scale, int32_t *outLen)
+{
+    Decimal128Wrapper inputDecimal(high, low);
+    std::string stringDecimal = inputDecimal.SetScale(scale).ToString();
+    *outLen = static_cast(stringDecimal.length());
+    auto ret = ArenaAllocatorMalloc(contextPtr, *outLen);
+    errno_t res = memcpy_s(ret, *outLen, stringDecimal.c_str(), *outLen);
+    if (res != EOK) {
+        *isNull = true;
+        *outLen = 0;
+        return nullptr;
+    }
+    return ret;
+}
+
+extern "C" DLLEXPORT int8_t CastStringToByteRetNull(bool *isNull, const char *str, int32_t strLen)
+{
+    int8_t result = -6;
+    Status status = ConvertStringToInteger(result, str, strLen);
+    *isNull = status != Status::CONVERT_SUCCESS;
+    return result;
+}
+
+extern "C" DLLEXPORT int16_t CastStringToShortRetNull(bool *isNull, const char *str, int32_t strLen)
+{
+    int16_t result = 0;
+    Status status = ConvertStringToInteger(result, str, strLen);
+    *isNull = status != Status::CONVERT_SUCCESS;
+    return result;
+}
+
+extern "C" DLLEXPORT int32_t CastStringToIntRetNull(bool *isNull, const char *str, int32_t strLen)
+{
+    int32_t result = 0;
+    Status status = ConvertStringToInteger(result, str, strLen);
+    *isNull = status != Status::CONVERT_SUCCESS;
+    return result;
+}
+
+extern "C" DLLEXPORT int64_t CastStringToLongRetNull(bool *isNull, const char *str, int32_t strLen)
+{
+    int64_t result = 0;
+    Status status = ConvertStringToInteger(result, str, strLen);
+    *isNull = status != Status::CONVERT_SUCCESS;
+    return result;
+}
+
+extern "C" DLLEXPORT double CastStringToDoubleRetNull(bool *isNull, const char *str, int32_t strLen)
+{
+    double result;
+    Status status = ConvertStringToDouble(result, str, strLen);
+    if (status != Status::CONVERT_SUCCESS) {
+        *isNull = true;
+        return 0;
+    }
+    return result;
+}
+
+extern "C" DLLEXPORT int64_t CastStringToDecimal64RetNull(bool *isNull, const char *str, int32_t strLen,
+    int32_t outPrecision, int32_t outScale)
+{
+    std::string s = std::string(str, strLen);
+    StringUtil::TrimString(s);
+    if (!regex_match(s, g_decimalRegex)) {
+        *isNull = true;
+        return 0;
+    }
+    Decimal64 result(std::string(str, strLen));
+    result.ReScale(outScale);
+    if (result.IsOverflow(outPrecision) != OpStatus::SUCCESS) {
+        *isNull = true;
+        return 0;
+    }
+    return result.GetValue();
+}
+
+extern "C" DLLEXPORT int64_t CastStringToDecimal64RoundUpRetNull(bool *isNull, const char *str, int32_t strLen,
+    int32_t outPrecision, int32_t outScale)
+{
+    std::string s = std::string(str, strLen);
+    Decimal64 result(std::string(str, strLen));
+    result.ReScale(outScale);
+    if (result.IsOverflow(outPrecision) != OpStatus::SUCCESS) {
+        *isNull = true;
+        return 0;
+    }
+    return result.GetValue();
+}
+
+extern "C" DLLEXPORT void CastStringToDecimal128RetNull(bool *isNull, const char *str, int32_t strLen,
+    int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr)
+{
+    std::string s = std::string(str, strLen);
+    StringUtil::TrimString(s);
+    if (!regex_match(s, g_decimalRegex)) {
+        *isNull = true;
+        return;
+    }
+    Decimal128Wrapper result(s.c_str());
+    result.ReScale(outScale);
+    if (result.IsOverflow(outPrecision) != OpStatus::SUCCESS) {
+        *isNull = true;
+        return;
+    }
+    *outHighPtr = result.HighBits();
+    *outLowPtr = result.LowBits();
+}
+
+extern "C" DLLEXPORT void CastStringToDecimal128RoundUpRetNull(bool *isNull, const char *str, int32_t strLen,
+    int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr)
+{
+    std::string s = std::string(str, strLen);
+    StringUtil::TrimString(s);
+    if (!regex_match(s, g_decimalRegex)) {
+        *isNull = true;
+        return;
+    }
+    Decimal128Wrapper result(s.c_str());
+    result.ReScale(outScale);
+    if (result.IsOverflow(outPrecision) != OpStatus::SUCCESS) {
+        *isNull = true;
+        return;
+    }
+    *outHighPtr = result.HighBits();
+    *outLowPtr = result.LowBits();
+}
+
+extern "C" DLLEXPORT const char *CastStrWithDiffWidthsRetNull(int64_t contextPtr, bool *isNull, const char *srcStr,
+    int32_t srcLen, int32_t srcWidth, int32_t dstWidth, int32_t *outLen)
+{
+    return StringUtil::CastStrStr(isNull, srcStr, srcWidth, srcLen, outLen, dstWidth);
+}
+
+extern "C" DLLEXPORT int32_t InStr(const char *srcStr, int32_t srcLen, const char *subStr, int32_t subLen, bool isNull)
+{
+    // currently return 0 if not found that means 1-based
+    if (isNull || subLen > srcLen) {
+        return 0;
+    }
+    if (subLen == 0) {
+        return 1;
+    }
+
+    int32_t tailPos = srcLen - subLen;
+    int32_t cmpLen = subLen - 1;
+    for (int32_t pos = 0; pos <= tailPos; ++pos) {
+        if (srcStr[pos] == subStr[0] && memcmp(srcStr + pos + 1, subStr + 1, cmpLen) == 0) {
+            auto result = omniruntime::Utf8Util::CountCodePoints(srcStr, pos);
+            return (result + 1);
+        }
+    }
+    return 0;
+}
+
+extern "C" DLLEXPORT bool StartsWithStr(const char *srcStr, int32_t srcLen, const char *matchStr, int32_t matchLen,
+    bool isNull)
+{
+    if (isNull || matchLen > srcLen) {
+        return false;
+    }
+    if (matchLen == 0) {
+        return true;
+    }
+    return memcmp(srcStr, matchStr, matchLen) == 0;
+}
+
+extern "C" DLLEXPORT bool EndsWithStr(const char *srcStr, int32_t srcLen, const char *matchStr, int32_t matchLen,
+    bool isNull)
+{
+    if (isNull || matchLen > srcLen) {
+        return false;
+    }
+    if (matchLen == 0) {
+        return true;
+    }
+    return memcmp(srcStr + srcLen - matchLen, matchStr, matchLen) == 0;
+}
+
+extern "C" DLLEXPORT bool RegexMatch(const char *srcStr, int32_t srcLen, const char *matchStr, int32_t matchLen,
+    bool isNull)
+{
+    if (isNull) {
+        return false;
+    }
+    if (matchLen == 0) {
+        return true;
+    }
+    if (srcLen == 0) {
+        return false;
+    }
+    for (int32_t i = 0; i < srcLen; i++) {
+        char c = srcStr[i];
+        if (c < '0' || c > '9') {
+            return false;
+        }
+    }
+    return true;
+}
+
+extern "C" DLLEXPORT const char *CastDateToStringRetNull(int64_t contextPtr, bool *isNull, int32_t value,
+    int32_t *outLen)
+{
+    Date32 date(value);
+    auto ret = ArenaAllocatorMalloc(contextPtr, MAX_DAY_ONLY_LENGTH);
+    *outLen = static_cast(date.ToString(ret, MAX_DAY_ONLY_LENGTH));
+    return ret;
+}
+
+extern "C" DLLEXPORT const char *CastDateToString(int64_t contextPtr, int32_t value, bool isNull, int32_t *outLen)
+{
+    if (isNull) {
+        return nullptr;
+    }
+    Date32 date(value);
+    auto ret = ArenaAllocatorMalloc(contextPtr, MAX_DAY_ONLY_LENGTH);
+    *outLen = static_cast(date.ToString(ret, MAX_DAY_ONLY_LENGTH));
+    return ret;
+}
+
+extern "C" DLLEXPORT char *Md5Str(int64_t contextPtr, const char *str, int32_t len, bool isNull, int32_t *outLen)
+{
+    if (isNull) {
+        return nullptr;
+    }
+    Md5Function md5(str, len);
+    *outLen = 32;
+    char *mdString = ArenaAllocatorMalloc(contextPtr, *outLen);
+    md5.FinishHex(mdString);
+    return mdString;
+}
+
+extern "C" DLLEXPORT bool ContainsStr(const char *srcStr, int32_t srcLen, const char *matchStr, int32_t matchLen,
+    bool isNull)
+{
+    if (isNull || matchLen > srcLen) {
+        return false;
+    }
+    if (matchLen == 0) {
+        return true;
+    }
+    return StringUtil::StrContainsStr(srcStr, srcLen, matchStr, matchLen);
+}
+
+extern "C" DLLEXPORT const char *GreatestStr(const char *lValue, int32_t lLen, bool lIsNull, const char *rValue,
+    int32_t rLen, bool rIsNull, bool *retIsNull, int32_t *outLen)
+{
+    if (lIsNull && rIsNull) {
+        *retIsNull = true;
+        *outLen = 0;
+        return nullptr;
+    }
+    if (lIsNull) {
+        *outLen = rLen;
+        return rValue;
+    }
+    if (!rIsNull) {
+        int32_t cmpRet = memcmp(lValue, rValue, std::min(lLen, rLen));
+        if (cmpRet < 0 || (cmpRet == 0 && rLen > lLen)) {
+            *outLen = rLen;
+            return rValue;
+        }
+    }
+    *outLen = lLen;
+    return lValue;
+}
+
+extern "C" DLLEXPORT const char *EmptyToNull(const char *str, int32_t len, bool isNull, int32_t *outLen)
+{
+    if (len == 0 || isNull) {
+        *outLen = 0;
+        return nullptr;
+    }
+
+    *outLen = len;
+    return str;
+}
+
+extern "C" DLLEXPORT const char *StaticInvokeVarcharTypeWriteSideCheck(int64_t contextPtr, const char *str, int32_t len,
+    int32_t limit, bool isNull, int32_t *outLen)
+{
+    if (isNull) {
+        *outLen = 0;
+        return nullptr;
+    }
+    int32_t ssLen = StringUtil::NumChars(str, len);
+    if (ssLen <= limit) {
+        *outLen = len;
+        return str;
+    }
+    int32_t numTailSpacesToTrim = ssLen - limit;
+    int32_t endIdx = len - 1;
+    int32_t trimTo = len - numTailSpacesToTrim;
+    while (endIdx >= trimTo && str[endIdx] == 0x20) {
+        endIdx--;
+    }
+    int32_t outByteNum = endIdx + 1;
+    ssLen = StringUtil::NumChars(str, outByteNum);
+    if (ssLen > limit) {
+        std::ostringstream errorMessage;
+        errorMessage << "Exceeds varchar type length limitation: " << limit;
+        SetError(contextPtr, errorMessage.str());
+        *outLen = 0;
+        return nullptr;
+    }
+
+    auto padded = ArenaAllocatorMalloc(contextPtr, outByteNum);
+    errno_t res = memcpy_s(padded, outByteNum, str, outByteNum);
+    if (res != EOK) {
+        SetError(contextPtr, "varcharTypeWriteSideCheck failed:memcpy_s error");
+        *outLen = 0;
+        return nullptr;
+    }
+    padded[outByteNum] = '\0';
+    *outLen = outByteNum;
+    return padded;
+}
+
+extern "C" DLLEXPORT const char *StaticInvokeCharReadPadding(int64_t contextPtr, const char *str, int32_t len,
+    int32_t limit, bool isNull, int32_t *outLen)
+{
+    if (isNull) {
+        *outLen = 0;
+        return nullptr;
+    } else if (len == 0) {
+        *outLen = 0;
+        return "";
+    }
+    int32_t ssLen = StringUtil::NumChars(str, len);
+    if (ssLen >= limit) {
+        *outLen = len;
+        return str;
+    }
+    int32_t diff = limit - ssLen;
+    int32_t outByteNum = len + diff + 1;
+    auto padded = ArenaAllocatorMalloc(contextPtr, outByteNum);
+    errno_t res = memcpy_s(padded, len, str, len);
+    if (res != EOK) {
+        SetError(contextPtr, "charReadPadding failed:memcpy_s error");
+        *outLen = 0;
+        return nullptr;
+    }
+    res = memset_s(padded + len, diff, ' ', diff);
+    if (res != EOK) {
+        SetError(contextPtr, "charReadPadding failed:memset_s error");
+        *outLen = 0;
+        return nullptr;
+    }
+    padded[outByteNum] = '\0';
+    *outLen = outByteNum - 1;
+    return padded;
+}
+
+extern "C" DLLEXPORT const char *SubstringIndex(int64_t contextPtr, const char *str, int32_t strLen, const char *delim,
+    int32_t delimLen, int32_t count, bool isNull, int32_t *outLen)
+{
+    if (count == 0 || isNull) {
+        *outLen = 0;
+        return nullptr;
+    }
+
+    int64_t index;
+    if (count > 0) {
+        index = stringImpl::StringPosition(std::string_view(str, strLen), std::string_view(delim, delimLen),
+            count);
+    } else {
+        index = stringImpl::StringPosition(std::string_view(str, strLen),
+            std::string_view(delim, delimLen), -count);
+    }
+
+    // If 'delim' is not found or found fewer than 'count' times,
+    // return the input string directly.
+    if (index == 0) {
+        auto result = ArenaAllocatorMalloc(contextPtr, strLen);
+        errno_t res = memcpy_s(result, strLen, str, strLen);
+        if (res != EOK) {
+            SetError(contextPtr, "charReadPadding failed:memcpy_s error");
+            *outLen = 0;
+            return nullptr;
+        }
+        *outLen = strLen;
+        return result;
+    }
+
+    auto start = 0;
+    auto length = strLen;
+    const auto delimLength = delimLen;
+    if (count > 0) {
+        length = index - 1;
+    } else {
+        start = index + delimLength - 1;
+        length -= start;
+    }
+
+    auto result = ArenaAllocatorMalloc(contextPtr, length);
+    errno_t res = memcpy_s(result, length, str + start, length);
+    if (res != EOK) {
+        SetError(contextPtr, "charReadPadding failed:memcpy_s error");
+        *outLen = 0;
+        return nullptr;
+    }
+    *outLen = length;
+    return result;
+}
+}
+
diff --git a/core/src/codegen/functions/stringfunctions.h b/core/src/codegen/functions/stringfunctions.h
new file mode 100644
index 0000000..407d961
--- /dev/null
+++ b/core/src/codegen/functions/stringfunctions.h
@@ -0,0 +1,364 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2024. All rights reserved.
+ * Description: registry  function
+ */
+#ifndef __STRINGFUNCTIONS_H__
+#define __STRINGFUNCTIONS_H__
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include "codegen/context_helper.h"
+#include "codegen/functions/decimal_arithmetic_functions.h"
+#include "codegen/functions/decimal_cast_functions.h"
+#include "util/utf8_util.h"
+#include "codegen/string_util.h"
+#include "type/date32.h"
+
+// All extern functions go here temporarily
+#ifdef _WIN32
+#define DLLEXPORT __declspec(dllexport)
+#else
+#define DLLEXPORT
+#endif
+
+namespace omniruntime::codegen::function {
+extern "C" DLLEXPORT bool StrEquals(const char *ap, int32_t apLen, const char *bp, int32_t bpLen);
+
+extern "C" DLLEXPORT int32_t StrCompare(const char *ap, int32_t apLen, const char *bp, int32_t bpLen);
+
+extern "C" DLLEXPORT bool LikeStr(const char *str, int32_t strLen, const char *regexToMatch, int32_t regexLen,
+    bool isNull);
+
+extern "C" DLLEXPORT bool LikeChar(const char *str, int32_t strWidth, int32_t strLen, const char *regexToMatch,
+    int32_t regexLen, bool isNull);
+
+extern "C" DLLEXPORT const char *ConcatStrStr(int64_t contextPtr, const char *ap, int32_t apLen, const char *bp,
+    int32_t bpLen, bool isNull, int32_t *outLen);
+
+extern "C" DLLEXPORT const char *ConcatCharChar(int64_t contextPtr, const char *ap, int32_t aWidth, int32_t apLen,
+    const char *bp, int32_t bWidth, int32_t bpLen, bool isNull, int32_t *outLen);
+
+extern "C" DLLEXPORT const char *ConcatCharStr(int64_t contextPtr, const char *ap, int32_t aWidth, int32_t apLen,
+    const char *bp, int32_t bpLen, bool isNull, int32_t *outLen);
+
+extern "C" DLLEXPORT const char *ConcatStrChar(int64_t contextPtr, const char *ap, int32_t apLen, const char *bp,
+    int32_t bWidth, int32_t bpLen, bool isNull, int32_t *outLen);
+
+extern "C" DLLEXPORT const char *ConcatWsStr(int64_t contextPtr, const char *separator, int32_t separatorLen,
+    bool separatorIsNull, const char *ap, int32_t apLen, const char *bp, int32_t bpLen, bool isNull, int32_t *outLen);
+
+extern "C" DLLEXPORT int32_t CastStringToDateNotAllowReducePrecison(int64_t contextPtr, const char *str, int32_t strLen,
+    bool isNull);
+
+extern "C" DLLEXPORT int32_t CastStringToDateAllowReducePrecison(int64_t contextPtr, const char *str, int32_t strLen,
+    bool isNull);
+
+// Cast numeric type to string
+extern "C" DLLEXPORT const char *CastIntToString(int64_t contextPtr, int32_t value, bool isNull, int32_t *outLen);
+
+extern "C" DLLEXPORT const char *CastInt16ToString(int64_t contextPtr, int16_t value, bool isNull, int32_t *outLen);
+
+extern "C" DLLEXPORT const char *CastInt8ToString(int64_t contextPtr, int8_t value, bool isNull, int32_t *outLen);
+
+extern "C" DLLEXPORT const char *CastLongToString(int64_t contextPtr, int64_t value, bool isNull, int32_t *outLen);
+
+extern "C" DLLEXPORT const char *CastDoubleToString(int64_t contextPtr, double value, bool isNull, int32_t *outLen);
+
+extern "C" DLLEXPORT const char *CastDecimal64ToString(int64_t contextPtr, int64_t x, int32_t precision, int32_t scale,
+    bool isNull, int32_t *outLen);
+
+extern "C" DLLEXPORT const char *CastDecimal128ToString(int64_t contextPtr, int64_t high, uint64_t low,
+    int32_t precision, int32_t scale, bool isNull, int32_t *outLen);
+
+extern "C" DLLEXPORT const char *CastStrWithDiffWidths(int64_t contextPtr, const char *srcStr, int32_t srcLen,
+    int32_t srcWidth, bool isNull, int32_t dstWidth, int32_t *outLen);
+
+// Cast string to numeric type
+
+extern "C" DLLEXPORT int8_t CastStringToByte(int64_t contextPtr, const char *str, int32_t strLen, bool isNull);
+
+extern "C" DLLEXPORT int16_t CastStringToShort(int64_t contextPtr, const char *str, int32_t strLen, bool isNull);
+
+extern "C" DLLEXPORT int32_t CastStringToInt(int64_t contextPtr, const char *str, int32_t strLen, bool isNull);
+
+extern "C" DLLEXPORT int64_t CastStringToLong(int64_t contextPtr, const char *str, int32_t strLen, bool isNull);
+
+extern "C" DLLEXPORT double CastStringToDouble(int64_t contextPtr, const char *str, int32_t strLen, bool isNull);
+
+extern "C" DLLEXPORT int64_t CastStringToDecimal64(int64_t contextPtr, const char *str, int32_t strLen, bool isNull,
+    int32_t precision, int32_t scale);
+
+extern "C" DLLEXPORT int64_t CastStringToDecimal64RoundUp(int64_t contextPtr, const char *str, int32_t strLen,
+    bool isNull, int32_t outPrecision, int32_t outScale);
+
+extern "C" DLLEXPORT void CastStringToDecimal128(int64_t contextPtr, const char *str, int32_t strLen, bool isNull,
+    int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr);
+
+extern "C" DLLEXPORT void CastStringToDecimal128RoundUp(int64_t contextPtr, const char *str, int32_t strLen,
+    bool isNull, int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr);
+
+extern "C" DLLEXPORT const char *EmptyToNull(const char *str, int32_t len, bool isNull, int32_t *outLen);
+
+extern "C" DLLEXPORT const char *StaticInvokeVarcharTypeWriteSideCheck(int64_t contextPtr, const char *str, int32_t len,
+    int32_t limit, bool isNull, int32_t *outLen);
+
+extern "C" DLLEXPORT const char *StaticInvokeCharReadPadding(int64_t contextPtr, const char *str,
+    int32_t len, int32_t limit, bool isNull, int32_t *outLen);
+
+/* *
+ * If isSupportNegativeIndex is false,the result of substr is "" when start index is negative
+ * If isSupportNegativeIndex is true,the substr rule is as follows:
+ * e.g., str="apple", strLength=5, startIndex=-7, subStringLength=3, Result="a".
+ * If isSupportZeroIndex is false,the result of substr is "" when start index is 0
+ * If isSupportZeroIndex is true,it refers to the first element when the start index is 0
+ */
+template 
+extern DLLEXPORT const char *SubstrVarcharWithStart(int64_t contextPtr, const char *str, int32_t strLen, T startIdx,
+    bool isNull, int32_t *outLen)
+{
+    if (isNull) {
+        return nullptr;
+    }
+    if constexpr (isSupportZeroIndex) {
+        startIdx = (startIdx == 0) ? 1 : startIdx;
+    }
+    if (startIdx == 0 || strLen == 0 || startIdx > strLen) {
+        *outLen = 0;
+        return reinterpret_cast(EMPTY);
+    }
+
+    int64_t startCodePoint = startIdx;
+    int64_t startIndex;
+    if (startCodePoint > 0) {
+        startIndex = omniruntime::Utf8Util::OffsetOfCodePoint(str, strLen, startCodePoint - 1);
+        if (startIndex < 0) {
+            *outLen = 0;
+            return reinterpret_cast(EMPTY);
+        }
+    } else {
+        // negative start is relative to end of string
+        int32_t codePoints = omniruntime::Utf8Util::CountCodePoints(str, strLen);
+        startCodePoint += codePoints;
+        if (startCodePoint < 0) {
+            if constexpr (isSupportNegativeIndex) {
+                startCodePoint = 0;
+            } else {
+                *outLen = 0;
+                return reinterpret_cast(EMPTY);
+            }
+        }
+
+        startIndex = omniruntime::Utf8Util::OffsetOfCodePoint(str, strLen, startCodePoint);
+    }
+
+    *outLen = strLen - startIndex;
+    return str + startIndex;
+}
+
+template 
+extern DLLEXPORT const char *SubstrVarchar(int64_t contextPtr, const char *str, int32_t strLen, T startIdx, T length,
+    bool isNull, int32_t *outLen)
+{
+    if (isNull) {
+        return nullptr;
+    }
+
+    if constexpr (isSupportZeroIndex) {
+        startIdx = (startIdx == 0) ? 1 : startIdx;
+    }
+    if (startIdx == 0 || (length <= 0) || (strLen == 0) || startIdx > strLen) {
+        *outLen = 0;
+        return reinterpret_cast(EMPTY);
+    }
+
+    int64_t endIdx;
+    int64_t startIndex;
+    int64_t startCodePoint = startIdx;
+    int64_t lengthCodePoint = length;
+    if (startCodePoint > 0) {
+        startIndex = omniruntime::Utf8Util::OffsetOfCodePoint(str, strLen, startCodePoint - 1);
+        if (startIndex < 0) {
+            // before beginning of string
+            *outLen = 0;
+            return reinterpret_cast(EMPTY);
+        }
+        endIdx = omniruntime::Utf8Util::OffsetOfCodePoint(str, strLen, startIndex, lengthCodePoint);
+        if (endIdx < 0) {
+            // after end of string
+            endIdx = strLen;
+        }
+    } else {
+        // negative start is relative to end of string
+        int32_t codePoints = omniruntime::Utf8Util::CountCodePoints(str, strLen);
+        startCodePoint += codePoints;
+        // before beginning of string
+        if (startCodePoint < 0) {
+            if constexpr (!isSupportNegativeIndex) {
+                *outLen = 0;
+                return reinterpret_cast(EMPTY);
+            }
+            if (startCodePoint + lengthCodePoint <= 0) {
+                *outLen = 0;
+                return reinterpret_cast(EMPTY);
+            }
+            lengthCodePoint += startCodePoint;
+            startCodePoint = 0;
+        }
+        startIndex = omniruntime::Utf8Util::OffsetOfCodePoint(str, strLen, startCodePoint);
+        if (startCodePoint + lengthCodePoint < codePoints) {
+            endIdx = omniruntime::Utf8Util::OffsetOfCodePoint(str, strLen, startIndex, lengthCodePoint);
+        } else {
+            endIdx = strLen;
+        }
+    }
+
+    *outLen = endIdx - startIndex;
+    return str + startIndex;
+}
+
+template 
+extern DLLEXPORT const char *SubstrChar(int64_t contextPtr, const char *str, int32_t width, int32_t strLen, T startIdx,
+    T length, bool isNull, int32_t *outLen)
+{
+    return SubstrVarchar(contextPtr, str, strLen, startIdx, length,
+        isNull, outLen);
+}
+
+template 
+extern DLLEXPORT const char *SubstrCharWithStart(int64_t contextPtr, const char *str, int32_t width, int32_t strLen,
+    T startIdx, bool isNull, int32_t *outLen)
+{
+    return SubstrVarcharWithStart(contextPtr, str, strLen, startIdx,
+        isNull, outLen);
+}
+
+extern "C" DLLEXPORT const char *ToUpperStr(int64_t contextPtr, const char *str, int32_t strLen, bool isNull,
+    int32_t *outLen);
+
+extern "C" DLLEXPORT const char *ToUpperChar(int64_t contextPtr, const char *str, int32_t width, int32_t strLen,
+    bool isNull, int32_t *outLen);
+
+extern "C" DLLEXPORT const char *ToLowerStr(int64_t contextPtr, const char *str, int32_t strLen, bool isNull,
+    int32_t *outLen);
+
+extern "C" DLLEXPORT const char *ToLowerChar(int64_t contextPtr, const char *str, int32_t width, int32_t strLen,
+    bool isNull, int32_t *outLen);
+
+extern "C" DLLEXPORT int64_t LengthChar(const char *str, int32_t width, int32_t strLen, bool isNull);
+
+extern "C" DLLEXPORT int32_t LengthCharReturnInt32(const char *str, int32_t width, int32_t strLen, bool isNull);
+
+extern "C" DLLEXPORT int32_t LengthStrReturnInt32(const char *str, int32_t strLen, bool isNull);
+
+extern "C" DLLEXPORT int64_t LengthStr(const char *str, int32_t strLen, bool isNull);
+
+extern "C" DLLEXPORT const char *ReplaceStrStrStrWithRepNotReplace(int64_t contextPtr, const char *str, int32_t strLen,
+    const char *searchStr, int32_t searchLen, const char *replaceStr, int32_t replaceLen, bool isNull, int32_t *outLen);
+
+extern "C" DLLEXPORT const char *ReplaceStrStrStrWithRepReplace(int64_t contextPtr, const char *str, int32_t strLen,
+    const char *searchStr, int32_t searchLen, const char *replaceStr, int32_t replaceLen, bool isNull, int32_t *outLen);
+
+extern "C" DLLEXPORT const char *ReplaceStrStrWithoutRepReplace(int64_t contextPtr, const char *str, int32_t strLen,
+    const char *searchStr, int32_t searchLen, bool isNull, int32_t *outLen);
+
+extern "C" DLLEXPORT const char *ReplaceStrStrWithoutRepNotReplace(int64_t contextPtr, const char *str, int32_t strLen,
+    const char *searchStr, int32_t searchLen, bool isNull, int32_t *outLen);
+
+extern "C" DLLEXPORT const char *ConcatStrStrRetNull(int64_t contextPtr, bool *isNull, const char *ap, int32_t apLen,
+    const char *bp, int32_t bpLen, int32_t *outLen);
+
+extern "C" DLLEXPORT const char *ConcatCharCharRetNull(int64_t contextPtr, bool *isNull, const char *ap, int32_t aWidth,
+    int32_t apLen, const char *bp, int32_t bWidth, int32_t bpLen, int32_t *outLen);
+
+extern "C" DLLEXPORT const char *ConcatCharStrRetNull(int64_t contextPtr, bool *isNull, const char *ap, int32_t aWidth,
+    int32_t apLen, const char *bp, int32_t bpLen, int32_t *outLen);
+
+extern "C" DLLEXPORT const char *ConcatStrCharRetNull(int64_t contextPtr, bool *isNull, const char *ap, int32_t apLen,
+    const char *bp, int32_t bWidth, int32_t bpLen, int32_t *outLen);
+
+extern "C" DLLEXPORT int32_t CastStringToDateRetNullAllowReducePrecison(bool *isNull, const char *str, int32_t strLen);
+
+extern "C" DLLEXPORT int32_t CastStringToDateRetNullNotAllowReducePrecison(bool *isNull, const char *str,
+    int32_t strLen);
+
+extern "C" DLLEXPORT const char *CastIntToStringRetNull(int64_t contextPtr, bool *isNull, int32_t value,
+    int32_t *outLen);
+
+extern "C" DLLEXPORT const char *CastInt16ToStringRetNull(int64_t contextPtr, bool *isNull, int16_t value,
+    int32_t *outLen);
+
+extern "C" DLLEXPORT const char *CastInt8ToStringRetNull(int64_t contextPtr, bool *isNull, int8_t value,
+    int32_t *outLen);
+
+extern "C" DLLEXPORT const char *CastLongToStringRetNull(int64_t contextPtr, bool *isNull, int64_t value,
+    int32_t *outLen);
+
+extern "C" DLLEXPORT const char *CastDoubleToStringRetNull(int64_t contextPtr, bool *isNull, double value,
+    int32_t *outLen);
+
+extern "C" DLLEXPORT const char *CastDecimal64ToStringRetNull(int64_t contextPtr, bool *isNull, int64_t x,
+    int32_t precision, int32_t scale, int32_t *outLen);
+
+extern "C" DLLEXPORT const char *CastDecimal128ToStringRetNull(int64_t contextPtr, bool *isNull, int64_t high,
+    uint64_t low, int32_t precision, int32_t scale, int32_t *outLen);
+
+
+extern "C" DLLEXPORT int8_t CastStringToByteRetNull(bool *isNull, const char *str, int32_t strLen);
+
+extern "C" DLLEXPORT int16_t CastStringToShortRetNull(bool *isNull, const char *str, int32_t strLen);
+
+extern "C" DLLEXPORT int32_t CastStringToIntRetNull(bool *isNull, const char *str, int32_t strLen);
+
+extern "C" DLLEXPORT int64_t CastStringToLongRetNull(bool *isNull, const char *str, int32_t strLen);
+
+extern "C" DLLEXPORT double CastStringToDoubleRetNull(bool *isNull, const char *str, int32_t strLen);
+
+extern "C" DLLEXPORT int64_t CastStringToDecimal64RetNull(bool *isNull, const char *str, int32_t strLen,
+    int32_t outPrecision, int32_t outScale);
+
+extern "C" DLLEXPORT int64_t CastStringToDecimal64RoundUpRetNull(bool *isNull, const char *str, int32_t strLen,
+    int32_t outPrecision, int32_t outScale);
+
+extern "C" DLLEXPORT void CastStringToDecimal128RetNull(bool *isNull, const char *str, int32_t strLen,
+    int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr);
+
+extern "C" DLLEXPORT void CastStringToDecimal128RoundUpRetNull(bool *isNull, const char *str, int32_t strLen,
+    int32_t outPrecision, int32_t outScale, int64_t *outHighPtr, uint64_t *outLowPtr);
+
+extern "C" DLLEXPORT const char *CastStrWithDiffWidthsRetNull(int64_t contextPtr, bool *isNull, const char *srcStr,
+    int32_t srcLen, int32_t srcWidth, int32_t dstWidth, int32_t *outLen);
+
+extern "C" DLLEXPORT int32_t InStr(const char *srcStr, int32_t srcLen, const char *subStr, int32_t subLen, bool isNull);
+
+extern "C" DLLEXPORT bool StartsWithStr(const char *srcStr, int32_t srcLen, const char *matchStr, int32_t matchLen,
+    bool isNull);
+
+extern "C" DLLEXPORT bool EndsWithStr(const char *srcStr, int32_t srcLen, const char *matchStr, int32_t matchLen,
+    bool isNull);
+
+extern "C" DLLEXPORT bool RegexMatch(const char *srcStr, int32_t srcLen, const char *matchStr, int32_t matchLen,
+    bool isNull);
+
+extern "C" DLLEXPORT const char *CastDateToStringRetNull(int64_t contextPtr, bool *isNull, int32_t value,
+    int32_t *outLen);
+
+extern "C" DLLEXPORT const char *CastDateToString(int64_t contextPtr, int32_t value, bool isNull, int32_t *outLen);
+
+extern "C" DLLEXPORT char *Md5Str(int64_t contextPtr, const char *str, int32_t len, bool isNull, int32_t *outLen);
+
+extern "C" DLLEXPORT bool ContainsStr(const char *srcStr, int32_t srcLen, const char *matchStr, int32_t matchLen,
+    bool isNull);
+
+extern "C" DLLEXPORT const char *GreatestStr(const char *lValue, int32_t lLen, bool lIsNull, const char *rValue,
+    int32_t rLen, bool rIsNull, bool *retIsNull, int32_t *outLen);
+
+extern "C" DLLEXPORT const char *SubstringIndex(int64_t contextPtr, const char *str, int32_t strLen, const char *delim,
+    int32_t delimLen, int32_t count, bool isNull, int32_t *outLen);
+}
+#endif
\ No newline at end of file
diff --git a/core/src/codegen/functions/udffunctions.cpp b/core/src/codegen/functions/udffunctions.cpp
new file mode 100644
index 0000000..bd06558
--- /dev/null
+++ b/core/src/codegen/functions/udffunctions.cpp
@@ -0,0 +1,54 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
+ * Description: udf functions.
+ */
+#include 
+#include "codegen/context_helper.h"
+#include "udf/cplusplus/java_udf_functions.h"
+#include "udffunctions.h"
+
+using namespace omniruntime::udf;
+
+namespace omniruntime::codegen::function {
+namespace {
+std::once_flag init_udf_flag;
+bool g_isUdfInited;
+const std::string INIT_UDF_FAILED = "Init UDF failed";
+}
+
+static void InitHiveUdf()
+{
+    auto ret = InitUdf();
+    if (ret != omniruntime::op::ErrorCode::SUCCESS) {
+        g_isUdfInited = false;
+    } else {
+        g_isUdfInited = true;
+    }
+}
+
+extern DLLEXPORT void EvaluateHiveUdfSingle(int64_t contextPtr, const char *udfClass, int32_t *inputTypes,
+    int32_t retType, int32_t vecCount, int64_t inputValue, int64_t inputNull, int64_t inputLength, int64_t outputValue,
+    int64_t outputNull, int64_t outputLength)
+{
+    std::call_once(init_udf_flag, InitHiveUdf);
+    if (!g_isUdfInited) {
+        SetError(contextPtr, INIT_UDF_FAILED);
+        return;
+    }
+    ExecuteHiveUdfSingle(contextPtr, udfClass, inputTypes, retType, vecCount, inputValue, inputNull, inputLength,
+        outputValue, outputNull, outputLength);
+}
+
+extern DLLEXPORT void EvaluateHiveUdfBatch(int64_t contextPtr, const char *udfClass, int32_t *inputTypes,
+    int32_t retType, int32_t vecCount, int32_t rowCount, int64_t *inputValues, int64_t *inputNulls,
+    int64_t *inputLengths, int64_t outputValue, int64_t outputNull, int64_t outputLength)
+{
+    std::call_once(init_udf_flag, InitHiveUdf);
+    if (!g_isUdfInited) {
+        SetError(contextPtr, INIT_UDF_FAILED);
+        return;
+    }
+    ExecuteHiveUdfBatch(contextPtr, udfClass, inputTypes, retType, vecCount, rowCount, inputValues, inputNulls,
+        inputLengths, outputValue, outputNull, outputLength);
+}
+}
\ No newline at end of file
diff --git a/core/src/codegen/functions/udffunctions.h b/core/src/codegen/functions/udffunctions.h
new file mode 100644
index 0000000..43f0245
--- /dev/null
+++ b/core/src/codegen/functions/udffunctions.h
@@ -0,0 +1,26 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
+ * Description: udf functions.
+ */
+#ifndef OMNI_RUNTIME_UDFFUNCTIONS_H
+#define OMNI_RUNTIME_UDFFUNCTIONS_H
+
+#include 
+#include "util/error_code.h"
+
+#ifdef _WIN32
+#define DLLEXPORT __declspec(dllexport)
+#else
+#define DLLEXPORT
+#endif
+
+namespace omniruntime::codegen::function {
+extern DLLEXPORT void EvaluateHiveUdfSingle(int64_t contextPtr, const char *udfClass, int32_t *inputTypes,
+    int32_t retType, int32_t vecCount, int64_t inputValue, int64_t inputNull, int64_t inputLength, int64_t outputValue,
+    int64_t outputNull, int64_t outputLength);
+
+extern DLLEXPORT void EvaluateHiveUdfBatch(int64_t contextPtr, const char *udfClass, int32_t *inputTypes,
+    int32_t retType, int32_t vecCount, int32_t rowCount, int64_t *inputValues, int64_t *inputNulls,
+    int64_t *inputLengths, int64_t outputValue, int64_t outputNull, int64_t outputLength);
+}
+#endif // OMNI_RUNTIME_UDFFUNCTIONS_H
diff --git a/core/src/codegen/functions/varcharVectorfunctions.cpp b/core/src/codegen/functions/varcharVectorfunctions.cpp
new file mode 100644
index 0000000..3cb0a03
--- /dev/null
+++ b/core/src/codegen/functions/varcharVectorfunctions.cpp
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
+ * Description: registry varcharVector functions
+ */
+
+#include "varcharVectorfunctions.h"
+#include "vector/vector.h"
+
+using namespace omniruntime::vec;
+using namespace std;
+namespace omniruntime::codegen::function {
+extern DLLEXPORT int32_t WrapVarcharVector(int64_t vectorAddr, int32_t index, uint8_t *data, int32_t dataLen)
+{
+    auto vec = reinterpret_cast> *>(vectorAddr);
+    if (data == nullptr) {
+        vec->SetNull(index);
+    } else {
+        std::string_view strView(reinterpret_cast(data), dataLen);
+        vec->SetValue(index, strView);
+    }
+    return 0;
+}
+
+extern DLLEXPORT void WrapSetBitNull(int32_t *bits, int32_t index, bool isNull)
+{
+    // bits最初始的时候已经全都赋值0,只有isNull是true的时候才调用BitUtil::SetBit
+    if (UNLIKELY(isNull)) {
+        BitUtil::SetBit(bits, index);
+    }
+}
+
+extern DLLEXPORT bool WrapIsBitNull(int32_t *bits, int32_t index)
+{
+    return BitUtil::IsBitSet(bits, index);
+}
+}
\ No newline at end of file
diff --git a/core/src/codegen/functions/varcharVectorfunctions.h b/core/src/codegen/functions/varcharVectorfunctions.h
new file mode 100644
index 0000000..b8992c6
--- /dev/null
+++ b/core/src/codegen/functions/varcharVectorfunctions.h
@@ -0,0 +1,24 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
+ * Description: registry varcharVector functions
+ */
+
+#ifndef OMNI_RUNTIME_VARCHARVECTORFUNCTIONS_H
+#define OMNI_RUNTIME_VARCHARVECTORFUNCTIONS_H
+
+#ifdef _WIN32
+#define DLLEXPORT __declspec(dllexport)
+#else
+#define DLLEXPORT
+#endif
+
+#include 
+
+namespace omniruntime::codegen::function {
+extern DLLEXPORT int32_t WrapVarcharVector(int64_t vectorAddr, int32_t index, uint8_t *data, int32_t dataLen);
+
+extern DLLEXPORT void WrapSetBitNull(int32_t *bits, int32_t index, bool isNull);
+
+extern DLLEXPORT bool WrapIsBitNull(int32_t *bits, int32_t index);
+}
+#endif // OMNI_RUNTIME_VARCHARVECTORFUNCTIONS_H
diff --git a/core/src/codegen/functions/xxhash64_hash.cpp b/core/src/codegen/functions/xxhash64_hash.cpp
new file mode 100644
index 0000000..df4cdd5
--- /dev/null
+++ b/core/src/codegen/functions/xxhash64_hash.cpp
@@ -0,0 +1,121 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved.
+ * Description: registry  function  implementation
+ */
+#include "xxhash64_hash.h"
+#include "operator/hash_util.h"
+#include "type/decimal128_utils.h"
+
+using namespace std;
+using namespace omniruntime::op;
+
+namespace omniruntime::codegen::function {
+static const int64_t PRIME64_5 = 0x27D4EB2F165667C5L;
+static const int64_t SIZE_OF_INT = 4L;
+static const int64_t SIZE_OF_LONG = 8L;
+
+int64_t ALWAYS_INLINE HashInt(int32_t val, int64_t seed)
+{
+    auto hash = seed + PRIME64_5 + SIZE_OF_INT;
+    hash = HashUtil::XxHash64UpdateTail(hash, val);
+    hash = static_cast(HashUtil::XxHash64FinalShuffle(static_cast(hash)));
+    return hash;
+}
+
+int64_t ALWAYS_INLINE HashLong(int64_t val, int64_t seed)
+{
+    auto hash = seed + PRIME64_5 + SIZE_OF_LONG;
+    hash = HashUtil::XxHash64UpdateTail(hash, val);
+    hash = static_cast(HashUtil::XxHash64FinalShuffle(static_cast(hash)));
+    return hash;
+}
+
+extern "C" DLLEXPORT int64_t XxH64Int16(int16_t val, bool isValNull, int64_t seed, bool isSeedNull)
+{
+    if (isSeedNull) {
+        seed = 0;
+    }
+    return isValNull ? seed : HashInt(static_cast(val), seed);
+}
+
+extern "C" DLLEXPORT int64_t XxH64Int32(int32_t val, bool isValNull, int64_t seed, bool isSeedNull)
+{
+    if (isSeedNull) {
+        seed = 0;
+    }
+    return isValNull ? seed : HashInt(val, seed);
+}
+
+extern "C" DLLEXPORT int64_t XxH64Int64(int64_t val, bool isValNull, int64_t seed, bool isSeedNull)
+{
+    if (isSeedNull) {
+        seed = 0;
+    }
+    return isValNull ? seed : HashLong(val, seed);
+}
+
+extern "C" DLLEXPORT int64_t XxH64String(const char *val, int32_t valLen, bool isValNull, int64_t seed, bool isSeedNull)
+{
+    if (isSeedNull) {
+        seed = 0;
+    }
+    if (isValNull) {
+        return seed;
+    }
+
+    auto data = reinterpret_cast(const_cast(val));
+    return HashUtil::XxHash64Hash(seed, data, 0, valLen);
+}
+
+extern "C" DLLEXPORT int64_t XxH64Double(double val, bool isValNull, int64_t seed, bool isSeedNull)
+{
+    if (isSeedNull) {
+        seed = 0;
+    }
+    if (isValNull) {
+        return seed;
+    }
+
+    return HashLong(HashUtil::DoubleToLongBits(val), seed);
+}
+
+extern "C" DLLEXPORT int64_t XxH64Decimal64(int64_t val, int32_t precision, int32_t scale, bool isValNull, int64_t seed,
+    bool isSeedNull)
+{
+    if (isSeedNull) {
+        seed = 0;
+    }
+    return isValNull ? seed : HashLong(val, seed);
+}
+
+extern "C" DLLEXPORT int64_t XxH64Decimal128(int64_t xHigh, uint64_t xLow, int32_t precision, int32_t scale,
+    bool isValNull, int64_t seed, bool isSeedNull)
+{
+    if (isSeedNull) {
+        seed = 0;
+    }
+    if (isValNull) {
+        return seed;
+    }
+
+    int32_t byteLen = 0;
+    auto bytes = omniruntime::type::Decimal128Utils::Decimal128ToBytes(xHigh, xLow, byteLen);
+    auto result = op::HashUtil::XxHash64Hash(seed, bytes, 0, byteLen);
+    delete[] bytes;
+    bytes = nullptr;
+    return result;
+}
+
+extern "C" DLLEXPORT int64_t XxH64Boolean(bool val, bool isValNull, int64_t seed, bool isSeedNull)
+{
+    if (isSeedNull) {
+        seed = 0;
+    }
+    if (isValNull) {
+        return seed;
+    }
+
+    int32_t intVal = val ? 1 : 0;
+    return HashInt(intVal, seed);
+}
+}
\ No newline at end of file
diff --git a/core/src/codegen/functions/xxhash64_hash.h b/core/src/codegen/functions/xxhash64_hash.h
new file mode 100644
index 0000000..c00e598
--- /dev/null
+++ b/core/src/codegen/functions/xxhash64_hash.h
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved.
+ * Description: registry  function  implementation
+ */
+#ifndef OMNI_RUNTIME_XXHASH64_HASH_H
+#define OMNI_RUNTIME_XXHASH64_HASH_H
+#include 
+
+namespace omniruntime::codegen::function {
+// All extern functions go here temporarily
+#ifdef _WIN32
+#define DLLEXPORT __declspec(dllexport)
+#else
+#define DLLEXPORT
+#endif
+extern "C" DLLEXPORT int64_t XxH64Int16(int16_t val, bool isValNull, int64_t seed, bool isSeedNull);
+
+extern "C" DLLEXPORT int64_t XxH64Int32(int32_t val, bool isValNull, int64_t seed, bool isSeedNull);
+
+extern "C" DLLEXPORT int64_t XxH64Int64(int64_t val, bool isValNull, int64_t seed, bool isSeedNull);
+
+extern "C" DLLEXPORT int64_t XxH64String(const char *val, int32_t valLen, bool isValNull, int64_t seed,
+    bool isSeedNull);
+
+extern "C" DLLEXPORT int64_t XxH64Double(double val, bool isValNull, int64_t seed, bool isSeedNull);
+
+extern "C" DLLEXPORT int64_t XxH64Decimal64(int64_t val, int32_t precision, int32_t scale, bool isValNull, int64_t seed,
+    bool isSeedNull);
+
+extern "C" DLLEXPORT int64_t XxH64Decimal128(int64_t xHigh, uint64_t xLow, int32_t precision, int32_t scale,
+    bool isValNull, int64_t seed, bool isSeedNull);
+
+extern "C" DLLEXPORT int64_t XxH64Boolean(bool val, bool isValNull, int64_t seed, bool isSeedNull);
+}
+#endif // OMNI_RUNTIME_XXHASH64_HASH_H
\ No newline at end of file
diff --git a/core/src/codegen/llvm_engine.cpp b/core/src/codegen/llvm_engine.cpp
new file mode 100644
index 0000000..ccae424
--- /dev/null
+++ b/core/src/codegen/llvm_engine.cpp
@@ -0,0 +1,331 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
+ * Description: Expression code generation utilities
+ */
+
+#include "llvm_engine.h"
+
+#include 
+#include "llvm/Pass.h"
+#include "llvm/Transforms/Scalar/SimpleLoopUnswitch.h"
+
+#include "expr_info_extractor.h"
+#include "func_registry.h"
+#include "util/config_util.h"
+
+namespace omniruntime {
+namespace codegen {
+namespace {
+std::once_flag g_codegenTargetInitFlag;
+constexpr unsigned SMALL_VECTOR_DEFAULT_INLINED_ELEMENTS_COUNT = 20;
+static llvm::StringRef CPU_NAME;
+static llvm::SmallVector CPU_ATTRS;
+}
+
+LLVMEngine::LLVMEngine()
+{
+    std::call_once(g_codegenTargetInitFlag, InitializeCodegenTargets);
+    llvm::ExitOnError eoe;
+    context = std::make_unique();
+    jit = eoe(LLJITBuilder().create());
+    builder = std::make_unique>(*context);
+    auto module = std::make_unique("the_module", *context);
+    module->setDataLayout(jit->getDataLayout());
+    modulePtr = module.get();
+    llvmTypes = std::make_unique(*context);
+    fpm = std::make_unique(module.get());
+
+    auto optLevel = llvm::CodeGenOpt::Aggressive;
+    std::string builderError;
+    llvm::EngineBuilder engine_builder(std::move(module));
+    engine_builder.setEngineKind(llvm::EngineKind::JIT).setOptLevel(optLevel).setErrorStr(&builderError);
+    engine_builder.setMCPU(CPU_NAME);
+    engine_builder.setMAttrs(CPU_ATTRS);
+    std::unique_ptr exec_engine { engine_builder.create() };
+    execution_engine = std::move(exec_engine);
+    // Although ConfigUtil::IsEnableBatchExprEvaluate() = true, RowProjection also need row functions.
+    RegisterFunctions(FunctionRegistry::GetBatchFunctions());
+    RegisterFunctions(FunctionRegistry::GetRowFunctions());
+}
+
+llvm::IRBuilder<> *LLVMEngine::GetIRBuilder()
+{
+    return builder.get();
+}
+
+Module *LLVMEngine::GetModule()
+{
+    return modulePtr;
+}
+
+LLVMContext *LLVMEngine::GetContext()
+{
+    return context.get();
+}
+
+LLVMTypes *LLVMEngine::GetTypes()
+{
+    return llvmTypes.get();
+}
+
+int64_t LLVMEngine::Compile()
+{
+    jit->getMainJITDylib().addGenerator(
+        eoe(DynamicLibrarySearchGenerator::GetForCurrentProcess(jit->getDataLayout().getGlobalPrefix())));
+    auto resTracker = jit->getMainJITDylib().createResourceTracker();
+    MakeThreadSafe(&resTracker);
+    rt = resTracker;
+    auto sym = eoe(jit->lookup("WRAPPER_FUNC"));
+    return sym.getValue();
+}
+
+void LLVMEngine::MakeThreadSafe(ResourceTrackerSP *resTracker)
+{
+    execution_engine->removeModule(modulePtr);
+    std::unique_ptr module(modulePtr);
+    auto threadSafeModule = llvm::orc::ThreadSafeModule(move(module), move(context));
+    eoe(jit->addIRModule(*resTracker, std::move(threadSafeModule)));
+}
+
+void LLVMEngine::InitializeCodegenTargets()
+{
+    llvm::InitializeNativeTarget();
+    llvm::InitializeNativeTargetAsmPrinter();
+    llvm::InitializeNativeTargetAsmParser();
+    llvm::InitializeNativeTargetDisassembler();
+    llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr);
+
+    CPU_NAME = llvm::sys::getHostCPUName();
+    llvm::StringMap host_features;
+    if (llvm::sys::getHostCPUFeatures(host_features)) {
+        for (auto &f : host_features) {
+            std::string attr = f.second ? std::string("+") + f.first().str() : std::string("-") + f.first().str();
+            CPU_ATTRS.push_back(attr);
+        }
+    }
+}
+
+void LLVMEngine::RegisterFunctions(const std::vector &functions)
+{
+    for (auto &func : functions) {
+        auto &jd = jit->getMainJITDylib();
+        auto &dl = jit->getDataLayout();
+        llvm::orc::MangleAndInterner mangle(jit->getExecutionSession(), dl);
+        std::vector params = func.GetParamTypes();
+        DataTypeId retType = func.GetReturnType();
+        std::vector args = this->GetFunctionArgTypeVector(params, retType, func.IsExecutionContextSet());
+        auto s = llvm::orc::absoluteSymbols({ { mangle(func.GetId()),
+            JITEvaluatedSymbol(pointerToJITTargetAddress(func.GetAddress()), JITSymbolFlags::Exported) } });
+        auto ign = jd.define(s);
+        if (ign) {
+            LogError("Error while defining absolute symbol in jd");
+        }
+        llvm::Type *ret = (retType == OMNI_DECIMAL128) ? llvmTypes->VoidType() : llvmTypes->ToLLVMType(retType);
+        llvm::FunctionType *ft = llvm::FunctionType::get(ret, args, false);
+        auto linkage = llvm::Function::ExternalLinkage;
+        llvm::Function::Create(ft, linkage, func.GetId(), *modulePtr);
+        modulePtr->getOrInsertFunction(func.GetId(), ft);
+    }
+}
+
+std::vector LLVMEngine::GetFunctionArgTypeVector(std::vector ¶ms, DataTypeId &retTypeId,
+    bool needsContext)
+{
+    std::vector args;
+    if (needsContext) {
+        args.push_back(llvmTypes->I64Type());
+    }
+    for (auto type : params) {
+        if (type == OMNI_DECIMAL128) {
+            args.push_back(llvmTypes->I64Type());
+            args.push_back(llvmTypes->I64Type());
+        } else {
+            args.push_back(llvmTypes->ToLLVMType(type));
+            if (TypeUtil::IsStringType(type)) {
+                if (type == OMNI_CHAR) {
+                    args.push_back(llvmTypes->I32Type());
+                }
+                args.push_back(llvmTypes->I32Type());
+            }
+        }
+    }
+    // return arguments
+    if (TypeUtil::IsStringType(retTypeId)) {
+        args.push_back(llvmTypes->I32PtrType());
+    } else if (retTypeId == OMNI_DECIMAL128) {
+        // Add high and low output pointers
+        args.push_back(llvmTypes->I64PtrType());
+        args.push_back(llvmTypes->I64PtrType());
+    }
+    return args;
+}
+
+void LLVMEngine::OptimizeFunctionsAndModule()
+{
+    auto machine = execution_engine->getTargetMachine();
+    llvm::TargetIRAnalysis target_analysis = machine->getTargetIRAnalysis();
+
+    mpm.add(llvm::createTargetTransformInfoWrapperPass(target_analysis));
+    mpm.add(llvm::createFunctionInliningPass());
+    mpm.add(llvm::createInstructionCombiningPass());
+    mpm.add(llvm::createPromoteMemoryToRegisterPass());
+    mpm.add(llvm::createGVNPass());
+    mpm.add(llvm::createNewGVNPass());
+    mpm.add(llvm::createCFGSimplificationPass());
+    mpm.add(llvm::createLoopVectorizePass());
+    mpm.add(llvm::createSLPVectorizerPass());
+    mpm.add(llvm::createGlobalOptimizerPass());
+    mpm.add(llvm::createStripDeadPrototypesPass());
+
+    fpm->add(llvm::createTargetTransformInfoWrapperPass(target_analysis));
+
+    // run the optimiser
+    llvm::PassManagerBuilder pass_builder;
+    pass_builder.OptLevel = llvm::CodeGenOpt::Aggressive;
+
+    pass_builder.populateFunctionPassManager(*fpm);
+
+    pass_builder.populateModulePassManager(mpm);
+
+    fpm->doInitialization();
+    for (auto &f : *modulePtr)
+        fpm->run(f);
+    fpm->doFinalization();
+
+    mpm.run(*modulePtr);
+}
+
+void LLVMEngine::OptimizeModule()
+{
+    mpm.add(createFunctionInliningPass());
+    mpm.add(createPruneEHPass());
+
+    mpm.run(*modulePtr);
+}
+
+CallInst *LLVMEngine::CreateCall(llvm::Function *func, const std::vector &argsVals,
+    const std::string &name)
+{
+    return builder->CreateCall(func, argsVals, name);
+}
+
+llvm::Value *LLVMEngine::CallExternFunction(const std::string &fn_name, const std::vector ¶ms,
+    const DataTypeId returnType, const std::vector &args, llvm::Value *executionContextPtr,
+    const std::string &msg, omniruntime::op::OverflowConfig *overflowConfig, llvm::Value *overflowNull)
+{
+    std::vector funcArgs;
+    funcArgs.insert(funcArgs.begin(), args.begin(), args.end());
+    if (executionContextPtr != nullptr) {
+        if (overflowConfig != nullptr &&
+            overflowConfig->GetOverflowConfigId() == omniruntime::op::OVERFLOW_CONFIG_NULL) {
+            funcArgs.insert(funcArgs.begin(), overflowNull);
+        } else {
+            funcArgs.insert(funcArgs.begin(), executionContextPtr);
+        }
+    }
+
+    std::string funcId = FunctionSignature(fn_name, params, returnType).ToString(overflowConfig);
+    auto f = modulePtr->getFunction(funcId);
+    auto ret = CreateCall(f, funcArgs, msg);
+    return ret;
+}
+
+void LLVMEngine::RecordMainFunction(llvm::Function *func)
+{
+    this->function = func;
+}
+
+void LLVMEngine::RemoveUnusedFunctions()
+{
+    llvm::Function *preserved = function;
+    mpm.add(llvm::createInternalizePass(
+        [preserved](const llvm::GlobalValue &func) { return (func.getName().str() == preserved->getName().str()); }));
+    mpm.add(llvm::createGlobalDCEPass());
+}
+
+DecimalSplitValue LLVMEngine::Split(llvm::Value *fullValue)
+{
+    LLVMTypes types(*context);
+    const int32_t intValue = 64;
+    auto high = builder->CreateLShr(fullValue, types.CreateConstant128(intValue), "split_high");
+    high = builder->CreateTrunc(high, types.I64Type(), "split_high");
+    auto low = builder->CreateTrunc(fullValue, types.I64Type(), "split_low");
+    return DecimalSplitValue(high, low);
+}
+
+llvm::Value *LLVMEngine::ToInt128(llvm::Value *high, llvm::Value *low) const
+{
+    LLVMTypes types(*context);
+    auto value = builder->CreateSExt(high, types.I128Type());
+    const int32_t intValue = 64;
+    value = builder->CreateShl(value, types.CreateConstant128(intValue));
+    value = builder->CreateAdd(value, builder->CreateZExt(low, types.I128Type()));
+    return value;
+}
+
+std::shared_ptr LLVMEngine::BuildDecimalValue(llvm::Value *data, omniruntime::type::DataType &retType,
+    llvm::Value *isNull)
+{
+    LLVMTypes llvmTypes(*context);
+    llvm::Value *precision;
+    llvm::Value *scale;
+    if (TypeUtil::IsDecimalType(retType.GetId())) {
+        precision = llvmTypes.CreateConstantInt(static_cast(retType).GetPrecision());
+        scale = llvmTypes.CreateConstantInt(static_cast(retType).GetScale());
+    } else {
+        precision = llvmTypes.CreateConstantInt(0);
+        scale = llvmTypes.CreateConstantInt(0);
+    }
+    return std::make_shared(data, isNull, precision, scale);
+}
+
+llvm::Value *LLVMEngine::CallDecimalFunction(const std::string &fnName, llvm::Type *retType,
+    const std::vector &args, llvm::Value *executionContextPtr,
+    omniruntime::op::OverflowConfig *overflowConfig, llvm::Value *overflowNull)
+{
+    LLVMTypes llvmTypes(*context);
+    std::vector disassembledArgs;
+
+    if (executionContextPtr != nullptr) {
+        if (overflowConfig != nullptr &&
+            overflowConfig->GetOverflowConfigId() == omniruntime::op::OVERFLOW_CONFIG_NULL) {
+            disassembledArgs.push_back(overflowNull);
+        } else {
+            disassembledArgs.push_back(executionContextPtr);
+        }
+    }
+    for (auto &arg : args) {
+        if (arg->getType() == llvmTypes.I128Type()) {
+            auto split = Split(arg);
+            disassembledArgs.push_back(const_cast(split.GetHigh()));
+            disassembledArgs.push_back(const_cast(split.GetLow()));
+        } else {
+            disassembledArgs.push_back(arg);
+        }
+    }
+    auto f = modulePtr->getFunction(fnName);
+    llvm::Value *result = nullptr;
+    if (f) {
+        if (retType == llvmTypes.I128Type()) {
+            auto outHighPtr = builder->CreateAlloca(llvmTypes.I64Type(), nullptr, "out_high");
+            auto outLowPtr = builder->CreateAlloca(llvmTypes.I64Type(), nullptr, "out_low");
+            disassembledArgs.push_back(outHighPtr);
+            disassembledArgs.push_back(outLowPtr);
+
+            CreateCall(f, disassembledArgs, const_cast(fnName));
+
+            auto outHigh = builder->CreateLoad(llvmTypes.I64Type(), outHighPtr);
+            auto outLow = builder->CreateLoad(llvmTypes.I64Type(), outLowPtr);
+            result = ToInt128(outHigh, outLow);
+        } else {
+            result = CreateCall(f, disassembledArgs, const_cast(fnName));
+        }
+        llvm::InlineFunctionInfo inlineFunctionInfo;
+        llvm::InlineFunction(*((llvm::CallInst *)result), inlineFunctionInfo);
+    } else {
+        LogWarn("Unable to generate function : %s", fnName.c_str());
+    }
+    return result;
+}
+}
+}
\ No newline at end of file
diff --git a/core/src/codegen/llvm_engine.h b/core/src/codegen/llvm_engine.h
new file mode 100644
index 0000000..5dc3e37
--- /dev/null
+++ b/core/src/codegen/llvm_engine.h
@@ -0,0 +1,131 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
+ * Description: Expression code generation utilities
+ */
+#ifndef OMNI_RUNTIME_LLVM_ENGINE_H
+#define OMNI_RUNTIME_LLVM_ENGINE_H
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/IPO.h"
+#include "llvm/Transforms/Utils.h"
+#include "llvm/Transforms/IPO/PassManagerBuilder.h"
+#include "llvm/Transforms/InstCombine/InstCombine.h"
+#include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Scalar/GVN.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+#include "llvm/Transforms/Vectorize.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/Host.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IRReader/IRReader.h"
+#include "llvm/IR/LegacyPassManager.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/ExecutionEngine/ExecutionEngine.h"
+#include "llvm/ExecutionEngine/Orc/LLJIT.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+
+#include "expression/expressions.h"
+#include "llvm_types.h"
+#include "codegen_value.h"
+
+namespace omniruntime {
+namespace codegen {
+using namespace llvm;
+using namespace orc;
+using namespace omniruntime;
+using namespace omniruntime::expressions;
+using namespace omniruntime::type;
+
+class LLVMEngine {
+public:
+    LLVMEngine();
+
+    virtual ~LLVMEngine() = default;
+
+    void OptimizeFunctionsAndModule();
+
+    void OptimizeModule();
+
+    llvm::CallInst *CreateCall(llvm::Function *func, const std::vector &argsVals,
+        const std::string &name);
+
+    llvm::Value *CallExternFunction(const std::string &fn_name,
+        const std::vector ¶ms, const omniruntime::type::DataTypeId returnType,
+        const std::vector &args, llvm::Value *executionContextPtr, const std::string &msg = "",
+        omniruntime::op::OverflowConfig *overflowConfig = nullptr, llvm::Value *overflowNull = nullptr);
+
+    static void InitializeCodegenTargets();
+
+    void RegisterFunctions(const std::vector &func);
+
+    void MakeThreadSafe(llvm::orc::ResourceTrackerSP *res);
+
+    void RecordMainFunction(llvm::Function *func);
+
+    void RemoveUnusedFunctions();
+
+    std::shared_ptr BuildDecimalValue(llvm::Value *data, omniruntime::type::DataType &retType,
+        llvm::Value *isNull = nullptr);
+
+    // Make from i128 value
+    DecimalSplitValue Split(llvm::Value *fullValue);
+
+    // Combine the two parts into an i128
+    llvm::Value *ToInt128(llvm::Value *high, llvm::Value *low) const;
+
+    llvm::Value *CallDecimalFunction(const std::string &function_name, llvm::Type *return_type,
+        const std::vector &args, llvm::Value *executionContextPtr = nullptr,
+        omniruntime::op::OverflowConfig *overflowConfig = nullptr, llvm::Value *overflowNull = nullptr);
+
+    /* *
+     * optimize and compiles the generated module contained in this LLVM Engine instance
+     * @return function address
+     */
+    int64_t Compile();
+
+    llvm::IRBuilder<> *GetIRBuilder();
+
+    llvm::Module *GetModule();
+
+    llvm::LLVMContext *GetContext();
+
+    LLVMTypes *GetTypes();
+
+protected:
+    std::unique_ptr context;
+    std::unique_ptr jit;
+    std::unique_ptr> builder;
+    llvm::Module *modulePtr;
+    std::unique_ptr llvmTypes;
+    std::unique_ptr fpm = nullptr;
+    std::unique_ptr execution_engine;
+    llvm::Function *function = nullptr;
+    llvm::ExitOnError eoe;
+    llvm::orc::ResourceTrackerSP rt;
+    llvm::legacy::PassManager mpm;
+
+private:
+    std::vector GetFunctionArgTypeVector(std::vector ¶ms,
+        omniruntime::type::DataTypeId &retTypeId, bool needsContext);
+};
+}
+}
+
+#endif // OMNI_RUNTIME_LLVM_ENGINE_H
\ No newline at end of file
diff --git a/core/src/codegen/llvm_types.cpp b/core/src/codegen/llvm_types.cpp
new file mode 100644
index 0000000..5d4a2ff
--- /dev/null
+++ b/core/src/codegen/llvm_types.cpp
@@ -0,0 +1,236 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
+ * Description: Expression code generator
+ */
+#include "llvm_types.h"
+#include 
+#include 
+#include 
+#include 
+
+namespace omniruntime::codegen {
+using namespace omniruntime::type;
+using namespace llvm;
+
+namespace {
+const int INT8_VALUE = 8;
+const int INT16_VALUE = 16;
+const int INT32_VALUE = 32;
+const int INT64_VALUE = 64;
+const int INT128_VALUE = 128;
+}
+
+LLVMTypes::LLVMTypes(llvm::LLVMContext &context) : context(context)
+{
+    VectorToLLVMTypeMap = { { OMNI_INT, I32Type() },
+        { OMNI_LONG, I64Type() },
+        { OMNI_DOUBLE, DoubleType() },
+        { OMNI_BOOLEAN, I1Type() },
+        { OMNI_BYTE, I8Type() },
+        { OMNI_SHORT, I16Type() },
+        { OMNI_DECIMAL64, I64Type() },
+        { OMNI_DECIMAL128, I128Type() },
+        { OMNI_DATE32, I32Type() },
+        { OMNI_DATE64, I64Type() },
+        { OMNI_TIMESTAMP, I64Type() },
+        { OMNI_INTERVAL_MONTHS, I32Type() },
+        { OMNI_INTERVAL_DAY_TIME, I32Type() },
+        { OMNI_VARCHAR, I8PtrType() },
+        { OMNI_CHAR, I8PtrType() } };
+}
+
+LLVMTypes::~LLVMTypes() = default;
+
+Value *LLVMTypes::CreateConstantBool(bool v)
+{
+    return ConstantInt::get(context, APInt(1, v));
+}
+
+Value *LLVMTypes::CreateConstantByte(int8_t v)
+{
+    return ConstantInt::get(context, APInt(INT8_VALUE, v, true));
+}
+
+Value *LLVMTypes::CreateConstantShort(int16_t v)
+{
+    return ConstantInt::get(context, APInt(INT16_VALUE, v, true));
+}
+
+Value *LLVMTypes::CreateConstantInt(int32_t v)
+{
+    return ConstantInt::get(context, APInt(INT32_VALUE, v, true));
+}
+
+Value *LLVMTypes::CreateConstantLong(int64_t v)
+{
+    return ConstantInt::get(context, APInt(INT64_VALUE, v, true));
+}
+
+Value *LLVMTypes::CreateConstantDouble(double v)
+{
+    return ConstantFP::get(context, APFloat(v));
+}
+
+Value *LLVMTypes::CreateConstant128(int64_t v)
+{
+    return ConstantInt::get(context, APInt(INT128_VALUE, v, true));
+}
+
+llvm::Type *LLVMTypes::VoidType()
+{
+    return llvm::Type::getVoidTy(context);
+}
+
+llvm::Type *LLVMTypes::I1Type()
+{
+    return llvm::Type::getInt1Ty(context);
+}
+
+llvm::Type *LLVMTypes::I8Type()
+{
+    return llvm::Type::getInt8Ty(context);
+}
+
+llvm::Type *LLVMTypes::I16Type()
+{
+    return llvm::Type::getInt16Ty(context);
+}
+
+llvm::Type *LLVMTypes::I32Type()
+{
+    return llvm::Type::getInt32Ty(context);
+}
+
+llvm::Type *LLVMTypes::I64Type()
+{
+    return llvm::Type::getInt64Ty(context);
+}
+
+llvm::Type *LLVMTypes::I128Type()
+{
+    return llvm::Type::getInt128Ty(context);
+}
+
+llvm::Type *LLVMTypes::DoubleType()
+{
+    return llvm::Type::getDoubleTy(context);
+}
+
+llvm::PointerType *LLVMTypes::PtrType(llvm::Type *type)
+{
+    return type->getPointerTo();
+}
+
+llvm::PointerType *LLVMTypes::I1PtrType()
+{
+    return PtrType(I1Type());
+}
+
+llvm::PointerType *LLVMTypes::I8PtrType()
+{
+    return PtrType(I8Type());
+}
+
+llvm::PointerType *LLVMTypes::I16PtrType()
+{
+    return PtrType(I16Type());
+}
+
+llvm::PointerType *LLVMTypes::I32PtrType()
+{
+    return PtrType(I32Type());
+}
+
+llvm::PointerType *LLVMTypes::I64PtrType()
+{
+    return PtrType(I64Type());
+}
+
+llvm::PointerType *LLVMTypes::I128PtrType()
+{
+    return PtrType(I128Type());
+}
+
+llvm::PointerType *LLVMTypes::DoublePtrType()
+{
+    return PtrType(DoubleType());
+}
+
+llvm::Type *LLVMTypes::ToLLVMType(DataTypeId id)
+{
+    auto result = VectorToLLVMTypeMap.find(id);
+    return (result == VectorToLLVMTypeMap.end()) ? NULL : result->second;
+}
+
+llvm::Type *LLVMTypes::VectorToLLVMType(const DataType &type)
+{
+    return ToLLVMType(type.GetId());
+}
+
+llvm::Type *LLVMTypes::ToPointerType(DataTypeId typeId)
+{
+    switch (typeId) {
+        case OMNI_BOOLEAN:
+            return I1PtrType();
+        case OMNI_BYTE:
+            return I8PtrType();
+        case OMNI_SHORT:
+            return I16PtrType();
+        case OMNI_INT:
+        case OMNI_DATE32:
+            return I32PtrType();
+        case OMNI_LONG:
+        case OMNI_TIMESTAMP:
+        case OMNI_DECIMAL64:
+            return I64PtrType();
+        case OMNI_DOUBLE:
+            return DoublePtrType();
+        case OMNI_CHAR:
+        case OMNI_VARCHAR:
+            return I8PtrType();
+        case OMNI_DECIMAL128:
+            return I128PtrType();
+        default:
+            LLVM_DEBUG_LOG("Unsupported column data type %d", typeId);
+            return I64PtrType();
+    }
+}
+
+llvm::Type *LLVMTypes::ToBatchDataPointerType(DataTypeId typeId)
+{
+    switch (typeId) {
+        case OMNI_BOOLEAN:
+            return I1PtrType();
+        case OMNI_BYTE:
+            return I8PtrType();
+        case OMNI_SHORT:
+            return I16PtrType();
+        case OMNI_INT:
+        case OMNI_DATE32:
+            return I32PtrType();
+        case OMNI_LONG:
+        case OMNI_TIMESTAMP:
+        case OMNI_DECIMAL64:
+            return I64PtrType();
+        case OMNI_DOUBLE:
+            return DoublePtrType();
+        case OMNI_CHAR:
+        case OMNI_VARCHAR:
+            return PtrType(I8PtrType());
+        case OMNI_DECIMAL128:
+            return I128PtrType();
+        default:
+            LLVM_DEBUG_LOG("Unsupported column data type %d", typeId);
+            return I64PtrType();
+    }
+}
+
+llvm::Type *LLVMTypes::GetFunctionReturnType(DataTypeId typeId)
+{
+    if (TypeUtil::IsStringType(typeId)) {
+        return Type::getInt64Ty(context);
+    } else {
+        return ToLLVMType(typeId);
+    }
+}
+}
diff --git a/core/src/codegen/llvm_types.h b/core/src/codegen/llvm_types.h
new file mode 100644
index 0000000..b417d2e
--- /dev/null
+++ b/core/src/codegen/llvm_types.h
@@ -0,0 +1,86 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
+ * Description: Expression code generator
+ */
+#ifndef OMNI_RUNTIME_LLVM_TYPES_H
+#define OMNI_RUNTIME_LLVM_TYPES_H
+
+#include 
+#include 
+#include 
+#include "util/type_util.h"
+#include "type/data_type.h"
+
+namespace omniruntime::codegen {
+class LLVMTypes {
+public:
+    explicit LLVMTypes(llvm::LLVMContext &context);
+
+    LLVMTypes();
+
+    llvm::Type *VoidType();
+
+    llvm::Type *I1Type();
+
+    llvm::Type *I8Type();
+
+    llvm::Type *I16Type();
+
+    llvm::Type *I32Type();
+
+    llvm::Type *I64Type();
+
+    llvm::Type *I128Type();
+
+    llvm::Type *DoubleType();
+
+    llvm::PointerType *PtrType(llvm::Type *type);
+
+    llvm::PointerType *I1PtrType();
+
+    llvm::PointerType *I8PtrType();
+
+    llvm::PointerType *I16PtrType();
+
+    llvm::PointerType *I32PtrType();
+
+    llvm::PointerType *I64PtrType();
+
+    llvm::PointerType *DoublePtrType();
+
+    llvm::PointerType *I128PtrType();
+
+    llvm::Value *CreateConstantBool(bool n);
+
+    llvm::Value *CreateConstantByte(int8_t n);
+
+    llvm::Value *CreateConstantShort(int16_t n);
+
+    llvm::Value *CreateConstantInt(int32_t n);
+
+    llvm::Value *CreateConstantLong(int64_t n);
+
+    llvm::Value *CreateConstantDouble(double n);
+
+    llvm::Value *CreateConstant128(int64_t v);
+
+    // / For a given Vector type, find the corresponding ir type.
+    llvm::Type *ToLLVMType(omniruntime::type::DataTypeId id);
+
+    llvm::Type *VectorToLLVMType(const omniruntime::type::DataType &type);
+
+    llvm::Type *ToPointerType(omniruntime::type::DataTypeId typeId);
+
+    llvm::Type *ToBatchDataPointerType(omniruntime::type::DataTypeId typeId);
+
+    llvm::Type *GetFunctionReturnType(omniruntime::type::DataTypeId typeId);
+
+    virtual ~LLVMTypes();
+
+private:
+    std::map VectorToLLVMTypeMap;
+    llvm::LLVMContext &context;
+};
+}
+
+#endif // OMNI_RUNTIME_LLVM_TYPES_H
\ No newline at end of file
diff --git a/core/src/codegen/projection_codegen.cpp b/core/src/codegen/projection_codegen.cpp
new file mode 100644
index 0000000..49469c5
--- /dev/null
+++ b/core/src/codegen/projection_codegen.cpp
@@ -0,0 +1,236 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
+ * Description: project  codegen
+ */
+#include "projection_codegen.h"
+
+namespace omniruntime {
+namespace codegen {
+using namespace llvm;
+using namespace orc;
+using namespace omniruntime::expressions;
+using namespace omniruntime::type;
+
+namespace {
+const int INPUT_TABLE_INDEX = 0;
+const int NUM_ROWS_INDEX = 1;
+const int OUTPUT_ADDRESS_INDEX = 2;
+const int SELECTED = 3;
+const int NUM_SELECTED = 4;
+const int BITMAP = 5;
+const int OFFSETS_INDEX = 6;
+const int NEW_NULL_VALUES_INDEX = 7;
+const int OUTPUT_OFFSETS_INDEX = 8;
+const int EXECUTION_CONTEXT_IDX = 9;
+const int DICTIONARY_VECTORS_IDX = 10;
+}
+
+intptr_t ProjectionCodeGen::GetFunction(const DataTypes &inputDataTypes)
+{
+    llvm::Function *func = CreateFunction(inputDataTypes);
+    if (func == nullptr) {
+        return 0;
+    }
+    return CreateWrapper();
+}
+
+intptr_t ProjectionCodeGen::CreateWrapper()
+{
+    // The args indicates the type of the function parameter list.
+    std::vector args {
+        llvmTypes->I64PtrType(), // data address array
+        llvmTypes->I32Type(),    // the num of rows
+        llvmTypes->I64Type(),    // output array address
+        llvmTypes->I32PtrType(), // selected array
+        llvmTypes->I32Type(),    // the num of selected rows
+        llvmTypes->I64PtrType(), // bitmap address array
+        llvmTypes->I64PtrType(), // offset address array
+        llvmTypes->I32PtrType(),  // output null values array
+        llvmTypes->I32PtrType(), // output offset array
+        llvmTypes->I64Type(),    // execution content address
+        llvmTypes->I64PtrType()  // dictionary address array
+    };
+
+    FunctionType *funcSignature = FunctionType::get(llvmTypes->I32Type(), args, false);
+    llvm::Function *funcDecl =
+        llvm::Function::Create(funcSignature, llvm::Function::ExternalLinkage, "WRAPPER_FUNC", modulePtr);
+    BasicBlock *preLoop = BasicBlock::Create(*context, "PRE_LOOP", funcDecl);
+    BasicBlock *loopBody = BasicBlock::Create(*context, "LOOP_BODY", funcDecl);
+    BasicBlock *addToOutput = BasicBlock::Create(*context, "ADD_OUTPUT", funcDecl);
+    BasicBlock *incrementCounter = BasicBlock::Create(*context, "INCREMENT_COUNTER", funcDecl);
+    BasicBlock *endBlock = BasicBlock::Create(*context, "END_BLOCK", funcDecl);
+    // preprocessing
+    Argument *input = funcDecl->getArg(INPUT_TABLE_INDEX);
+    input->setName("INPUT_TABLE");
+    Argument *numRows = funcDecl->getArg(NUM_ROWS_INDEX);
+    numRows->setName("NUM_ROWS");
+    Argument *outputAddress = funcDecl->getArg(OUTPUT_ADDRESS_INDEX);
+    outputAddress->setName("OUTPUT_ADDRESS");
+
+    RecordMainFunction(funcDecl);
+
+    // Only use these values if filter enabled
+    Argument *selected = nullptr;
+    Argument *numSelected = nullptr;
+    if (filter) {
+        selected = funcDecl->getArg(SELECTED);
+        selected->setName("SELECTED_ARRAY");
+        numSelected = funcDecl->getArg(NUM_SELECTED);
+        numSelected->setName("NUM_SELECTED");
+    }
+
+    Argument *bitmap = funcDecl->getArg(BITMAP);
+    bitmap->setName("BITMAP");
+
+    Argument *offsets = funcDecl->getArg(OFFSETS_INDEX);
+    offsets->setName("OFFSETS");
+
+    Argument *nullValuesAddress = funcDecl->getArg(NEW_NULL_VALUES_INDEX);
+    nullValuesAddress->setName("NULL_VALUES_ADDRESS");
+
+    Argument *outputOffsetsAddress = funcDecl->getArg(OUTPUT_OFFSETS_INDEX);
+    outputOffsetsAddress->setName("OUTPUT_OFFSETS_ADDRESS");
+
+    Argument *executionContext = funcDecl->getArg(EXECUTION_CONTEXT_IDX);
+    executionContext->setName("EXECUTION_CONTEXT_ADDRESS");
+
+    Argument *dictionaryVectors = funcDecl->getArg(DICTIONARY_VECTORS_IDX);
+    dictionaryVectors->setName("DICTIONARY_VECTORS");
+
+    Value *zero = llvmTypes->CreateConstantInt(0);
+    Value *one = llvmTypes->CreateConstantInt(1);
+
+    // pre loop body
+    builder->SetInsertPoint(preLoop);
+    // i32* ptrToCounter,Pointer to counter
+    AllocaInst *indexStore = builder->CreateAlloca(llvmTypes->I32Type(), nullptr, "INDEX_COUNTER");
+    // Initialize row index to 0.
+    builder->CreateStore(zero, indexStore);
+    AllocaInst *offsetStore = builder->CreateAlloca(llvmTypes->I32Type(), nullptr, "CURRENT_OFFSET");
+    // Initialize offset to 0.
+    builder->CreateStore(zero, offsetStore);
+    // i32 counter, Counter variable value.
+    Value *curIndexVal;
+    // i32 rowIndex,Index of row to be processed.
+    Value *rowIndexVal;
+    // i32 nextCounterValue,Temp value for next row index.
+    Value *nextIndexVal;
+
+    // i64 selectedAddress,Only use if filter enabled
+    Value *selectedAddress;
+
+    // set bits null func
+    FunctionSignature setBitNullFuncSignature = FunctionSignature("WrapSetBitNull", { OMNI_INT }, OMNI_BOOLEAN);
+    llvm::Function *setBitNullFunc =
+        modulePtr->getFunction(FunctionRegistry::LookupFunction(&setBitNullFuncSignature)->GetId());
+
+    // Type of output column
+    llvm::Function *varcharVectorFunc = nullptr;
+    if (expr->GetReturnTypeId() == OMNI_CHAR || expr->GetReturnTypeId() == OMNI_VARCHAR) {
+        std::vector paramTypes = { OMNI_LONG, OMNI_INT, OMNI_VARCHAR };
+        FunctionSignature varcharVectorFuncSignature = FunctionSignature("WrapVarcharVector", paramTypes, OMNI_INT);
+        varcharVectorFunc =
+            modulePtr->getFunction(FunctionRegistry::LookupFunction(&varcharVectorFuncSignature)->GetId());
+    }
+    Type *outPtrType = llvmTypes->ToPointerType(expr->GetReturnTypeId());
+    if (outPtrType == nullptr) {
+        return 0;
+    }
+    Value *outColPtr = builder->CreateIntToPtr(outputAddress, outPtrType);
+    // Create a integer pointer to store output length value
+    AllocaInst *outputLenPtr = builder->CreateAlloca(llvmTypes->I32Type(), nullptr, "OUTPUT_LENGTH");
+    auto isNullPtr = builder->CreateAlloca(llvmTypes->I1Type(), nullptr, "IS_NULL");
+
+    auto columnArgs = exprFunc->ToColumnArgs(input);
+    auto dicArgs = exprFunc->ToDicArgs(dictionaryVectors);
+    auto nullArgs = exprFunc->ToNullArgs(bitmap);
+    auto offsetArgs = exprFunc->ToOffsetArgs(offsets);
+
+    builder->CreateBr(loopBody);
+    // loop body
+    builder->SetInsertPoint(loopBody);
+    // i32 counter = *ptrToCounter, Get the value of the current row index to process.
+    curIndexVal = builder->CreateLoad(llvmTypes->I32Type(), indexStore, "CUR_INDEX");
+    if (filter) {
+        // i32* selectedAddress = gep i32* selected, i32 counter, Get address of selected index.
+        selectedAddress = builder->CreateGEP(llvmTypes->I32Type(), selected, curIndexVal, "SELECTED_ADDRESS");
+        // i32 rowIndexVal = *selectedAddress
+        rowIndexVal = builder->CreateLoad(llvmTypes->I32Type(), selectedAddress);
+    } else {
+        // i32 rowIndexVal = counter
+        rowIndexVal = curIndexVal;
+    }
+
+    builder->CreateStore(llvmTypes->CreateConstantBool(false), isNullPtr);
+
+    // projFuncArgs contains the values of the arguments to the projection function
+    std::vector projFuncArgs;
+    int32_t argsSize = exprFunc->GetArgumentCount() + exprFunc->GetInputColumnCount() * 4;
+    projFuncArgs.reserve(argsSize);
+
+    projFuncArgs.push_back(rowIndexVal);
+    projFuncArgs.push_back(outputLenPtr);
+    projFuncArgs.push_back(executionContext);
+    projFuncArgs.push_back(isNullPtr);
+
+    projFuncArgs.insert(projFuncArgs.end(), columnArgs.begin(), columnArgs.end());
+    projFuncArgs.insert(projFuncArgs.end(), dicArgs.begin(), dicArgs.end());
+    projFuncArgs.insert(projFuncArgs.end(), nullArgs.begin(), nullArgs.end());
+    projFuncArgs.insert(projFuncArgs.end(), offsetArgs.begin(), offsetArgs.end());
+
+    // Get the boolean response for this row from the filter function.
+    // ret = column value after applying projection
+    CallInst *ret = builder->CreateCall(func, projFuncArgs, "ROW_PROCESS");
+
+    // Add the processed value to output column.
+    builder->CreateBr(addToOutput);
+    // Add row index to results array
+    builder->SetInsertPoint(addToOutput);
+
+    Value *gep;
+    Type *ty = llvmTypes->VectorToLLVMType(*(expr->GetReturnType()));
+    if (TypeUtil::IsStringType(expr->GetReturnTypeId())) {
+        auto outputLen = builder->CreateLoad(llvmTypes->I32Type(), outputLenPtr, "OUTPUT_LENGTH");
+        auto stringPtr = builder->CreateIntToPtr(ret, Type::getInt8PtrTy(*context));
+        // call wrap_varchar_vector function
+        std::vector argVals { outColPtr, curIndexVal, stringPtr, outputLen };
+        auto call = builder->CreateCall(varcharVectorFunc, argVals, "wrap_varchar_vector");
+        InlineFunctionInfo inlineFunctionInfo;
+        InlineFunction(*call, inlineFunctionInfo);
+    } else {
+        // x* gep = gep x* outColPtr, i32 counter
+        gep = builder->CreateGEP(ty, outColPtr, curIndexVal, "OUTPUT_ADDRESS");
+        // *gep = ret
+        builder->CreateStore(ret, gep);
+    }
+
+    auto setNullRet = builder->CreateCall(setBitNullFunc,
+        { nullValuesAddress, curIndexVal, builder->CreateLoad(llvmTypes->I1Type(), isNullPtr) }, "wrap_set_bit_null");
+    InlineFunctionInfo inlineSetNullFuncInfo;
+    InlineFunction(*setNullRet, inlineSetNullFuncInfo);
+
+    builder->CreateBr(incrementCounter);
+    // Increment loop counter
+    builder->SetInsertPoint(incrementCounter);
+    // Increment counter.
+    nextIndexVal = builder->CreateAdd(curIndexVal, one, "NEXT_INDEX");
+    builder->CreateStore(nextIndexVal, indexStore);
+    // If there are rows remaining, repeat, otherwise, exit.
+    Value *sentinel;
+    if (filter) {
+        sentinel = numSelected;
+    } else {
+        sentinel = numRows;
+    }
+    Value *cond = builder->CreateICmpSLT(nextIndexVal, sentinel, "END_LOOP_COND");
+    builder->CreateCondBr(cond, loopBody, endBlock);
+
+    // Return results
+    builder->SetInsertPoint(endBlock);
+    builder->CreateRet(nextIndexVal);
+    OptimizeFunctionsAndModule();
+
+    return Compile();
+}
+}
+}
diff --git a/core/src/codegen/projection_codegen.h b/core/src/codegen/projection_codegen.h
new file mode 100644
index 0000000..db8e2c6
--- /dev/null
+++ b/core/src/codegen/projection_codegen.h
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
+ * Description: project  codegen
+ */
+#ifndef PROJECTION_CODEGEN_H
+#define PROJECTION_CODEGEN_H
+
+#include "expression_codegen.h"
+#include "util/type_util.h"
+#include "vector/vector_batch.h"
+
+namespace omniruntime {
+namespace codegen {
+class ProjectionCodeGen : public ExpressionCodeGen {
+public:
+    /**
+     * Method to initialize a ProjectionCodeGen instance
+     * @param name ProjectionCodeGen module name
+     * @param expr the projection expression to code generation
+     * @param filter whether to support filter
+     * @param overflowConfig config of overflow
+     */
+    ProjectionCodeGen(std::string name, const omniruntime::expressions::Expr &expr, bool filter,
+        omniruntime::op::OverflowConfig *overflowConfig)
+        : ExpressionCodeGen(std::move(name), expr, overflowConfig), filter(filter)
+    {}
+
+    ~ProjectionCodeGen() override = default;
+
+    /**
+     * Method to get function of processing projection expression
+     * @param inputDataTypes is used to provide data type when preload data
+     * @return the address of function
+     */
+    intptr_t GetFunction(const DataTypes &inputDataTypes) override;
+
+private:
+    /**
+     * Method to generate function by using LLVM API which processes projection expression line by line
+     * @return the address of function
+     */
+    intptr_t CreateWrapper();
+
+    bool filter;
+};
+}
+}
+
+#endif
\ No newline at end of file
diff --git a/core/src/codegen/simple_filter_codegen.cpp b/core/src/codegen/simple_filter_codegen.cpp
new file mode 100644
index 0000000..2d4eed7
--- /dev/null
+++ b/core/src/codegen/simple_filter_codegen.cpp
@@ -0,0 +1,169 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
+ * Description: simple filter code generator
+ */
+#include "simple_filter_codegen.h"
+
+namespace omniruntime {
+namespace codegen {
+using namespace llvm;
+using namespace llvm::orc;
+using namespace omniruntime::expressions;
+
+namespace {
+const std::string FUNCTION_NAME = "WRAPPER_FUNC";
+const int SIMPLE_FILTER_OUTPUT_LENGTH_INDEX = 4;
+const int SIMPLE_FILTER_OUTPUT_IS_NULL_INDEX = 3;
+}
+
+void SimpleFilterCodeGen::Visit(const omniruntime::expressions::FieldExpr &fieldExpr)
+{
+    Value *data = this->codegenContext->data;
+    Value *isNulls = this->codegenContext->nullBitmap;
+    Value *lengths = this->codegenContext->offsets;
+
+    Value *colIdx = llvmTypes->CreateConstantInt(fieldExpr.colVal);
+    // Find address of this column in the addresses array argument.
+    Value *gep = builder->CreateGEP(llvmTypes->I64Type(), data, colIdx);
+    // Load the address value.
+    Value *elementAddr = builder->CreateLoad(llvmTypes->I64Type(), gep);
+    Value *elementPtr = GetPtrTypeFromInt(fieldExpr.GetReturnTypeId(), elementAddr);
+
+    Type *ty = llvmTypes->VectorToLLVMType(*(fieldExpr.GetReturnType()));
+    Value *dataValue = nullptr;
+    Value *length = nullptr;
+    if (TypeUtil::IsStringType(fieldExpr.GetReturnTypeId())) {
+        // Get length for varchar/char type
+        auto lengthGEP = builder->CreateGEP(llvmTypes->I32Type(), lengths, colIdx);
+        length = builder->CreateLoad(llvmTypes->I32Type(), lengthGEP);
+        // For varchar, only need to get the pointer
+        dataValue = elementPtr;
+    } else {
+        dataValue = builder->CreateLoad(ty, elementPtr);
+    }
+
+    // Get isNull value
+    auto isNullGEP = builder->CreateGEP(llvmTypes->I1Type(), isNulls, colIdx);
+    Value *isNull = builder->CreateLoad(llvmTypes->I1Type(), isNullGEP);
+
+    if (TypeUtil::IsDecimalType(fieldExpr.GetReturnTypeId())) {
+        Value *precision = llvmTypes->CreateConstantInt(
+            dynamic_cast(fieldExpr.GetReturnType().get())->GetPrecision());
+        Value *scale =
+            llvmTypes->CreateConstantInt(dynamic_cast(fieldExpr.GetReturnType().get())->GetScale());
+        this->value = std::make_shared(dataValue, isNull, precision, scale);
+    } else {
+        this->value = std::make_shared(dataValue, isNull, length);
+    }
+}
+
+bool SimpleFilterCodeGen::InitCodegenContext(iterator_range args)
+{
+    this->codegenContext = std::make_unique();
+    for (auto &arg : args) {
+        auto argName = arg.getName().str();
+        if (argName == "data") {
+            codegenContext->data = &arg;
+        } else if (argName == "isNulls") {
+            codegenContext->nullBitmap = &arg;
+        } else if (argName == "lengths") {
+            codegenContext->offsets = &arg;
+        } else if (argName == "executionContext") {
+            codegenContext->executionContext = &arg;
+        } else if (argName == "dataLength" || argName == "isResultNull") {
+            continue;
+        } else {
+            LLVM_DEBUG_LOG("Invalid argument %s", argName.c_str());
+            return false;
+        }
+    }
+
+    codegenContext->print = modulePtr->getOrInsertFunction("printf",
+        FunctionType::get(IntegerType::getInt32Ty(*context), PointerType::get(Type::getInt8Ty(*context), 0), true));
+
+    return true;
+}
+
+llvm::Function *SimpleFilterCodeGen::CreateFunction()
+{
+    // The args indicates the type of the function parameter list.
+    std::vector args {
+        llvmTypes->I64PtrType(), // valueArray*
+        llvmTypes->I1PtrType(),  // isNullArray*
+        llvmTypes->I32PtrType(), // lengthArray*
+        llvmTypes->I1PtrType(),  // isResultNull*
+        llvmTypes->I32PtrType(), // outputLength*
+        llvmTypes->I64Type()     // executionContext
+    };
+
+    FunctionType *prototype = FunctionType::get(llvmTypes->GetFunctionReturnType(expr->GetReturnTypeId()), args, false);
+    func = llvm::Function::Create(prototype, llvm::Function::ExternalLinkage, FUNCTION_NAME, modulePtr);
+
+    std::string argNames[] = {
+        "data", "isNulls", "lengths", "isResultNull",
+        "dataLength", "executionContext"
+    };
+    int32_t idx = 0;
+    for (auto &arg : func->args()) {
+        arg.setName(argNames[idx]);
+        idx++;
+    }
+
+    RecordMainFunction(func);
+
+    BasicBlock *body = BasicBlock::Create(*context, "FUNC_BODY", func);
+    builder->SetInsertPoint(body);
+
+    if (!InitCodegenContext(func->args())) {
+        return nullptr;
+    }
+
+    // Generate code
+    auto result = VisitExpr(*expr);
+    if (!result->IsValidValue()) {
+        return nullptr;
+    }
+
+    // Update final output Length
+    if (result->length != nullptr) {
+        Argument *outputLength = func->getArg(SIMPLE_FILTER_OUTPUT_LENGTH_INDEX);
+        Value *lengthGep = builder->CreateGEP(llvmTypes->I32Type(), outputLength, llvmTypes->CreateConstantInt(0),
+            "OUTPUT_LENGTH_ADDRESS");
+        builder->CreateStore(result->length, lengthGep);
+    }
+
+    Argument *isResultNull = this->func->getArg(SIMPLE_FILTER_OUTPUT_IS_NULL_INDEX);
+    Value *nullGep =
+        builder->CreateGEP(llvmTypes->I1Type(), isResultNull, llvmTypes->CreateConstantInt(0), "OUTPUT_NULL_ADDRESS");
+    builder->CreateStore(result->isNull, nullGep);
+
+    // Return value
+    builder->CreateRet(result->data);
+
+    OptimizeModule();
+
+    verifyFunction(*func);
+    return func;
+}
+
+intptr_t SimpleFilterCodeGen::GetFunction()
+{
+#ifdef DEBUG
+    std::cout << "Row Expression: " << std::endl;
+    ExprPrinter p;
+    expr->Accept(p);
+    std::cout << std::endl;
+#endif
+
+    auto func = this->CreateFunction();
+    if (func == nullptr) {
+        return 0;
+    }
+
+#ifdef DEBUG_LLVM
+    modulePtr->print(errs(), nullptr);
+#endif
+    return Compile();
+}
+}
+}
\ No newline at end of file
diff --git a/core/src/codegen/simple_filter_codegen.h b/core/src/codegen/simple_filter_codegen.h
new file mode 100644
index 0000000..ef2f307
--- /dev/null
+++ b/core/src/codegen/simple_filter_codegen.h
@@ -0,0 +1,44 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
+ * Description: simple filter code generator
+ */
+#ifndef OMNI_RUNTIME_SIMPLE_FILTER_CODEGEN_H
+#define OMNI_RUNTIME_SIMPLE_FILTER_CODEGEN_H
+
+#include 
+
+#include "expression_codegen.h"
+#include "util/type_util.h"
+
+namespace omniruntime {
+namespace codegen {
+class SimpleFilterCodeGen : public ExpressionCodeGen {
+public:
+    /* *
+     * Method to initialize a SimpleFilterCodeGen instance
+     * @param name Name for SimpleFilterCodeGen module
+     * @param expression the expression
+     * @param overflowConfig
+     */
+    SimpleFilterCodeGen(std::string name, const omniruntime::expressions::Expr &expression,
+        omniruntime::op::OverflowConfig *overflowConfig)
+        : ExpressionCodeGen(std::move(name), expression, overflowConfig)
+    {
+        this->ExtractVectorIndexes();
+    }
+
+    ~SimpleFilterCodeGen() override = default;
+
+    intptr_t GetFunction();
+
+    llvm::Function *CreateFunction();
+
+    void Visit(const omniruntime::expressions::FieldExpr &e) override;
+
+private:
+    bool InitCodegenContext(llvm::iterator_range args);
+};
+}
+}
+
+#endif
diff --git a/core/src/codegen/string_util.h b/core/src/codegen/string_util.h
new file mode 100644
index 0000000..c07ab8a
--- /dev/null
+++ b/core/src/codegen/string_util.h
@@ -0,0 +1,258 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
+ * Description: string some common operators
+ */
+
+#ifndef OMNI_RUNTIME_STRING_UTIL_H
+#define OMNI_RUNTIME_STRING_UTIL_H
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include "util/utf8_util.h"
+
+namespace omniruntime::codegen::function {
+static std::string REPLACE_ERR_MSG = "Replace failed";
+static std::string CONCAT_ERR_MSG = "Concat failed";
+static constexpr uint8_t EMPTY[] = "";
+static int32_t STEP = static_cast('a') - static_cast('A');
+static uint8_t BytesOfCodePointInUTF8[] = {
+    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x00..0F
+    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x10..1F
+    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x20..2F
+    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x30..3F
+    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x40..4F
+    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x50..5F
+    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x60..6F
+    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x70..7F
+    // Consecutive bytes cannot be used as the first byte
+    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x80..8F
+    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x90..9F
+    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0xA0..AF
+    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0xB0..BF
+    0, 0, // 0xC0..C1 - disallowed in UTF-8
+    2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // 0xC2..CF
+    2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // 0xD0..DF
+    3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, // 0xE0..EF
+    4, 4, 4, 4, 4, // 0xF0..F4
+    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 // 0xF5..FF - disallowed in UTF-8
+};
+
+class StringUtil {
+public:
+    static inline std::wstring ToWideString(std::string &s)
+    {
+        std::wstring_convert> convert;
+        return convert.from_bytes(s);
+    }
+
+    static inline const char *CastStrStr(bool *hasErr, const char *str, int32_t srcWidth, int32_t strLen,
+        int32_t *outLen, int32_t dstWidth)
+    {
+        int32_t chCount = std::min(srcWidth, dstWidth);
+        int32_t dstLen = 0;
+        int32_t count = 0;
+        while (dstLen < strLen && count < chCount) {
+            int32_t charLen = omniruntime::Utf8Util::LengthOfCodePoint(str[dstLen]);
+            if (charLen < 0) {
+                *hasErr = true;
+                *outLen = 0;
+                return nullptr;
+            }
+            dstLen += charLen;
+            count++;
+        }
+        *outLen = dstLen;
+        return str;
+    }
+
+    static inline const char *ConcatCharDiffWidths(int64_t contextPtr, const char *ap, int32_t aWidth, int32_t apLen,
+        const char *bp, int32_t bpLen, bool *hasErr, int32_t *outLen)
+    {
+        int32_t aPaddingCount = bpLen > 0 ? aWidth - omniruntime::Utf8Util::CountCodePoints(ap, apLen) : 0;
+        *outLen = apLen + aPaddingCount + bpLen;
+        if (*outLen <= 0) {
+            *outLen = 0;
+            return reinterpret_cast(EMPTY);
+        }
+
+        // allocate one more byte is mainly for memcpy_s, when the copy source and destination are
+        // both empty strings, the security function will not return an error.
+        auto ret = ArenaAllocatorMalloc(contextPtr, *outLen + 1);
+        errno_t res1 = memcpy_s(ret, *outLen + 1, ap, apLen);
+        errno_t res2 = memset_s(ret + apLen, *outLen - apLen + 1, ' ', aPaddingCount);
+        errno_t res3 = memcpy_s(ret + apLen + aPaddingCount, *outLen - (apLen + aPaddingCount) + 1, bp, bpLen);
+        if (res1 != EOK || res2 != EOK || res3 != EOK) {
+            *hasErr = true;
+            *outLen = 0;
+            return nullptr;
+        }
+        return ret;
+    }
+
+    static inline const char *ConcatStrDiffWidths(int64_t contextPtr, const char *ap, int32_t apLen, const char *bp,
+        int32_t bpLen, bool *hasErr, int32_t *outLen)
+    {
+        *outLen = apLen + bpLen;
+        if (*outLen <= 0) {
+            *outLen = 0;
+            return reinterpret_cast(EMPTY);
+        }
+        // allocate one more byte is mainly for memcpy_s, when the copy source and destination are
+        // both empty strings, the security function will not return an error.
+        auto ret = ArenaAllocatorMalloc(contextPtr, *outLen + 1);
+        errno_t res1 = memcpy_s(ret, *outLen + 1, ap, apLen);
+        errno_t res2 = memcpy_s(ret + apLen, *outLen - apLen + 1, bp, bpLen);
+        if (res1 != EOK || res2 != EOK) {
+            *hasErr = true;
+            *outLen = 0;
+            return nullptr;
+        }
+        return ret;
+    }
+
+    static inline const char *ConcatWsStrDiffWidths(int64_t contextPtr, const char *separator, int32_t separatorLen,
+        const char *ap, int32_t apLen, const char *bp, int32_t bpLen, bool *hasErr, int32_t *outLen)
+    {
+        *outLen = apLen + separatorLen + bpLen;
+        if (*outLen <= 0) {
+            *outLen = 0;
+            return reinterpret_cast(EMPTY);
+        }
+        auto ret = ArenaAllocatorMalloc(contextPtr, *outLen + 1);
+        errno_t res1 = memcpy_s(ret, *outLen + 1, ap, apLen);
+        errno_t res2 = memcpy_s(ret + apLen, *outLen + 1 - apLen, separator, separatorLen);
+        errno_t res3 = memcpy_s(ret + apLen + separatorLen, *outLen + 1 - apLen - separatorLen, bp, bpLen);
+        if (res1 != EOK || res2 != EOK || res3 != EOK) {
+            *hasErr = true;
+            *outLen = 0;
+            return nullptr;
+        }
+        return ret;
+    }
+
+    static inline const char *ReplaceWithSearchNotEmpty(int64_t contextPtr, const char *str, int32_t strLen,
+        const char *searchStr, int32_t searchLen, const char *replaceStr, int32_t replaceLen, bool *hasErr,
+        int32_t *outLen)
+    {
+        if (strLen == 0) {
+            *outLen = 0;
+            return reinterpret_cast(EMPTY);
+        }
+        std::string s = std::string(str, strLen);
+        std::string search = std::string(searchStr, searchLen);
+        std::string replace = std::string(replaceStr, replaceLen);
+        std::string::size_type matchIndex = 0;
+        if (replaceLen == 0) {
+            while ((matchIndex = s.find(search, matchIndex)) != std::string::npos) {
+                s = s.substr(0, matchIndex) + s.substr(matchIndex + searchLen);
+            }
+        } else {
+            while ((matchIndex = s.find(search, matchIndex)) != std::string::npos) {
+                s.replace(matchIndex, searchLen, replace);
+                matchIndex += replaceLen;
+            }
+        }
+
+        *outLen = static_cast(s.length());
+        auto ret = ArenaAllocatorMalloc(contextPtr, *outLen + 1);
+        error_t res = memcpy_s(ret, *outLen + 1, s.c_str(), s.length());
+        if (res != EOK) {
+            *hasErr = true;
+            *outLen = 0;
+            return nullptr;
+        }
+        return ret;
+    }
+
+    static inline const char *ReplaceWithSearchEmpty(int64_t contextPtr, const char *str, int32_t strLen,
+        const char *replaceStr, int32_t replaceLen, bool *hasErr, int32_t *outLen)
+    {
+        int32_t strCodePoints = omniruntime::Utf8Util::CountCodePoints(str, strLen);
+        *outLen = strLen + (strCodePoints + 1) * replaceLen;
+        auto ret = ArenaAllocatorMalloc(contextPtr, *outLen + 1);
+        int32_t indexBuffer = 0;
+        errno_t res;
+        for (int32_t index = 0; index < strLen;) {
+            res = memcpy_s(ret + indexBuffer, *outLen - indexBuffer + 1, replaceStr, replaceLen);
+            if (res != EOK) {
+                *hasErr = true;
+                *outLen = 0;
+                return nullptr;
+            }
+            indexBuffer += replaceLen;
+            int32_t codePointLength = omniruntime::Utf8Util::LengthOfCodePoint(*(str + index));
+            res = memcpy_s(ret + indexBuffer, *outLen - indexBuffer + 1, str + index, codePointLength);
+            if (res != EOK) {
+                *hasErr = true;
+                *outLen = 0;
+                return nullptr;
+            }
+            indexBuffer += codePointLength;
+            index += codePointLength;
+        }
+        res = memcpy_s(ret + indexBuffer, *outLen - indexBuffer + 1, replaceStr, replaceLen);
+        if (res != EOK) {
+            *hasErr = true;
+            *outLen = 0;
+            return nullptr;
+        }
+        return ret;
+    }
+
+    static inline void TrimString(std::string &str)
+    {
+        str.erase(0, str.find_first_not_of(' '));
+        str.erase(str.find_last_not_of(' ') + 1);
+    }
+
+    static inline bool StrContainsStr(const char *srcStr, int32_t srcLen, const char *matchStr, int32_t matchLen)
+    {
+        int next[matchLen];
+        next[0] = -1;
+        int i = 0;
+        int j = -1;
+        while (i < matchLen - 1) {
+            if (j == -1 || matchStr[i] == matchStr[j]) {
+                i++;
+                j++;
+                next[i] = j;
+            } else {
+                j = next[j];
+            }
+        }
+
+        i = 0;
+        j = 0;
+        while (i < srcLen && j < matchLen) {
+            if (j == -1 || srcStr[i] == matchStr[j]) {
+                i++;
+                j++;
+            } else {
+                j = next[j];
+            }
+        }
+
+        return j == matchLen;
+    }
+
+    static inline int32_t NumChars(const char *str, int32_t strLen)
+    {
+        int32_t len = 0;
+        int32_t i = 0;
+
+        while (i < strLen) {
+            len += 1;
+            int32_t offset = str[i] & 0xFF;
+            uint8_t numBytes = BytesOfCodePointInUTF8[offset];
+            i += numBytes == 0 ? 1 : numBytes;
+        }
+        return len;
+    }
+}; // class stringUtils
+} // namespace codegen function
+
+#endif // OMNI_RUNTIME_STRING_UTIL_H
diff --git a/core/src/codegen/time_util.h b/core/src/codegen/time_util.h
new file mode 100644
index 0000000..e58e0cc
--- /dev/null
+++ b/core/src/codegen/time_util.h
@@ -0,0 +1,234 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
+ * Description: timezone util
+ */
+
+#ifndef OMNI_RUNTIME_TIMEZONE_UTIL_H
+#define OMNI_RUNTIME_TIMEZONE_UTIL_H
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include "type/date_time_utils.h"
+#include "type/date32.h"
+
+namespace omniruntime::codegen::function {
+static const int YEAR_LENGTH = 4;
+static const int MONTH_LENGTH = 2;
+static const int DAY_LENGTH = 2;
+static const int HOUR_LENGTH = 2;
+static const int MINUTE_LENGTH = 2;
+static const int SECOND_LENGTH = 2;
+// for example: "2020-12-12 00:00:00"
+static const int TIME_LENGTH = 19;
+// for example: "2020-12-12"
+static const int DATE_LENGTH = 10;
+// for example: "%Y-%m-%d %H:%M:%S"
+static const int TIME_FORMAT_LENGTH = 17;
+// for example: "%Y-%m-%d"
+static const int DATE_FORMAT_LENGTH = 8;
+static const int MIN_YEAR = 0;
+static const int MAX_YEAR = 9999;
+static const int MIN_MONTH = 1;
+static const int MAX_MONTH = 12;
+static const int MIN_DAY = 1;
+static const int DAYS_PER_MONTH[] = {31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31};
+static const int FEBRUARY_DAY_IN_LEAP_YEAR = 29;
+static const int FEBRUARY = 2;
+static const int MIN_HOUR = 0;
+static const int MAX_HOUR = 23;
+static const int MIN_MINUTE = 0;
+static const int MAX_MINUTE = 59;
+static const int MIN_SECOND = 0;
+static const int MAX_SECOND = 59;
+static const int GREGORIAN_CALENDAR_START_YEAR = 1582;
+
+static const std::set UNIX_TIMESTAMP_FROM_DATE_SHANGHAI_NON_DST_SET = {
+    -18526 * type::SECOND_OF_DAY, -10806 * type::SECOND_OF_DAY,
+    -10519 * type::SECOND_OF_DAY, -10197 * type::SECOND_OF_DAY,
+    -8632 * type::SECOND_OF_DAY, -8297 * type::SECOND_OF_DAY,
+    -7915 * type::SECOND_OF_DAY, -7550 * type::SECOND_OF_DAY,
+    5967 * type::SECOND_OF_DAY, 6310 * type::SECOND_OF_DAY,
+    6681 * type::SECOND_OF_DAY, 7045 * type::SECOND_OF_DAY,
+    7409 * type::SECOND_OF_DAY, 7773 * type::SECOND_OF_DAY
+};
+static const std::set UNIX_TIMESTAMP_FROM_DATE_SHANGHAI_DST_SET = {
+    6100 * type::SECOND_OF_DAY, 6464 * type::SECOND_OF_DAY,
+    6828 * type::SECOND_OF_DAY, 7199 * type::SECOND_OF_DAY,
+    7563 * type::SECOND_OF_DAY, 7927 * type::SECOND_OF_DAY
+};
+static const std::set UNIX_TIMESTAMP_FROM_STR_SHANGHAI_NON_DST_SET = {
+    "1919-04-13", "1940-06-01", "1941-03-15", "1942-01-31", "1946-05-15", "1947-04-15", "1948-05-01", "1949-05-01"
+};
+
+using JudgeDSTActionByUnixTimestampFromDate = std::function;
+static std::map AdjustDSTByUnixTimestampFromDateMap = {
+    {"GMT+08:00", [](const struct tm *timeInfo, time_t desiredTime) {return timeInfo->tm_isdst;}},
+    {"Asia/Shanghai", [](const struct tm *timeInfo, time_t desiredTime) {
+            short flag = 0;
+            flag = UNIX_TIMESTAMP_FROM_DATE_SHANGHAI_NON_DST_SET.find(desiredTime) !=
+                    UNIX_TIMESTAMP_FROM_DATE_SHANGHAI_NON_DST_SET.end() ? 1 : flag;
+            flag = UNIX_TIMESTAMP_FROM_DATE_SHANGHAI_DST_SET.find(desiredTime) !=
+                    UNIX_TIMESTAMP_FROM_DATE_SHANGHAI_DST_SET.end() ? -1 : flag;
+            return flag;
+        }
+    }
+};
+
+using JudgeDSTActionByUnixTimestampFromStr =
+        std::function;
+static std::map JudgeDSTByUnixTimestampFromStrMap = {
+    {"GMT+08:00", [](const struct tm *timeInfo, const char *timeStr, int32_t timeLen,
+            const char *fmtStr, int32_t fmtLen) {return false;}},
+    {"Asia/Shanghai", [](const struct tm *timeInfo, const char *timeStr, int32_t timeLen,
+            const char *fmtStr, int32_t fmtLen) {
+            bool flag = false;
+            flag = timeInfo->tm_isdst > 0;
+            std::string substr(timeStr, timeLen);
+            flag = UNIX_TIMESTAMP_FROM_STR_SHANGHAI_NON_DST_SET.find(substr) ==
+                    UNIX_TIMESTAMP_FROM_STR_SHANGHAI_NON_DST_SET.end() && flag;
+            return flag;
+        }
+    }
+};
+
+using JudgeDSTActionByFromUnixTime = std::function;
+static std::map JudgeDSTByFromUnixTimeMap = {
+    {"GMT+08:00", [](const struct tm *timeInfo) {return timeInfo->tm_isdst > 0 ? false : true;}},
+    {"Asia/Shanghai", [](const struct tm *timeInfo) {return true;}}
+};
+
+class TimeZoneUtil {
+public:
+    static inline short AdjustDSTByUnixTimestampFromDate(const char *tzStr,
+        int32_t tzLen, const struct tm *timeInfo, time_t desiredTime)
+    {
+        std::string timeZoneStr(tzStr, tzLen);
+        auto it = AdjustDSTByUnixTimestampFromDateMap.find(timeZoneStr);
+        return it ->second(timeInfo, desiredTime);
+    }
+
+    static inline bool JudgeDSTByUnixTimestampFromStr(const char *tzStr, int32_t tzLen, const struct tm *timeInfo,
+        const char *timeStr, int32_t timeLen, const char *fmtStr, int32_t fmtLen)
+    {
+        std::string timeZoneStr(tzStr, tzLen);
+        auto it = JudgeDSTByUnixTimestampFromStrMap.find(timeZoneStr);
+        return it->second(timeInfo, timeStr, timeLen, fmtStr, fmtLen);
+    }
+
+    static inline bool JudgeDSTByFromUnixTime(const char *tzStr, int32_t tzLen, const struct tm * timeInfo)
+    {
+        std::string timeZoneStr(tzStr, tzLen);
+        auto it = JudgeDSTByFromUnixTimeMap.find(timeZoneStr);
+        return it->second(timeInfo);
+    }
+
+    static inline const char* GetTZ(const char *tzStr)
+    {
+        if (strcmp(tzStr, "GMT+08:00") == 0) {
+            return "Etc/GMT-8";
+        } else {
+            return tzStr;
+        }
+    }
+}; // class TimeZoneUtil
+class TimeUtil {
+public:
+    // Verify that the format is %Y-%m-%d %H:%M:%S and %Y-%m-%d in the blue zone.
+    static bool IsTimeValid(const char *timeStr, int timeLen, const char *fmtStr, int fmtLen, const char *policyStr)
+    {
+        if ((timeLen != TIME_LENGTH || fmtLen != TIME_FORMAT_LENGTH) &&
+                (timeLen != DATE_LENGTH || fmtLen != DATE_FORMAT_LENGTH)) {
+            return false;
+        }
+        int year = 0;
+        int month = 0;
+        int day = 0;
+        int offset = 0;
+        bool retYear = CheckAndGetNonNegativeInteger(timeStr, timeLen, offset, YEAR_LENGTH, year);
+        if (!retYear || year < MIN_YEAR || year > MAX_YEAR) {
+            return false;
+        }
+        offset = YEAR_LENGTH + 1;
+        bool retMonth = CheckAndGetNonNegativeInteger(timeStr, timeLen, offset, MONTH_LENGTH, month);
+        if (!retMonth || month < MIN_MONTH || month > MAX_MONTH) {
+            return false;
+        }
+        offset += MONTH_LENGTH + 1;
+        bool retDay = CheckAndGetNonNegativeInteger(timeStr, timeLen, offset, DAY_LENGTH, day);
+        if (!retDay) {
+            return false;
+        }
+        bool hasException = false;
+        if (month == FEBRUARY) {
+            if (LocalDate::IsGregorianLeapYear(year)) {
+                if (day < MIN_DAY || day > FEBRUARY_DAY_IN_LEAP_YEAR) {
+                    return false;
+                }
+            } else if (strcmp(policyStr, "EXCEPTION") == 0 && LocalDate::IsJulianLeapYear(year) &&
+                year < GREGORIAN_CALENDAR_START_YEAR && year > 0 && day == FEBRUARY_DAY_IN_LEAP_YEAR) {
+                hasException = true;
+            } else {
+                if (day < MIN_DAY || day > DAYS_PER_MONTH[month-1]) {
+                    return false;
+                }
+            }
+        } else {
+            if (day < MIN_DAY || day > DAYS_PER_MONTH[month-1]) {
+                return false;
+            }
+        }
+        // It means that the format is "%Y-%m-%d %H:%M:%S“
+        if (fmtLen == TIME_FORMAT_LENGTH) {
+            int hour = 0;
+            int minute = 0;
+            int second = 0;
+            offset += DAY_LENGTH + 1;
+            bool retHour = CheckAndGetNonNegativeInteger(timeStr, timeLen, offset, HOUR_LENGTH, hour);
+            if (!retHour || hour < MIN_HOUR || hour > MAX_HOUR) {
+                return false;
+            }
+            offset += HOUR_LENGTH + 1;
+            bool retMinute = CheckAndGetNonNegativeInteger(timeStr, timeLen, offset, MINUTE_LENGTH, minute);
+            if (!retMinute || minute < MIN_MINUTE || minute > MAX_MINUTE) {
+                return false;
+            }
+            offset += MINUTE_LENGTH + 1;
+            bool retSecond = CheckAndGetNonNegativeInteger(timeStr, timeLen, offset, SECOND_LENGTH, second);
+            if (!retSecond || second < MIN_SECOND || second > MAX_SECOND) {
+                return false;
+            }
+        }
+        if (hasException) {
+            throw exception::OmniException("OPERATOR_RUNTIME_ERROR",
+                "Invalid date 'February 29' as '" + std::to_string(year) + "' is not a leap year");
+        }
+        return true;
+    }
+
+    static size_t GetWallTimeMillis()
+    {
+        return std::chrono::duration_cast(
+            std::chrono::steady_clock::now().time_since_epoch()).count();
+    }
+
+private:
+    static bool CheckAndGetNonNegativeInteger(const char *str, int strLen, int start, int substrLen, int &outValue)
+    {
+        outValue = 0;
+        auto ptr = str + start;
+        for (int i = 0; i < substrLen; i++) {
+            auto value = ptr[i] - '0';
+            if (value < 0 || value > 9) {
+                return false;
+            }
+            outValue = outValue * 10 + value;
+        }
+        return true;
+    }
+}; // class TimeUtil
+} // namespace codegen function
+
+#endif // OMNI_RUNTIME_TIMEZONE_UTIL_H
diff --git a/core/src/compute/CMakeLists.txt b/core/src/compute/CMakeLists.txt
new file mode 100644
index 0000000..b348e0d
--- /dev/null
+++ b/core/src/compute/CMakeLists.txt
@@ -0,0 +1,4 @@
+include_directories(${CMAKE_CURRENT_SOURCE_DIR})
+
+file(GLOB COMPUTE_FILES ${SOURCE_ROOT}/src/compute/*.h)
+install(FILES ${COMPUTE_FILES} DESTINATION ${CMAKE_INSTALL_PREFIX}/include/compute)
diff --git a/core/src/compute/ColumnarBatchIterator.h b/core/src/compute/ColumnarBatchIterator.h
new file mode 100644
index 0000000..3bb137d
--- /dev/null
+++ b/core/src/compute/ColumnarBatchIterator.h
@@ -0,0 +1,19 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+ */
+
+#pragma once
+
+#include 
+
+namespace omniruntime {
+
+class ColumnarBatchIterator {
+public:
+    ColumnarBatchIterator() {}
+
+    virtual ~ColumnarBatchIterator() = default;
+
+    virtual vec::VectorBatch* Next() = 0;
+};
+}
\ No newline at end of file
diff --git a/core/src/compute/ResultIterator.cpp b/core/src/compute/ResultIterator.cpp
new file mode 100644
index 0000000..5b74064
--- /dev/null
+++ b/core/src/compute/ResultIterator.cpp
@@ -0,0 +1,9 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+ */
+
+#include "ResultIterator.h"
+
+namespace omniruntime {
+
+} // namespace gluten
diff --git a/core/src/compute/ResultIterator.h b/core/src/compute/ResultIterator.h
new file mode 100644
index 0000000..f57eb26
--- /dev/null
+++ b/core/src/compute/ResultIterator.h
@@ -0,0 +1,80 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+ */
+
+#pragma once
+
+#include "ColumnarBatchIterator.h"
+#include "metrics/omni_metrics.h"
+
+namespace omniruntime {
+
+class ResultIterator {
+public:
+    explicit ResultIterator(std::unique_ptr iter)
+        : iter_(std::move(iter)), next_(nullptr) {}
+
+    // copy constructor and copy assignment (deleted)
+    ResultIterator(const ResultIterator &in) = delete;
+
+    ResultIterator &operator=(const ResultIterator &) = delete;
+
+    // move constructor and move assignment
+    ResultIterator(ResultIterator &&in) = default;
+
+    ResultIterator &operator=(ResultIterator &&in) = default;
+
+    bool HasNext()
+    {
+        CheckValid();
+        GetNext();
+        return next_ != nullptr;
+    }
+
+    vec::VectorBatch *Next()
+    {
+        CheckValid();
+        GetNext();
+        auto tmp = next_;
+        next_ = nullptr;
+        return tmp;
+    }
+
+    // For testing and benchmarking.
+    ColumnarBatchIterator *GetInputIter()
+    {
+        return iter_.get();
+    }
+
+    void SetExportNanos(int64_t exportNanos)
+    {
+        exportNanos_ = exportNanos;
+    }
+
+    int64_t GetExportNanos() const
+    {
+        return exportNanos_;
+    }
+
+    OmniMetrics* getMetrics();
+
+private:
+    void CheckValid() const
+    {
+        if (iter_ == nullptr) {
+            throw std::runtime_error("ResultIterator: the underlying iterator has expired.");
+        }
+    }
+
+    void GetNext()
+    {
+        if (next_ == nullptr) {
+            next_ = iter_->Next();
+        }
+    }
+
+    std::unique_ptr iter_;
+    vec::VectorBatch *next_;
+    int64_t exportNanos_{0};
+};
+}
diff --git a/core/src/compute/cpuWall_timer.cpp b/core/src/compute/cpuWall_timer.cpp
new file mode 100644
index 0000000..528c0b9
--- /dev/null
+++ b/core/src/compute/cpuWall_timer.cpp
@@ -0,0 +1,22 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+ */
+
+#include "cpuWall_timer.h"
+
+namespace omniruntime::compute {
+
+    CpuWallTimer::CpuWallTimer(CpuWallTiming& timing) : timing_(timing)
+    {
+        ++timing_.count;
+        cpuTimeStart_ = ThreadCpuNanos();
+        wallTimeStart_ = 0;
+    }
+
+    CpuWallTimer::~CpuWallTimer()
+    {
+        timing_.cpuNanos += ThreadCpuNanos() - cpuTimeStart_;
+        timing_.wallNanos += 0;
+    }
+
+} // namespace omniruntime::compute
\ No newline at end of file
diff --git a/core/src/compute/cpuWall_timer.h b/core/src/compute/cpuWall_timer.h
new file mode 100644
index 0000000..6599171
--- /dev/null
+++ b/core/src/compute/cpuWall_timer.h
@@ -0,0 +1,102 @@
+/*
+* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+ */
+ 
+#ifndef CPUWALL_TIMER_H
+#define CPUWALL_TIMER_H
+
+#pragma once
+ 
+#include 
+
+#include "process_base.h"
+#include "util/format.h"
+ 
+namespace omniruntime::compute {
+ 
+// Tracks call count and elapsed CPU and wall time for a repeating operation.
+struct CpuWallTiming {
+    int64_t  count{0};
+    int64_t  wallNanos{0};
+    int64_t  cpuNanos{0};
+ 
+    void Add(const CpuWallTiming& other)
+    {
+        count += other.count;
+        cpuNanos += other.cpuNanos;
+        wallNanos += other.wallNanos;
+    }
+ 
+    void Clear()
+    {
+        count = 0;
+        wallNanos = 0;
+        cpuNanos = 0;
+    }
+ 
+    std::string toString() const
+    {
+        return Format(
+            "count: {}, wallTime: {}, cpuTime: {}",
+            count,
+            wallNanos,
+            cpuNanos);
+    }
+};
+ 
+// Adds elapsed CPU and wall time to a CpuWallTiming.
+class CpuWallTimer {
+public:
+    explicit CpuWallTimer(CpuWallTiming& timing);
+    ~CpuWallTimer();
+ 
+private:
+    int64_t cpuTimeStart_;
+    int64_t wallTimeStart_;
+    CpuWallTiming& timing_;
+};
+ 
+/// Keeps track of elapsed CPU and wall time from construction time.
+class DeltaCpuWallTimeStopWatch {
+public:
+    explicit DeltaCpuWallTimeStopWatch()
+        : wallTimeStart_(0),
+        cpuTimeStart_(ThreadCpuNanos()) {}
+ 
+    CpuWallTiming Elapsed() const
+    {
+        // NOTE: End the cpu-time timing first, and then end the wall-time timing,
+        // so as to avoid the counter-intuitive phenomenon that the final calculated
+        // cpu-time is slightly larger than the wall-time.
+        int64_t cpuTimeDuration = ThreadCpuNanos() - cpuTimeStart_;
+        int64_t wallTimeDuration = 0;
+        return CpuWallTiming{1, wallTimeDuration, cpuTimeDuration};
+    }
+ 
+private:
+    // NOTE: Put `wallTimeStart_` before `cpuTimeStart_`, so that wall-time starts
+    // counting earlier than cpu-time.
+    const int64_t wallTimeStart_;
+    const int64_t cpuTimeStart_;
+};
+ 
+/// Composes delta CpuWallTiming upon destruction and passes it to the user
+/// callback, where it can be added to the user's CpuWallTiming using
+/// CpuWallTiming::add().
+template 
+class DeltaCpuWallTimer {
+public:
+    explicit DeltaCpuWallTimer(F&& func) : func_(std::move(func)) {}
+ 
+    ~DeltaCpuWallTimer()
+    {
+        func_(timer_.Elapsed());
+    }
+ 
+private:
+    DeltaCpuWallTimeStopWatch timer_;
+    F func_;
+};
+ 
+} // namespace omniruntime
+#endif
diff --git a/core/src/compute/driver.cpp b/core/src/compute/driver.cpp
new file mode 100644
index 0000000..d28ac6b
--- /dev/null
+++ b/core/src/compute/driver.cpp
@@ -0,0 +1,249 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+ */
+ 
+#include "driver.h"
+#include "codegen/time_util.h"
+#include 
+
+namespace omniruntime::compute {
+std::atomic_uint64_t BlockingState::numBlockdDrivers_{0};
+
+BlockingState::BlockingState(
+    std::shared_ptr driver,
+    ContinueFuture &&future,
+    omniruntime::op::Operator *op,
+    BlockingReason reason)
+    : driver_(std::move(driver)),
+      future_(std::move(future)),
+      operator_(op),
+      reason_(reason)
+{
+    numBlockdDrivers_++;
+}
+
+vec::VectorBatch *OmniDriver::Next(ContinueFuture *future, StopReason *stopReason)
+{
+    auto self = shared_from_this();
+    std::shared_ptr blockingState;
+    vec::VectorBatch *result = nullptr;
+    *stopReason = RunInternal(self, blockingState, &result);
+
+    if (blockingState != nullptr) {
+        *future = blockingState->Future();
+        return nullptr;
+    }
+
+    if (*stopReason == StopReason::kPause) {
+        return nullptr;
+    }
+
+    return result;
+}
+
+void OmniDriver::close()
+{
+    if (closed_) {
+        return;
+    }
+    for (auto &op : operators_) {
+        op->Close();
+        op = nullptr;
+    }
+    closed_ = true;
+}
+
+// Call an Operator method. record silenced throws, but not a query
+// terminating throw. Annotate exceptions with Operator info.
+#define CALL_OPERATOR(call, operatorPtr, operatorId, operatorMethod)                                    \
+        opCallStatus_.Start(operatorId, operatorMethod);                                                \
+        call;                                                                                           \
+        opCallStatus_.TimeSegmentStatistic(operatorPtr, operatorMethod);                                \
+        opCallStatus_.Stop();                                                                           \
+
+void OpCallStatus::Start(int32_t operatorId, const char* operatorMethod)
+{
+    opId = operatorId;
+    method = operatorMethod;
+    cpuTimeStartNs = ThreadCpuNanos();
+}
+
+void OpCallStatus::Stop()
+{
+    cpuTimeStartNs = 0;
+}
+
+CpuWallTiming OmniDriver::processLazyIoStats(op::Operator& op, const CpuWallTiming& timing)
+{
+    if (&op == operators_[0].get()) {
+        return timing;
+    }
+    auto lockStats = op.stats();
+
+    int64_t wallDelta = 0;
+    uint64_t inputBytesDelta = 0;
+    wallDelta = std::min(wallDelta, timing.wallNanos);
+    lockStats = operators_[0]->stats();
+    lockStats.getOutputTime.Add(CpuWallTiming{
+        1, wallDelta, 0
+    });
+    lockStats.inputBytes += inputBytesDelta;
+    lockStats.outputBytes += inputBytesDelta;
+    return CpuWallTiming{
+        1,
+        timing.wallNanos - wallDelta,
+        timing.cpuNanos - 0,
+    };
+}
+
+void OpCallStatus::TimeSegmentStatistic(op::Operator* op, const char* operatorMethod) const
+{
+    const int64_t cpuTimeSegment = ThreadCpuNanos() - cpuTimeStartNs;
+    std::string_view opMethod(operatorMethod);
+    if (opMethod != kOpMethodAddInput && opMethod != kOpMethodGetOutput) {
+        LogDebug("not input or output for operator");
+        return;
+    }
+    auto &lockedStats = op->stats();
+    if (opMethod == kOpMethodAddInput) {
+        lockedStats.addInputTime.cpuNanos = cpuTimeSegment / static_cast(1e6);
+        lockedStats.addInputTime.count = 1;
+    } else if (opMethod == kOpMethodGetOutput) {
+        lockedStats.getOutputTime.cpuNanos = cpuTimeSegment / static_cast(1e6);
+        lockedStats.getOutputTime.count = 1;
+    }
+}
+
+StopReason OmniDriver::RunInternal(
+    std::shared_ptr &self,
+    std::shared_ptr &blockingState,
+    vec::VectorBatch **result)
+{
+    try {
+        const uint32_t numOperators = operators_.size();
+        ContinueFuture future = OmniFuture::makeEmpty();
+        for (;;) {
+            for (int32_t i = numOperators - 1; i >= 0; --i) {
+                if (shouldStop) {
+                    return StopReason::kAtEnd;
+                }
+                auto *op = operators_[i].get();
+                curOperatorId_ = i;
+
+                blockingReason_ = op->IsBlocked(&future);
+                if (blockingReason_ != BlockingReason::kNotBlocked) {
+                    return BlockDriver(self, i, std::move(future), blockingState);
+                }
+
+                if (i < numOperators - 1) {
+                    auto *nextOp = operators_[i + 1].get();
+                    blockingReason_ = nextOp->IsBlocked(&future);
+                    if (blockingReason_ != BlockingReason::kNotBlocked) {
+                        return BlockDriver(self, i + 1, std::move(future), blockingState);
+                    }
+
+                    bool needsInput;
+                    CALL_OPERATOR(needsInput = nextOp->needsInput(), nextOp, curOperatorId_ + 1, kOpMethodNeedsInput);
+                    if (needsInput) {
+                        uint64_t resultBytes = 0;
+                        vec::VectorBatch *intermediateResult = nullptr;
+                        withDeltaCpuWallTimer(op, &OperatorStats::getOutputTime, [&]() {
+                            CALL_OPERATOR(op->GetOutput(&intermediateResult), op, curOperatorId_, kOpMethodGetOutput);
+                            if (intermediateResult) {
+                                resultBytes = intermediateResult->CalculateTotalSize();
+                                {
+                                    auto &lockedStats = op->stats();
+                                    lockedStats.AddOutputVector(resultBytes, intermediateResult->GetVectorCount(), intermediateResult->GetRowCount());
+                                }
+                            }
+                        });
+                        if (intermediateResult != nullptr) {
+                            withDeltaCpuWallTimer(nextOp, &OperatorStats::addInputTime, [&]() {
+                                {
+                                    auto &lockedStats = nextOp->stats();
+                                    lockedStats.AddInputVector(resultBytes, intermediateResult->GetVectorCount(), intermediateResult->GetRowCount());
+                                }
+                                CALL_OPERATOR(nextOp->AddInput(intermediateResult), nextOp, curOperatorId_ + 1,
+                                              kOpMethodAddInput);
+                            });
+
+                            // The next iteration will see if operators_[i + 1] has
+                            // output now that it got input
+                            i += 2;
+                            continue;
+                        } else {
+                            blockingReason_ = op->IsBlocked(&future);
+                            if (blockingReason_ != BlockingReason::kNotBlocked) {
+                                return BlockDriver(self, i, std::move(future), blockingState);
+                            }
+                            if (op->isFinished()) {
+                                nextOp->noMoreInput();
+                                break;
+                            }
+                        }
+                    }
+                } else {
+                    withDeltaCpuWallTimer(op, &OperatorStats::getOutputTime, [&]() {
+                        CALL_OPERATOR(op->GetOutput(result), op, curOperatorId_, kOpMethodGetOutput);
+                        if (*result != nullptr) {
+                            {
+                                auto &lockedStats = op->stats();
+                                lockedStats.AddOutputVector((*result)->CalculateTotalSize(), (*result)->GetVectorCount(), (*result)->GetRowCount());
+                            }
+                        }
+                    });
+                    if (*result != nullptr  && !op->isFinished()) {
+                        blockingReason_ = BlockingReason::kWaitForConsumer;
+                        return StopReason::kBlock;
+                    }
+
+                    bool finished{false};
+                    finished = op->isFinished();
+                    if (finished) {
+                        finished_ = true;
+                        return StopReason::kAtEnd;
+                    }
+                }
+            }
+        }
+    } catch (const std::exception &e) {
+        throw std::runtime_error(e.what());
+    }
+}
+#undef CALL_OPERATOR
+
+StopReason OmniDriver::BlockDriver(
+    const std::shared_ptr &self,
+    size_t blockedOperatorId,
+    ContinueFuture &&future,
+    std::shared_ptr &blockingState)
+{
+    auto *op = operators_[blockedOperatorId].get();
+    blockedOperatorId_ = blockedOperatorId;
+    blockingState = std::make_shared(
+        self, std::move(future), op, blockingReason_);
+    return StopReason::kBlock;
+}
+
+template 
+void OmniDriver::withDeltaCpuWallTimer(op::Operator* op, TimingMemberPtr opTimingMember, Func&& opFunction)
+{
+    // If 'trackOperatorCpuUsage_' is true, create and initialize the timer object
+    // to track cpu and wall time of the opFunction.
+    if (!trackOperatorCpuUsage_) {
+        opFunction();
+        return;
+    }
+
+    // The delta CpuWallTiming object would be recorded to the corresponding
+    // 'opTimingMember' upon destruction of the timer when withDeltaCpuWallTimer
+    // ends. The timer is created on the stack to avoid heap allocation
+    auto f = [op, opTimingMember, this](const CpuWallTiming& elapsedTime) {
+        auto elapsedSelfTime = processLazyIoStats(*op, elapsedTime);
+        (op->stats().*opTimingMember).Add(elapsedSelfTime);
+    };
+    DeltaCpuWallTimer timer(std::move(f));
+
+    opFunction();
+}
+} // end of omniruntime
\ No newline at end of file
diff --git a/core/src/compute/driver.h b/core/src/compute/driver.h
new file mode 100644
index 0000000..9eec5ac
--- /dev/null
+++ b/core/src/compute/driver.h
@@ -0,0 +1,176 @@
+/*
+ * @Copyright: Copyright (c) Huawei Technologies Co., Ltd. 2021-2024. All rights reserved.
+ */
+#ifndef __DRIVER_H__
+#define __DRIVER_H__
+ 
+#include 
+#include 
+#include 
+#include 
+ 
+#include "operator/operator.h"
+#include "operator/operator_factory.h"
+#include "vector/vector_batch.h"
+#include "compute/reason.h"
+#include "plannode/planNode.h"
+#include "plannode/RowVectorStream.h"
+
+namespace omniruntime {
+
+namespace compute {
+
+using OperatorSupplier = std::function<
+    std::unique_ptr(const OperatorConfig& operatorConfig)>;
+
+class BlockingState;
+class OmniTask;
+
+constexpr const char* kOpMethodNone = "";
+constexpr const char* kOpMethodIsBlocked = "isBlocked";
+constexpr const char* kOpMethodNeedsInput = "needsInput";
+constexpr const char* kOpMethodGetOutput = "getOutput";
+constexpr const char* kOpMethodAddInput = "addInput";
+constexpr const char* kOpMethodNoMoreInput = "noMoreInput";
+constexpr const char* kOpMethodIsFinished = "isFinished";
+
+
+/// Same as the structure below, but does not have atomic members.
+/// Used to return the status from the struct with atomics.
+struct OpCallStatusRaw {
+    /// cpu time (ms) when the operator call started.
+    clock_t cpuTimeStartMs{0};
+    /// wall Time (ms) when the operator call started.
+    size_t timeStartMs{0};
+    /// Id of the operator, method of which is currently running. It is index into
+    /// the vector of Driver's operators.
+    int32_t opId{0};
+    /// Method of the operator, which is currently running.
+    const char* method{kOpMethodNone};
+};
+
+/// Structure holds the information about the current operator call the driver
+/// is in. Can be used to detect deadlocks and otherwise blocked calls.
+/// If timeStartMs is zero, then we aren't in an operator call.
+struct OpCallStatus {
+    OpCallStatus()
+    {
+    }
+
+    /// The status accessor.
+    OpCallStatusRaw operator()() const
+    {
+        return OpCallStatusRaw{cpuTimeStartNs, timeStartMs, opId, method};
+    }
+
+    void Start(int32_t operatorId, const char* operatorMethod);
+    void Stop();
+    void TimeSegmentStatistic(op::Operator* op, const char* operatorMethod) const;
+
+private:
+    /// cpu time (ns) when the operator call started.
+    int64_t cpuTimeStartNs{0};
+    /// wall Time (ms) when the operator call started.
+    size_t  timeStartMs{0};
+    /// Id of the operator, method of which is currently running. It is index into
+    /// the vector of Driver's operators.
+    std::atomic_int32_t opId{0};
+    /// Method of the operator, which is currently running.
+    std::atomic method{kOpMethodNone};
+};
+
+class OmniDriver : public std::enable_shared_from_this {
+public:
+    OmniDriver()
+        : curOperatorId_(0),
+          blockingReason_(BlockingReason::kNotBlocked),
+          blockedOperatorId_(0) {}
+ 
+    // Run this pipeline until it produces a batch of data or get blocked.
+    vec::VectorBatch* Next(ContinueFuture* future, StopReason* stopReason);
+
+    void addOperator(std::unique_ptr operatorPtr)
+    {
+        operators_.emplace_back(std::move(operatorPtr));
+    }
+
+    void close();
+
+    std::vector>* operators()
+    {
+        return &operators_;
+    }
+
+    ALWAYS_INLINE bool isFinished() const
+    {
+        return finished_;
+    }
+
+public:
+    bool inputDriver{false};
+    bool outputDriver{false};
+
+    bool shouldStop{false};
+
+private:
+ 
+    StopReason RunInternal(
+        std::shared_ptr& self,
+        std::shared_ptr& blockingState,
+        vec::VectorBatch** result);
+
+    ALWAYS_INLINE StopReason BlockDriver(
+        const std::shared_ptr& self,
+        size_t blockedOperatorId,
+        ContinueFuture&& future,
+        std::shared_ptr& blockingState);
+
+    // Index of the current operator to run (or the 1st one if we haven't stated yet).
+    size_t curOperatorId_;
+ 
+    std::vector> operators_;
+ 
+    BlockingReason blockingReason_;
+    size_t blockedOperatorId_;
+    bool trackOperatorCpuUsage_ = true;
+
+    OpCallStatus opCallStatus_;
+    CpuWallTiming processLazyIoStats(omniruntime::op::Operator& op, const CpuWallTiming& timing);
+    using TimingMemberPtr = CpuWallTiming OperatorStats::*;
+    template 
+    void withDeltaCpuWallTimer(omniruntime::op::Operator* op, TimingMemberPtr opTimingMember, Func&& opFunction);
+
+    bool closed_{false};
+    bool finished_{false};
+};
+
+class BlockingState {
+public:
+    BlockingState(
+        std::shared_ptr driver,
+        ContinueFuture&& future,
+        omniruntime::op::Operator* op,
+        BlockingReason reason);
+
+    ~BlockingState()
+    {
+        numBlockdDrivers_--;
+    }
+
+    ContinueFuture Future()
+    {
+        return std::move(future_);
+    }
+
+private:
+    std::shared_ptr driver_;
+    ContinueFuture future_;
+    omniruntime::op::Operator* operator_;
+    BlockingReason reason_;
+
+    static std::atomic_uint64_t numBlockdDrivers_;
+};
+
+} // end of compute
+} // end of omniruntime
+#endif
\ No newline at end of file
diff --git a/core/src/compute/local_planner.cpp b/core/src/compute/local_planner.cpp
new file mode 100644
index 0000000..c5179ee
--- /dev/null
+++ b/core/src/compute/local_planner.cpp
@@ -0,0 +1,215 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+ */
+
+#include "local_planner.h"
+
+#include 
+
+#include "operator/join/hash_builder_expr.h"
+#include "operator/join/lookup_join_expr.h"
+#include "operator/join/lookup_join_wrapper.h"
+#include "operator/join/nest_loop_join_builder.h"
+#include "operator/join/nest_loop_join_lookup_wrapper.h"
+#include "operator/limit/limit.h"
+#include "operator/sort/sort_expr.h"
+#include "operator/topn/topn_expr.h"
+#include "operator/topnsort/topn_sort_expr.h"
+#include "operator/union/union.h"
+#include "operator/window/window_expr.h"
+#include "operator/join/sortmergejoin/sort_merge_join_expr_v2.h"
+#include 
+#include "operator/expand/expand.h"
+#include "operator/grouping/grouping.h"
+
+namespace omniruntime::compute {
+
+// Returns ture if source nodes must run in a separate pipeline
+bool MustStartNewPipeline(int sourceId) { return sourceId != 0; }
+
+std::shared_ptr createOperator(
+    OperatorFactory* factory, const std::shared_ptr& planNode)
+{
+    std::shared_ptr operatorPtr(factory->CreateOperator());
+    operatorPtr->setNoMoreInput(false);
+    operatorPtr->SetPlanNodeId(planNode->Id());
+    if (operatorPtr->operatorType().empty()) {
+        operatorPtr->SetOperatorType(string(planNode->Name()));
+    }
+    return std::move(operatorPtr);
+}
+
+std::shared_ptr createUnionBuildOperator(
+    std::shared_ptr unionOperator,
+    const std::shared_ptr& planNode)
+{
+    std::shared_ptr unionBuildOperator = std::make_shared(unionOperator);
+    unionBuildOperator->setNoMoreInput(false);
+    unionBuildOperator->SetPlanNodeId(planNode->Id());
+    if (unionBuildOperator->operatorType().empty()) {
+        unionBuildOperator->SetOperatorType(string(planNode->Name()));
+    }
+    return std::move(unionBuildOperator);
+}
+
+OperatorFactory* createOperatorFactory(
+    const std::shared_ptr& planNode,
+    const config::QueryConfig& queryConfig)
+{
+    if (auto orderByNode = std::dynamic_pointer_cast(planNode)) {
+        return SortWithExprOperatorFactory::CreateSortWithExprOperatorFactory(orderByNode, queryConfig);
+    } else if (auto projectNode = std::dynamic_pointer_cast(planNode)) {
+        return CreateProjectOperatorFactory(projectNode, queryConfig);
+    } else if (auto filterNode = std::dynamic_pointer_cast(planNode)) {
+        return CreateFilterOperatorFactory(filterNode, queryConfig);
+    } else if (auto windowNode = std::dynamic_pointer_cast(planNode)) {
+        return WindowWithExprOperatorFactory::CreateWindowWithExprOperatorFactory(windowNode, queryConfig);
+    } else if (auto topNNode = std::dynamic_pointer_cast(planNode)) {
+        return TopNWithExprOperatorFactory::CreateTopNWithExprOperatorFactory(topNNode, queryConfig);
+    } else if (auto topNSortNode = std::dynamic_pointer_cast(planNode)) {
+        return TopNSortWithExprOperatorFactory::CreateTopNSortWithExprOperatorFactory(
+            topNSortNode, queryConfig);
+    } else if (auto limitNode = std::dynamic_pointer_cast(planNode)) {
+        return LimitOperatorFactory::CreateLimitOperatorFactory(limitNode);
+    } else if (auto unionNode = std::dynamic_pointer_cast(planNode)) {
+        return UnionOperatorFactory::CreateUnionOperatorFactory(unionNode);
+    } else if (auto valueStreamNode = std::dynamic_pointer_cast(planNode)) {
+        return ValueStreamFactory::CreateValueStreamFactory(valueStreamNode);
+    } else if (auto aggregationNode = std::dynamic_pointer_cast(planNode)) {
+        if (aggregationNode->GetGroupByKeys().empty()) {
+            return AggregationWithExprOperatorFactory::CreateAggregationWithExprOperatorFactory(
+                aggregationNode, queryConfig);
+        }
+        return HashAggregationWithExprOperatorFactory::CreateAggregationWithExprOperatorFactory(
+            aggregationNode, queryConfig);
+    } else if (auto expandNode = std::dynamic_pointer_cast(planNode)) {
+        return CreateExpandOperatorFactory(expandNode, queryConfig);
+    } else if (auto groupingNode = std::dynamic_pointer_cast(planNode)) {
+        return GroupingOperatorFactory::CreateGroupingOperatorFactory(groupingNode, queryConfig);
+    } else {
+        throw omniruntime::exception::OmniException(
+            "PLANNODE_NOT_SUPPORT", "The plannode is not supported yet." + planNode->Id());
+    }
+}
+
+void planDetail(
+    const std::shared_ptr& planNode,
+    std::vector>* currentOperators,
+    std::vector>* drivers,
+    std::vector* factories,
+    const config::QueryConfig& queryConfig)
+{
+    OperatorFactory* factory = nullptr;
+    if (!currentOperators) {
+        drivers->emplace_back(std::make_unique());
+        currentOperators = drivers->back()->operators();
+    }
+
+    const auto &sources = planNode->Sources();
+    std::vector>> builderDrivers(sources.size());
+    if (sources.empty()) {
+        drivers->back()->inputDriver = true;
+    } else {
+        for (int32_t i = 0; i < sources.size(); ++i) {
+            planDetail(sources[i], MustStartNewPipeline(i) ? nullptr : currentOperators,
+                MustStartNewPipeline(i) ? &builderDrivers[i] : drivers, factories, queryConfig);
+        }
+    }
+
+    // JoinNode and UnionNode has multiple sources, so we need to create a builder driver for each source
+    if (auto joinNode = std::dynamic_pointer_cast(planNode)) {
+        auto hashBuilderOperatorFactory =
+            HashBuilderWithExprOperatorFactory::CreateHashBuilderWithExprOperatorFactory(joinNode, queryConfig);
+        auto builderDriver = builderDrivers[1][0];
+        builderDriver->operators()->emplace_back(createOperator(hashBuilderOperatorFactory, joinNode));
+        factories->emplace_back(hashBuilderOperatorFactory);
+
+        auto joinType = joinNode->GetJoinType();
+        if (joinNode->IsFullJoin() || (joinNode->IsLeftJoin() && joinNode->IsBuildLeft()) || (joinNode->IsRightJoin() && joinNode->IsBuildRight())) {
+            factory =
+                LookupJoinWrapperOperatorFactory::CreateLookupJoinWrapperOperatorFactory(joinNode, hashBuilderOperatorFactory, queryConfig);
+        } else {
+            factory =
+                LookupJoinWithExprOperatorFactory::CreateLookupJoinWithExprOperatorFactory(joinNode, hashBuilderOperatorFactory, queryConfig);
+        }
+    } else if (auto sortMergejoinNode = std::dynamic_pointer_cast(planNode)) {
+        auto streamedTableWithExprOperatorFactoryV2 =
+            StreamedTableWithExprOperatorFactoryV2::CreateStreamedTableWithExprOperatorFactoryV2(
+                sortMergejoinNode, queryConfig);
+        auto bufferedTableWithExprOperatorFactoryV2 =
+            BufferedTableWithExprOperatorFactoryV2::CreateBufferedTableWithExprOperatorFactoryV2(
+                sortMergejoinNode, reinterpret_cast(streamedTableWithExprOperatorFactoryV2), queryConfig);
+        auto builderDriver = builderDrivers[1][0];
+        builderDriver->operators()->emplace_back(createOperator(bufferedTableWithExprOperatorFactoryV2, sortMergejoinNode));
+        factories->emplace_back(bufferedTableWithExprOperatorFactoryV2);
+        factory = streamedTableWithExprOperatorFactoryV2;
+    } else if (auto nestedLoopJoinNode = std::dynamic_pointer_cast(planNode)) {
+        auto nestedLoopJoinBuilderOperatorFactory =
+            NestedLoopJoinBuildOperatorFactory::CreateNestedLoopJoinBuildOperatorFactory(nestedLoopJoinNode);
+        auto nestedLoopJoinLookupWrapperOperatorFactory =
+            NestLoopJoinLookupWrapperOperatorFactory::CreateNestLoopJoinLookupWrapperOperatorFactory(nestedLoopJoinNode, nestedLoopJoinBuilderOperatorFactory, queryConfig);
+        auto builderDriver = builderDrivers[1][0];
+        builderDriver->operators()->emplace_back(createOperator(nestedLoopJoinBuilderOperatorFactory, nestedLoopJoinNode));
+        factories->emplace_back(nestedLoopJoinBuilderOperatorFactory);
+        factory = nestedLoopJoinLookupWrapperOperatorFactory;
+    } else {
+        factory = createOperatorFactory(planNode, queryConfig);
+    }
+
+    auto currentOperator = createOperator(factory, planNode);
+    currentOperator->setInputOperatorCnt(sources.size());
+    currentOperators->emplace_back(currentOperator);
+    factories->emplace_back(factory);
+
+    if (auto unionNode = std::dynamic_pointer_cast(planNode)) {
+        for (auto i = 1; i < sources.size(); i++) {
+            auto builderDriver = builderDrivers[i][0];
+            builderDriver->operators()->emplace_back(createUnionBuildOperator(currentOperator, planNode));
+        }
+    }
+
+    for (auto builderDriver : builderDrivers) {
+        drivers->insert(drivers->end(), builderDriver.begin(), builderDriver.end());
+    }
+}
+
+void LocalPlanner::buildOperatorStats(std::vector>* drivers)
+{
+    if (drivers->empty()) {
+        LogError("drivers is empty");
+        return;
+    }
+
+    for (auto index = 0; index size(); ++index) {
+        const auto operators = (*drivers)[index]->operators();
+        if (operators->empty()) {
+            LogError("operators is empty");
+            return;
+        }
+        for (auto i = 0; i < operators->size(); ++i) {
+            (*operators)[i]->SetOperatorId(i);
+            (*operators)[i]->stats_ = OperatorStats(
+                i,
+                index,
+                (*operators)[i]->planNodeId(),
+                (*operators)[i]->operatorType());
+        }
+    }
+}
+
+void LocalPlanner::plan(
+    const PlanFragment& fragment,
+    std::vector>* drivers,
+    std::vector* factories,
+    const config::QueryConfig& queryConfig)
+{
+    planDetail(fragment.planNode,
+        nullptr,
+        drivers,
+        factories,
+        queryConfig);
+    (*drivers)[0]->outputDriver = true;
+
+    buildOperatorStats(drivers);
+}
+} // namespace omniruntime::compute
diff --git a/core/src/compute/local_planner.h b/core/src/compute/local_planner.h
new file mode 100644
index 0000000..9fb0177
--- /dev/null
+++ b/core/src/compute/local_planner.h
@@ -0,0 +1,34 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+ */
+
+#ifndef __LOCAL_PLANNER_H__
+#define __LOCAL_PLANNER_H__
+
+#include 
+#include 
+#include "compute/driver.h"
+#include "plannode/planFragment.h"
+#include "plannode/planNode.h"
+
+
+namespace omniruntime {
+namespace compute {
+
+std::shared_ptr createOperator(OperatorFactory* factory);
+OperatorFactory* createOperatorFactory(
+    const std::shared_ptr& planNode,
+    const config::QueryConfig& queryConfig);
+
+class LocalPlanner {
+public:
+    static void buildOperatorStats(std::vector>* drivers);
+    static void plan(
+        const omniruntime::PlanFragment& fragment,
+        std::vector>* drivers,
+        std::vector* factories,
+        const config::QueryConfig& queryConfig);
+};
+}  // namespace compute
+}  // namespace omniruntime
+#endif  // __LOCAL_PLANNER_H__
\ No newline at end of file
diff --git a/core/src/compute/operator_stats.h b/core/src/compute/operator_stats.h
new file mode 100644
index 0000000..264c2e1
--- /dev/null
+++ b/core/src/compute/operator_stats.h
@@ -0,0 +1,207 @@
+/*
+* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+ */
+ 
+#ifndef OPERATOR_STATS_H
+#define OPERATOR_STATS_H
+
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+#include "cpuWall_timer.h"
+#include "metrics/metrics_config.h"
+
+namespace omniruntime::compute {
+using namespace std;
+using PlanNodeId = std::string;
+ 
+struct OperatorStats {
+    /// Initial ordinal position in the operator's pipeline.
+    int32_t operatorId{0};
+    int32_t pipelineId{0};
+    PlanNodeId planNodeId;
+ 
+    /// Some operators perform the logic describe in multiple consecutive plan
+    /// nodes. For example, FilterProject operator maps to Filter node followed by
+    /// Project node. In this case, runtime stats are collected for the combined
+    /// operator and attached to the "main" plan node ID chosen by the operator.
+    /// (Project node ID in case of FilterProject operator.) The operator can then
+    /// provide a function to split the stats among all plan nodes that are being
+    /// represented. For example, FilterProject would split the stats but moving
+    /// cardinality reduction to Filter and making Project cardinality neutral.
+    using StatsSplitter = std::function(
+        const OperatorStats& combinedStats)>;
+ 
+    std::optional statsSplitter;
+ 
+    /// Name for reporting. We use Presto compatible names set at
+    /// construction of the Operator where applicable.
+    std::string operatorType;
+ 
+    /// Number of splits (or chunks of work). Split can be a part of data file to
+    /// read.
+    int64_t numSplits{0};
+ 
+    CpuWallTiming isBlockedTiming;
+ 
+    /// For Scan
+    uint64_t rawInputBytes{0};
+    uint64_t rawInputRows{0};
+ 
+    /// Bytes of input in terms of retained size of input vectors.
+    uint64_t inputRows{0};
+    uint64_t inputBytes{0};
+    uint64_t numInputVecBatches{0};
+
+    CpuWallTiming addInputTime;
+ 
+    /// Bytes of output in terms of retained size of vectors.
+    uint64_t outputBytes{0};
+    uint64_t outputRows{0};
+    uint64_t numOutputVecBatches{0};
+
+    CpuWallTiming getOutputTime;
+ 
+    // Total bytes written to file for spilling.
+    uint64_t spilledBytes{0};
+ 
+    // Total rows written for spilling.
+    uint64_t spilledRows{0};
+ 
+    // Total spilled partitions.
+    uint32_t spilledPartitions{0};
+ 
+    // Total current spilled files.
+    uint32_t spilledFiles{0};
+
+    CpuWallTiming finishTiming;
+
+    int numDrivers = 0;
+
+    // For BHJ/SHJ
+    uint64_t buildInputRows;
+    uint64_t buildNumInputVecBatches;
+    CpuWallTiming buildAddInputTime;
+    CpuWallTiming buildGetOutputTime;
+
+    uint64_t lookupInputRows;
+    uint64_t lookupNumInputVecBatches;
+    uint64_t lookupOutputRows;
+    uint64_t lookupNumOutputVecBatches;
+    CpuWallTiming lookupAddInputTime;
+    CpuWallTiming lookupGetOutputTime;
+ 
+    OperatorStats() = default;
+ 
+    OperatorStats(
+        int32_t _operatorId,
+        int32_t _pipelineId,
+        PlanNodeId _planNodeId,
+        std::string _operatorType)
+        : operatorId(_operatorId),
+          pipelineId(_pipelineId),
+          planNodeId(std::move(_planNodeId)),
+          operatorType(std::move(_operatorType)) {}
+ 
+    void setStatSplitter(StatsSplitter splitter)
+    {
+        statsSplitter = std::move(splitter);
+    }
+ 
+    void AddInputVector(uint64_t bytes, uint64_t inputVecBatches, uint64_t rowCount)
+    {
+        inputBytes += bytes;
+        numInputVecBatches += 1;
+        inputRows += rowCount;
+    }
+ 
+    void AddOutputVector(uint64_t bytes, uint64_t outputVecBatches, uint64_t rowCount)
+    {
+        outputBytes += bytes;
+        numOutputVecBatches += 1;
+        outputRows += rowCount;
+    }
+
+    void HashJoinOperator(const OperatorStats& stats)
+    {
+        const std::string& opType = stats.operatorType;
+
+        if (opType == opNameForHashBuilder) {
+            buildInputRows += stats.inputRows;
+            buildAddInputTime.Add(stats.addInputTime);
+            buildGetOutputTime.Add(stats.getOutputTime);
+            buildNumInputVecBatches += stats.numInputVecBatches;
+        }
+
+        if (opType == opNameForLookUpJoin) {
+            lookupInputRows += stats.inputRows;
+            lookupOutputRows += stats.outputRows;
+            lookupAddInputTime.Add(stats.addInputTime);
+            lookupGetOutputTime.Add(stats.getOutputTime);
+            lookupNumInputVecBatches += stats.numInputVecBatches;
+            lookupNumOutputVecBatches += stats.numOutputVecBatches;
+        }
+    }
+
+    void Add(const OperatorStats& other)
+    {
+        HashJoinOperator(other);
+
+        operatorType = other.operatorType;
+        planNodeId = other.planNodeId;
+        numSplits += other.numSplits;
+        rawInputBytes += other.rawInputBytes;
+        rawInputRows += other.rawInputRows;
+ 
+        addInputTime.Add(other.addInputTime);
+        inputBytes += other.inputBytes;
+        inputRows += other.inputRows;
+        numInputVecBatches += other.numInputVecBatches;
+ 
+        getOutputTime.Add(other.getOutputTime);
+        outputBytes += other.outputBytes;
+        numOutputVecBatches += other.numOutputVecBatches;
+        outputRows += other.outputRows;
+ 
+        isBlockedTiming.Add(other.isBlockedTiming);
+
+        numDrivers += other.numDrivers;
+        spilledBytes += other.spilledBytes;
+        spilledRows += other.spilledRows;
+        spilledPartitions += other.spilledPartitions;
+        spilledFiles += other.spilledFiles;
+
+        finishTiming.Add(other.finishTiming);
+    }
+ 
+    void Clear()
+    {
+        numSplits = 0;
+        rawInputBytes = 0;
+        rawInputRows = 0;
+ 
+        addInputTime.Clear();
+        inputBytes = 0;
+        inputRows = 0;
+        numInputVecBatches = 0;
+ 
+        getOutputTime.Clear();
+        outputBytes = 0;
+        outputRows = 0;
+        numOutputVecBatches = 0;
+ 
+        numDrivers = 0;
+        spilledBytes = 0;
+        spilledRows = 0;
+        spilledPartitions = 0;
+        spilledFiles = 0;
+
+        finishTiming.Clear();
+    }
+};
+} // omniruntime::compute
+#endif
diff --git a/core/src/compute/plannode_stats.cpp b/core/src/compute/plannode_stats.cpp
new file mode 100644
index 0000000..24a4096
--- /dev/null
+++ b/core/src/compute/plannode_stats.cpp
@@ -0,0 +1,206 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+ * Description: This file declares for plannode_stats.cpp
+ */
+ 
+#include "plannode_stats.h"
+#include "compute/task.h"
+#include "plannode/planNode.h"
+#include "util/format.h"
+#include "metrics/metrics_config.h"
+#include 
+
+namespace omniruntime::compute {
+PlanNodeStats& PlanNodeStats::operator+=(const PlanNodeStats& another)
+{
+    inputRows += another.inputRows;
+    inputBytes += another.inputBytes;
+    numInputVecBatches += another.numInputVecBatches;
+
+    rawInputRows += another.inputRows;
+    rawInputBytes += another.rawInputBytes;
+
+    outputRows += another.outputRows;
+    outputBytes += another.outputBytes;
+    numOutputVecBatches += another.numOutputVecBatches;
+
+    isBlockedTiming.Add(another.isBlockedTiming);
+    addInputTime.Add(another.addInputTime);
+    getOutputTime.Add(another.getOutputTime);
+    finishTiming.Add(another.finishTiming);
+    cpuWallTiming.Add(another.addInputTime);
+    cpuWallTiming.Add(another.getOutputTime);
+    cpuWallTiming.Add(another.finishTiming);
+    cpuWallTiming.Add(another.isBlockedTiming);
+
+    backgroundTiming.Add(another.backgroundTiming);
+
+    blockedWallNanos += another.blockedWallNanos;
+
+    physicalWrittenBytes += another.physicalWrittenBytes;
+
+    // Populating number of drivers for plan nodes with multiple operators is not
+    // useful. Each operator could have been executed in different pipelines with
+    // different number of drivers.
+    if (!IsMultiOperatorTypeNode()) {
+        numDrivers += another.numDrivers;
+    } else {
+        numDrivers = 0;
+    }
+
+    numSplits += another.numSplits;
+
+    spilledBytes += another.spilledBytes;
+    spilledRows += another.spilledRows;
+    spilledPartitions += another.spilledPartitions;
+    spilledFiles += another.spilledFiles;
+
+    return *this;
+}
+
+// Returns true if an operator is a hash join operator given 'operatorType'.
+void PlanNodeStats::HashJoinOperator(const OperatorStats& stats)
+{
+    const std::string& opType = stats.operatorType;
+
+    if (opType == opNameForHashBuilder) {
+        buildInputRows += stats.inputRows;
+        buildAddInputTime.Add(stats.addInputTime);
+        buildGetOutputTime.Add(stats.getOutputTime);
+        buildNumInputVecBatches += stats.numInputVecBatches;
+    }
+
+    if (opType == opNameForLookUpJoin) {
+        lookupInputRows += stats.inputRows;
+        lookupOutputRows += stats.outputRows;
+        lookupAddInputTime.Add(stats.addInputTime);
+        lookupGetOutputTime.Add(stats.getOutputTime);
+        lookupNumInputVecBatches += stats.numInputVecBatches;
+        lookupNumOutputVecBatches += stats.numOutputVecBatches;
+    }
+}
+
+
+void PlanNodeStats::Add(const OperatorStats& stats)
+{
+    auto it = operatorStats.find(stats.operatorType);
+    if (it != operatorStats.end()) {
+        it->second->AddTotals(stats);
+    } else {
+        auto opStats = std::make_unique();
+        opStats->AddTotals(stats);
+        operatorStats.emplace(stats.operatorType, std::move(opStats));
+    }
+    AddTotals(stats);
+}
+
+void PlanNodeStats::AddTotals(const OperatorStats& stats)
+{
+    inputRows += stats.inputRows;
+    inputBytes += stats.inputBytes;
+    numInputVecBatches += stats.numInputVecBatches;
+
+    rawInputRows += stats.rawInputRows;
+    rawInputBytes += stats.rawInputBytes;
+
+    outputRows += stats.outputRows;
+    outputBytes += stats.outputBytes;
+    numOutputVecBatches += stats.numOutputVecBatches;
+
+    isBlockedTiming.Add(stats.isBlockedTiming);
+    addInputTime.Add(stats.addInputTime);
+    getOutputTime.Add(stats.getOutputTime);
+    finishTiming.Add(stats.finishTiming);
+    cpuWallTiming.Add(stats.addInputTime);
+    cpuWallTiming.Add(stats.getOutputTime);
+    cpuWallTiming.Add(stats.finishTiming);
+    cpuWallTiming.Add(stats.isBlockedTiming);
+
+    // Populating number of drivers for plan nodes with multiple operators is not
+    // useful. Each operator could have been executed in different pipelines with
+    // different number of drivers.
+    if (!IsMultiOperatorTypeNode()) {
+        numDrivers += stats.numDrivers;
+    } else {
+        numDrivers = 0;
+    }
+
+    numSplits += stats.numSplits;
+
+    spilledBytes += stats.spilledBytes;
+    spilledRows += stats.spilledRows;
+    spilledPartitions += stats.spilledPartitions;
+    spilledFiles += stats.spilledFiles;
+
+    HashJoinOperator(stats);
+}
+
+void appendOperatorStats(
+    const OperatorStats& stats,
+    std::unordered_map& planStats)
+{
+    const auto& planNodeId = stats.planNodeId;
+    auto it = planStats.find(planNodeId);
+    if (it != planStats.end()) {
+        it->second.Add(stats);
+    } else {
+        PlanNodeStats nodeStats;
+        nodeStats.Add(stats);
+        planStats.emplace(planNodeId, std::move(nodeStats));
+    }
+}
+
+std::unordered_map ToPlanStats(
+    const TaskStats& taskStats)
+{
+    std::unordered_map planStats;
+
+    for (const auto& pipelineStats : taskStats.pipelineStats) {
+        for (const auto& opStats : pipelineStats.operatorStats) {
+            if (opStats.statsSplitter.has_value()) {
+                const auto& multiNodeStats = opStats.statsSplitter.value()(opStats);
+                for (const auto& stats : multiNodeStats) {
+                    appendOperatorStats(stats, planStats);
+                }
+            } else {
+                appendOperatorStats(opStats, planStats);
+            }
+        }
+    }
+    return planStats;
+}
+
+std::string PlanNodeStats::ToString(
+    bool includeInputStats) const
+{
+    std::stringstream out;
+    if (includeInputStats) {
+        out << "Input: " << inputRows << " rows (" << inputBytes
+            << ", " << numInputVecBatches << " batches), ";
+        if ((rawInputRows > 0) && (rawInputRows != inputRows)) {
+            out << "Raw Input: " << rawInputRows << " rows ("
+                << rawInputBytes << "), ";
+        }
+    }
+    out << "Output: " << outputRows << " rows (" << outputBytes
+        << ", " << numOutputVecBatches << " batches)";
+    if (physicalWrittenBytes > 0) {
+        out << ", Physical written output: " << physicalWrittenBytes;
+    }
+    out << ", Cpu time: " << cpuWallTiming.cpuNanos
+        << ", Wall time: " << cpuWallTiming.wallNanos
+        << ", Blocked wall time: " << blockedWallNanos
+        << ", Peak memory: " << peakMemoryBytes
+        << ", Memory allocations: " << numMemoryAllocations
+        << ", Threads: " << numDrivers
+        << ", Splits: " << numSplits
+        <<", Spilled: " << spilledRows << " rows ("
+            << spilledBytes << ", " << spilledFiles << " files)";
+    out << ", CPU breakdown: B/I/O/F "
+        << Format(
+            "({}/{}/{}/{})", isBlockedTiming.cpuNanos, addInputTime.cpuNanos,
+            getOutputTime.cpuNanos, finishTiming.cpuNanos);
+    return out.str();
+}
+
+}
\ No newline at end of file
diff --git a/core/src/compute/plannode_stats.h b/core/src/compute/plannode_stats.h
new file mode 100644
index 0000000..a61c344
--- /dev/null
+++ b/core/src/compute/plannode_stats.h
@@ -0,0 +1,142 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+ * Description: This file declares for plannode_stats.h
+ */
+ 
+#pragma once
+
+#include 
+#include 
+
+#include "cpuWall_timer.h"
+#include "operator_stats.h"
+#include "task_stats.h"
+
+namespace omniruntime::compute {
+/// Aggregated runtime statistics per plan node.
+///
+/// Runtime statistics are collected on a per-operator instance basis. There can
+/// be multiple operator types and multiple instances of each operator type that
+/// correspond to a given plan node. For example, ProjectNode corresponds to
+/// a single operator type, FilterProject, but HashJoinNode corresponds to two
+/// operator types, HashProbe and HashBuild. Each operator type may have
+/// different runtime parallelism, e.g. there can be multiple instances of each
+/// operator type. Plan node statistics are calculated by adding up
+/// operator-level statistics for all corresponding operator instances.
+struct PlanNodeStats {
+    explicit PlanNodeStats() = default;
+
+    PlanNodeStats(const PlanNodeStats&) = delete;
+    PlanNodeStats& operator=(const PlanNodeStats&) = delete;
+
+    PlanNodeStats(PlanNodeStats&&) = default;
+    PlanNodeStats& operator=(PlanNodeStats&&) = default;
+
+    PlanNodeStats& operator+=(const struct omniruntime::compute::PlanNodeStats& another);
+
+    /// Sum of input rows for all corresponding operators. Useful primarily for
+    /// leaf plan nodes or plan nodes that correspond to a single operator type.
+    uint64_t inputRows{0};
+    size_t numInputVecBatches{0};
+    uint64_t inputBytes{0};
+    std::string operatorType;
+    /// Sum of raw input rows for all corresponding operators. Applies primarily
+    /// to TableScan operator which reports rows before pushed down filter as raw
+    /// input.
+    uint64_t rawInputRows{0};
+    uint64_t rawInputBytes{0};
+
+    /// Sum of output rows for all corresponding operators. When
+    /// plan node corresponds to multiple operator types, operators of only one of
+    /// these types report non-zero output rows.
+    uint64_t outputRows{0};
+    size_t numOutputVecBatches{0};
+    uint64_t outputBytes{0};
+
+    // Sum of CPU, scheduled and wall times for isBLocked call for all
+    // corresponding operators.
+    CpuWallTiming isBlockedTiming;
+
+    // Sum of CPU, scheduled and wall times for addInput call for all
+    // corresponding operators.
+    CpuWallTiming addInputTime;
+
+    // Sum of CPU, scheduled and wall times for noMoreInput call for all
+    // corresponding operators.
+    CpuWallTiming finishTiming;
+
+    // Sum of CPU, scheduled and wall times for getOutput call for all
+    // corresponding operators.
+    CpuWallTiming getOutputTime;
+
+    /// Sum of CPU, scheduled and wall times for all corresponding operators. For
+    /// each operator, timing of addInput, getOutput and finish calls are added
+    /// up.
+    CpuWallTiming cpuWallTiming;
+
+    /// Sum of CPU, scheduled and wall times spent on background activities
+    /// (activities that are not running on driver threads) for all corresponding
+    /// operators.
+    CpuWallTiming backgroundTiming;
+
+    /// Sum of blocked wall time for all corresponding operators.
+    uint64_t blockedWallNanos{0};
+
+    /// Max of peak memory usage for all corresponding operators. Assumes that all
+    /// operator instances were running concurrently.
+    uint64_t peakMemoryBytes{0};
+
+    uint64_t numMemoryAllocations{0};
+
+    uint64_t physicalWrittenBytes{0};
+
+    /// Breakdown of stats by operator type.
+    std::unordered_map> operatorStats;
+
+    /// Number of drivers that executed the pipeline.
+    int numDrivers{0};
+
+    /// Number of total splits.
+    int numSplits{0};
+
+    /// Total bytes written for spilling.
+    uint64_t spilledBytes{0};
+    uint64_t spilledRows{0};
+    uint32_t spilledPartitions{0};
+    uint32_t spilledFiles{0};
+
+    // For BHJ/SHJ
+    uint64_t buildInputRows;
+    uint64_t buildNumInputVecBatches;
+    CpuWallTiming buildAddInputTime;
+    CpuWallTiming buildGetOutputTime;
+
+    uint64_t lookupInputRows;
+    uint64_t lookupNumInputVecBatches;
+    uint64_t lookupOutputRows;
+    uint64_t lookupNumOutputVecBatches;
+    CpuWallTiming lookupAddInputTime;
+    CpuWallTiming lookupGetOutputTime;
+
+    /// Add stats for a single operator instance.
+    void Add(const OperatorStats& stats);
+
+    std::string ToString(
+        bool includeInputStats = false) const;
+
+    bool IsMultiOperatorTypeNode() const
+    {
+        return operatorStats.size() > 1;
+    }
+
+private:
+    void HashJoinOperator(const OperatorStats& stats);
+    void AddTotals(const OperatorStats& stats);
+};
+
+std::unordered_map ToPlanStats(
+    const TaskStats& taskStats);
+
+using PlanNodeAnnotation =
+    std::function;
+} // namespace omniruntime:exec
\ No newline at end of file
diff --git a/core/src/compute/process_base.cpp b/core/src/compute/process_base.cpp
new file mode 100644
index 0000000..860c2d5
--- /dev/null
+++ b/core/src/compute/process_base.cpp
@@ -0,0 +1,24 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+ * Description: This file declares for process_base.cpp
+ */
+#include "compute/process_base.h"
+#include 
+
+namespace omniruntime::compute {
+
+const int64_t NANOS_PER_SEC = 1'000'000'000;
+
+static int64_t ThreadCpuNanos()
+{
+#ifdef __aarch64__
+        int64_t time;
+        asm volatile("mrs %0, cntvct_el0" : "=r" (time));
+        return time;
+#else
+        timespec ts{};
+        clock_gettime(CLOCK_THREAD_CPUTIME_ID, &ts);
+        return ts.tv_sec * NANOS_PER_SEC + ts.tv_nsec;
+#endif
+}
+} // namespace omniruntime::compute
\ No newline at end of file
diff --git a/core/src/compute/process_base.h b/core/src/compute/process_base.h
new file mode 100644
index 0000000..a2136ab
--- /dev/null
+++ b/core/src/compute/process_base.h
@@ -0,0 +1,17 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+ */
+
+#ifndef PROCESS_BASE_H
+#define PROCESS_BASE_H
+
+#pragma once
+
+#include 
+
+namespace omniruntime::compute {
+
+    /// Returns elapsed CPU nanoseconds on the calling thread
+    int64_t ThreadCpuNanos();
+} // namespace omniruntime
+#endif
diff --git a/core/src/compute/reason.h b/core/src/compute/reason.h
new file mode 100644
index 0000000..9b921b8
--- /dev/null
+++ b/core/src/compute/reason.h
@@ -0,0 +1,118 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+ */
+#ifndef __REASON_H__
+#define __REASON_H__
+
+#include 
+#include 
+
+namespace omniruntime {
+namespace compute {
+
+using ContinueFuture = std::future;
+
+class OmniFuture {
+public:
+    /// Create an invalid Future
+    /// RESULT.valid() == false
+    static ContinueFuture makeEmpty()
+    {
+        ContinueFuture emptyFuture;
+        return std::move(emptyFuture);
+    }
+
+    static bool isReady(const ContinueFuture& future)
+    {
+        try {
+            return future.wait_for(std::chrono::seconds(0)) == std::future_status::ready;
+        } catch (const std::future_error& e) {
+            return false;
+        }
+    }
+
+    static ContinueFuture collectAll(std::vector futures)
+    {
+        for (auto& future : futures) {
+            future.get();
+        }
+
+        ContinueFuture future;
+        return std::move(future);
+    }
+
+    /// create valid Futures
+    /// RESULT.valid == true
+    /// RESULT.isReady == true
+    static std::vector createValidFutures(uint32_t count)
+    {
+        std::vector futures;
+        futures.reserve(count);
+
+        for (auto i = 0; i < count; ++i) {
+            std::promise promise;
+            promise.set_value();
+            futures.emplace_back(promise.get_future());
+        }
+
+        return std::move(futures);
+    }
+};
+
+ 
+enum class StopReason {
+    /// Keep running.
+    kNone,
+    /// Go off thread and do not schedule more activity.
+    kPause,
+    /// Stop and free all. This is returned once and the thread that gets this
+    /// value is responsible for freeing the state associated with the thread.
+    /// Other threads will get kAlreadyTerminated after the first thread has
+    /// received kTerminate.
+    kTerminate,
+    kAlreadyTerminated,
+    /// Go off thread and then enqueue to the back of the runnable queue.
+    kYield,
+    /// Must wait for external events.
+    kBlock,
+    /// No more data to produce.
+    kAtEnd,
+    kAlreadyOnThread
+};
+ 
+enum class BlockingReason {
+    kNotBlocked,
+    kWaitForConsumer,
+    kWaitForSplit,
+    /// Some operators can get blocked due to the producer(s) (they are
+    /// currently waiting data from) not having anything produced. Used by
+    /// LocalExchange, LocalMergeExchange, Exchange and MergeExchange operators.
+    kWaitForProducer,
+    kWaitForJoinBuild,
+    /// For a build operator, it is blocked waiting for the probe operators to
+    /// finish probing before build the next hash table from one of the
+    /// previously spilled partition data. For a probe operator, it is blocked
+    /// waiting for all its peer probe operators to finish probing before
+    /// notifying the build operators to build the next hash table from the
+    /// previously spilled data.
+    kWaitForJoinProbe,
+    /// Used by MergeJoin operator, indicating that it was blocked by the right
+    /// side input being unavailable.
+    kWaitForMergeJoinRightSide,
+    kWaitForMemory,
+    kWaitForConnector,
+    /// Build operator is blocked waiting for all its peers to stop to run group
+    /// spill on all of them.
+    kWaitForSpill,
+    /// Some operators (like Table Scan) may run long loops and can 'voluntarily'
+    /// exit them because Task requested to yield or stop or after a certain time.
+    /// This is the blocking reason used in such cases.
+    kYield,
+    /// Operator is blocked waiting for its associated query memory arbitration to
+    /// finish.
+    kWaitForArbitration,
+    kWaitForUnionBuild,
+};
+} // end of compute
+} // end of omniruntime
+#endif
\ No newline at end of file
diff --git a/core/src/compute/task.cpp b/core/src/compute/task.cpp
new file mode 100644
index 0000000..7632408
--- /dev/null
+++ b/core/src/compute/task.cpp
@@ -0,0 +1,80 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+ */
+#include "task.h"
+#include "local_planner.h"
+#include "codegen/time_util.h"
+
+namespace omniruntime::compute {
+ 
+vec::VectorBatch* OmniTask::Next(ContinueFuture* future)
+{
+    if (drivers_.empty()) {
+        taskStats_.executionStartTimeMs = static_cast(ThreadCpuNanos());
+        LocalPlanner::plan(
+            planFragment_, &drivers_, &operatorFactories_, queryConfig_);
+        std::reverse(drivers_.begin(), drivers_.end());
+    }
+    const auto numDrivers = drivers_.size();
+    auto futures = OmniFuture::createValidFutures(numDrivers);
+    for (;;) {
+        int runableDrivers = 0;
+        for (auto i = 0; i < numDrivers; ++i) {
+            if (drivers_[i]->isFinished()) {
+                // This driver has finished processing.
+                continue;
+            }
+ 
+            ++runableDrivers;
+ 
+            ContinueFuture driverFuture = OmniFuture::makeEmpty();
+            StopReason stopReason = StopReason::kNone;
+            auto result = drivers_[i]->Next(&driverFuture, &stopReason);
+            if (result) {
+                return result;
+            }
+ 
+            futures[i] = std::move(driverFuture);
+        }
+ 
+        if (runableDrivers == 0) {
+            return nullptr;
+        }
+    }
+}
+
+TaskStats OmniTask::GetTaskStats() const
+{
+    // 'taskStats_' contains task stats plus stats for the completed drivers
+    // (their operators).
+    TaskStats taskStats = taskStats_;
+
+    taskStats.numTotalDrivers = drivers_.size();
+    LogDebug("total driver num is %d", taskStats.numTotalDrivers);
+    // Add stats of the drivers (their operators) that are still running.
+    for (const auto& driver : drivers_) {
+        // Driver can be null.
+        if (driver == nullptr) {
+            ++taskStats.numCompletedDrivers;
+            continue;
+        }
+        auto operators = driver->operators();
+        for (auto& op : *operators) {
+            auto opStatsCopy = op->stats(false);
+            int32_t pipelineId = opStatsCopy.pipelineId;
+            int32_t operatorId = opStatsCopy.operatorId;
+            PlanNodeId planNodeId = opStatsCopy.planNodeId;
+            if (taskStats.pipelineStats.size() <= static_cast(pipelineId)) {
+                taskStats.pipelineStats.resize(pipelineId + 1);
+            }
+            if (taskStats.pipelineStats[pipelineId].operatorStats.size() <= static_cast(operatorId)) {
+                taskStats.pipelineStats[pipelineId].operatorStats.resize(operatorId + 1);
+            }
+            taskStats.pipelineStats[pipelineId]
+                .operatorStats[operatorId]
+                .Add(opStatsCopy);
+        }
+    }
+    return taskStats;
+}
+} // end of omniruntime
\ No newline at end of file
diff --git a/core/src/compute/task.h b/core/src/compute/task.h
new file mode 100644
index 0000000..9f2dce3
--- /dev/null
+++ b/core/src/compute/task.h
@@ -0,0 +1,61 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+ */
+#ifndef __TASK_H__
+#define __TASK_H__
+ 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include "compute/driver.h"
+#include "compute/task_stats.h"
+#include "vector/vector_batch.h"
+#include "operator/operator.h"
+#include "plannode/planFragment.h"
+ 
+namespace omniruntime {
+namespace compute {
+
+class OmniDriver;
+
+class OmniTask {
+public:
+    OmniTask(const PlanFragment& planFragment,  config::QueryConfig&& queryConfig)
+        : planFragment_(planFragment), queryConfig_(std::move(queryConfig))
+    {
+        taskStats_ = TaskStats();
+    }
+
+    ~OmniTask()
+    {
+        for (auto& driver : drivers_) {
+            if (driver) {
+                driver->shouldStop = true;
+                driver->close();
+            }
+        }
+        for (auto& factory : operatorFactories_) {
+            delete factory;
+        }
+    }
+
+    vec::VectorBatch* Next(ContinueFuture* future = nullptr);
+
+    /// Returns Task Stats by copy as other threads might be updating the
+    /// structure.
+    TaskStats GetTaskStats() const;
+
+private:
+    std::vector> drivers_;
+    std::vector operatorFactories_;
+    PlanFragment planFragment_;
+    OperatorConfig operatorConfig_;
+    TaskStats taskStats_;
+    const config::QueryConfig queryConfig_;
+};
+} // end of compute
+} // end of omniruntime
+#endif
\ No newline at end of file
diff --git a/core/src/compute/task_stats.h b/core/src/compute/task_stats.h
new file mode 100644
index 0000000..69d3b46
--- /dev/null
+++ b/core/src/compute/task_stats.h
@@ -0,0 +1,108 @@
+/*
+* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+ * Description: This file declares for operator_utils.cpp
+ */
+#ifndef TASK_STATS_H
+#define TASK_STATS_H
+
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+#include "driver.h"
+#include "operator_stats.h"
+
+namespace omniruntime::compute {
+struct OperatorStats;
+
+    /// Stores execution stats per pipeline.
+struct PipelineStats {
+    /// Cumulative OperatorStats for finished Drivers. The subscript is the
+    /// operator id, which is the initial ordinal position of the operator in the
+    /// DriverFactory.
+    std::vector operatorStats;
+
+    /// True if contains the source node for the task.
+    bool inputPipeline;
+
+    /// True if contains the sync node for the task.
+    bool outputPipeline;
+
+    explicit PipelineStats() = default;
+
+    PipelineStats(bool _inputPipeline, bool _outputPipeline)
+        : inputPipeline{_inputPipeline}, outputPipeline{_outputPipeline} {}
+};
+
+/// Stores execution stats per task.
+struct TaskStats {
+    int32_t numTotalSplits{0};
+    int32_t numFinishedSplits{0};
+    int32_t numRunningSplits{0};
+    int32_t numQueuedSplits{0};
+    std::unordered_set completedSplitGroups;
+
+    // The number of barriers that have been processed by the task.
+    int32_t numBarriers{0};
+
+    /// Table scan split stats.
+    int32_t numRunningTableScanSplits{0};
+    int32_t numQueuedTableScanSplits{0};
+    int64_t runningTableScanSplitWeights{0};
+    int64_t queuedTableScanSplitWeights{0};
+
+    /// The subscript is given by each Operator's
+    /// DriverCtx::pipelineId. This is a sum total reflecting fully
+    /// processed Splits for Drivers of this pipeline.
+    std::vector pipelineStats;
+
+    /// Epoch time (ms) when task starts to run
+    uint64_t executionStartTimeMs{0};
+
+    /// Epoch time (ms) when last split is processed. For some tasks there might
+    /// be some additional time to send buffered results before the task finishes.
+    uint64_t executionEndTimeMs{0};
+
+    /// Epoch time (ms) when first split is fetched from the task by an operator.
+    uint64_t firstSplitStartTimeMs{0};
+
+    /// Epoch time (ms) when last split is fetched from the task by an operator.
+    uint64_t lastSplitStartTimeMs{0};
+
+    /// Epoch time (ms) when the task completed, e.g. all splits were processed
+    /// and results have been consumed.
+    uint64_t endTimeMs{0};
+
+    /// Epoch time (ms) when the task was terminated, i.e. its terminal state
+    /// has been set, whether by finishing successfully or with an error, or
+    /// being cancelled or aborted.
+    uint64_t terminationTimeMs{0};
+
+    /// Total number of drivers.
+    uint64_t numTotalDrivers{0};
+    /// Total number of drivers queued on an executor but not on thread.
+    uint64_t numQueuedDrivers{0};
+    /// The number of completed drivers (which slots are null in Task 'drivers_'
+    /// list).
+    uint64_t numCompletedDrivers{0};
+    /// The number of drivers that are terminating or terminated (isTerminated()
+    /// returns true).
+    uint64_t numTerminatedDrivers{0};
+    /// The number of drivers that are currently running on driver thread.
+    uint64_t numRunningDrivers{0};
+
+    /// Drivers blocked for various reasons. Based on enum BlockingReason.
+    std::unordered_map numBlockedDrivers;
+
+    /// The longest still running operator call in "op::call" format.
+    std::string longestRunningOpCall;
+    /// The longest still running operator call's duration in ms.
+    size_t longestRunningOpCallMs{0};
+};
+
+} // namespace omniruntime::compute
+
+#endif
diff --git a/core/src/cpu_checker/CMakeLists.txt b/core/src/cpu_checker/CMakeLists.txt
new file mode 100644
index 0000000..a7bb8e6
--- /dev/null
+++ b/core/src/cpu_checker/CMakeLists.txt
@@ -0,0 +1,9 @@
+aux_source_directory(${CMAKE_CURRENT_LIST_DIR} CPU_CHECKER_LIST)
+set(CPU_CHECKER_TARGET cpu_checker)
+add_library(${CPU_CHECKER_TARGET} ${CPU_CHECKER_LIST})
+
+# dependent include
+target_link_libraries(${CPU_CHECKER_TARGET})
+
+file(GLOB CPU_CHECKER_FILES ${SOURCE_ROOT}/src/cpu_checker/*.h)
+install(FILES ${CPU_CHECKER_FILES} DESTINATION ${CMAKE_INSTALL_PREFIX}/include/cpu_checker)
diff --git a/core/src/cpu_checker/omniruntime_cpu_checker.cpp b/core/src/cpu_checker/omniruntime_cpu_checker.cpp
new file mode 100644
index 0000000..4d55087
--- /dev/null
+++ b/core/src/cpu_checker/omniruntime_cpu_checker.cpp
@@ -0,0 +1,105 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
+ * Description: omniruntime_cpu_checker
+ */
+
+#include "omniruntime_cpu_checker.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include "config.h"
+
+int KunpengCpuCheck()
+{
+#ifdef DISABLE_CPU_CHECKER
+    return 0;
+#endif
+    unsigned long int midrEl1;
+    asm("mrs %0, MIDR_EL1" : "=r"(midrEl1));
+    unsigned int partId = (midrEl1 >> PART_NUM_SHIFT) & PART_NUM_MASK;
+    unsigned int vendorId = (midrEl1 >> IMPLEMENTER_SHIFT) & IMPLEMENTER_MASK;
+    if (vendorId == HISILICON_VENDOR_ID || (vendorId == ARM_VENDOR_ID && partId == KUNPENG_PART_ID)) {
+        return 0;
+    }
+    return -1;
+}
+
+int DoQingsongCpuCheck(char *file)
+{
+#ifdef DISABLE_CPU_CHECKER
+    return 0;
+#endif
+    int fd;
+    int rsp;
+
+    // register address
+    unsigned char regAddr = REG_ADDR;
+    // read result from I2C device
+    unsigned char readResult[DEVICE_ID_LEN] = {0};
+
+    unsigned int result;
+
+    struct i2c_rdwr_ioctl_data packets;
+    struct i2c_msg messages[I2C_MSG_NUM];
+
+    fd = open(file, O_RDONLY);
+    if (fd < 0) {
+        /* Fail to open i2c devices */
+        return -1;
+    }
+
+    rsp = ioctl(fd, I2C_TENBIT, 0); // 7bit
+    if (rsp != 0) {
+        close(fd);
+        /* Fail to set 7 bit addresse */
+        return -1;
+    }
+
+    rsp = ioctl(fd, I2C_SLAVE, SLAVE_ADDR); // set slave device
+    if (rsp != 0) {
+        close(fd);
+        /* Fail to set slave address */
+        return -1;
+    }
+
+    messages[0].addr = SLAVE_ADDR;
+    messages[0].flags = 0; // write data
+    messages[0].len = REG_ADDR_LEN;
+    messages[0].buf = ®Addr;
+
+    messages[1].addr = SLAVE_ADDR;
+    messages[1].flags = I2C_M_RD; // read data
+    messages[1].len = DEVICE_ID_LEN;
+    messages[1].buf = readResult;
+
+    packets.msgs = messages;
+    packets.nmsgs = I2C_MSG_NUM;
+
+    rsp = ioctl(fd, I2C_RDWR, &packets);
+    if (rsp < 0) {
+        close(fd);
+        /* failed to WRITE & READ I2C Device */
+        return -1;
+    }
+
+    close(fd);
+    result = *(unsigned int *)readResult;
+
+    if (result == DEVICE_ID_VALUE) {
+        /* matched */
+        return 0;
+    }
+
+    return -1;
+}
+
+int QingsongCpuCheck()
+{
+    std::string dev0 = I2C_DEV_0;
+    std::string dev1 = I2C_DEV_1;
+    return DoQingsongCpuCheck((char *)dev0.data()) == 0 || DoQingsongCpuCheck((char *)dev1.data()) == 0 ? 0 : -1;
+}
diff --git a/core/src/cpu_checker/omniruntime_cpu_checker.h b/core/src/cpu_checker/omniruntime_cpu_checker.h
new file mode 100644
index 0000000..bee6af8
--- /dev/null
+++ b/core/src/cpu_checker/omniruntime_cpu_checker.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
+ * Description: omniruntime_cpu_checker
+ */
+
+#ifndef OMNI_RUNTIME_OMNIRUNTIME_CPU_CHECKER_H
+#define OMNI_RUNTIME_OMNIRUNTIME_CPU_CHECKER_H
+
+#define IMPLEMENTER_SHIFT 24
+#define IMPLEMENTER_MASK 0xFF
+#define PART_NUM_SHIFT 4
+#define PART_NUM_MASK 0xFFF
+
+// 0x41 is ARM
+#define ARM_VENDOR_ID 0x41
+
+// 0x48 is HiSilicon
+#define HISILICON_VENDOR_ID 0x48
+
+// 0xd01 is Kunpeng 920, 0xd08 is Kunpeng 916
+#define KUNPENG_PART_ID 0xd08
+
+// for qingsong CPU check
+#define I2C_DEV_0 "/dev/i2c-0"     // I2C-0 device name
+#define I2C_DEV_1 "/dev/i2c-1"     // I2C-1 device name
+#define SLAVE_ADDR 0x43            // slave I2C slave address
+#define REG_ADDR 0x2b              // I2C register address
+#define REG_ADDR_LEN 1             // I2C register address len
+#define DEVICE_ID_VALUE 0xEAC76903 // value of device ID
+#define DEVICE_ID_LEN 4            // length of device ID
+
+#define FORTIFY_SRC_STR_LEN 10
+#define FORTIFY_SRC_IDX 1
+
+#define I2C_MSG_NUM 2
+int KunpengCpuCheck();
+int QingsongCpuCheck();
+
+#endif // OMNI_RUNTIME_OMNIRUNTIME_CPU_CHECKER_H
diff --git a/core/src/expression/CMakeLists.txt b/core/src/expression/CMakeLists.txt
new file mode 100644
index 0000000..ce7d303
--- /dev/null
+++ b/core/src/expression/CMakeLists.txt
@@ -0,0 +1,8 @@
+file(GLOB_RECURSE EXPRESSION_LIST ${CMAKE_CURRENT_LIST_DIR}/*.cpp)
+set(EXPRESSION_TARGET expression)
+add_library(${EXPRESSION_TARGET} ${EXPRESSION_LIST})
+
+file(GLOB EXPRESSION_HEAD_FILES ${SOURCE_ROOT}/src/expression/*.h)
+install(FILES ${EXPRESSION_HEAD_FILES} DESTINATION ${CMAKE_INSTALL_PREFIX}/include/expression)
+file(GLOB PARSER_HEAD_FILES ${SOURCE_ROOT}/src/expression/parser/*.h)
+install(FILES ${PARSER_HEAD_FILES} DESTINATION ${CMAKE_INSTALL_PREFIX}/include/expression/parser)
\ No newline at end of file
diff --git a/core/src/expression/README.md b/core/src/expression/README.md
new file mode 100644
index 0000000..da09b53
--- /dev/null
+++ b/core/src/expression/README.md
@@ -0,0 +1,32 @@
+## Parser and Expressions
+
+### Parser Specifications
+The current parser must take in a `string` as input, and return a `Expr*` in a method with the following signature: 
+```c++
+Expr *parseRowExpression(string input, DataType *inputDataTypes, int32_t veccount);
+```
+`DataType` is an enum which is defined in `core/src/expression/expressions.h`, and covers all the possible types for a column. These are: `STRINGD`, `INT32D`, `INT64D`, `DOUBLED`, `BOOLD`. The `inputVecTypes` array contains the type of each column in order, and there are `veccount` columns in total.
+
+ All the different kinds of `Expr*` are listed in `core/src/expression/expressions.h`. When parsing, each `Expr*` must have its `dataType` member set correctly (usually with the constructor for the `Expr` which takes in a data type. The different types of `Expr` are:
+ `DataExpr`: Holds a piece of literal data or a column index.
+ `UnaryExpr`: Represents a unary expression (such as NOT).
+ `BinaryExpr`: Represents a binary expression (such as +, -, *, /, %, AND, OR, <, <=, >, >=, =, <>)
+ `InExpr`: Represents an `IN` SQL operator. Holds a vector of arguments which are `Expr*`, and the first argument is the value to be compared to every other argument.
+ `BetweenExpr`: Represents a `BETWEEN` SQL operator. Holds a value, lower bound, and upper bound, all of which are `Expr*`.
+ `IfExpr`: Represents an SQL conditional. Holds a condition, an expression to be evaluated if the condition is true, and an expression to be evaluated if the condition is false; all of these are `Expr*`.
+`CoalesceExpr`: Represents an SQL `COALESCE`. Has a value 1 and a value 2. If value 1 has a null value then value 2 is returned, otherwise value 1 is returned.
+`FuncExpr`: Represents any function, whether internal or user-defined. Has a name and an argument vector of `Expr*`.
+
+### Using `parserhelper`
+In `core/src/expression/parserhelper.h` and its corresponding `.cpp` file, there are a few methods to help with parsing.
+
+```c++
+DataType FuncRetTypeMap(string funcID, vector args);
+```
+The `FuncRetTypeMap` method takes in a function name `funcID` and a vector containing its arguments (which are all `Expr*`), and returns a `DataType` denoting the returned type of the function.
+
+```
+bool HasValidArguments(string funcID, vector args, bool checkTypes);
+```
+
+The `HasValidArguments` method checks whether or not a `FuncExpr*` (i.e. an expression representing a function call) is valid, in terms of whether the number of arguments match up with the function name. If `true` is passed to the `checkTypes` argument, then the types will be checked as well, although types will not be checked for external developer functions.
diff --git a/core/src/expression/expr_printer.cpp b/core/src/expression/expr_printer.cpp
new file mode 100644
index 0000000..c775492
--- /dev/null
+++ b/core/src/expression/expr_printer.cpp
@@ -0,0 +1,469 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
+ * Description: print expression tree methods
+ */
+#include "expr_printer.h"
+#include 
+#include 
+#include 
+#include 
+#include 
+#include "util/debug.h"
+
+using namespace omniruntime::expressions;
+using namespace omniruntime::type;
+using namespace std;
+
+string ExprPrinter::BinaryExprPrinterHelper(const Operator &op, const DataType &type) const
+{
+    string typeStr = TypeUtil::TypeToString(type.GetId());
+    if (TypeUtil::IsDecimalType(type.GetId())) {
+        typeStr += "(";
+        typeStr += to_string(static_cast(type).GetPrecision());
+        typeStr += ", ";
+        typeStr += to_string(static_cast(type).GetScale());
+        typeStr += ")";
+    }
+
+    switch (op) {
+        case Operator::EQ:
+            return "Cmp:" + typeStr + "(EQ ";
+        case Operator::NEQ:
+            return "Cmp:" + typeStr + "(NEQ ";
+        case Operator::LT:
+            return "Cmp:" + typeStr + "(LT ";
+        case Operator::LTE:
+            return "Cmp:" + typeStr + "(LTE ";
+        case Operator::GT:
+            return "Cmp:" + typeStr + "(GT ";
+        case Operator::GTE:
+            return "Cmp:" + typeStr + "(GTE ";
+        case Operator::AND:
+            return "Bin:" + typeStr + "(AND ";
+        case Operator::OR:
+            return "Bin:" + typeStr + "(OR ";
+        case Operator::ADD:
+            return "Arith:" + typeStr + "(ADD ";
+        case Operator::SUB:
+            return "Arith:" + typeStr + "(SUB ";
+        case Operator::MUL:
+            return "Arith:" + typeStr + "(MUL ";
+        case Operator::DIV:
+            return "Arith:" + typeStr + "(DIV ";
+        case Operator::MOD:
+            return "Arith:" + typeStr + "(MOD ";
+        default:
+            return "Invalid";
+    }
+}
+
+string ExprPrinter::GenerateIndentation() const
+{
+    string indent = "";
+    for (int i = 0; i < this->indentationDepth; i++) {
+        indent.append("\t");
+    }
+    return indent;
+}
+
+
+std::string GetBoolValOutput(const LiteralExpr &e)
+{
+    string output = "Literal:bool:";
+    e.boolVal ? output += "true" : output += "false";
+    return output;
+}
+
+std::string GetShortValOutput(const LiteralExpr &e)
+{
+    string output = "Literal:" + TypeUtil::TypeToString(e.GetReturnTypeId()) + ":" + to_string(e.shortVal);
+    return output;
+}
+
+std::string GetByteValOutput(const LiteralExpr &e)
+{
+    string output = "Literal:" + TypeUtil::TypeToString(e.GetReturnTypeId()) + ":" + to_string(e.byteVal);
+    return output;
+}
+
+std::string GetIntValOutput(const LiteralExpr &e)
+{
+    string output = "Literal:" + TypeUtil::TypeToString(e.GetReturnTypeId()) + ":" + to_string(e.intVal);
+    return output;
+}
+
+std::string GetLongValOutput(const LiteralExpr &e)
+{
+    string output = "Literal:" + TypeUtil::TypeToString(e.GetReturnTypeId()) + ":" + to_string(e.longVal);
+    return output;
+}
+
+std::string GetDoubleValOutput(const LiteralExpr &e)
+{
+    string output = "Literal:" + TypeUtil::TypeToString(e.GetReturnTypeId()) + ":" + to_string(e.doubleVal);
+    return output;
+}
+
+std::string GetCharValOutput(const LiteralExpr &e)
+{
+    string output = "Literal:";
+    if (e.GetReturnTypeId() == OMNI_CHAR) {
+        // meant to look like "%s[%d]:'%s'"
+        output += TypeUtil::TypeToString(e.GetReturnTypeId()) + +"[" +
+            to_string(static_cast(e.dataType.get())->GetWidth()) + "]" + ":'" + *(e.stringVal) + "'";
+    } else {
+        std::string tmp = e.stringVal == nullptr ? "null" : *e.stringVal;
+        // meant to look like "%s:'%s'"
+        output += TypeUtil::TypeToString(e.GetReturnTypeId()) + ":'" + tmp + "'";
+    }
+    return output;
+}
+
+std::string GetDecimal64ValOutput(const LiteralExpr &e)
+{
+    // meant to look like "Literal:%s(%d, %d):%ld"
+    string output = "Literal:";
+    output += TypeUtil::TypeToString(e.GetReturnTypeId());
+    output += "(";
+    output += to_string(static_cast(e.dataType.get())->GetPrecision());
+    output += ", ";
+    output += to_string(static_cast(e.dataType.get())->GetScale());
+    output += "):";
+    output += to_string(e.longVal);
+    return output;
+}
+
+std::string GetDecimal128ValOutput(const LiteralExpr &e)
+{
+    // meant to look like "%s(%d, %d):'%s'"
+    string output = "Literal:";
+    output += TypeUtil::TypeToString(e.GetReturnTypeId());
+    output += "(";
+    output += to_string(static_cast(e.dataType.get())->GetPrecision());
+    output += ", ";
+    output += to_string(static_cast(e.dataType.get())->GetScale());
+    output += "):";
+    output += "'";
+    output += *(e.stringVal);
+    output += "'";
+    return output;
+}
+
+/*
+ * EXAMPLE
+ *
+ * Bin:bool(OR,
+ * Cmp:bool(EQ,
+ * #7,
+ * 'Winter
+ * ),
+ * Cmp:bool(EQ,
+ * #2,
+ * 'Summer'
+ * )
+ * )
+ *
+ */
+void ExprPrinter::Visit(const BinaryExpr &e)
+{
+    string indent = GenerateIndentation();
+    string message = BinaryExprPrinterHelper(e.op, *(e.GetReturnType()));
+    if (message == "Invalid") {
+        message = "InvalidBinaryOperator:" + to_string(static_cast(e.op)) + "(";
+    }
+    message = indent + message;
+    printf("%s\n", message.c_str());
+
+    this->indentationDepth++;
+    (e.left)->Accept(*this);
+
+    (e.right)->Accept(*this);
+    string lastParentheses = indent + ")";
+    printf("%s\n", lastParentheses.c_str());
+    this->indentationDepth--;
+}
+
+/*
+ * EXAMPLE
+ *
+ * Unary:bool(NOT,
+ * IsNull:bool(
+ * #0
+ * )
+ * )
+ *
+ */
+void ExprPrinter::Visit(const UnaryExpr &e)
+{
+    string indent = GenerateIndentation();
+    string output = indent;
+    switch (e.op) {
+        case Operator::NOT:
+            output += "Unary:" + TypeUtil::TypeToString(e.GetReturnTypeId()) + "(NOT ";
+            break;
+        default:
+            output += "InvalidUnaryOperator:" + to_string(static_cast(e.op)) + "(";
+            break;
+    }
+    printf("%s\n", output.c_str());
+    this->indentationDepth++;
+    (e.exp)->Accept(*this);
+    string lastParentheses = indent + ")";
+    printf("%s\n", lastParentheses.c_str());
+    this->indentationDepth--;
+}
+
+void ExprPrinter::Visit(const LiteralExpr &e)
+{
+    string output = GenerateIndentation();
+    switch (e.GetReturnTypeId()) {
+        case OMNI_BOOLEAN:
+            output += GetBoolValOutput(e);
+            break;
+        case OMNI_BYTE:
+            output += GetByteValOutput(e);
+            break;
+        case OMNI_SHORT:
+            output += GetShortValOutput(e);
+            break;
+        case OMNI_INT:
+        case OMNI_DATE32:
+            output += GetIntValOutput(e);
+            break;
+        case OMNI_TIMESTAMP:
+        case OMNI_LONG:
+            output += GetLongValOutput(e);
+            break;
+        case OMNI_DOUBLE:
+            output += GetDoubleValOutput(e);
+            break;
+        case OMNI_CHAR:
+            output += GetCharValOutput(e);
+            break;
+        case OMNI_VARCHAR:
+            output += GetCharValOutput(e);
+            break;
+        case OMNI_DECIMAL64:
+            output += GetDecimal64ValOutput(e);
+            break;
+        case OMNI_DECIMAL128:
+            output += GetDecimal128ValOutput(e);
+            break;
+        default:
+            output += "Literal:invalid DataType " + to_string(e.GetReturnTypeId());
+    }
+    printf("%s\n", output.c_str());
+}
+
+void ExprPrinter::Visit(const FieldExpr &e)
+{
+    string output = GenerateIndentation() + "Field:";
+    output += TypeUtil::TypeToString(e.GetReturnTypeId());
+    if (e.GetReturnTypeId() == OMNI_CHAR) {
+        output += '[' + to_string(static_cast(*(e.GetReturnType())).GetWidth()) + ']';
+    } else if (e.GetReturnTypeId() == OMNI_DECIMAL64 || e.GetReturnTypeId() == OMNI_DECIMAL128) {
+        output += "(";
+        output += to_string(static_cast(e.dataType.get())->GetPrecision());
+        output += ", ";
+        output += to_string(static_cast(e.dataType.get())->GetScale());
+        output += ")";
+    }
+    output += ":#" + to_string(e.colVal);
+    printf("%s\n", output.c_str());
+}
+
+/*
+ * EXAMPLE
+ *
+ * In:bool(
+ * 1,
+ * 2,
+ * 3,
+ * 4
+ * )
+ *
+ */
+void ExprPrinter::Visit(const InExpr &e)
+{
+    string indent = GenerateIndentation();
+    string output = indent + "In:" + TypeUtil::TypeToString(e.GetReturnTypeId()) + "(";
+    printf("%s\n", output.c_str());
+    this->indentationDepth++;
+    for (uint32_t i = 0; i < e.arguments.size(); i++) {
+        (e.arguments[i])->Accept(*this);
+        if (i == e.arguments.size() - 1) {
+            printf("%s\n", (indent + ")").c_str());
+        }
+    }
+    this->indentationDepth--;
+}
+/*
+ * Example:switch:bool(
+ * Cmp:bool(EQ
+ * #0,
+ * 100,
+ * ),
+ * Cmp:bool(EQ
+ * #0,
+ * 200,
+ * ),
+ * Cmp:bool(),
+ *
+ *
+ * )
+ */
+void ExprPrinter::Visit(const SwitchExpr &e)
+{
+    string indent = GenerateIndentation();
+    string output = indent + "Switch:" + TypeUtil::TypeToString(e.GetReturnTypeId()) + "(";
+    printf("%s\n", output.c_str());
+    this->indentationDepth++;
+    for (const auto &i : e.whenClause) {
+        (i.first)->Accept(*this);
+        (i.second)->Accept(*this);
+    }
+    e.falseExpr->Accept(*this);
+    printf("%s\n", (indent + ")").c_str());
+    this->indentationDepth--;
+}
+/*
+ * EXAMPLE
+ *
+ * Between:bool(
+ * 5,
+ * -1,
+ * 7,
+ * )
+ *
+ */
+void ExprPrinter::Visit(const BetweenExpr &e)
+{
+    string indent = GenerateIndentation();
+    string output = indent + "Between:" + TypeUtil::TypeToString(e.GetReturnTypeId()) + "(";
+    printf("%s\n", output.c_str());
+    this->indentationDepth++;
+    (e.value)->Accept(*this);
+
+    (e.lowerBound)->Accept(*this);
+
+    (e.upperBound)->Accept(*this);
+    printf("%s\n", (indent + ")").c_str());
+    this->indentationDepth--;
+}
+
+/*
+ * EXAMPLE
+ *
+ * If:bool(
+ * Cmp:bool(GT,
+ * #0,
+ * 100,
+ * ),
+ * Cmp:bool(GT,
+ * #0,
+ * 200
+ * ),
+ * Cmp:bool(LT,
+ * #0,
+ * 0
+ * )
+ * )
+ *
+ */
+void ExprPrinter::Visit(const IfExpr &e)
+{
+    string indent = GenerateIndentation();
+    string output = indent + "If:" + TypeUtil::TypeToString(e.GetReturnTypeId()) + "(";
+    printf("%s\n", output.c_str());
+    this->indentationDepth++;
+    e.condition->Accept(*this);
+
+    e.trueExpr->Accept(*this);
+
+    e.falseExpr->Accept(*this);
+    printf("%s\n", (indent + ")").c_str());
+    this->indentationDepth--;
+}
+
+/*
+ * EXAMPLE
+ *
+ * Cmp:bool(EQ,
+ * Coalesce:int64(
+ * #0,
+ * 0
+ * ),
+ * 123
+ * )
+ *
+ */
+void ExprPrinter::Visit(const CoalesceExpr &e)
+{
+    string indent = GenerateIndentation();
+    string output = indent + "Coalesce:" + TypeUtil::TypeToString(e.GetReturnTypeId()) + "(";
+    printf("%s\n", output.c_str());
+    this->indentationDepth++;
+    e.value1->Accept(*this);
+
+    e.value2->Accept(*this);
+    printf("%s\n", (indent + ")").c_str());
+    this->indentationDepth--;
+}
+
+/*
+ * EXAMPLE
+ *
+ * Printed Pared Expression
+ *
+ * IsNull:bool(
+ * #0
+ * )
+ *
+ */
+void ExprPrinter::Visit(const IsNullExpr &e)
+{
+    string indent = GenerateIndentation();
+    string output = indent + "IsNull:" + TypeUtil::TypeToString(e.GetReturnTypeId()) + "(";
+    printf("%s\n", output.c_str());
+    this->indentationDepth++;
+    e.value->Accept(*this);
+    printf("%s\n", (indent + ")").c_str());
+    this->indentationDepth--;
+}
+
+/*
+ * EXAMPLE
+ * concat:string(
+ * #1,
+ * #2
+ * )
+ *
+ */
+void ExprPrinter::Visit(const FuncExpr &e)
+{
+    string indent = GenerateIndentation();
+    string typeStr = TypeUtil::TypeToString(e.GetReturnTypeId());
+    if (TypeUtil::IsDecimalType(e.GetReturnTypeId())) {
+        auto decimalDataType = static_cast(e.GetReturnType().get());
+        typeStr += "(";
+        typeStr += to_string(decimalDataType->GetPrecision());
+        typeStr += ", ";
+        typeStr += to_string(decimalDataType->GetScale());
+        typeStr += ")";
+    } else if (TypeUtil::IsStringType(e.GetReturnTypeId())) {
+        typeStr += "[";
+        typeStr += to_string(static_cast(e.GetReturnType().get())->GetWidth());
+        typeStr += "]";
+    }
+
+    string output = indent + "Function:" + ":" + e.funcName + ":" + typeStr + "(";
+    printf("%s\n", output.c_str());
+    this->indentationDepth++;
+    for (uint32_t i = 0; i < e.arguments.size(); i++) {
+        (e.arguments[i])->Accept(*this);
+        if (i == e.arguments.size() - 1) {
+            printf("%s\n", (indent + ")").c_str());
+        }
+    }
+    this->indentationDepth--;
+}
diff --git a/core/src/expression/expr_printer.h b/core/src/expression/expr_printer.h
new file mode 100644
index 0000000..8a0c517
--- /dev/null
+++ b/core/src/expression/expr_printer.h
@@ -0,0 +1,32 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
+ * Description: print expression tree visitor for expressions
+ */
+#ifndef __OMNI_RUNTIME_EXPRESSION_PRINTER_H__
+#define __OMNI_RUNTIME_EXPRESSION_PRINTER_H__
+
+#include "expr_visitor.h"
+#include "util/type_util.h"
+
+class ExprPrinter : public ExprVisitor {
+public:
+    void Visit(const omniruntime::expressions::LiteralExpr &e) override;
+    void Visit(const omniruntime::expressions::FieldExpr &e) override;
+    void Visit(const omniruntime::expressions::UnaryExpr &e) override;
+    void Visit(const omniruntime::expressions::BinaryExpr &e) override;
+    void Visit(const omniruntime::expressions::InExpr &e) override;
+    void Visit(const omniruntime::expressions::BetweenExpr &e) override;
+    void Visit(const omniruntime::expressions::IfExpr &e) override;
+    void Visit(const omniruntime::expressions::CoalesceExpr &e) override;
+    void Visit(const omniruntime::expressions::IsNullExpr &e) override;
+    void Visit(const omniruntime::expressions::FuncExpr &e) override;
+    void Visit(const omniruntime::expressions::SwitchExpr &e) override;
+
+private:
+    std::string BinaryExprPrinterHelper(const omniruntime::expressions::Operator &op,
+        const omniruntime::type::DataType &type) const;
+    std::string GenerateIndentation() const;
+    int indentationDepth = 0;
+};
+
+#endif
\ No newline at end of file
diff --git a/core/src/expression/expr_verifier.cpp b/core/src/expression/expr_verifier.cpp
new file mode 100644
index 0000000..477cd80
--- /dev/null
+++ b/core/src/expression/expr_verifier.cpp
@@ -0,0 +1,308 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
+ * Description: Expression Verifier
+ */
+#include "expr_verifier.h"
+#include "codegen/func_registry.h"
+
+using namespace omniruntime::expressions;
+using namespace omniruntime::type;
+
+namespace omniruntime {
+namespace expressions {
+bool ExprVerifier::VisitExpr(const Expr &e)
+{
+    e.Accept(*this);
+    return this->supportedFlag;
+}
+
+bool ExprVerifier::VisitExpr(const std::shared_ptr &e)
+{
+    e->Accept(*this);
+    return this->supportedFlag;
+}
+
+bool ExprVerifier::AreInvalidDataTypes(DataTypeId type1, DataTypeId type2)
+{
+    return type1 != type2 && !(TypeUtil::IsStringType(type1) && TypeUtil::IsStringType(type2)) &&
+        !(TypeUtil::IsDecimalType(type1) && TypeUtil::IsDecimalType(type2));
+}
+
+void ExprVerifier::Visit(const LiteralExpr &literalExpr)
+{
+    switch (literalExpr.GetReturnTypeId()) {
+        case OMNI_BYTE:
+        case OMNI_SHORT:
+        case OMNI_INT:
+        case OMNI_DATE32:
+        case OMNI_LONG:
+        case OMNI_TIMESTAMP:
+        case OMNI_DOUBLE:
+        case OMNI_CHAR:
+        case OMNI_VARCHAR:
+        case OMNI_BOOLEAN:
+        case OMNI_DECIMAL64:
+        case OMNI_DECIMAL128:
+            this->supportedFlag = true;
+            break;
+        default:
+            this->supportedFlag = false;
+            break;
+    }
+}
+
+void ExprVerifier::Visit(const FieldExpr &fieldExpr)
+{
+    switch (fieldExpr.GetReturnTypeId()) {
+        case OMNI_BYTE:
+        case OMNI_SHORT:
+        case OMNI_INT:
+        case OMNI_DATE32:
+        case OMNI_LONG:
+        case OMNI_TIMESTAMP:
+        case OMNI_DOUBLE:
+        case OMNI_CHAR:
+        case OMNI_VARCHAR:
+        case OMNI_BOOLEAN:
+        case OMNI_DECIMAL64:
+        case OMNI_DECIMAL128:
+            this->supportedFlag = true;
+            break;
+        default:
+            this->supportedFlag = false;
+            break;
+    }
+}
+
+void ExprVerifier::Visit(const UnaryExpr &unaryExpr)
+{
+    if (!VisitExpr(*(unaryExpr.exp))) {
+        this->supportedFlag = false;
+        return;
+    }
+    switch (unaryExpr.op) {
+        case omniruntime::expressions::Operator::NOT:
+            this->supportedFlag = true;
+            break;
+        default:
+            this->supportedFlag = false;
+            break;
+    }
+}
+
+void ExprVerifier::Visit(const BinaryExpr &binaryExpr)
+{
+    const type::DataType &leftType = *(binaryExpr.left->GetReturnType());
+    const type::DataType &rightType = *(binaryExpr.right->GetReturnType());
+
+    if (AreInvalidDataTypes(leftType.GetId(), rightType.GetId())) {
+        this->supportedFlag = false;
+        return;
+    }
+
+    if (!VisitExpr(*(binaryExpr.left))) {
+        this->supportedFlag = false;
+        return;
+    }
+    if (!VisitExpr(*(binaryExpr.right))) {
+        this->supportedFlag = false;
+        return;
+    }
+
+    if (binaryExpr.op == omniruntime::expressions::Operator::AND ||
+        binaryExpr.op == omniruntime::expressions::Operator::OR) {
+        this->supportedFlag = (binaryExpr.left->GetReturnTypeId() == binaryExpr.right->GetReturnTypeId() &&
+            binaryExpr.left->GetReturnTypeId() == DataTypeId::OMNI_BOOLEAN);
+        return;
+    }
+
+    if (binaryExpr.left->GetReturnTypeId() == OMNI_BYTE ||binaryExpr.left->GetReturnTypeId() == OMNI_SHORT ||
+        binaryExpr.left->GetReturnTypeId() == OMNI_INT || binaryExpr.left->GetReturnTypeId() == OMNI_LONG ||
+        binaryExpr.left->GetReturnTypeId() == OMNI_DATE32 || binaryExpr.left->GetReturnTypeId() == OMNI_DOUBLE) {
+        this->supportedFlag = true;
+        return;
+    } else if (TypeUtil::IsStringType(binaryExpr.left->GetReturnTypeId()) ||
+        binaryExpr.left->GetReturnTypeId() == OMNI_TIMESTAMP) {
+        switch (binaryExpr.op) {
+            case omniruntime::expressions::Operator::LT:
+            case omniruntime::expressions::Operator::GT:
+            case omniruntime::expressions::Operator::LTE:
+            case omniruntime::expressions::Operator::GTE:
+            case omniruntime::expressions::Operator::EQ:
+            case omniruntime::expressions::Operator::NEQ:
+                this->supportedFlag = true;
+                break;
+            default:
+                this->supportedFlag = false;
+                break;
+        }
+        return;
+    } else if (TypeUtil::IsDecimalType(binaryExpr.left->GetReturnTypeId())) {
+        this->supportedFlag = true;
+        return;
+    }
+    this->supportedFlag = false;
+}
+
+void ExprVerifier::Visit(const InExpr &inExpr)
+{
+    Expr *toCompare = inExpr.arguments[0];
+    switch (toCompare->GetReturnTypeId()) {
+        case OMNI_BYTE:
+        case OMNI_SHORT:
+        case OMNI_INT:
+        case OMNI_DATE32:
+        case OMNI_LONG:
+        case OMNI_TIMESTAMP:
+        case OMNI_DOUBLE:
+        case OMNI_CHAR:
+        case OMNI_VARCHAR:
+        case OMNI_DECIMAL64:
+        case OMNI_DECIMAL128:
+            break;
+        default:
+            this->supportedFlag = false;
+            return;
+    }
+
+    if (!VisitExpr(*toCompare)) {
+        this->supportedFlag = false;
+        return;
+    }
+    for (size_t i = 1; i < inExpr.arguments.size(); i++) {
+        if (AreInvalidDataTypes(toCompare->GetReturnTypeId(), inExpr.arguments[i]->GetReturnTypeId())) {
+            this->supportedFlag = false;
+            return;
+        }
+        if (!VisitExpr(*(inExpr.arguments[i]))) {
+            this->supportedFlag = false;
+            return;
+        }
+    }
+    this->supportedFlag = true;
+}
+
+void ExprVerifier::Visit(const BetweenExpr &betweenExpr)
+{
+    DataTypeId valueTypeId = betweenExpr.value->GetReturnTypeId();
+    if (AreInvalidDataTypes(valueTypeId, betweenExpr.lowerBound->GetReturnTypeId()) &&
+        AreInvalidDataTypes(valueTypeId, betweenExpr.upperBound->GetReturnTypeId())) {
+        this->supportedFlag = false;
+        return;
+    }
+
+    if (!VisitExpr(*betweenExpr.value)) {
+        this->supportedFlag = false;
+        return;
+    }
+    if (!VisitExpr(*betweenExpr.lowerBound)) {
+        this->supportedFlag = false;
+        return;
+    }
+    if (!VisitExpr(*betweenExpr.upperBound)) {
+        this->supportedFlag = false;
+        return;
+    }
+
+    this->supportedFlag = true;
+}
+
+void ExprVerifier::Visit(const IfExpr &ifExpr)
+{
+    Expr *cond = ifExpr.condition;
+    Expr *ifTrue = ifExpr.trueExpr;
+    Expr *ifFalse = ifExpr.falseExpr;
+
+    if (!VisitExpr(*cond)) {
+        this->supportedFlag = false;
+        return;
+    }
+    if (!VisitExpr(*ifTrue)) {
+        this->supportedFlag = false;
+        return;
+    }
+    if (!VisitExpr(*ifFalse)) {
+        this->supportedFlag = false;
+        return;
+    }
+    this->supportedFlag = true;
+}
+
+void ExprVerifier::Visit(const CoalesceExpr &coalesceExpr)
+{
+    Expr *value1Expr = coalesceExpr.value1;
+    Expr *value2Expr = coalesceExpr.value2;
+    if (!VisitExpr(*value1Expr)) {
+        this->supportedFlag = false;
+        return;
+    }
+    if (!VisitExpr(*value2Expr)) {
+        this->supportedFlag = false;
+        return;
+    }
+
+    this->supportedFlag = true;
+}
+
+void ExprVerifier::Visit(const IsNullExpr &isNullExpr)
+{
+    Expr *valueExpr = isNullExpr.value;
+    if (!VisitExpr(*valueExpr)) {
+        this->supportedFlag = false;
+        return;
+    }
+    this->supportedFlag = true;
+}
+
+void ExprVerifier::Visit(const FuncExpr &funcExpr)
+{
+    if (funcExpr.funcName == "LIKE") {
+        this->supportedFlag = false;
+        return;
+    }
+    int numArgs = funcExpr.arguments.size();
+    std::vector params;
+    for (int i = 0; i < numArgs; i++) {
+        params.push_back(funcExpr.arguments[i]->GetReturnTypeId());
+        if (!VisitExpr(*funcExpr.arguments[i])) {
+            this->supportedFlag = false;
+            return;
+        }
+    }
+    auto signature = FunctionSignature(funcExpr.funcName, params, funcExpr.GetReturnTypeId());
+    auto function = codegen::FunctionRegistry::LookupFunction(&signature);
+    if (function == nullptr) {
+        this->supportedFlag = false;
+        return;
+    }
+    this->supportedFlag = true;
+}
+
+void ExprVerifier::Visit(const SwitchExpr &switchExpr)
+{
+    std::vector> whenClause = switchExpr.whenClause;
+    auto size = whenClause.size();
+
+    for (size_t i = 0; i < size; i++) {
+        Expr *cond = whenClause[i].first;
+        Expr *resExpr = whenClause[i].second;
+        if (!VisitExpr(*cond)) {
+            this->supportedFlag = false;
+            return;
+        }
+        if (!VisitExpr(*resExpr)) {
+            this->supportedFlag = false;
+            return;
+        }
+    }
+
+    Expr *elseExpr = switchExpr.falseExpr;
+    if (!VisitExpr(*elseExpr)) {
+        this->supportedFlag = false;
+        return;
+    }
+
+    this->supportedFlag = true;
+}
+}
+}
diff --git a/core/src/expression/expr_verifier.h b/core/src/expression/expr_verifier.h
new file mode 100644
index 0000000..130e394
--- /dev/null
+++ b/core/src/expression/expr_verifier.h
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
+ * Description: Expression Verifier
+ */
+#ifndef OMNI_RUNTIME_EXPR_VERIFIER_H
+#define OMNI_RUNTIME_EXPR_VERIFIER_H
+
+#include "expression/expr_visitor.h"
+
+namespace omniruntime::expressions {
+class ExprVerifier : public ExprVisitor {
+public:
+    void Visit(const omniruntime::expressions::LiteralExpr &literalExpr) override;
+    void Visit(const omniruntime::expressions::FieldExpr &fieldExpr) override;
+    void Visit(const omniruntime::expressions::UnaryExpr &unaryExpr) override;
+    void Visit(const omniruntime::expressions::BinaryExpr &binaryExpr) override;
+    void Visit(const omniruntime::expressions::InExpr &inExpr) override;
+    void Visit(const omniruntime::expressions::BetweenExpr &betweenExpr) override;
+    void Visit(const omniruntime::expressions::IfExpr &ifExpr) override;
+    void Visit(const omniruntime::expressions::CoalesceExpr &coalesceExpr) override;
+    void Visit(const omniruntime::expressions::IsNullExpr &isNullExpr) override;
+    void Visit(const omniruntime::expressions::FuncExpr &funcExpr) override;
+    void Visit(const omniruntime::expressions::SwitchExpr &switchExpr) override;
+    bool VisitExpr(const omniruntime::expressions::Expr &e);
+    bool VisitExpr(const std::shared_ptr &e);
+
+private:
+    bool supportedFlag = false;
+    static bool AreInvalidDataTypes(omniruntime::type::DataTypeId type1, omniruntime::type::DataTypeId type2);
+};
+}
+
+#endif // OMNI_RUNTIME_EXPR_VERIFIER_H
diff --git a/core/src/expression/expr_visitor.cpp b/core/src/expression/expr_visitor.cpp
new file mode 100644
index 0000000..d06d996
--- /dev/null
+++ b/core/src/expression/expr_visitor.cpp
@@ -0,0 +1,62 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
+ * Description: visitor accept methods
+ */
+#include "expr_visitor.h"
+
+using namespace omniruntime::expressions;
+
+void LiteralExpr::Accept(ExprVisitor &visitor) const
+{
+    return visitor.Visit(*this);
+}
+
+void FieldExpr::Accept(ExprVisitor &visitor) const
+{
+    return visitor.Visit(*this);
+}
+
+void BinaryExpr::Accept(ExprVisitor &visitor) const
+{
+    return visitor.Visit(*this);
+}
+
+void InExpr::Accept(ExprVisitor &visitor) const
+{
+    return visitor.Visit(*this);
+}
+
+void BetweenExpr::Accept(ExprVisitor &visitor) const
+{
+    return visitor.Visit(*this);
+}
+
+void SwitchExpr::Accept(ExprVisitor &visitor) const
+{
+    return visitor.Visit(*this);
+}
+
+void IfExpr::Accept(ExprVisitor &visitor) const
+{
+    return visitor.Visit(*this);
+}
+
+void CoalesceExpr::Accept(ExprVisitor &visitor) const
+{
+    return visitor.Visit(*this);
+}
+
+void IsNullExpr::Accept(ExprVisitor &visitor) const
+{
+    return visitor.Visit(*this);
+}
+
+void FuncExpr::Accept(ExprVisitor &visitor) const
+{
+    return visitor.Visit(*this);
+}
+
+void UnaryExpr::Accept(ExprVisitor &visitor) const
+{
+    return visitor.Visit(*this);
+}
diff --git a/core/src/expression/expr_visitor.h b/core/src/expression/expr_visitor.h
new file mode 100644
index 0000000..bd9593f
--- /dev/null
+++ b/core/src/expression/expr_visitor.h
@@ -0,0 +1,26 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
+ * Description: visitor class for expressions
+ */
+#ifndef __OMNI_RUNTIME_EXPRESSION_VISITOR_H__
+#define __OMNI_RUNTIME_EXPRESSION_VISITOR_H__
+
+#include "expressions.h"
+
+class ExprVisitor {
+public:
+    virtual ~ExprVisitor() = default;
+    virtual void Visit(const omniruntime::expressions::LiteralExpr &e) = 0;
+    virtual void Visit(const omniruntime::expressions::FieldExpr &e) = 0;
+    virtual void Visit(const omniruntime::expressions::UnaryExpr &e) = 0;
+    virtual void Visit(const omniruntime::expressions::BinaryExpr &e) = 0;
+    virtual void Visit(const omniruntime::expressions::InExpr &e) = 0;
+    virtual void Visit(const omniruntime::expressions::BetweenExpr &e) = 0;
+    virtual void Visit(const omniruntime::expressions::IfExpr &e) = 0;
+    virtual void Visit(const omniruntime::expressions::CoalesceExpr &e) = 0;
+    virtual void Visit(const omniruntime::expressions::IsNullExpr &e) = 0;
+    virtual void Visit(const omniruntime::expressions::FuncExpr &e) = 0;
+    virtual void Visit(const omniruntime::expressions::SwitchExpr &e) = 0;
+};
+
+#endif
\ No newline at end of file
diff --git a/core/src/expression/expressions.cpp b/core/src/expression/expressions.cpp
new file mode 100644
index 0000000..eed4248
--- /dev/null
+++ b/core/src/expression/expressions.cpp
@@ -0,0 +1,379 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
+ * Description:
+ */
+#include "expressions.h"
+#include 
+#include 
+#include 
+#include "type/data_type.h"
+#include "codegen/func_registry.h"
+#include "util/type_util.h"
+#include "expr_verifier.h"
+#include "expr_printer.h"
+
+using namespace std;
+using namespace omniruntime::type;
+using namespace omniruntime::codegen;
+
+namespace omniruntime {
+namespace expressions {
+
+// Prevent ExprVerifier from being optimized out by the compiler.
+static ExprVerifier globalExprVerifier;
+static ExprPrinter globalExprPrinter;
+
+bool IsNullLiteral(const std::string &value)
+{
+    const std::string loweredNullValue = "null";
+    if (value.size() != loweredNullValue.size()) {
+        return false;
+    }
+    for (uint32_t i = 0; i < loweredNullValue.size(); i++) {
+        if (tolower(value[i]) != loweredNullValue[i]) {
+            return false;
+        }
+    }
+    return true;
+}
+
+bool IsComparisonOperator(Operator op)
+{
+    return op == Operator::GT || op == Operator::GTE || op == Operator::LT || op == Operator::LTE ||
+        op == Operator::EQ || op == Operator::NEQ;
+}
+
+bool IsLogicalOperator(Operator op)
+{
+    return op == Operator::AND || op == Operator::OR || op == Operator::NOT;
+}
+
+Operator StringToOperator(const std::string &opStr)
+{
+    auto opItr = OPERATOR_FROM_STRING.find(opStr);
+    if (opItr != OPERATOR_FROM_STRING.end()) {
+        return opItr->second;
+    }
+    return Operator::INVALIDOP;
+}
+
+ExprType Expr::GetType() const
+{
+    return ExprType::INVALID_E;
+}
+
+DataTypePtr Expr::GetReturnType() const
+{
+    return dataType;
+}
+
+DataTypeId Expr::GetReturnTypeId() const
+{
+    return dataType->GetId();
+}
+
+void Expr::DeleteExprs(const std::vector &exprs)
+{
+    for (Expr *exp : exprs) {
+        delete exp;
+    }
+}
+
+void Expr::DeleteExprs(const std::vector> &exprs)
+{
+    for (const std::vector &expr : exprs) {
+        Expr::DeleteExprs(expr);
+    }
+}
+
+// Literal Expression methods
+LiteralExpr::LiteralExpr() = default;
+
+LiteralExpr::~LiteralExpr()
+{
+    delete stringVal;
+}
+
+ExprType LiteralExpr::GetType() const
+{
+    return ExprType::LITERAL_E;
+}
+
+// Helper constructors for different data types
+LiteralExpr::LiteralExpr(bool val, DataTypePtr dt)
+{
+    dataType = std::move(dt);
+    boolVal = val;
+}
+
+LiteralExpr::LiteralExpr(int8_t val, DataTypePtr dt, bool isNulls)
+{
+    dataType = std::move(dt);
+    byteVal = val;
+    isNull = isNulls;
+}
+
+LiteralExpr::LiteralExpr(int16_t val, DataTypePtr dt, bool isNulls)
+{
+    dataType = std::move(dt);
+    shortVal = val;
+    isNull = isNulls;
+}
+
+LiteralExpr::LiteralExpr(int32_t val, DataTypePtr dt, bool isNulls)
+{
+    dataType = std::move(dt);
+    intVal = val;
+    isNull = isNulls;
+}
+
+LiteralExpr::LiteralExpr(int64_t val, DataTypePtr dt)
+{
+    dataType = std::move(dt);
+    longVal = val;
+}
+
+LiteralExpr::LiteralExpr(double val, DataTypePtr dt)
+{
+    dataType = std::move(dt);
+    doubleVal = val;
+}
+
+LiteralExpr::LiteralExpr(std::string *val, DataTypePtr dt)
+{
+    dataType = std::move(dt);
+    stringVal = val;
+}
+
+// FieldExpr
+FieldExpr::FieldExpr() = default;
+
+FieldExpr::~FieldExpr() = default;
+
+ExprType FieldExpr::GetType() const
+{
+    return ExprType::FIELD_E;
+}
+
+// Helper constructors
+FieldExpr::FieldExpr(int32_t colIdx, DataTypePtr colType)
+{
+    dataType = std::move(colType);
+    colVal = colIdx;
+}
+
+BinaryExpr::BinaryExpr()
+{
+    dataType = BooleanType();
+}
+
+BinaryExpr::BinaryExpr(Operator bop, Expr *leftExpr, Expr *rightExpr, DataTypePtr dt)
+{
+    op = bop;
+    left = leftExpr;
+    right = rightExpr;
+    dataType = std::move(dt);
+}
+
+BinaryExpr::~BinaryExpr()
+{
+    delete left;
+    delete right;
+}
+
+ExprType BinaryExpr::GetType() const
+{
+    return ExprType::BINARY_E;
+}
+
+UnaryExpr::UnaryExpr()
+{
+    dataType = BooleanType();
+}
+
+UnaryExpr::UnaryExpr(Operator logOp, Expr *bodyExpr) : op(logOp), exp(bodyExpr) {}
+
+UnaryExpr::UnaryExpr(Operator uop, Expr *expr, DataTypePtr dt) : op(uop), exp(expr)
+{
+    dataType = std::move(dt);
+}
+
+UnaryExpr::~UnaryExpr()
+{
+    delete exp;
+}
+
+ExprType UnaryExpr::GetType() const
+{
+    return ExprType::UNARY_E;
+}
+
+InExpr::InExpr()
+{
+    dataType = BooleanType();
+}
+
+InExpr::~InExpr()
+{
+    DeleteExprs(arguments);
+}
+
+InExpr::InExpr(std::vector args)
+{
+    dataType = BooleanType();
+    arguments = std::move(args);
+}
+
+ExprType InExpr::GetType() const
+{
+    return ExprType::IN_E;
+}
+
+BetweenExpr::BetweenExpr()
+{
+    dataType = BooleanType();
+}
+
+BetweenExpr::~BetweenExpr()
+{
+    delete value;
+    delete lowerBound;
+    delete upperBound;
+}
+
+BetweenExpr::BetweenExpr(Expr *val, Expr *lowBound, Expr *upBound)
+{
+    dataType = BooleanType();
+    value = val;
+    lowerBound = lowBound;
+    upperBound = upBound;
+}
+
+ExprType BetweenExpr::GetType() const
+{
+    return ExprType::BETWEEN_E;
+}
+
+SwitchExpr::SwitchExpr() : whenClause(), falseExpr() {}
+
+SwitchExpr::~SwitchExpr()
+{
+    for (std::pair &vec : whenClause) {
+        delete vec.first;
+        delete vec.second;
+    }
+    delete falseExpr;
+}
+
+SwitchExpr::SwitchExpr(const std::vector> &whens, Expr *fexp)
+{
+    dataType = fexp->GetReturnType();
+    whenClause = whens;
+    falseExpr = fexp;
+}
+
+ExprType SwitchExpr::GetType() const
+{
+    return ExprType::SWITCH_E;
+}
+
+IfExpr::IfExpr() : condition(), trueExpr(), falseExpr() {}
+
+IfExpr::~IfExpr()
+{
+    delete condition;
+    delete trueExpr;
+    delete falseExpr;
+}
+
+IfExpr::IfExpr(Expr *cond, Expr *texp, Expr *fexp)
+{
+    dataType = texp->GetReturnType();
+    condition = cond;
+    trueExpr = texp;
+    falseExpr = fexp;
+}
+
+ExprType IfExpr::GetType() const
+{
+    return ExprType::IF_E;
+}
+
+CoalesceExpr::CoalesceExpr() : value1(), value2() {}
+
+CoalesceExpr::~CoalesceExpr()
+{
+    delete value1;
+    delete value2;
+}
+
+CoalesceExpr::CoalesceExpr(Expr *val1, Expr *val2)
+{
+    dataType = val1->GetReturnType();
+    value1 = val1;
+    value2 = val2;
+}
+
+ExprType CoalesceExpr::GetType() const
+{
+    return ExprType::COALESCE_E;
+}
+
+IsNullExpr::IsNullExpr() : value() {}
+
+IsNullExpr::~IsNullExpr()
+{
+    delete value;
+}
+
+IsNullExpr::IsNullExpr(Expr *value)
+{
+    dataType = BooleanType();
+
+    this->value = value;
+}
+
+ExprType IsNullExpr::GetType() const
+{
+    return ExprType::IS_NULL_E;
+}
+
+FuncExpr::FuncExpr() : function(nullptr) {}
+
+FuncExpr::~FuncExpr()
+{
+    DeleteExprs(arguments);
+}
+
+FuncExpr::FuncExpr(const std::string &fnName, const std::vector &args, DataTypePtr returnType)
+    : funcName(fnName), arguments(args), functionType(BUILTIN)
+{
+    dataType = std::move(returnType);
+
+    std::vector argTypes(arguments.size());
+    std::transform(arguments.begin(), arguments.end(), argTypes.begin(),
+        [](Expr *expr) -> DataTypeId { return expr->GetReturnTypeId(); });
+    auto signature = FunctionSignature(funcName, argTypes, dataType->GetId());
+    this->function = FunctionRegistry::LookupFunction(&signature);
+}
+
+FuncExpr::FuncExpr(const std::string &fnName, const std::vector &args, DataTypePtr returnType,
+    const Function *function)
+    : funcName(fnName), arguments(args), function(function), functionType(BUILTIN)
+{
+    dataType = std::move(returnType);
+}
+
+FuncExpr::FuncExpr(const std::string &fnName, const std::vector &args, DataTypePtr returnType,
+    ExprFunctionType functionType)
+    : funcName(fnName), arguments(args), function(nullptr), functionType(functionType)
+{
+    dataType = std::move(returnType);
+}
+
+ExprType FuncExpr::GetType() const
+{
+    return ExprType::FUNC_E;
+}
+}
+}
diff --git a/core/src/expression/expressions.h b/core/src/expression/expressions.h
new file mode 100644
index 0000000..a099fd4
--- /dev/null
+++ b/core/src/expression/expressions.h
@@ -0,0 +1,256 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
+ * Description:
+ */
+#ifndef __EXPRESSIONS_H__
+#define __EXPRESSIONS_H__
+
+#include 
+#include 
+#include 
+#include 
+#include "type/data_type.h"
+#include "type/decimal128.h"
+
+class ExprVisitor;
+
+namespace omniruntime {
+namespace expressions {
+using namespace type;
+using namespace codegen;
+
+enum class Operator {
+    // Comparison
+    EQ,
+    NEQ,
+    LT,
+    LTE,
+    GT,
+    GTE,
+    // Logical
+    AND,
+    OR,
+    NOT,
+    // Arithmetic
+    ADD,
+    SUB,
+    MUL,
+    DIV,
+    MOD,
+    TRY_ADD,
+    TRY_SUB,
+    TRY_MUL,
+    TRY_DIV,
+    INVALIDOP
+};
+
+enum class OperatorType { COMPARISON, LOGICAL, ARITHMETIC, INVALIDOPTYPE };
+
+enum class ExprType {
+    LITERAL_E,
+    FIELD_E,
+    BINARY_E,
+    UNARY_E,
+    IN_E,
+    BETWEEN_E,
+    IF_E,
+    SWITCH_E,
+    COALESCE_E,
+    IS_NULL_E,
+    FUNC_E,
+    INVALID_E,
+};
+
+const std::map OPERATOR_FROM_STRING = {{"EQUAL", Operator::EQ}, {"LESS_THAN", Operator::LT},
+    {"LESS_THAN_OR_EQUAL", Operator::LTE}, {"GREATER_THAN_OR_EQUAL", Operator::GTE}, {"GREATER_THAN", Operator::GT},
+    {"NOT_EQUAL", Operator::NEQ}, {"AND", Operator::AND}, {"OR", Operator::OR}, {"NOT", Operator::NOT},
+    {"not", Operator::NOT}, {"ADD", Operator::ADD}, {"SUBTRACT", Operator::SUB}, {"MULTIPLY", Operator::MUL},
+    {"DIVIDE", Operator::DIV}, {"MODULUS", Operator::MOD}, {"TRY_ADD", Operator::TRY_ADD},
+    {"TRY_SUBTRACT", Operator::TRY_SUB}, {"TRY_MULTIPLY", Operator::TRY_MUL}, {"TRY_DIVIDE", Operator::TRY_DIV}};
+
+bool IsNullLiteral(const std::string &value);
+bool IsComparisonOperator(Operator op);
+bool IsLogicalOperator(Operator op);
+Operator StringToOperator(const std::string &opStr);
+
+enum ExprFunctionType { BUILTIN = 0, HIVE_UDF };
+
+class Expr {
+public:
+    DataTypePtr dataType; // dataType of returned value
+    DataTypePtr GetReturnType() const;
+    omniruntime::type::DataTypeId GetReturnTypeId() const;
+    virtual ExprType GetType() const;
+    virtual ~Expr() = default;
+    virtual void Accept(ExprVisitor &visitor) const = 0;
+    static void DeleteExprs(const std::vector &exprs);
+    static void DeleteExprs(const std::vector> &exprs);
+};
+
+class LiteralExpr : public Expr {
+public:
+    bool isNull = false;
+    bool boolVal = false;
+    int8_t byteVal = 0;
+    int16_t shortVal = 0;
+    int32_t intVal = 0;
+    int64_t longVal = 0;
+    double doubleVal = 0;
+    std::string *stringVal = nullptr;
+
+    LiteralExpr();
+    ~LiteralExpr() override;
+    explicit LiteralExpr(bool val, DataTypePtr colType);
+    explicit LiteralExpr(int8_t val, DataTypePtr colType, bool isNull = false);
+    explicit LiteralExpr(int16_t val, DataTypePtr colType, bool isNull = false);
+    explicit LiteralExpr(int32_t val, DataTypePtr colType, bool isNull = false);
+    explicit LiteralExpr(int64_t val, DataTypePtr colType);
+    explicit LiteralExpr(double val, DataTypePtr colType);
+    explicit LiteralExpr(std::string *val, DataTypePtr colType);
+    void Accept(ExprVisitor &visitor) const override;
+    ExprType GetType() const override;
+};
+
+class FieldExpr : public Expr {
+public:
+    bool isNull = false;
+    int32_t colVal = 0;
+
+    FieldExpr();
+    ~FieldExpr() override;
+    FieldExpr(int32_t colIdx, DataTypePtr colType);
+    void Accept(ExprVisitor &visitor) const override;
+    ExprType GetType() const override;
+};
+
+class UnaryExpr : public Expr {
+public:
+    Operator op = Operator::EQ;
+    Expr *exp = nullptr;
+
+    UnaryExpr();
+    ~UnaryExpr() override;
+    UnaryExpr(Operator logOp, Expr *bodyExpr);
+    UnaryExpr(Operator uop, Expr *expr, DataTypePtr dt);
+
+    void Accept(ExprVisitor &visitor) const override;
+    ExprType GetType() const override;
+};
+
+class BinaryExpr : public Expr {
+public:
+    Operator op = Operator::EQ;
+    Expr *left = nullptr;
+    Expr *right = nullptr;
+
+    BinaryExpr();
+    ~BinaryExpr() override;
+    BinaryExpr(Operator bop, Expr *leftExpr, Expr *rightExpr, DataTypePtr dt);
+    void Accept(ExprVisitor &visitor) const override;
+    ExprType GetType() const override;
+};
+
+class InExpr : public Expr {
+public:
+    // first element of arguments is the value to be compared to every other argument
+    std::vector arguments;
+
+    InExpr();
+    ~InExpr() override;
+    explicit InExpr(std::vector args);
+
+    void Accept(ExprVisitor &visitor) const override;
+    ExprType GetType() const override;
+};
+
+class BetweenExpr : public Expr {
+public:
+    Expr *value = nullptr;
+    Expr *lowerBound = nullptr;
+    Expr *upperBound = nullptr;
+
+    BetweenExpr();
+    ~BetweenExpr() override;
+    BetweenExpr(Expr *val, Expr *lowBound, Expr *upBound);
+
+    void Accept(ExprVisitor &visitor) const override;
+    ExprType GetType() const override;
+};
+
+class SwitchExpr : public Expr {
+public:
+    std::vector> whenClause;
+    Expr *falseExpr = nullptr;
+    SwitchExpr();
+    ~SwitchExpr() override;
+    SwitchExpr(const std::vector> &whens, Expr *fexp);
+
+    void Accept(ExprVisitor &visitor) const override;
+    ExprType GetType() const override;
+};
+
+class IfExpr : public Expr {
+public:
+    Expr *condition = nullptr;
+    Expr *trueExpr = nullptr;
+    Expr *falseExpr = nullptr;
+
+    IfExpr();
+    ~IfExpr() override;
+    IfExpr(Expr *cond, Expr *texp, Expr *fexp);
+
+    void Accept(ExprVisitor &visitor) const override;
+    ExprType GetType() const override;
+};
+
+class CoalesceExpr : public Expr {
+public:
+    Expr *value1 = nullptr;
+    Expr *value2 = nullptr;
+
+    CoalesceExpr();
+    ~CoalesceExpr() override;
+    CoalesceExpr(Expr *val1, Expr *val2);
+
+    void Accept(ExprVisitor &visitor) const override;
+    ExprType GetType() const override;
+};
+
+class IsNullExpr : public Expr {
+public:
+    Expr *value = nullptr;
+    IsNullExpr();
+    ~IsNullExpr() override;
+    explicit IsNullExpr(Expr *value);
+
+    void Accept(ExprVisitor &visitor) const override;
+    ExprType GetType() const override;
+};
+
+class FuncExpr : public Expr {
+public:
+    std::string funcName;
+    std::vector arguments;
+    const Function *function;
+    ExprFunctionType functionType;
+
+    FuncExpr();
+    ~FuncExpr() override;
+    FuncExpr(const std::string &fnName, const std::vector &args, DataTypePtr returnType);
+    FuncExpr(
+        const std::string &fnName, const std::vector &args, DataTypePtr returnType, const Function *function);
+    FuncExpr(const std::string &fnName, const std::vector &args, DataTypePtr returnType,
+        ExprFunctionType functionType);
+
+    void Accept(ExprVisitor &visitor) const override;
+    ExprType GetType() const override;
+    static inline bool IsCastStrStr(const omniruntime::expressions::FuncExpr &e)
+    {
+        return (e.funcName == "CAST" || e.funcName == "CAST_null") &&
+            e.arguments[0]->GetReturnTypeId() == omniruntime::type::OMNI_VARCHAR &&
+            e.GetReturnTypeId() == omniruntime::type::OMNI_VARCHAR;
+    }
+};
+} // namespace expressions
+} // namespace omniruntime
+#endif
diff --git a/core/src/expression/jsonparser/jsonparser.cpp b/core/src/expression/jsonparser/jsonparser.cpp
new file mode 100644
index 0000000..2a5baf4
--- /dev/null
+++ b/core/src/expression/jsonparser/jsonparser.cpp
@@ -0,0 +1,462 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2012-2018. All rights reserved.
+ * Description:
+ */
+#include "jsonparser.h"
+#include 
+#include  // for testing
+
+using namespace std;
+using namespace omniruntime::expressions;
+using namespace omniruntime::type;
+
+using Json = nlohmann::json;
+
+Expr *JSONParser::ParseJSONFieldRef(const Json &jsonExpr)
+{
+    auto typeId = static_cast(jsonExpr["dataType"].get());
+    DataTypePtr retType;
+    auto colVal = jsonExpr["colVal"].get();
+    if (TypeUtil::IsStringType(typeId)) {
+        int width = jsonExpr["width"].get();
+        if (typeId == OMNI_CHAR) {
+            retType = std::make_shared(width);
+        } else {
+            retType = std::make_shared(width);
+        }
+    } else if (TypeUtil::IsDecimalType(typeId)) {
+        int precision = jsonExpr["precision"].get();
+        int scale = jsonExpr["scale"].get();
+        if (typeId == OMNI_DECIMAL64) {
+            retType = std::make_shared(precision, scale);
+        } else {
+            retType = std::make_shared(precision, scale);
+        }
+    } else {
+        retType = std::make_shared(typeId);
+    }
+    return new FieldExpr(colVal, std::move(retType));
+}
+
+Expr *JSONParser::ParseJSONLiteral(const Json &jsonExpr)
+{
+    auto typeId = static_cast(jsonExpr["dataType"].get());
+    // Null check on Literals
+    if (jsonExpr["isNull"].get()) {
+        LiteralExpr *expr = nullptr;
+        if (TypeUtil::IsDecimalType(typeId)) {
+            auto precision = jsonExpr["precision"].get();
+            auto scale = jsonExpr["scale"].get();
+            expr = ParserHelper::GetDefaultValueForType(typeId, precision, scale);
+        } else {
+            expr = ParserHelper::GetDefaultValueForType(typeId);
+        }
+
+        if (expr == nullptr) {
+            return nullptr;
+        }
+        expr->isNull = true;
+        return expr;
+    }
+    // proceed with non-null value
+    switch (typeId) {
+        case OMNI_BOOLEAN: {
+            bool boolVal = jsonExpr["value"].get();
+            return new LiteralExpr(boolVal, std::make_shared());
+        }
+        case OMNI_BYTE: {
+            auto byteVal = jsonExpr["value"].get();
+            return new LiteralExpr(byteVal, std::make_shared());
+        }
+        case OMNI_SHORT: {
+            auto shortVal = jsonExpr["value"].get();
+            return new LiteralExpr(shortVal, std::make_shared());
+        }
+        case OMNI_INT: {
+            auto intVal = jsonExpr["value"].get();
+            return new LiteralExpr(intVal, std::make_shared());
+        }
+        case OMNI_DATE32: {
+            auto intVal = jsonExpr["value"].get();
+            return new LiteralExpr(intVal, std::make_shared());
+        }
+        case OMNI_LONG: {
+            auto longVal = jsonExpr["value"].get();
+            return new LiteralExpr(longVal, std::make_shared());
+        }
+        case OMNI_TIMESTAMP: {
+            auto timestampVal = jsonExpr["value"].get();
+            return new LiteralExpr(timestampVal, std::make_shared());
+        }
+        case OMNI_DOUBLE: {
+            auto doubleVal = jsonExpr["value"].get();
+            return new LiteralExpr(doubleVal, std::make_shared());
+        }
+        case OMNI_DECIMAL64: {
+            auto decimalVal = jsonExpr["value"].get();
+            return new LiteralExpr(decimalVal, std::make_shared(jsonExpr["precision"].get(),
+                jsonExpr["scale"].get()));
+        }
+        case OMNI_DECIMAL128: {
+            auto *dec128String = new string(jsonExpr["value"].get());
+            return new LiteralExpr(dec128String, std::make_shared(
+                jsonExpr["precision"].get(), jsonExpr["scale"].get()));
+        }
+        case OMNI_CHAR: {
+            auto *stringVal = new string(jsonExpr["value"].get());
+            auto width = jsonExpr["width"].get();
+            return new LiteralExpr(stringVal, std::make_shared(width));
+        }
+        case OMNI_VARCHAR: {
+            auto *stringVal = new string(jsonExpr["value"].get());
+            auto width = jsonExpr["width"].get();
+            return new LiteralExpr(stringVal, std::make_shared(width));
+        }
+        default:
+            return new LiteralExpr(0, std::make_shared());
+    }
+}
+
+Expr *JSONParser::ParseJSONBinary(const Json &jsonExpr)
+{
+    Operator op = StringToOperator(jsonExpr["operator"].get());
+    if (op == Operator::INVALIDOP) {
+        return nullptr;
+    }
+
+    DataTypePtr dataType = ParserHelper::GetReturnDataType(jsonExpr);
+    if (dataType == nullptr) {
+        return nullptr;
+    }
+
+    Expr *leftExpr = ParseJSON(jsonExpr["left"]);
+    if (leftExpr == nullptr) {
+        return nullptr;
+    }
+    Expr *rightExpr = ParseJSON(jsonExpr["right"]);
+    if (rightExpr == nullptr) {
+        delete leftExpr;
+        return nullptr;
+    }
+
+    return new BinaryExpr(op, leftExpr, rightExpr, std::move(dataType));
+}
+
+Expr *JSONParser::ParseJSONUnary(const Json &jsonExpr)
+{
+    Operator op = StringToOperator(jsonExpr["operator"].get());
+    if (op == Operator::INVALIDOP) {
+        return nullptr;
+    }
+
+    Expr *expr = ParseJSON(jsonExpr["expr"]);
+    if (expr == nullptr) {
+        return nullptr;
+    }
+    return new UnaryExpr(op, expr, std::make_shared());
+}
+
+Expr *JSONParser::ParseJSONIn(const Json &jsonExpr)
+{
+    std::vector args;
+    for (const auto &item : jsonExpr["arguments"].items()) {
+        Expr *arg = ParseJSON(item.value());
+        if (arg != nullptr) {
+            args.push_back(arg);
+        } else {
+            Expr::DeleteExprs(args);
+            return nullptr;
+        }
+    }
+    return new InExpr(args);
+}
+
+Expr *JSONParser::ParseJSONBetween(const Json &jsonExpr)
+{
+    Expr *val = ParseJSON(jsonExpr["value"]);
+    if (val == nullptr) {
+        return nullptr;
+    }
+    Expr *lowBoundExpr = ParseJSON(jsonExpr["lower_bound"]);
+    if (lowBoundExpr == nullptr) {
+        delete val;
+        return nullptr;
+    }
+    Expr *upBoundExpr = ParseJSON(jsonExpr["upper_bound"]);
+    if (upBoundExpr == nullptr) {
+        delete val;
+        delete lowBoundExpr;
+        return nullptr;
+    }
+
+    return new BetweenExpr(val, lowBoundExpr, upBoundExpr);
+}
+
+static void DeleteWhenClause(const std::vector> &whenClause)
+{
+    for (auto iter = whenClause.begin(); iter != whenClause.end(); iter++) {
+        delete iter->first;
+        delete iter->second;
+    }
+}
+
+Expr *JSONParser::ParseJSONSwitch(const Json &jsonExpr)
+{
+    auto numOfCases = jsonExpr["numOfCases"].get();
+    std::vector> whenClause;
+    for (int32_t i = 0; i < numOfCases; i++) {
+        Expr *left = ParseJSON(jsonExpr["input"]);
+        if (left == nullptr) {
+            DeleteWhenClause(whenClause);
+            return nullptr;
+        }
+        Expr *right = ParseJSON(jsonExpr["Case" + std::to_string(i + 1)]["when"]);
+        if (right == nullptr) {
+            delete left;
+            DeleteWhenClause(whenClause);
+            return nullptr;
+        }
+        Expr *result = ParseJSON(jsonExpr["Case" + std::to_string(i + 1)]["result"]);
+        if (result == nullptr) {
+            delete left;
+            delete right;
+            DeleteWhenClause(whenClause);
+            return nullptr;
+        }
+        auto *condition = new BinaryExpr(Operator::EQ, left, right, std::make_shared());
+        std::pair when = make_pair(condition, result);
+        whenClause.push_back(when);
+    }
+
+    Expr *elseExpr = ParseJSON(jsonExpr["else"]);
+    if (elseExpr == nullptr) {
+        DeleteWhenClause(whenClause);
+        return nullptr;
+    }
+    if (TypeUtil::IsStringType(elseExpr->GetReturnTypeId()) && elseExpr->GetType() == ExprType::LITERAL_E &&
+        static_cast(elseExpr)->stringVal->compare("null") == 0) {
+        auto literalExpr = ParserHelper::GetDefaultValueForType(elseExpr->GetReturnTypeId());
+        delete elseExpr;
+        if (literalExpr == nullptr) {
+            DeleteWhenClause(whenClause);
+            return nullptr;
+        }
+        return new SwitchExpr(whenClause, literalExpr);
+    }
+    return new SwitchExpr(whenClause, elseExpr);
+}
+
+Expr *JSONParser::ParseJSONIf(const Json &jsonExpr)
+{
+    Expr *cond = ParseJSON(jsonExpr["condition"]);
+    if (cond == nullptr) {
+        return nullptr;
+    }
+    Expr *trueExpr = ParseJSON(jsonExpr["if_true"]);
+    if (trueExpr == nullptr) {
+        delete cond;
+        return nullptr;
+    }
+    Expr *falseExpr = ParseJSON(jsonExpr["if_false"]);
+    if (falseExpr == nullptr) {
+        delete cond;
+        delete trueExpr;
+        return nullptr;
+    }
+    if (TypeUtil::IsStringType(falseExpr->GetReturnTypeId()) && falseExpr->GetType() == ExprType::LITERAL_E &&
+        static_cast(falseExpr)->stringVal->compare("null") == 0) {
+        delete falseExpr;
+        auto literalExpr = ParserHelper::GetDefaultValueForType(trueExpr->GetReturnTypeId());
+        if (literalExpr == nullptr) {
+            delete cond;
+            delete trueExpr;
+            return nullptr;
+        }
+        return new IfExpr(cond, trueExpr, literalExpr);
+    }
+
+    return new IfExpr(cond, trueExpr, falseExpr);
+}
+
+Expr *JSONParser::ParseJSONCoalesce(const Json &jsonExpr)
+{
+    Expr *val1 = ParseJSON(jsonExpr["value1"]);
+    if (val1 == nullptr) {
+        return nullptr;
+    }
+    Expr *val2 = ParseJSON(jsonExpr["value2"]);
+    if (val2 == nullptr) {
+        delete val1;
+        return nullptr;
+    }
+
+    return new CoalesceExpr(val1, val2);
+}
+
+Expr *JSONParser::ParseJsonIsNull(const Json &jsonExpr)
+{
+    Expr *val = ParseJSON(jsonExpr["arguments"].at(0));
+    if (val == nullptr) {
+        return nullptr;
+    }
+    return new IsNullExpr(val);
+}
+
+Expr *JSONParser::ParseJSONFunc(const Json &jsonExpr)
+{
+    string funcName = jsonExpr["function_name"];
+    auto retTypeId = static_cast(jsonExpr["returnType"].get());
+    std::vector args;
+    DataTypePtr retType;
+    int32_t width = INT32_MAX;
+    int32_t precision;
+    int32_t scale;
+
+    for (const auto &item : jsonExpr["arguments"].items()) {
+        Expr *arg = ParseJSON(item.value());
+        if (arg != nullptr) {
+            args.push_back(arg);
+        } else {
+            Expr::DeleteExprs(args);
+            return nullptr;
+        }
+    }
+
+    if (TypeUtil::IsStringType(retTypeId)) {
+        width = jsonExpr.contains("width") ? jsonExpr["width"].get() : width;
+        if (retTypeId == OMNI_CHAR) {
+            retType = std::make_shared(width);
+        } else {
+            retType = std::make_shared(width);
+        }
+    } else if (TypeUtil::IsDecimalType(retTypeId)) {
+        precision = jsonExpr["precision"].get();
+        scale = jsonExpr["scale"].get();
+        if (retTypeId == OMNI_DECIMAL64) {
+            retType = std::make_shared(precision, scale);
+        } else {
+            retType = std::make_shared(precision, scale);
+        }
+    } else {
+        retType = std::make_shared(retTypeId);
+    }
+
+    // CAST short-circuit - Convert CAST function of a type to its own type to DataExpr
+    if (funcName == "CAST" && args.size() == 1) {
+        auto argReturnType = args[0]->GetReturnType().get();
+        if (retTypeId == argReturnType->GetId()) {
+            if (TypeUtil::IsStringType(retTypeId)) {
+                auto argWidth = static_cast(argReturnType)->GetWidth();
+                auto retWidth = static_cast(retType.get())->GetWidth();
+                if (argWidth <= retWidth) {
+                    return args[0];
+                }
+            } else if (TypeUtil::IsDecimalType(retTypeId)) {
+                auto argScale = static_cast(argReturnType)->GetScale();
+                auto argPrecision = static_cast(argReturnType)->GetPrecision();
+                auto retScale = static_cast(retType.get())->GetScale();
+                auto retPrecision = static_cast(retType.get())->GetPrecision();
+                if (argScale == retScale && argPrecision <= retPrecision) {
+                    return args[0];
+                }
+            } else {
+                return args[0];
+            }
+        }
+    }
+
+    // check rlike since we only support ^d+$ currently, all other regex are fallback
+    if (funcName == "RLike" && args.size() == 2) {
+        auto secondArg = args[1];
+        if (secondArg->GetType() != ExprType::LITERAL_E) {
+            Expr::DeleteExprs(args);
+            return nullptr;
+        }
+
+        auto literalExpr = static_cast(secondArg);
+        if (*(literalExpr->stringVal) != "^\\d+$") {
+            Expr::DeleteExprs(args);
+            return nullptr;
+        }
+    }
+
+    // Check that the signature matches
+    vector argTypes(args.size());
+    std::transform(args.begin(), args.end(), argTypes.begin(),
+        [](Expr *expr) -> DataTypeId { return expr->GetReturnTypeId(); });
+    auto signature = FunctionSignature(funcName, argTypes, retTypeId);
+    auto function = omniruntime::codegen::FunctionRegistry::LookupFunction(&signature);
+    if (function != nullptr) {
+        return new FuncExpr(funcName, args, std::move(retType), function);
+    }
+
+    auto &hiveUdfClass = omniruntime::codegen::FunctionRegistry::LookupHiveUdf(funcName);
+    if (!hiveUdfClass.empty()) {
+        return new FuncExpr(hiveUdfClass, args, std::move(retType), HIVE_UDF);
+    }
+    LogWarn("Function not supported: %s", funcName.c_str());
+
+    Expr::DeleteExprs(args);
+    // if operator is not supported, return nullptr
+    return nullptr;
+}
+
+Expr *JSONParser::ParseJSON(const Json &jsonExpr)
+{
+    string exprTypeStr = jsonExpr["exprType"].get();
+    if (exprTypeStr == "FIELD_REFERENCE") {
+        return ParseJSONFieldRef(jsonExpr);
+    } else if (exprTypeStr == "LITERAL") {
+        return ParseJSONLiteral(jsonExpr);
+    } else if (exprTypeStr == "BINARY") {
+        return ParseJSONBinary(jsonExpr);
+    } else if (exprTypeStr == "UNARY") {
+        return ParseJSONUnary(jsonExpr);
+    } else if (exprTypeStr == "IN") {
+        return ParseJSONIn(jsonExpr);
+    } else if (exprTypeStr == "BETWEEN") {
+        return ParseJSONBetween(jsonExpr);
+    } else if (exprTypeStr == "IF") {
+        return ParseJSONIf(jsonExpr);
+    } else if (exprTypeStr == "COALESCE") {
+        return ParseJSONCoalesce(jsonExpr);
+    } else if (exprTypeStr == "IS_NULL") {
+        return ParseJsonIsNull(jsonExpr);
+    } else if (exprTypeStr == "FUNC" || exprTypeStr == "FUNCTION") {
+        return ParseJSONFunc(jsonExpr);
+    } else if (exprTypeStr == "SWITCH") {
+        return ParseJSONSwitch(jsonExpr);
+    } else {
+        // return nullptr if ExprType not supported
+        return nullptr;
+    }
+}
+
+std::vector JSONParser::ParseJSON(nlohmann::json *expressions,
+    int32_t numberOfExpressions)
+{
+    std::vector result;
+    for (int32_t i = 0; i < numberOfExpressions; i++) {
+        Expr *expression = ParseJSON(expressions[i]);
+        if (expression == nullptr) {
+            LogWarn("The %d-th expression is not supported: %s", i, expressions[i].dump(1).c_str());
+            Expr::DeleteExprs(result);
+            result.clear();
+            break;
+        }
+        result.push_back(expression);
+    }
+    return result;
+}
+
+Expr *JSONParser::ParseJSON(const std::string &exprStr)
+{
+    omniruntime::expressions::Expr *expr = nullptr;
+    if (!exprStr.empty()) {
+        expr = JSONParser::ParseJSON(nlohmann::json::parse(exprStr));
+        if (expr == nullptr) {
+            LogWarn("The expression is not supported: %s", exprStr.c_str());
+        }
+    }
+    return expr;
+}
\ No newline at end of file
diff --git a/core/src/expression/jsonparser/jsonparser.h b/core/src/expression/jsonparser/jsonparser.h
new file mode 100644
index 0000000..85f92f6
--- /dev/null
+++ b/core/src/expression/jsonparser/jsonparser.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2020-2020. All rights reserved.
+ */
+#ifndef __JSONPARSER_H__
+#define __JSONPARSER_H__
+
+#include 
+#include 
+
+#include 
+#include "expression/parserhelper.h"
+#include "codegen/func_registry.h"
+#include "util/type_util.h"
+#include "expression/expressions.h"
+
+class JSONParser {
+public:
+    static omniruntime::expressions::Expr *ParseJSON(const nlohmann::json &jsonExpr);
+    static std::vector ParseJSON(nlohmann::json expressions[],
+        int32_t numberOfExpressions);
+    static omniruntime::expressions::Expr *ParseJSON(const std::string &expression);
+
+private:
+    static omniruntime::expressions::Expr *ParseJSONFieldRef(const nlohmann::json &jsonExpr);
+    static omniruntime::expressions::Expr *ParseJSONLiteral(const nlohmann::json &jsonExpr);
+    static omniruntime::expressions::Expr *ParseJSONBinary(const nlohmann::json &jsonExpr);
+    static omniruntime::expressions::Expr *ParseJSONUnary(const nlohmann::json &jsonExpr);
+    static omniruntime::expressions::Expr *ParseJSONIn(const nlohmann::json &jsonExpr);
+    static omniruntime::expressions::Expr *ParseJSONBetween(const nlohmann::json &jsonExpr);
+    static omniruntime::expressions::Expr *ParseJSONIf(const nlohmann::json &jsonExpr);
+    static omniruntime::expressions::Expr *ParseJSONCoalesce(const nlohmann::json &jsonExpr);
+    static omniruntime::expressions::Expr *ParseJsonIsNull(const nlohmann::json &jsonExpr);
+    static omniruntime::expressions::Expr *ParseJSONFunc(const nlohmann::json &jsonExpr);
+    static omniruntime::expressions::Expr *ParseJSONSwitch(const nlohmann::json &jsonExpr);
+};
+
+#endif
\ No newline at end of file
diff --git a/core/src/expression/parser/parser.cpp b/core/src/expression/parser/parser.cpp
new file mode 100644
index 0000000..87583aa
--- /dev/null
+++ b/core/src/expression/parser/parser.cpp
@@ -0,0 +1,381 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
+ * Description: parser function
+ */
+#include "parser.h"
+#include 
+
+using namespace std;
+using namespace omniruntime::expressions;
+using namespace omniruntime::type;
+
+Parser::Parser() {}
+
+Parser::~Parser() {}
+
+namespace {
+const string OPERATOR_PREFIX = "$operator$";
+const int32_t SUBSTR_LEN = 10;
+const int32_t ARG2 = 2;
+}
+
+// Helper function to remove operator prefix if it is there
+string DemangleOperator(string opStr)
+{
+    if (opStr.size() > SUBSTR_LEN && opStr.substr(0, SUBSTR_LEN) == OPERATOR_PREFIX) {
+        return opStr.substr(SUBSTR_LEN);
+    }
+    return opStr;
+}
+
+OperatorType GetBinaryOperatorType(string opStr)
+{
+    vector allCmpOps { "LESS_THAN", "LESS_THAN_OR_EQUAL", "GREATER_THAN", "GREATER_THAN_OR_EQUAL",
+        "EQUAL",     "NOT_EQUAL" };
+    vector allLogOps { "AND", "OR" };
+    vector allArithOps { "ADD", "SUBTRACT", "MULTIPLY", "DIVIDE", "MODULUS" };
+    for (const string &cmpOp : allCmpOps) {
+        if (opStr == cmpOp) {
+            return OperatorType::COMPARISON;
+        }
+    }
+    for (const string &logOp : allLogOps) {
+        if (opStr == logOp) {
+            return OperatorType::LOGICAL;
+        }
+    }
+    for (const string &arithOp : allArithOps) {
+        if (opStr == arithOp) {
+            return OperatorType::ARITHMETIC;
+        }
+    }
+    return OperatorType::INVALIDOPTYPE;
+}
+
+bool IsUnaryOperator(const string &opStr)
+{
+    vector allUnaryOps { "NOT", "not" };
+    for (const string &unaryOp : allUnaryOps) {
+        if (opStr == unaryOp) {
+            return true;
+        }
+    }
+    return false;
+}
+
+string Parser::StripString(const string &input)
+{
+    // remove spaces from input but not from inside strings
+    string newInput;
+    bool isInString = false;
+    for (char i : input) {
+        if (i == '\'') {
+            isInString = !isInString;
+            newInput.push_back(i);
+        } else if (i == ' ') {
+            if (isInString) {
+                newInput.push_back(i);
+            }
+        } else {
+            newInput.push_back(i);
+        }
+    }
+    return newInput;
+}
+
+DataTypeId ParseReturnType(const string &typeString)
+{
+    int endIdx = 2;
+    auto widthIdx = typeString.find('[');
+    if (widthIdx != string::npos) {
+        if (stoi(typeString.substr(0, endIdx)) == OMNI_CHAR) {
+            return OMNI_CHAR;
+        }
+    }
+    if (typeString.find_first_not_of("0123456789") == string::npos && stoi(typeString) < INT32_MAX) {
+        int typeOrdinal = stoi(typeString);
+        return static_cast(typeOrdinal);
+    }
+    LogError("Invalid return type: %s", typeString.c_str());
+    return OMNI_INVALID;
+}
+
+std::vector Parser::ParseExpressions(const string expressions[],
+    int32_t numberOfExpressions, DataTypes &inputTypes)
+{
+    std::vector vExprs;
+    for (int32_t i = 0; i < numberOfExpressions; i++) {
+        Expr *expr = ParseRowExpression(expressions[i], inputTypes, inputTypes.GetSize());
+        if (expr == nullptr) {
+            Expr::DeleteExprs(vExprs);
+            return {};
+        }
+        vExprs.push_back(expr);
+    }
+    return vExprs;
+}
+
+Expr *Parser::ParseRowExpression(const string &inputStr, DataTypes &inputTypes, int32_t vecCount)
+{
+    string input = this->StripString(inputStr);
+    auto firstParenInd = input.find('(');
+    // Check if it is just data (i.e. 123, #4, 34.4)
+    if (firstParenInd == string::npos) {
+        if (input[0] == '#') {
+            return GenerateFieldExpr(input, inputTypes);
+        } else {
+            return GenerateLiteralExpr(input);
+        }
+    }
+
+    // demangled operator string
+    string opStr = DemangleOperator(input.substr(0, firstParenInd));
+    string exprStr = input.substr(firstParenInd + 1, input.size() - firstParenInd - 1 - 1);
+
+    // ensure that strings and parentheses are respected
+    vector commaPositions; // indices of commas in exprStr
+    int numCommas = 0;
+    int parenCount = 0;
+    bool outsideQuotes = true;
+    for (uint32_t i = 0; i < exprStr.size(); i++) {
+        if (exprStr[i] == ',' && parenCount == 0 && outsideQuotes) {
+            commaPositions.push_back(i);
+            numCommas++;
+        }
+        if (exprStr[i] == '\'') {
+            outsideQuotes = !outsideQuotes;
+        }
+        if (exprStr[i] == '(') {
+            parenCount++;
+        }
+        if (exprStr[i] == ')') {
+            parenCount--;
+        }
+    }
+    commaPositions.push_back(exprStr.size());
+
+    // Place all of the arguments into a vector first
+    vector args;
+    auto expr = ParseRowExpression(exprStr.substr(0, commaPositions[0]), inputTypes, vecCount);
+    if (expr == nullptr) {
+        return nullptr;
+    }
+    args.push_back(expr);
+    for (int i = 1; i <= numCommas; i++) {
+        string currVal = exprStr.substr(commaPositions[i - 1] + 1, commaPositions[i] - commaPositions[i - 1] - 1);
+        expr = ParseRowExpression(currVal, inputTypes, vecCount);
+        if (expr == nullptr) {
+            return nullptr;
+        }
+        args.push_back(expr);
+    }
+
+    return ParseRowExpressionHelper(opStr, args);
+}
+
+Expr *Parser::ParseRowExpressionHelper(string opStr, vector args)
+{
+    auto typeIdx = opStr.find(':');
+    int stepSize = 4;
+    int32_t width = INT32_MAX;
+    omniruntime::type::DataTypePtr type;
+    DataTypeId typeId;
+    if (typeIdx != string::npos) {
+        typeId = ParseReturnType(opStr.substr(typeIdx + 1));
+        if (typeId == OMNI_CHAR) {
+            width = stoi(opStr.substr(typeIdx + stepSize, opStr.size() - typeIdx - stepSize));
+            type = std::make_shared(width);
+        } else {
+            type = std::make_shared(typeId);
+        }
+        opStr = opStr.substr(0, typeIdx);
+    }
+
+    // BinaryExpr
+    OperatorType binRetType = GetBinaryOperatorType(opStr);
+    if (binRetType != OperatorType::INVALIDOPTYPE && args.size() == ARG2) {
+        return new BinaryExpr(StringToOperator(DemangleOperator(opStr)), args[0], args[1], std::move(type));
+    }
+
+    // UnaryExpr
+    // only handling NOT for now
+    if (IsUnaryOperator(opStr) && args.size() == 1) {
+        return new UnaryExpr(StringToOperator(DemangleOperator(opStr)), args[0], std::move(type));
+    }
+
+    // Special form
+    // Special forms are IN, BETWEEN, IF, COALESCE
+    if (opStr == "BETWEEN") {
+        return new BetweenExpr(args[0], args[1], args[ARG2]);
+    }
+    if (opStr == "IN") {
+        return new InExpr(args);
+    }
+    if (opStr == "COALESCE") {
+        return new CoalesceExpr(args[0], args[1]);
+    }
+    if (opStr == "IF") {
+        if (TypeUtil::IsStringType(args[ARG2]->GetReturnTypeId()) && args[ARG2]->GetType() == ExprType::LITERAL_E &&
+            static_cast(args[ARG2])->stringVal->compare("null") == 0) {
+            return new IfExpr(args[0], args[1], ParserHelper::GetDefaultValueForType(args[1]->GetReturnTypeId()));
+        }
+        return new IfExpr(args[0], args[1], args[ARG2]);
+    }
+    if (opStr == "IS_NULL") {
+        return new IsNullExpr(args[0]);
+    }
+    if (opStr == "IS_NOT_NULL") {
+        auto isNullExpr = new IsNullExpr(args[0]);
+        return new UnaryExpr(Operator::NOT, isNullExpr, std::move(type));
+    }
+    // When casting to the same type, the result is the argument itself
+    // Treat argument as constant DataExpr instead of returning FuncExpr
+    if (opStr == "CAST" && args.size() == 1 && (typeId == args[0]->GetReturnTypeId())) {
+        if (args[0]->GetType() == ExprType::LITERAL_E) {
+            return static_cast(args[0]);
+        } else if (args[0]->GetType() == ExprType::FIELD_E) {
+            return static_cast(args[0]);
+        } else {
+            return nullptr;
+        }
+    }
+
+    // Function
+    // Check that the signature matches
+    vector argTypes(args.size());
+    std::transform(args.begin(), args.end(), argTypes.begin(),
+        [](Expr *expr) -> DataTypeId { return expr->GetReturnTypeId(); });
+    for (size_t i = 0; i < argTypes.size(); i++) {
+        if (argTypes[i] == omniruntime::type::OMNI_DATE32) {
+            argTypes[i] = omniruntime::type::OMNI_INT;
+        }
+    }
+    auto signature = FunctionSignature(opStr, argTypes, type->GetId());
+    auto function = omniruntime::codegen::FunctionRegistry::LookupFunction(&signature);
+    if (function != nullptr) {
+        return new FuncExpr(opStr, args, std::move(type), function);
+    }
+
+    // No expression can be matched
+    LogWarn("operator is not supported: %s", opStr.c_str());
+    return nullptr;
+}
+
+// Helper function to turn all % to .* for regex wildcard matching
+string *FixString(const string &dataStr)
+{
+    auto *fixedStr = new string("");
+    for (char i : dataStr) {
+        if (i == '%') {
+            fixedStr->push_back('.');
+            fixedStr->push_back('*');
+        } else {
+            fixedStr->push_back(i);
+        }
+    }
+    return fixedStr;
+}
+
+LiteralExpr *Parser::GenerateLiteralExprHelper(const string &literalStr, DataTypePtr currType)
+{
+    switch (currType->GetId()) {
+        // handle boolean as int32
+        case OMNI_BOOLEAN: {
+            return new LiteralExpr(stoi(literalStr), std::move(currType));
+        }
+        case OMNI_BYTE: {
+            return new LiteralExpr(static_cast(stoi(literalStr)), std::move(currType));
+        }
+        case OMNI_SHORT: {
+            return new LiteralExpr(static_cast(stoi(literalStr)), std::move(currType));
+        }
+        case OMNI_INT:
+        case OMNI_DATE32: {
+            LiteralExpr *e = new LiteralExpr(stoi(literalStr), std::move(currType));
+            e->longVal = e->intVal;
+            e->doubleVal = e->intVal;
+            return e;
+        }
+        // Need to handle decimals properly
+        case OMNI_DECIMAL128: {
+            string *dec128String = new string(literalStr);
+            return new LiteralExpr(dec128String, std::move(currType));
+        }
+        case OMNI_DECIMAL64:
+        case OMNI_TIMESTAMP:
+        case OMNI_LONG: {
+            return new LiteralExpr(stol(literalStr), std::move(currType));
+        }
+        case OMNI_DOUBLE: {
+            return new LiteralExpr(stod(literalStr), std::move(currType));
+        }
+        case OMNI_CHAR:
+        case OMNI_VARCHAR: {
+            return new LiteralExpr(FixString(literalStr), std::move(currType));
+        }
+        case OMNI_NONE: {
+            return new LiteralExpr(0, std::move(currType));
+        }
+        default: {
+            LogError("type %u is not supported", currType->GetId());
+            return nullptr;
+        }
+    }
+}
+
+FieldExpr *Parser::GenerateFieldExpr(string fieldStr, const DataTypes &inputTypes)
+{
+    int colIdx = stoi(fieldStr.substr(1));
+    const DataTypePtr &colType = inputTypes.GetType(colIdx);
+    return new FieldExpr(colIdx, colType);
+}
+
+LiteralExpr *Parser::GenerateLiteralExpr(string literalStr)
+{
+    auto typeIdx = literalStr.find(':');
+    int stepSize = 4;
+    int32_t width = INT32_MAX;
+    DataTypePtr currType;
+    DataTypeId currTypeId;
+    if (typeIdx != string::npos) {
+        currTypeId = ParseReturnType(literalStr.substr(typeIdx + 1));
+        if (currTypeId == OMNI_CHAR) {
+            width = stoi(literalStr.substr(typeIdx + stepSize, literalStr.size() - typeIdx - stepSize));
+        }
+        literalStr = literalStr.substr(0, typeIdx);
+    } else {
+        LogError("Unknown constant type for expr: %s", literalStr.c_str());
+        return nullptr;
+    }
+
+    // Case with boolean true/false
+    if (literalStr == "true" || literalStr == "false") {
+        currType = BooleanType();
+        return new LiteralExpr(literalStr == "true", std::move(currType));
+    }
+
+    // trim the single quotes for string values if there is any
+    if (TypeUtil::IsStringType(currTypeId) && literalStr[0] == '\'' && literalStr[literalStr.size() - 1] == '\'') {
+        literalStr = literalStr.substr(1, literalStr.size() - 1 - 1);
+    }
+
+    // case with null constants
+    if (IsNullLiteral(literalStr)) {
+        auto expr = ParserHelper::GetDefaultValueForType(currTypeId);
+        expr->isNull = true;
+        return expr;
+    }
+
+    if (TypeUtil::IsStringType(currTypeId)) {
+        if (currTypeId == OMNI_CHAR) {
+            currType = std::make_shared(width);
+        } else {
+            currType = std::make_shared(width);
+        }
+    } else {
+        currType = std::make_shared(currTypeId);
+    }
+
+    // Case with regular data (int, long, double, string ...)
+    return GenerateLiteralExprHelper(literalStr, std::move(currType));
+}
diff --git a/core/src/expression/parser/parser.h b/core/src/expression/parser/parser.h
new file mode 100644
index 0000000..c3e0063
--- /dev/null
+++ b/core/src/expression/parser/parser.h
@@ -0,0 +1,44 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
+ * Description: parser function
+ */
+#ifndef __PARSER_H__
+#define __PARSER_H__
+
+#include "expression/parserhelper.h"
+#include "type/data_types.h"
+#include "codegen/func_registry.h"
+
+enum ParserFormat {
+    STRING = 0,
+    JSON = 1
+};
+
+class Parser {
+public:
+    Parser();
+    ~Parser();
+
+    omniruntime::expressions::Expr *ParseRowExpression(const std::string &input,
+        omniruntime::type::DataTypes &inputTypes, int32_t vecCount);
+
+    std::vector ParseExpressions(const std::string expressions[],
+        int32_t numberOfExpressions, omniruntime::type::DataTypes &inputTypes);
+
+    omniruntime::expressions::Expr *ParseRowExpressionHelper(std::string opStr,
+        std::vector args);
+
+    static omniruntime::expressions::FieldExpr *GenerateFieldExpr(std::string fieldStr,
+        const omniruntime::type::DataTypes &vecTypePtr);
+    static omniruntime::expressions::LiteralExpr *GenerateLiteralExpr(std::string literalStr);
+    static omniruntime::expressions::LiteralExpr *GenerateLiteralExprHelper(const std::string &literalStr,
+        omniruntime::expressions::DataTypePtr inputType);
+
+private:
+    ParserHelper ph;
+    // Helper function to strip a string but keep spaces intact inside string literals
+    static std::string StripString(const std::string &input);
+};
+
+
+#endif
\ No newline at end of file
diff --git a/core/src/expression/parserhelper.cpp b/core/src/expression/parserhelper.cpp
new file mode 100644
index 0000000..e4ba71a
--- /dev/null
+++ b/core/src/expression/parserhelper.cpp
@@ -0,0 +1,110 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
+ */
+#include "parserhelper.h"
+
+using namespace std;
+using namespace omniruntime::expressions;
+using namespace omniruntime::type;
+
+constexpr int8_t BYTE_DEFAULT_VALUE = 0;
+constexpr int16_t SHORT_DEFAULT_VALUE = 0;
+constexpr int32_t INT_DEFAULT_VALUE = 0;
+constexpr int64_t LONG_DEFAULT_VALUE = 0L;
+constexpr double DOUBLE_DEFAULT_VALUE = 0.000;
+constexpr bool BOOL_DEFAULT_VALUE = true;
+constexpr char CHAR_DEFAULT_VALUE[] = "NULL";
+constexpr char DECIMAL128_DEFAULT_VALUE[] = "0";
+constexpr int32_t CHAR_DEFAULT_WIDTH = 50;
+
+omniruntime::expressions::LiteralExpr *ParserHelper::GetDefaultValueForType(DataTypeId destTypeId, int32_t precision,
+    int32_t scale)
+{
+    DataTypePtr destType;
+    if (TypeUtil::IsDecimalType(destTypeId)) {
+        switch (destTypeId) {
+            case OMNI_DECIMAL64: {
+                destType = std::make_shared(precision, scale);
+                return new LiteralExpr(LONG_DEFAULT_VALUE, std::move(destType));
+            }
+            case OMNI_DECIMAL128:
+            default: {
+                destType = std::make_shared(precision, scale);
+                return new LiteralExpr(new string(DECIMAL128_DEFAULT_VALUE), std::move(destType));
+            }
+        }
+    } else {
+        destType = std::make_shared(destTypeId);
+        switch (destTypeId) {
+            case OMNI_BYTE:
+                return new LiteralExpr(BYTE_DEFAULT_VALUE, std::move(destType));
+            case OMNI_SHORT:
+                return new LiteralExpr(SHORT_DEFAULT_VALUE, std::move(destType));
+            case OMNI_INT:
+            case OMNI_DATE32:
+                return new LiteralExpr(INT_DEFAULT_VALUE, std::move(destType));
+            case OMNI_TIMESTAMP:
+            case OMNI_LONG:
+                return new LiteralExpr(LONG_DEFAULT_VALUE, std::move(destType));
+            case OMNI_DOUBLE:
+                return new LiteralExpr(DOUBLE_DEFAULT_VALUE, std::move(destType));
+            case OMNI_BOOLEAN:
+                return new LiteralExpr(BOOL_DEFAULT_VALUE, std::move(destType));
+            case OMNI_CHAR:
+                return new LiteralExpr(new string(CHAR_DEFAULT_VALUE),
+                                       std::make_shared(CHAR_DEFAULT_WIDTH));
+            case OMNI_VARCHAR:
+                return new LiteralExpr(new string(CHAR_DEFAULT_VALUE),
+                                       std::make_shared(CHAR_DEFAULT_WIDTH));
+            case OMNI_NONE:
+                return new LiteralExpr(INT_DEFAULT_VALUE, std::move(destType));
+            default:
+                return nullptr;
+        }
+    }
+}
+
+DataTypePtr ParserHelper::GetReturnDataType(nlohmann::json jsonExpr)
+{
+    auto typeId = static_cast(jsonExpr["returnType"].get());
+    int32_t precision = 0;
+    int32_t scale = 0;
+    uint32_t width = 0;
+    switch (typeId) {
+        case OMNI_BOOLEAN:
+            return std::make_shared();
+        case OMNI_BYTE:
+            return std::make_shared();
+        case OMNI_SHORT:
+            return std::make_shared();
+        case OMNI_INT:
+            return std::make_shared();
+        case OMNI_DATE32:
+            return std::make_shared();
+        case OMNI_LONG:
+            return std::make_shared();
+        case OMNI_TIMESTAMP:
+            return std::make_shared();
+        case OMNI_DOUBLE:
+            return std::make_shared();
+        case OMNI_DECIMAL64:
+            precision = jsonExpr["precision"].get();
+            scale = jsonExpr["scale"].get();
+            return std::make_shared(precision, scale);
+        case OMNI_DECIMAL128:
+            precision = jsonExpr["precision"].get();
+            scale = jsonExpr["scale"].get();
+            return std::make_shared(precision, scale);
+        case OMNI_VARCHAR:
+            width = jsonExpr["width"].get();
+            return std::make_shared(width);
+        case OMNI_CHAR:
+            width = jsonExpr["width"].get();
+            return std::make_shared(width);
+        case OMNI_NONE:
+            return std::make_shared();
+        default:
+            LogError("Unsupported data type %d ", typeId);
+            return nullptr;
+    }
+}
diff --git a/core/src/expression/parserhelper.h b/core/src/expression/parserhelper.h
new file mode 100644
index 0000000..87a3871
--- /dev/null
+++ b/core/src/expression/parserhelper.h
@@ -0,0 +1,21 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
+ */
+#ifndef __PARSERHELPER_H__
+#define __PARSERHELPER_H__
+
+#include 
+#include 
+#include 
+#include 
+#include "expressions.h"
+#include "util/type_util.h"
+
+class ParserHelper {
+public:
+    static omniruntime::expressions::LiteralExpr *GetDefaultValueForType(omniruntime::type::DataTypeId destTypeId,
+        int32_t precision = 0, int32_t scale = 0);
+    static omniruntime::type::DataTypePtr GetReturnDataType(nlohmann::json jsonExpr);
+};
+
+#endif
\ No newline at end of file
diff --git a/core/src/memory/CMakeLists.txt b/core/src/memory/CMakeLists.txt
new file mode 100644
index 0000000..02259ed
--- /dev/null
+++ b/core/src/memory/CMakeLists.txt
@@ -0,0 +1,8 @@
+aux_source_directory(${CMAKE_CURRENT_LIST_DIR} MEM_POOL_LIST)
+set(MEM_TARGET memory)
+add_library(${MEM_TARGET} ${MEM_POOL_LIST})
+
+# dependent include
+target_include_directories(${MEM_TARGET} PUBLIC /usr/local/include/jemalloc/)
+target_link_libraries(${MEM_TARGET} PUBLIC jemalloc cpu_checker)
+
diff --git a/core/src/memory/aligned_buffer.h b/core/src/memory/aligned_buffer.h
new file mode 100644
index 0000000..015e60e
--- /dev/null
+++ b/core/src/memory/aligned_buffer.h
@@ -0,0 +1,87 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved.
+ */
+
+
+#ifndef OMNI_RUNTIME_ALIGNED_BUFFER_H
+#define OMNI_RUNTIME_ALIGNED_BUFFER_H
+#include "allocator.h"
+
+namespace omniruntime::mem {
+template  class AlignedBuffer {
+public:
+    AlignedBuffer() : buffer(nullptr), capacity(0)
+    {
+        allocator = Allocator::GetAllocator();
+    }
+
+    AlignedBuffer(size_t size, bool zerofill = false)
+    {
+        capacity = size * sizeof(RAW_DATA_TYPE);
+        allocator = Allocator::GetAllocator();
+        buffer = reinterpret_cast(allocator->Alloc(capacity, zerofill));
+    }
+
+    ~AlignedBuffer()
+    {
+        Release();
+    }
+
+    ALWAYS_INLINE RAW_DATA_TYPE *AllocateReuse(size_t size, bool zerofill)
+    {
+        if (buffer == nullptr) {
+            // no memory allocated, creating new one
+            capacity = size * sizeof(RAW_DATA_TYPE);
+            buffer = reinterpret_cast(allocator->Alloc(capacity, zerofill));
+            return buffer;
+        }
+
+        size_t newCapacity = size * sizeof(RAW_DATA_TYPE);
+        if (capacity < newCapacity) {
+            // memory already allocated, but cannot hold newCapacity
+            // releasing previous buffer and allocating new one
+            allocator->Free(buffer, capacity);
+            capacity = newCapacity;
+            buffer = reinterpret_cast(allocator->Alloc(capacity, zerofill));
+            return buffer;
+        }
+
+        // memory already allocated, and can hold newCapacity
+        // just set content to zero (if needed) and return previous buffer
+        if (zerofill) {
+            // since capacity is usually large (> 1024 bytes), we can use memset_sp.
+            // Based on benchmark memset_sp for large buffer sizes is not much worse than std::meset
+            // but for smaller buffers (i.e. less than 100 bytes), memset_sp is x2 to x4 times slower than std::meset
+            memset_sp(buffer, capacity, 0, newCapacity);
+        }
+        return buffer;
+    }
+
+    ALWAYS_INLINE RAW_DATA_TYPE *GetBuffer()
+    {
+        return buffer;
+    }
+
+    ALWAYS_INLINE RAW_DATA_TYPE GetValue(int32_t index)
+    {
+        return buffer[index];
+    }
+
+private:
+    ALWAYS_INLINE void Release()
+    {
+        if (buffer != nullptr) {
+            allocator->Free(buffer, capacity);
+        }
+        buffer = nullptr;
+        capacity = 0;
+    }
+
+    RAW_DATA_TYPE *buffer;
+    size_t capacity;
+    Allocator *allocator;
+};
+}
+
+
+#endif // OMNI_RUNTIME_ALIGNED_BUFFER_H
diff --git a/core/src/memory/allocator.h b/core/src/memory/allocator.h
new file mode 100644
index 0000000..41d8eb5
--- /dev/null
+++ b/core/src/memory/allocator.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
+ */
+
+#ifndef OMNI_RUNTIME_ALLOCATOR_H
+#define OMNI_RUNTIME_ALLOCATOR_H
+
+#include 
+#include "memory_pool.h"
+#include "thread_memory_manager.h"
+#include "memory_trace.h"
+
+namespace omniruntime::mem {
+class Allocator {
+public:
+    static ALWAYS_INLINE Allocator *GetAllocator()
+    {
+        static Allocator globalAllocator;
+        return &globalAllocator;
+    }
+
+    ALWAYS_INLINE void *Alloc(int64_t size, bool zeroFill = false)
+    {
+        // memory usage statistics. If the memory_cap_exceed exception is thrown, the memory is not allocated.
+        ThreadMemoryManager::ReportMemory(size);
+        uint8_t *data = nullptr;
+        pool->Allocate(size, &data, zeroFill);
+        MemoryTrace::AddArenaMemory(reinterpret_cast(data), size);
+        return data;
+    }
+
+    ALWAYS_INLINE void Free(void *data, int64_t size)
+    {
+        // memory usage statistics
+        ThreadMemoryManager::ReclaimMemory(size);
+        MemoryTrace::SubArenaMemory(reinterpret_cast(reinterpret_cast(data)), size);
+        pool->Release(reinterpret_cast(data));
+    }
+
+private:
+    Allocator(){};
+    ~Allocator(){};
+    MemoryPool *pool = GetMemoryPool();
+};
+} // omniruntime::mem
+
+#endif // OMNI_RUNTIME_ALLOCATOR_H
diff --git a/core/src/memory/chunk.cpp b/core/src/memory/chunk.cpp
new file mode 100644
index 0000000..b8461d9
--- /dev/null
+++ b/core/src/memory/chunk.cpp
@@ -0,0 +1,34 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2024. All rights reserved.
+ */
+
+#include "chunk.h"
+
+namespace omniruntime {
+namespace mem {
+Chunk::Chunk(Allocator *allocator, void *address, uint64_t sizeInBytes)
+    : address(address), sizeInBytes(sizeInBytes), allocator(allocator)
+{}
+
+Chunk::~Chunk()
+{
+    if (address == nullptr) {
+        std::cerr << "address is null in chunk." << std::endl;
+        return;
+    }
+
+    allocator->Free(address, static_cast(sizeInBytes));
+    address = nullptr;
+}
+
+void *Chunk::GetAddress() const
+{
+    return address;
+}
+
+uint64_t Chunk::GetSizeInBytes()
+{
+    return sizeInBytes;
+}
+} // namespace mem
+} // namespace omniruntime
\ No newline at end of file
diff --git a/core/src/memory/chunk.h b/core/src/memory/chunk.h
new file mode 100644
index 0000000..55c376f
--- /dev/null
+++ b/core/src/memory/chunk.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2024. All rights reserved.
+ */
+
+#ifndef CHUNK_H
+#define CHUNK_H
+
+#include 
+#include 
+#include "allocator.h"
+
+namespace omniruntime {
+namespace mem {
+class Chunk {
+public:
+    ~Chunk();
+
+    void *GetAddress() const;
+
+    uint64_t GetSizeInBytes();
+
+    static Chunk *NewChunk(Allocator *globalAllocator, uint64_t sizeInByte, bool zeroFill = false)
+    {
+        void *data = globalAllocator->Alloc(static_cast(sizeInByte));
+        if (data != nullptr) {
+            return new Chunk(globalAllocator, data, sizeInByte);
+        }
+        throw OmniException("MEMORY_CHUNK_ERROR", "NewChunk return nullptr error");
+    }
+
+protected:
+    explicit Chunk(Allocator *allocator, void *address, uint64_t sizeInBytes);
+
+private:
+    void *address = nullptr;
+    uint64_t sizeInBytes;
+    Allocator *allocator = nullptr;
+};
+} // namespace mem
+} // namespace omniruntime
+#endif // CHUNK_H
diff --git a/core/src/memory/memory_manager.cpp b/core/src/memory/memory_manager.cpp
new file mode 100644
index 0000000..cd313cf
--- /dev/null
+++ b/core/src/memory/memory_manager.cpp
@@ -0,0 +1,174 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved.
+ */
+#include "memory_manager.h"
+
+namespace omniruntime::mem {
+// constructor for globalMemoryManager
+MemoryManager::MemoryManager()
+{
+    parent.store(nullptr, std::memory_order_relaxed);
+    memoryAmount.store(0, std::memory_order_relaxed);
+    memoryLimit.store(UNLIMIT, std::memory_order_relaxed);
+    memoryPeak.store(0, std::memory_order_relaxed);
+    isBlocked.store(false, std::memory_order_relaxed);
+}
+
+// constructor for threadMemoryManager
+MemoryManager::MemoryManager(MemoryManager *parentMemoryManager)
+{
+    // The capacity of each thread can also be limited. It needs to be set the corresponding value if necessary.
+    int64_t globalMemoryLimit = parentMemoryManager->GetMemoryLimit();
+    parent.store(parentMemoryManager, std::memory_order_relaxed);
+    memoryAmount.store(0, std::memory_order_relaxed);
+    memoryLimit.store(globalMemoryLimit, std::memory_order_relaxed);
+    memoryPeak.store(0, std::memory_order_relaxed);
+    isBlocked.store(false, std::memory_order_relaxed);
+}
+
+MemoryManager::~MemoryManager()
+{}
+
+void MemoryManager::AddMemory(int64_t reportedMemory, int64_t curAllocateSize)
+{
+    int64_t newMemoryAmount = memoryAmount.fetch_add(reportedMemory, std::memory_order_relaxed) + reportedMemory;
+    int64_t limit = memoryLimit.load(std::memory_order_relaxed);
+    if (limit == UNLIMIT || newMemoryAmount < limit) {
+        isBlocked.store(false, std::memory_order_relaxed);
+    } else {
+        /* *
+         * A thread cannot throw multiple exceptions in the case of multiple threads.
+         * The If statement is used to avoid the following case: When the thread exceeds the limit,
+         * the other thread is called and the AddMemory interface is executed again,
+         * which may cause the OmniException to be thrown again.
+         * Note: Only global memory manager can throw the MEM_CAP_EXCEEDED exception.
+        *  */
+        if (!isBlocked.load(std::memory_order_relaxed)) {
+            isBlocked.store(true, std::memory_order_relaxed);
+            memoryAmount.fetch_sub(curAllocateSize, std::memory_order_relaxed);
+
+            if (parent.load(std::memory_order_relaxed) == nullptr) {
+                auto message =
+                        op::GetErrorMessage(op::ErrorCode::MEM_CAP_EXCEEDED) + std::to_string(limit / 1024 / 1024)
+                        + "; current Memory Usage Total: " + std::to_string(newMemoryAmount / 1024 / 1024)
+                        + "MB; current Memory Allocate Size: " + std::to_string(curAllocateSize) + "B";
+                throw OmniException(GetErrorCode(op::ErrorCode::MEM_CAP_EXCEEDED), message);
+            }
+        }
+    }
+
+    if (auto parentMemoryManager = parent.load(std::memory_order_relaxed)) {
+        parentMemoryManager->AddMemory(reportedMemory, curAllocateSize);
+    }
+}
+
+void MemoryManager::SubMemory(int64_t reclaimedMemory)
+{
+    memoryAmount.fetch_add(reclaimedMemory, std::memory_order_relaxed);
+    if (auto parentMemoryManager = parent.load(std::memory_order_relaxed)) {
+        parentMemoryManager->SubMemory(reclaimedMemory);
+    }
+}
+
+void MemoryManager::UpdatePeak(int64_t size)
+{
+    int64_t peak = memoryPeak.load(std::memory_order_relaxed);
+    if (size > peak) {
+        memoryPeak.store(size, std::memory_order_relaxed);
+    }
+}
+
+#ifdef DEBUG
+void MemoryManager::AddScopeAmount(const std::string &scope, int64_t size)
+{
+    // keep thread safety.
+    std::lock_guard lock(mapLock);
+    scopeMap[scope] += size;
+    if (auto parentMemoryManager = parent.load(std::memory_order_relaxed)) {
+        parentMemoryManager->scopeMap[scope] += size;
+    }
+}
+
+void MemoryManager::SubScopeAmount(const std::string &scope, int64_t size)
+{
+    // keep thread safety.
+    std::lock_guard lock(mapLock);
+    scopeMap[scope] -= size;
+    if (auto parentMemoryManager = parent.load(std::memory_order_relaxed)) {
+        parentMemoryManager->scopeMap[scope] -= size;
+    }
+}
+#endif
+
+void MemoryManager::SetParent(MemoryManager *parentMemoryManager)
+{
+    parent.store(parentMemoryManager, std::memory_order_relaxed);
+}
+
+MemoryManager *MemoryManager::GetParent()
+{
+    return parent.load(std::memory_order_relaxed);
+}
+
+void MemoryManager::SetMemoryAmount(int64_t amount)
+{
+    memoryAmount.store(amount, std::memory_order_relaxed);
+}
+
+int64_t MemoryManager::GetMemoryAmount()
+{
+    return memoryAmount.load(std::memory_order_relaxed);
+}
+
+void MemoryManager::SetMemoryLimit(int64_t limit)
+{
+    memoryLimit.store(limit, std::memory_order_relaxed);
+}
+
+int64_t MemoryManager::GetMemoryLimit()
+{
+    return memoryLimit.load(std::memory_order_relaxed);
+}
+
+void MemoryManager::SetMemoryPeak(int64_t peak)
+{
+    memoryPeak.store(peak, std::memory_order_relaxed);
+}
+
+int64_t MemoryManager::GetMemoryPeak()
+{
+    return memoryPeak.load(std::memory_order_relaxed);
+}
+
+bool MemoryManager::IsMemoryManagerBlocked()
+{
+    return isBlocked.load(std::memory_order_relaxed);
+}
+
+#ifdef DEBUG
+std::unordered_map, std::equal_to,
+    MemoryManagerAllocator>>
+MemoryManager::GetScopeMap()
+{
+    return scopeMap;
+}
+#endif
+
+void MemoryManager::Clear()
+{
+    memoryAmount.store(0, std::memory_order_relaxed);
+    memoryLimit.store(UNLIMIT, std::memory_order_relaxed);
+    memoryPeak.store(0, std::memory_order_relaxed);
+    isBlocked.store(false, std::memory_order_relaxed);
+#ifdef DEBUG
+    // keep thread safety.
+    std::lock_guard lock(mapLock);
+    if (auto parentMemoryManager = parent.load(std::memory_order_relaxed)) {
+        for (const auto &it : scopeMap) {
+            parentMemoryManager->scopeMap.erase(it.first);
+        }
+    }
+    scopeMap.clear();
+#endif
+}
+}
\ No newline at end of file
diff --git a/core/src/memory/memory_manager.h b/core/src/memory/memory_manager.h
new file mode 100644
index 0000000..956ea25
--- /dev/null
+++ b/core/src/memory/memory_manager.h
@@ -0,0 +1,129 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved.
+ */
+
+#ifndef OMNI_RUNTIME_MEMORY_MANAGER_H
+#define OMNI_RUNTIME_MEMORY_MANAGER_H
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include "util/omni_exception.h"
+#include "util/error_code.h"
+#include "util/compiler_util.h"
+#include "memory_manager_allocator.h"
+#include "memory_trace.h"
+
+namespace omniruntime {
+namespace mem {
+using namespace exception;
+
+/**
+ * it is responsible for memory usage statistics of each thread and the global memory usage.
+ **/
+class MemoryManager {
+public:
+    // unlimited memory usage
+    const static int64_t UNLIMIT = -1;
+    static ALWAYS_INLINE MemoryManager *GetGlobalMemoryManager()
+    {
+        static MemoryManager globalMemoryManger;
+        return &globalMemoryManger;
+    }
+
+    static void SetGlobalMemoryLimit(int64_t limit)
+    {
+        LogInfo("set global memory manager limit:%ld Byte", limit);
+        MemoryManager *globalMemoryManager = GetGlobalMemoryManager();
+        globalMemoryManager->SetMemoryLimit(limit);
+    }
+
+    static ALWAYS_INLINE int64_t GetGlobalAccountedMemory()
+    {
+        MemoryManager *globalMemoryManager = GetGlobalMemoryManager();
+        return globalMemoryManager->GetMemoryAmount();
+    }
+
+    static ALWAYS_INLINE int64_t GetGlobalMemoryLimit()
+    {
+        MemoryManager *globalMemoryManager = GetGlobalMemoryManager();
+        return globalMemoryManager->GetMemoryLimit();
+    }
+
+    // constructor for globalMemoryManager
+    MemoryManager();
+
+    // constructor for threadMemoryManager
+    explicit MemoryManager(MemoryManager *globalMemoryManager);
+
+    ~MemoryManager();
+
+    /**
+     * reportedMemory is a positive size, indicate the memory manager need to be added.
+     * curAllocateSize indicate the size of memory to be allocated of one object, which triggers the statistical event.
+     * AddMemory() interface is called when the reportedMemory exceeds the untracked memory threshold like 1MB
+     * */
+    void AddMemory(int64_t reportedMemory, int64_t curAllocateSize = 0);
+
+    /**
+     * reclaimedMemory is a negative size.
+     * SubMemory() interface is called when the reclaimedMemory exceeds the threshold '-THRESHOLD'.
+     * */
+    void SubMemory(int64_t reclaimedMemory);
+
+    // memoryPeak seems to lack actual application scenario, so memoryPeak is not worth ensuring thread safety.
+    void UpdatePeak(int64_t size);
+
+#ifdef DEBUG
+    void AddScopeAmount(const std::string &scope, int64_t size);
+
+    void SubScopeAmount(const std::string &scope, int64_t size);
+#endif
+
+    void SetParent(MemoryManager *parentMemoryManager);
+
+    MemoryManager *GetParent();
+
+    // for UT
+    void SetMemoryAmount(int64_t amount);
+
+    int64_t GetMemoryAmount();
+
+    void SetMemoryLimit(int64_t limit);
+
+    int64_t GetMemoryLimit();
+
+    void SetMemoryPeak(int64_t peak);
+
+    int64_t GetMemoryPeak();
+
+    bool IsMemoryManagerBlocked();
+
+#ifdef DEBUG
+    std::unordered_map, std::equal_to,
+        MemoryManagerAllocator>>
+    GetScopeMap();
+#endif
+
+    void Clear();
+
+private:
+    std::atomic parent;
+    std::atomic memoryAmount;
+    std::atomic memoryLimit;
+    std::atomic memoryPeak;
+    std::atomic isBlocked;
+#ifdef DEBUG
+    std::mutex mapLock;
+    std::unordered_map, std::equal_to,
+        MemoryManagerAllocator>>
+        scopeMap; // Scope : Amount
+#endif
+};
+} // vec
+} // omniruntime
+
+#endif // OMNI_RUNTIME_MEMORY_MANAGER_H
diff --git a/core/src/memory/memory_manager_allocator.h b/core/src/memory/memory_manager_allocator.h
new file mode 100644
index 0000000..e086a62
--- /dev/null
+++ b/core/src/memory/memory_manager_allocator.h
@@ -0,0 +1,81 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved.
+ */
+
+#ifndef OMNI_RUNTIME_MEMORY_MANAGER_ALLOCATOR_H
+#define OMNI_RUNTIME_MEMORY_MANAGER_ALLOCATOR_H
+
+#include 
+#include 
+#include 
+
+namespace omniruntime::mem {
+/**
+ * Our memory manager manages the heap memory, such as allocation and statistics. This has a problem, that is,
+ * loop calls may be occur when the internal data structure of the memory manager needs to apply for heap memory.
+ * To solve the problem, the class MemoryManagerAllocator is introduced to replace stl::allocator and
+ * implement memory management of the internal data structure of our memory manager.
+ *  */
+template  class MemoryManagerAllocator {
+public:
+    typedef T value_type;
+    typedef T *pointer;
+    typedef const T *const_pointer;
+    typedef const T &const_reference;
+    typedef T &reference;
+    typedef size_t size_type;
+    typedef ptrdiff_t difference_type;
+
+    template  struct rebind {
+        typedef MemoryManagerAllocator other;
+    };
+
+    MemoryManagerAllocator() noexcept {}
+
+    MemoryManagerAllocator(const MemoryManagerAllocator &) noexcept {}
+
+    MemoryManagerAllocator& operator=(const MemoryManagerAllocator& allocator) noexcept {}
+
+    template  explicit MemoryManagerAllocator(const MemoryManagerAllocator &) noexcept {}
+
+    ~MemoryManagerAllocator() noexcept {}
+
+    pointer address(reference x) noexcept
+    {
+        return static_cast(&x);
+    }
+
+    const_pointer address(const_reference x) noexcept
+    {
+        return static_cast(&x);
+    }
+
+    // modified allocation method from "::operator new(n * sizeof(T))" in stl::allocator to "malloc(n * sizeof(T)))"
+    pointer allocate(size_type n)
+    {
+        void *pMem = nullptr;
+        if (n > this->max_size() || (pMem = malloc(n * sizeof(T))) == nullptr) {
+            throw std::bad_alloc();
+        }
+        return static_cast(pMem);
+    }
+
+    // modified allocation method from "::operator delete(p)" in stl::allocator to "free(p)"
+    void deallocate(pointer p, size_type)
+    {
+        free(p);
+    }
+
+    size_type max_size() const noexcept
+    {
+        return size_t(UINT_MAX / sizeof(T));
+    }
+
+    void construct(pointer p, const_reference value)
+    {
+        new (p)T(value);
+    }
+};
+}
+
+#endif // OMNI_RUNTIME_MEMORY_MANAGER_ALLOCATOR_H
diff --git a/core/src/memory/memory_pool.cpp b/core/src/memory/memory_pool.cpp
new file mode 100644
index 0000000..8f544c9
--- /dev/null
+++ b/core/src/memory/memory_pool.cpp
@@ -0,0 +1,161 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
+ */
+
+#include "memory_pool.h"
+
+#include 
+#include 
+#include "util/omni_exception.h"
+#include "util/compiler_util.h"
+
+using namespace std;
+namespace omniruntime {
+namespace mem {
+class SimpleAllocator {
+public:
+    static void Allocate(int64_t size, uint8_t **buffer, bool zeroFill = false)
+    {
+        // background: If size is 0, then malloc() returns either NULL, or a unique pointer value that can later be
+        // successfully passed to free().
+        if (size < 0) {
+            throw omniruntime::exception::OmniException("OPERATOR_RUNTIME_ERROR", "allocate size is negative.");
+        }
+
+        if (zeroFill) {
+            // alloc based on the size
+            *buffer = static_cast(calloc(1, static_cast(size)));
+        } else {
+            // alloc based on the size
+            *buffer = static_cast(malloc(static_cast(size)));
+        }
+        if (UNLIKELY(*buffer == nullptr)) {
+            throw omniruntime::exception::OmniException("OPERATOR_RUNTIME_ERROR", "allocate fails.");
+        }
+    }
+
+    static void Release(uint8_t *buffer)
+    {
+        // free the memory
+        free(static_cast(buffer));
+    }
+};
+
+class JemallocAllocator {
+public:
+    static void Allocate(int64_t size, uint8_t **buffer, bool zeroFill = false)
+    {
+        if (size < 0) {
+            throw omniruntime::exception::OmniException("OPERATOR_RUNTIME_ERROR", "allocate size is negative.");
+        }
+        // jemalloc alloc
+        if (zeroFill) {
+            *buffer = static_cast(mallocx(
+                static_cast(size),
+                MALLOCX_ALIGN(alignment) | MALLOCX_ZERO
+            ));
+        } else {
+            *buffer = static_cast(mallocx(static_cast(size), MALLOCX_ALIGN(alignment)));
+        }
+        if (UNLIKELY(*buffer == nullptr)) {
+            throw omniruntime::exception::OmniException("OPERATOR_RUNTIME_ERROR", "allocate fails.");
+        }
+    }
+
+    static void Release(uint8_t *buffer)
+    {
+        // jemalloc free
+        dallocx(static_cast(buffer), MALLOCX_ALIGN(alignment));
+    }
+    const static size_t alignment = 64;
+};
+
+template  class BaseMemoryPoolImpl : public MemoryPool {
+public:
+    void Allocate(int64_t size, uint8_t **buffer, bool zeroFill = false) override
+    {
+        Allocator::Allocate(size, buffer, zeroFill);
+    }
+
+    void Release(uint8_t *buffer) override
+    {
+        Allocator::Release(buffer);
+    }
+
+    ~BaseMemoryPoolImpl() override = default;
+
+    uint64_t GetPreferredSize(uint64_t size) override
+    {
+        return size;
+    }
+};
+
+class SimpleMemoryPool : public BaseMemoryPoolImpl {
+public:
+    uint64_t GetPreferredSize(uint64_t size) override
+    {
+        if (size == 0) {
+            return size;
+        }
+
+        const uint64_t smallSize = 8;
+        if (size < smallSize) {
+            return smallSize;
+        }
+        uint32_t bits = 63 - __builtin_clzll(size);
+        size_t lower = 1ULL << bits;
+        // Size is a power of 2.
+        if (lower == size) {
+            return size;
+        }
+        // If size is below 1.5 * previous power of two, return 1.5 *
+        // the previous power of two, else the next power of 2.
+        uint64_t preferredSize = lower + (lower / 2);
+        if (preferredSize >= size) {
+            return preferredSize;
+        }
+        return (lower + lower);
+    }
+};
+
+class JemallocMemoryPool : public BaseMemoryPoolImpl {
+public:
+    uint64_t GetPreferredSize(uint64_t size) override
+    {
+        if (size == 0) {
+            return size;
+        }
+
+        const uint64_t smallSize = 8;
+        if (size < smallSize) {
+            return smallSize;
+        }
+
+        uint32_t bits = 63 - __builtin_clzll(size);
+        size_t lower = 1ULL << bits;
+        // Size is a power of 2.
+        if (lower == size) {
+            return size;
+        }
+        // If size is below 1.5 * previous power of two, return 1.5 *
+        // the previous power of two, else the next power of 2.
+        uint64_t preferredSize = lower + (lower / 2);
+        if (preferredSize >= size) {
+            return preferredSize;
+        }
+        return (lower + lower);
+    }
+};
+
+#ifdef COVERAGE
+        static omniruntime::mem::SimpleMemoryPool g_memoryPoolInstance;
+#else
+        static omniruntime::mem::JemallocMemoryPool g_memoryPoolInstance;
+#endif
+
+MemoryPool *GetMemoryPool()
+{
+    return &g_memoryPoolInstance;
+}
+}
+}
diff --git a/core/src/memory/memory_pool.h b/core/src/memory/memory_pool.h
new file mode 100644
index 0000000..96e02a2
--- /dev/null
+++ b/core/src/memory/memory_pool.h
@@ -0,0 +1,29 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2020-2020. All rights reserved.
+ */
+#ifndef __MEMORY_POOL_H__
+#define __MEMORY_POOL_H__
+#pragma once
+
+#include 
+
+namespace omniruntime {
+namespace mem {
+class MemoryPool {
+public:
+    virtual void Allocate(int64_t size, uint8_t **buffer, bool zeroFill = false) = 0;
+
+    virtual void Release(uint8_t *buffer) = 0;
+
+    virtual ~MemoryPool() {}
+
+    virtual uint64_t GetPreferredSize(uint64_t size) = 0;
+
+protected:
+    MemoryPool() = default;
+};
+
+MemoryPool *GetMemoryPool();
+} // namespace mem
+} // namespace omniruntime
+#endif // MEMORY_POOL_H
diff --git a/core/src/memory/memory_trace.cpp b/core/src/memory/memory_trace.cpp
new file mode 100644
index 0000000..f562c06
--- /dev/null
+++ b/core/src/memory/memory_trace.cpp
@@ -0,0 +1,71 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
+ */
+
+#include "memory_trace.h"
+#include "allocator.h"
+#include "vector/vector.h"
+
+namespace omniruntime::mem {
+void MemoryTrace::AddVectorMemory(uintptr_t ptr, int64_t size)
+{
+#ifdef TRACE
+    ThreadMemoryTrace *threadMemoryTrace = ThreadMemoryTrace::GetThreadMemoryTrace();
+    threadMemoryTrace->AddVectorMemory(ptr, size);
+#endif
+}
+
+void MemoryTrace::SubVectorMemory(uintptr_t ptr, int64_t size)
+{
+#ifdef TRACE
+    ThreadMemoryTrace *threadMemoryTrace = ThreadMemoryTrace::GetThreadMemoryTrace();
+    threadMemoryTrace->RemoveVectorMemory(ptr, size);
+#endif
+}
+
+void MemoryTrace::AddArenaMemory(uintptr_t ptr, int64_t size)
+{
+#ifdef TRACE
+    ThreadMemoryTrace *threadMemoryTrace = ThreadMemoryTrace::GetThreadMemoryTrace();
+    threadMemoryTrace->AddArenaMemory(ptr, size);
+#endif
+}
+
+void MemoryTrace::SubArenaMemory(uintptr_t ptr, int64_t size)
+{
+#ifdef TRACE
+    ThreadMemoryTrace *threadMemoryTrace = ThreadMemoryTrace::GetThreadMemoryTrace();
+    threadMemoryTrace->RemoveArenaMemory(ptr, size);
+#endif
+}
+
+MemoryTrace::MemoryTrace() {}
+
+MemoryTrace::~MemoryTrace()
+{
+    threadMemoryTraceSet.clear();
+}
+
+void MemoryTrace::AddThreadMemoryTrace(ThreadMemoryTrace *threadMemoryTrace)
+{
+    std::lock_guard lock(m_mutex);
+    threadMemoryTraceSet.emplace(threadMemoryTrace);
+}
+
+void MemoryTrace::SubThreadMemoryTrace(ThreadMemoryTrace *threadMemoryTrace)
+{
+    std::lock_guard lock(m_mutex);
+    threadMemoryTraceSet.erase(threadMemoryTrace);
+}
+
+const std::unordered_set &MemoryTrace::GetThreadMemoryTraceSet()
+{
+    return threadMemoryTraceSet;
+}
+
+static MemoryTrace g_globalMemoryTrace;
+MemoryTrace *GetMemoryTrace()
+{
+    return &g_globalMemoryTrace;
+}
+}
\ No newline at end of file
diff --git a/core/src/memory/memory_trace.h b/core/src/memory/memory_trace.h
new file mode 100644
index 0000000..aaa7887
--- /dev/null
+++ b/core/src/memory/memory_trace.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved.
+ */
+
+#ifndef OMNI_RUNTIME_MEMORY_TRACE_H
+#define OMNI_RUNTIME_MEMORY_TRACE_H
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include "util/compiler_util.h"
+#include "memory_manager_allocator.h"
+#include "thread_memory_trace.h"
+
+namespace omniruntime {
+namespace mem {
+/**
+ * it is responsible for memory usage trace of each thread and the global memory usage.
+ **/
+class MemoryTrace {
+public:
+    static void AddVectorMemory(uintptr_t ptr, int64_t size);
+
+    static void SubVectorMemory(uintptr_t ptr, int64_t size);
+
+    static void AddArenaMemory(uintptr_t ptr, int64_t size);
+
+    static void SubArenaMemory(uintptr_t ptr, int64_t size);
+
+    MemoryTrace();
+
+    ~MemoryTrace();
+
+    void AddThreadMemoryTrace(ThreadMemoryTrace *threadMemoryTrace);
+
+    void SubThreadMemoryTrace(ThreadMemoryTrace *threadMemoryTrace);
+
+    const std::unordered_set &GetThreadMemoryTraceSet();
+
+private:
+    std::unordered_set threadMemoryTraceSet;
+    std::mutex m_mutex;
+};
+
+MemoryTrace *GetMemoryTrace();
+}
+}
+
+
+#endif // OMNI_RUNTIME_MEMORY_TRACE_H
diff --git a/core/src/memory/simple_arena_allocator.h b/core/src/memory/simple_arena_allocator.h
new file mode 100644
index 0000000..75564e3
--- /dev/null
+++ b/core/src/memory/simple_arena_allocator.h
@@ -0,0 +1,225 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2024. All rights reserved.
+ */
+
+#ifndef SIMPLE_ARENA_ALLOCATOR_H
+#define SIMPLE_ARENA_ALLOCATOR_H
+
+#include 
+#include "chunk.h"
+
+namespace omniruntime {
+namespace mem {
+// this allocator is not thread-safe, and mainly applies for temporary memory usage for operators,
+// such as when dealing with types such as varchar/decimal and so on.
+class SimpleArenaAllocator {
+public:
+    explicit SimpleArenaAllocator(int64_t minChunkSize = 4096, Allocator *allocator = Allocator::GetAllocator(),
+        uint32_t growthFactor = 2, int64_t linearGrowthThreshold = 512 * 1024)
+        : minChunkSize(minChunkSize),
+          totalBytes(0),
+          usedBytes(0),
+          availBytes(0),
+          availBuf(nullptr),
+          continuousUsedMemoryBytes(0),
+          allocator(allocator),
+          growthFactor(growthFactor),
+          linearGrowthThreshold(linearGrowthThreshold)
+    {}
+
+    ~SimpleArenaAllocator()
+    {
+        ReleaseChunks(false /* retainFirst */);
+    }
+
+    void SetMinChunkSize(uint64_t chunkSize)
+    {
+        if (chunkSize > minChunkSize) {
+            uint32_t bits = 63 - __builtin_clzll(chunkSize);
+            uint64_t lower = 1ULL << bits;
+            if (lower == chunkSize) {
+                minChunkSize = chunkSize;
+            } else {
+                minChunkSize = 2 * lower;
+            }
+        }
+    }
+
+    uint64_t GetNextSize(uint64_t sizeInBytes)
+    {
+        if (chunks.empty()) {
+            return std::max(sizeInBytes, minChunkSize);
+        }
+        auto lastChunkSize = chunks.back()->GetSizeInBytes();
+        if (lastChunkSize < linearGrowthThreshold) {
+            return std::max(sizeInBytes, lastChunkSize * growthFactor);
+        } else {
+            return ((sizeInBytes + linearGrowthThreshold - 1) / linearGrowthThreshold) * linearGrowthThreshold;
+        }
+    }
+
+    uint8_t *Allocate(int64_t sizeInBytes)
+    {
+        if (sizeInBytes == 0) {
+            // a non-null pointer is returned if allocated size is 0.
+            static int64_t zeroAddress[1];
+            return reinterpret_cast(&zeroAddress);
+        }
+        if (availBytes < sizeInBytes) {
+            AllocateChunk(GetNextSize(static_cast(sizeInBytes)));
+        }
+        continuousUsedMemoryBytes = sizeInBytes;
+        uint8_t *ret = availBuf;
+        availBuf += sizeInBytes;
+        availBytes -= sizeInBytes;
+        usedBytes += sizeInBytes;
+        continuousUsed = false;
+        return ret;
+    }
+
+    uint8_t *GetAvailBuf()
+    {
+        return availBuf;
+    }
+
+    uint64_t GetAvailBytes()
+    {
+        return availBytes;
+    }
+
+    uint64_t GetContinuousUsedMemoryBytes()
+    {
+        return continuousUsedMemoryBytes;
+    }
+
+    uint64_t GetMinChunkSize()
+    {
+        return minChunkSize;
+    }
+
+    uint8_t *AllocateContinue(int64_t sizeInBytes, const uint8_t *&start)
+    {
+        continuousUsed = true;
+        // null means a new begin of allocate
+        if (start == nullptr) {
+            uint8_t *ret = (Allocate(sizeInBytes));
+            start = (ret);
+            return ret;
+        }
+
+        uint8_t *ret = AllocateContinueNotNull(sizeInBytes, start);
+        usedBytes += sizeInBytes;
+        return ret;
+    }
+
+    void Reset()
+    {
+        if (chunks.empty()) {
+            // if there are no chunks, nothing to do.
+            return;
+        }
+
+        // Release all but the first chunk.
+        if (chunks.size() > 1) {
+            ReleaseChunks(true);
+            chunks.erase(chunks.cbegin() + 1, chunks.cend());
+        }
+
+        auto chunk = chunks[0];
+        availBuf = reinterpret_cast(chunk->GetAddress());
+        availBytes = totalBytes = chunk->GetSizeInBytes();
+        continuousUsedMemoryBytes = 0;
+        usedBytes = 0;
+    }
+
+    ALWAYS_INLINE void RollBackContinualMem()
+    {
+        if (continuousUsed) {
+            availBuf -= continuousUsedMemoryBytes;
+            availBytes += continuousUsedMemoryBytes;
+            usedBytes -= continuousUsedMemoryBytes;
+        }
+    }
+
+    ALWAYS_INLINE uint64_t TotalBytes()
+    {
+        return totalBytes;
+    }
+
+    ALWAYS_INLINE uint64_t UsedBytes()
+    {
+        return usedBytes;
+    }
+
+    ALWAYS_INLINE uint64_t AvailBytes()
+    {
+        return availBytes;
+    }
+
+    ALWAYS_INLINE Allocator *GetAllocator()
+    {
+        return this->allocator;
+    }
+
+private:
+    void AllocateChunk(int64_t sizeInBytes)
+    {
+        Chunk *chunk = Chunk::NewChunk(allocator, sizeInBytes);
+
+        chunks.emplace_back(chunk);
+        availBuf = reinterpret_cast(chunk->GetAddress());
+        availBytes = sizeInBytes; // left-over bytes in the previous chunk cannot be used anymore.
+        totalBytes += sizeInBytes;
+    }
+
+    void ReleaseChunks(bool retainFirst)
+    {
+        for (auto &chunk : chunks) {
+            if (retainFirst) {
+                // skip freeing first chunk.
+                retainFirst = false;
+                continue;
+            }
+            delete chunk;
+        }
+        continuousUsedMemoryBytes = 0;
+    }
+
+    uint8_t *AllocateContinueNotNull(int64_t sizeInBytes, const uint8_t *&start)
+    {
+        auto *p = const_cast(start);
+        uint8_t *ret = p;
+        if (sizeInBytes == 0) {
+            return ret;
+        }
+        auto newSpace = continuousUsedMemoryBytes + static_cast(sizeInBytes);
+        if (availBytes < sizeInBytes) {
+            AllocateChunk(GetNextSize(static_cast(newSpace)));
+            std::copy(start, start + continuousUsedMemoryBytes, availBuf);
+            start = availBuf;
+            availBuf += continuousUsedMemoryBytes;
+            availBytes -= continuousUsedMemoryBytes;
+        }
+        ret = availBuf;
+        availBuf += sizeInBytes;
+        continuousUsedMemoryBytes += sizeInBytes;
+        availBytes -= sizeInBytes;
+        return ret;
+    }
+
+    uint64_t minChunkSize;
+    uint64_t totalBytes;
+    uint64_t usedBytes;
+    uint64_t availBytes;
+    uint8_t *availBuf;
+    // Record the size of the memory used continuously.
+    uint64_t continuousUsedMemoryBytes;
+    uint32_t continuousUsed = false;
+    std::vector chunks;
+    Allocator *allocator;
+    uint32_t growthFactor;
+    uint64_t linearGrowthThreshold;
+};
+} // namespace mem
+} // namespace omniruntime
+#endif // SIMPLE_ARENA_ALLOCATOR_H
diff --git a/core/src/memory/thread_memory_manager.cpp b/core/src/memory/thread_memory_manager.cpp
new file mode 100644
index 0000000..a6cb1bd
--- /dev/null
+++ b/core/src/memory/thread_memory_manager.cpp
@@ -0,0 +1,62 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved.
+ */
+#include "thread_memory_manager.h"
+
+namespace omniruntime::mem {
+ThreadMemoryManager::ThreadMemoryManager() noexcept
+{
+    MemoryManager *globalMemoryManager = MemoryManager::GetGlobalMemoryManager();
+    thread_local MemoryManager memoryManger(globalMemoryManager);
+    currentMemoryManager = &memoryManger;
+#ifdef DEBUG
+    pthread_getname_np(pthread_self(), currentScope, THREAD_NAME_SIZE);
+#endif
+}
+
+ThreadMemoryManager::~ThreadMemoryManager() noexcept
+{
+    currentMemoryManager->SubMemory(untrackedMemory);
+    untrackedMemory = 0;
+#ifdef DEBUG
+    DeleteScope(currentScope);
+#endif
+}
+
+#ifdef DEBUG
+/**
+ * The current logic is that when the scope ends,
+ * the value of the thread's scopeMap in thread is set to 0, and the value of the global scopeMap decreases.
+ * todo: In the future, reference count may be introduced to clear key-value pair.
+ *         */
+void ThreadMemoryManager::DeleteScope(const std::string &scope)
+{
+    std::unordered_map, std::equal_to,
+        MemoryManagerAllocator>>
+        map = currentMemoryManager->GetScopeMap();
+    if (map.find(scope) != map.end()) {
+        int64_t size = map.find(scope)->second;
+        currentMemoryManager->SubScopeAmount(scope, size);
+    }
+}
+#endif
+
+void ThreadMemoryManager::Clear()
+{
+    currentMemoryManager->Clear();
+    untrackedMemory = 0;
+    if (auto parentMemoryManager = currentMemoryManager->GetParent()) {
+        parentMemoryManager->Clear();
+    }
+}
+
+int64_t ThreadMemoryManager::GetUntrackedMemory() const
+{
+    return untrackedMemory;
+}
+
+int64_t ThreadMemoryManager::GetThreadAccountedMemory()
+{
+    return currentMemoryManager->GetMemoryAmount();
+}
+}
diff --git a/core/src/memory/thread_memory_manager.h b/core/src/memory/thread_memory_manager.h
new file mode 100644
index 0000000..181d4ab
--- /dev/null
+++ b/core/src/memory/thread_memory_manager.h
@@ -0,0 +1,117 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved.
+ */
+
+#ifndef OMNI_RUNTIME_THREAD_MEMORY_MANAGER_H
+#define OMNI_RUNTIME_THREAD_MEMORY_MANAGER_H
+
+#include 
+#include 
+#include 
+#include "memory_manager.h"
+
+namespace omniruntime {
+namespace mem {
+#ifdef DEBUG
+#define THREAD_NAME_SIZE 16
+#endif
+/**
+ * TLS Object, it is responsible for memory aggregation per thread.
+ *      */
+class ThreadMemoryManager {
+public:
+    static ALWAYS_INLINE ThreadMemoryManager *GetThreadMemoryManager()
+    {
+        thread_local ThreadMemoryManager threadMemoryManager;
+        return &threadMemoryManager;
+    }
+
+    ThreadMemoryManager() noexcept;
+
+    ~ThreadMemoryManager() noexcept;
+
+    static ALWAYS_INLINE void ReportMemory(int64_t size)
+    {
+        auto threadMemoryManager = ThreadMemoryManager::GetThreadMemoryManager();
+        threadMemoryManager->ReportMemoryUsage(size);
+    }
+
+    static ALWAYS_INLINE void ReclaimMemory(int64_t size)
+    {
+        auto threadMemoryManager = ThreadMemoryManager::GetThreadMemoryManager();
+        threadMemoryManager->ReclaimMemoryUsage(size);
+    }
+
+    ALWAYS_INLINE void ReportMemoryUsage(int64_t size)
+    {
+        untrackedMemory += size;
+        allocMemory += size;
+        if (currentMemoryManager && untrackedMemory > untrackedMemoryThreshold) {
+            // AddMemory maybe throw an exception, so untrackedMemory needs to be set to 0 in advance.
+            int64_t toReportedMemory = untrackedMemory;
+            untrackedMemory = 0;
+            currentMemoryManager->AddMemory(toReportedMemory, size);
+#ifdef DEBUG
+            currentMemoryManager->AddScopeAmount(currentScope, untrackedMemory);
+#endif
+        }
+    }
+
+    ALWAYS_INLINE void ReclaimMemoryUsage(int64_t size)
+    {
+        untrackedMemory -= size;
+        freeMemory += size;
+        if (currentMemoryManager && labs(untrackedMemory) > untrackedMemoryThreshold) {
+            int64_t toReclaimedMemory = untrackedMemory;
+            untrackedMemory = 0;
+            currentMemoryManager->SubMemory(toReclaimedMemory);
+#ifdef DEBUG
+            currentMemoryManager->SubScopeAmount(currentScope, untrackedMemory);
+#endif
+        }
+    }
+
+#ifdef DEBUG
+    /* *
+     * DeleteScope interface is used to end the memory statistics of a certain sql.
+     * @param scope: scope is mapped to sql
+     *      */
+    void DeleteScope(const std::string &scope);
+#endif
+
+    // for UT
+    void Clear();
+
+    int64_t GetUntrackedMemory() const;
+
+    int64_t GetThreadAccountedMemory();
+    int64_t GetAllocMemory() const
+    {
+        return allocMemory;
+    }
+
+    int64_t GetFreeMemory() const
+    {
+        return freeMemory;
+    }
+
+private:
+#ifdef DEBUG
+    char currentScope[THREAD_NAME_SIZE];
+#endif
+    MemoryManager *currentMemoryManager;
+    /* *
+     * Each thread has an untracked memory. the memory usage is not updated when the memory usage of each thread is
+     * within the range of [-Threshold, Threshold]. Moreover, a memory usage update request is initiated
+     * when the memory usage of each thread exceeds the range.
+     * The benefit is to avoid frequent updates of each thread and global memory usage.
+     *      */
+    int64_t untrackedMemory = 0;
+    int64_t untrackedMemoryThreshold = 1 * 1024 * 1024;
+    int64_t allocMemory = 0;
+    int64_t freeMemory = 0;
+};
+} // mem
+} // omniruntime
+
+#endif // OMNI_RUNTIME_THREAD_MEMORY_MANAGER_H
diff --git a/core/src/memory/thread_memory_trace.cpp b/core/src/memory/thread_memory_trace.cpp
new file mode 100644
index 0000000..48d0854
--- /dev/null
+++ b/core/src/memory/thread_memory_trace.cpp
@@ -0,0 +1,205 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
+ */
+
+#include "thread_memory_trace.h"
+#include "allocator.h"
+#include "vector/vector.h"
+#include "memory_trace.h"
+
+namespace omniruntime::mem {
+ThreadMemoryTrace::ThreadMemoryTrace()
+{
+    MemoryTrace *globalMemoryTrace = GetMemoryTrace();
+    globalMemoryTrace->AddThreadMemoryTrace(this);
+}
+
+ThreadMemoryTrace::~ThreadMemoryTrace()
+{
+    MemoryTrace *globalMemoryTrace = GetMemoryTrace();
+    globalMemoryTrace->SubThreadMemoryTrace(this);
+}
+
+void ThreadMemoryTrace::AddVectorMemory(uintptr_t ptr, int64_t size)
+{
+    vectorTraced.emplace(ptr, size);
+#ifdef TRACE
+    vectorTracedWithLog.emplace(ptr, std::make_pair(size, TraceUtil::GetStack()));
+#endif
+}
+
+void ThreadMemoryTrace::RemoveVectorMemory(uintptr_t ptr, int64_t size)
+{
+    std::unordered_map::iterator iter;
+    if ((iter = vectorTraced.find(ptr)) != vectorTraced.end()) {
+        if (iter->second != size) {
+            auto message =
+                    "wrong vector size, alloc: " + std::to_string(iter->second) + ", free: " + std::to_string(size);
+            throw exception::OmniException("Memory Trace Error", message);
+        }
+        vectorTraced.erase(ptr);
+#ifdef TRACE
+        vectorTracedWithLog.erase(ptr);
+#endif
+    } else {
+        MemoryTrace *globalMemoryTrace = GetMemoryTrace();
+        std::unordered_set set = globalMemoryTrace->GetThreadMemoryTraceSet();
+        std::unordered_set::iterator traceIter;
+        for (traceIter = set.begin(); traceIter != set.end() && *traceIter != this; ++traceIter) {
+            if ((*traceIter)->vectorTraced.find(ptr) != (*traceIter)->vectorTraced.end()) {
+                std::string message = "vector allocated by ThreadA, but freed by ThreadB.";
+#ifdef TRACE
+                auto originStack = (*traceIter)->vectorTracedWithLog.find(ptr)->second.second;
+                auto currentStack = this->vectorTracedWithLog.find(ptr)->second.second;
+                message.append("\n originStack is: " + originStack + "\n currentStack is: " + currentStack);
+#endif
+                (*traceIter)->vectorTraced.erase(ptr);
+            }
+        }
+    }
+}
+
+void ThreadMemoryTrace::AddArenaMemory(uintptr_t ptr, int64_t size)
+{
+    arenaTraced.emplace(ptr, size);
+#ifdef TRACE
+    arenaTracedWithLog.emplace(ptr, std::make_pair(size, TraceUtil::GetStack()));
+#endif
+}
+
+void ThreadMemoryTrace::RemoveArenaMemory(uintptr_t ptr, int64_t size)
+{
+    std::unordered_map::iterator iter;
+    if ((iter = arenaTraced.find(ptr)) != arenaTraced.end()) {
+        if (iter->second != size) {
+            auto message =
+                    "wrong arena size, alloc: " + std::to_string(iter->second) + ", free: " + std::to_string(size);
+            throw exception::OmniException("Memory Trace Error", message);
+        }
+        arenaTraced.erase(ptr);
+#ifdef TRACE
+        arenaTracedWithLog.erase(ptr);
+#endif
+    } else {
+        MemoryTrace *globalMemoryTrace = GetMemoryTrace();
+        std::unordered_set set = globalMemoryTrace->GetThreadMemoryTraceSet();
+        std::unordered_set::iterator traceIter;
+        for (traceIter = set.begin(); traceIter != set.end() && *traceIter != this; ++traceIter) {
+            if ((*traceIter)->arenaTraced.find(ptr) != (*traceIter)->arenaTraced.end()) {
+                std::string message = "arena allocated by ThreadA, but freed by ThreadB.";
+#ifdef TRACE
+                auto originStack = (*traceIter)->arenaTracedWithLog.find(ptr)->second.second;
+                auto currentStack = this->arenaTracedWithLog.find(ptr)->second.second;
+                message.append("\n originStack is: " + originStack + "\n currentStack is: " + currentStack);
+#endif
+                (*traceIter)->arenaTraced.erase(ptr);
+            }
+        }
+    }
+}
+
+std::unordered_map ThreadMemoryTrace::GetVectorTraced()
+{
+    return vectorTraced;
+}
+
+std::unordered_map ThreadMemoryTrace::GetArenaTraced()
+{
+    return arenaTraced;
+}
+
+std::unordered_map> ThreadMemoryTrace::GetVectorTracedWithLog()
+{
+    return vectorTracedWithLog;
+}
+
+std::unordered_map> ThreadMemoryTrace::GetArenaTracedWithLog()
+{
+    return arenaTracedWithLog;
+}
+
+/**
+ * stack will be replaced if vector is created by jni
+ *   */
+void ThreadMemoryTrace::ReplaceVectorTracedLog(uintptr_t ptr, const std::string &stack)
+{
+    std::unordered_map>::iterator iter;
+    if ((iter = vectorTracedWithLog.find(ptr)) != vectorTracedWithLog.end()) {
+        iter->second.second = stack;
+    } else {
+        throw OmniException("Memory Trace Error", "vector create failed!");
+    }
+}
+
+/**
+ * check for memory leak in the thread.
+ *   */
+bool ThreadMemoryTrace::HasMemoryLeak()
+{
+#ifdef TRACE
+    if (!vectorTracedWithLog.empty()) {
+        // print leak stackLog
+        std::unordered_map>::iterator iter;
+        for (iter = vectorTracedWithLog.begin(); iter != vectorTracedWithLog.end(); ++iter) {
+            std::cout << "vector leaked memory: " << iter->second.first << ", stack is: "
+                      << iter->second.second << std::endl;
+        }
+    }
+
+    if (!arenaTracedWithLog.empty()) {
+        // print leak stackLog
+        std::unordered_map>::iterator iter;
+        for (iter = arenaTracedWithLog.begin(); iter != arenaTracedWithLog.end(); ++iter) {
+            std::cout << "arena leaked memory: " << iter->second.first << ", stack is: "
+                      << iter->second.second << std::endl;
+        }
+    }
+#endif
+    return !(vectorTraced.empty() && arenaTraced.empty());
+}
+
+/**
+ * free memory of vector and arena when memory leak happened
+ * */
+void ThreadMemoryTrace::FreeLeakedMemory()
+{
+    std::unordered_map::iterator iter;
+    if (!vectorTraced.empty()) {
+        std::vector vectors;
+        for (iter = vectorTraced.begin(); iter != vectorTraced.end(); ++iter) {
+            // copy the leaked vector record to avoid invalidating the unordered_map iterator.
+            // Because the iterator is traversed at the same time as the map is updated.
+            vectors.emplace_back(iter->first);
+        }
+
+        for (uint32_t i = 0; i < vectors.size(); ++i) {
+            // free vector and buffer if vector type is varchar.
+            delete reinterpret_cast(vectors.at(i));
+        }
+    }
+
+    if (!arenaTraced.empty()) {
+        Allocator *allocator = Allocator::GetAllocator();
+        std::unordered_map arenas;
+        for (iter = arenaTraced.begin(); iter != arenaTraced.end(); ++iter) {
+            // copy the leaked arena record to avoid invalidating the unordered_map iterator.
+            arenas.emplace(iter->first, iter->second);
+        }
+
+        for (iter = arenas.begin(); iter != arenas.end(); ++iter) {
+            // free arena ptr
+            allocator->Free(reinterpret_cast(iter->first), iter->second);
+        }
+    }
+    Clear();
+}
+
+void ThreadMemoryTrace::Clear()
+{
+    vectorTraced.clear();
+    arenaTraced.clear();
+
+    vectorTracedWithLog.clear();
+    arenaTracedWithLog.clear();
+}
+}
diff --git a/core/src/memory/thread_memory_trace.h b/core/src/memory/thread_memory_trace.h
new file mode 100644
index 0000000..63296c3
--- /dev/null
+++ b/core/src/memory/thread_memory_trace.h
@@ -0,0 +1,65 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
+ */
+
+#ifndef OMNI_RUNTIME_THREAD_MEMORY_TRACE_H
+#define OMNI_RUNTIME_THREAD_MEMORY_TRACE_H
+
+#include 
+#include 
+
+namespace omniruntime {
+namespace mem {
+/**
+ * TLS Object, it is responsible for memory trace per thread.
+ *      */
+class ThreadMemoryTrace {
+public:
+    static ThreadMemoryTrace *GetThreadMemoryTrace()
+    {
+        thread_local ThreadMemoryTrace threadMemoryTrace;
+        return &threadMemoryTrace;
+    }
+
+    ThreadMemoryTrace();
+
+    ~ThreadMemoryTrace();
+
+    void AddVectorMemory(uintptr_t ptr, int64_t size);
+
+    void RemoveVectorMemory(uintptr_t ptr, int64_t size);
+
+    void AddArenaMemory(uintptr_t ptr, int64_t size);
+
+    void RemoveArenaMemory(uintptr_t ptr, int64_t size);
+
+    std::unordered_map GetVectorTraced();
+
+    std::unordered_map GetArenaTraced();
+
+    std::unordered_map> GetVectorTracedWithLog();
+
+    std::unordered_map> GetArenaTracedWithLog();
+
+    void ReplaceVectorTracedLog(uintptr_t ptr, const std::string &stack);
+
+    bool HasMemoryLeak();
+
+    void FreeLeakedMemory();
+
+    void Clear();
+
+private:
+    // , record the size of each vector.
+    std::unordered_map vectorTraced;
+    // , record the size of each arena.
+    std::unordered_map arenaTraced;
+
+    // >, record the size and stack of each vector.
+    std::unordered_map> vectorTracedWithLog;
+    // >, record the size and stack of each arena.
+    std::unordered_map> arenaTracedWithLog;
+};
+}
+}
+#endif // OMNI_RUNTIME_THREAD_MEMORY_TRACE_H
diff --git a/core/src/metrics/metrics.h b/core/src/metrics/metrics.h
new file mode 100644
index 0000000..762c146
--- /dev/null
+++ b/core/src/metrics/metrics.h
@@ -0,0 +1,122 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
+ * Description: Metrics header
+ */
+
+#ifndef OMNI_RUNTIME_METRICS_H
+#define OMNI_RUNTIME_METRICS_H
+#include 
+#include 
+#include 
+#include 
+#include 
+#include "type/data_types.h"
+#include "metrics_memory_info.h"
+#include "metrics_row_counter.h"
+#include "metrics_spill_info.h"
+#include "util/global_log.h"
+#include "memory/simple_arena_allocator.h"
+namespace omniruntime {
+namespace op {
+const std::string metricsNameHashAgg = "hashAggregation";
+const std::string metricsNameHashBuilder = "hashBuilder";
+const std::string metricsNameFilter = "filter";
+const std::string metricsNameLookUpJoin = "lookUpJoin";
+const std::string metricsNameSort = "sort";
+const std::string metricsNameWindow = "window";
+const std::string metricsNameNestedLoopJoinBuilder = "nestedLoopJoinBuilder";
+const std::string metricsNameNestedLoopJoinLookup = "nestedLoopJoinLookup";
+
+class Metrics {
+public:
+    Metrics() : isDebugEnabled(IsDebugEnable()), tid(std::to_string(pthread_self())), pid(std::to_string(getpid())) {}
+
+    ~Metrics() = default;
+
+    void SetOperatorName(const std::string &operatorName)
+    {
+        this->operatorName = operatorName;
+    }
+
+    void UpdateAddInputInfo(int32_t rowCount,
+        const std::unique_ptr &executionContext)
+    {
+        metricsRowCounter.UpdateAddInputInfo(rowCount);
+        SetRowInfo();
+        SetSpillInfo();
+        SetMemoryInfo(executionContext);
+        std::string allInfo = "In operator:" + operatorName + "," + rowInfoStr + spillInfoStr + memoryInfoStr;
+        LogDebug("%s", allInfo.c_str());
+    }
+
+    void UpdateGetOutputInfo(int32_t rowCount,
+        const std::unique_ptr &executionContext)
+    {
+        metricsRowCounter.UpdateGetOutputInfo(rowCount);
+        SetRowInfo();
+        SetSpillInfo();
+        SetMemoryInfo(executionContext);
+        std::string allInfo = "In operator:" + operatorName + "," + rowInfoStr + spillInfoStr + memoryInfoStr;
+        LogDebug("%s", allInfo.c_str());
+    }
+
+    void UpdateSpillFileInfo(int32_t fileCount,
+        const std::unique_ptr &executionContext)
+    {
+        metricsSpillInfo.UpdateSpillFileInfo(fileCount);
+        SetRowInfo();
+        SetSpillInfo();
+        SetMemoryInfo(executionContext);
+        std::string allInfo = "In operator:" + operatorName + "," + rowInfoStr + spillInfoStr + memoryInfoStr;
+        LogDebug("%s", allInfo.c_str());
+    }
+
+    void UpdateSpillTimesInfo(const std::unique_ptr &executionContext)
+    {
+        metricsSpillInfo.UpdateSpillTimesInfo();
+        SetRowInfo();
+        SetSpillInfo();
+        SetMemoryInfo(executionContext);
+        std::string allInfo = "In operator:" + operatorName + "," + rowInfoStr + spillInfoStr + memoryInfoStr;
+        LogDebug("%s", allInfo.c_str());
+    }
+
+    void UpdateCloseInfo(const std::unique_ptr &executionContext)
+    {
+        SetRowInfo();
+        SetSpillInfo();
+        SetMemoryInfo(executionContext);
+        std::string allInfo = "In operator:" + operatorName + "," + rowInfoStr + spillInfoStr + memoryInfoStr;
+        LogDebug("%s", allInfo.c_str());
+    }
+
+private:
+    bool isDebugEnabled = false;
+    std::string rowInfoStr;
+    std::string memoryInfoStr;
+    std::string spillInfoStr;
+    MetricsMemoryInfo metricsMemoryInfo;
+    MetricsRowCounter metricsRowCounter;
+    MetricsSpillInfo metricsSpillInfo;
+    const std::string pid;
+    const std::string tid;
+    std::string operatorName;
+
+    void SetRowInfo()
+    {
+        rowInfoStr = " pid=" + pid + ",tid=" + tid + "." + metricsRowCounter.GetRowCounterInfo();
+    }
+
+    void SetMemoryInfo(const std::unique_ptr &executionContext)
+    {
+        memoryInfoStr = metricsMemoryInfo.SetMemoryInfo(executionContext);
+    }
+
+    void SetSpillInfo()
+    {
+        spillInfoStr = metricsSpillInfo.GetSpillInfo();
+    }
+};
+}
+}
+#endif // OMNI_RUNTIME_METRICS_H
diff --git a/core/src/metrics/metrics_config.h b/core/src/metrics/metrics_config.h
new file mode 100644
index 0000000..d4c600e
--- /dev/null
+++ b/core/src/metrics/metrics_config.h
@@ -0,0 +1,12 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+ * Description: This file declares for metrics_config.h
+ */
+#pragma once
+
+#include 
+
+namespace omniruntime {
+    const std::string opNameForHashBuilder = "hashBuilder";
+    const std::string opNameForLookUpJoin = "lookUpJoin";
+}
\ No newline at end of file
diff --git a/core/src/metrics/metrics_memory_info.h b/core/src/metrics/metrics_memory_info.h
new file mode 100644
index 0000000..c9b1b95
--- /dev/null
+++ b/core/src/metrics/metrics_memory_info.h
@@ -0,0 +1,51 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
+ * Description: Metrics memory info header
+ */
+
+#ifndef OMNI_RUNTIME_MEMORY_INFO_H
+#define OMNI_RUNTIME_MEMORY_INFO_H
+#include "type/data_types.h"
+#include "unistd.h"
+#include "memory/thread_memory_manager.h"
+namespace omniruntime {
+namespace op {
+class MetricsMemoryInfo {
+public:
+    MetricsMemoryInfo() {}
+
+    ~MetricsMemoryInfo() = default;
+
+    std::string &SetMemoryInfo(const std::unique_ptr &executionContext)
+    {
+        threadAllocMemory = omniruntime::mem::ThreadMemoryManager::GetThreadMemoryManager()->GetAllocMemory();
+        threadFreeMemory = omniruntime::mem::ThreadMemoryManager::GetThreadMemoryManager()->GetFreeMemory();
+        processAllocMemory = omniruntime::mem::MemoryManager::GetGlobalAccountedMemory();
+        FillMemoryInfoStr(executionContext);
+        return memoryInfoStr;
+    }
+
+private:
+    uint64_t threadAllocMemory = 0;
+    uint64_t threadFreeMemory = 0;
+    uint64_t processAllocMemory = 0;
+    std::string memoryInfoStr = "";
+
+    void FillMemoryInfoStr(const std::unique_ptr &executionContext)
+    {
+        memoryInfoStr = "processUsedMemory=" + std::to_string(processAllocMemory) +
+            ".ThreadInfo:"
+            "allocMemory=" +
+            std::to_string(threadAllocMemory) + ",freeMemory=" + std::to_string(threadFreeMemory) + ",remainMemory=" +
+            std::to_string(threadAllocMemory - threadFreeMemory) +
+            ".ArenaInfo:"
+            "totalBytes=" +
+            std::to_string(executionContext->GetArena()->TotalBytes()) +
+            ",usedBytes=" + std::to_string(executionContext->GetArena()->UsedBytes()) + ",remainingBytes=" +
+            std::to_string(executionContext->GetArena()->TotalBytes() - executionContext->GetArena()->UsedBytes()) +
+            ".";
+    }
+};
+}
+}
+#endif // OMNI_RUNTIME_MEMORY_INFO_H
diff --git a/core/src/metrics/metrics_row_counter.h b/core/src/metrics/metrics_row_counter.h
new file mode 100644
index 0000000..2966420
--- /dev/null
+++ b/core/src/metrics/metrics_row_counter.h
@@ -0,0 +1,48 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
+ * Description: Metrics row info header
+ */
+
+#ifndef OMNI_RUNTIME_METRICS_ROW_COUNTER_H
+#define OMNI_RUNTIME_METRICS_ROW_COUNTER_H
+#include "type/data_types.h"
+namespace omniruntime {
+namespace op {
+class MetricsRowCounter {
+public:
+    MetricsRowCounter() {}
+
+    ~MetricsRowCounter() = default;
+
+    void UpdateAddInputInfo(int32_t rowCount)
+    {
+        addInputTimes++;
+        addInputRowCount += rowCount;
+    }
+
+    void UpdateGetOutputInfo(int32_t rowCount)
+    {
+        getOutputTimes++;
+        getOutputRowCount += rowCount;
+    }
+
+    std::string &GetRowCounterInfo()
+    {
+        rowInfoStr.clear();
+        rowInfoStr = "AddInput:Times=" + std::to_string(addInputTimes) +
+            ",RowCount=" + std::to_string(addInputRowCount) + ",GetOutput:Times=" + std::to_string(getOutputTimes) +
+            ",RowCount=" + std::to_string(getOutputRowCount) + ".";
+        return rowInfoStr;
+    }
+
+private:
+    std::string rowInfoStr = "";
+    int64_t addInputTimes = 0;
+    int64_t addInputRowCount = 0;
+    int64_t getOutputTimes = 0;
+    int64_t getOutputRowCount = 0;
+};
+}
+}
+
+#endif // OMNI_RUNTIME_METRICS_ROW_COUNTER_H
diff --git a/core/src/metrics/metrics_spill_info.h b/core/src/metrics/metrics_spill_info.h
new file mode 100644
index 0000000..e34d021
--- /dev/null
+++ b/core/src/metrics/metrics_spill_info.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
+ * Description: Metrics spill info header
+ */
+
+#ifndef OMNI_RUNTIME_METRICS_SPILL_INFO_H
+#define OMNI_RUNTIME_METRICS_SPILL_INFO_H
+
+#include "type/data_types.h"
+
+namespace omniruntime {
+namespace op {
+class MetricsSpillInfo {
+public:
+    MetricsSpillInfo(){};
+    ~MetricsSpillInfo() = default;
+    void UpdateSpillFileInfo(uint32_t fileCount)
+    {
+        spillFileCount += fileCount;
+    }
+
+    void UpdateSpillTimesInfo()
+    {
+        spillTimes++;
+    }
+    std::string &GetSpillInfo()
+    {
+        rowInfoStr.clear();
+        rowInfoStr = "Spill:Times=" + std::to_string(spillTimes) + ",FileNum=" + std::to_string(spillFileCount) + ".";
+        return rowInfoStr;
+    }
+
+private:
+    std::string rowInfoStr = "";
+    int64_t spillFileCount = 0;
+    int64_t spillTimes = 0;
+};
+}
+}
+
+#endif // OMNI_RUNTIME_METRICS_SPILL_INFO_H
diff --git a/core/src/metrics/omni_metrics.h b/core/src/metrics/omni_metrics.h
new file mode 100644
index 0000000..5ab7068
--- /dev/null
+++ b/core/src/metrics/omni_metrics.h
@@ -0,0 +1,115 @@
+/*
+* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+ */
+
+#pragma once
+
+#include 
+
+namespace omniruntime {
+
+struct OmniMetrics {
+    unsigned int numMetrics = 0;
+    long omniToArrow = 0;
+
+    // The underlying memory buffer.
+    std::unique_ptr array;
+
+    // Point to array.get() after the above unique_ptr created.
+    long* arrayRawPtr = nullptr;
+
+    enum TYPE {
+        // Begin from 0.
+        kBegin = 0,
+
+        kInputRows = kBegin,
+        kNumInputVecBatches,
+        kInputBytes,
+
+        kAddInputTime,
+        kAddInputCpuCount,
+
+        kOutputRows,
+        kNumOutputVecBatches,
+        kOutputBytes,
+
+        kGetOutputTime,
+        kGetOutputCpuCount,
+
+        kRawInputRows,
+        kRawInputBytes,
+
+        // CpuWallTiming.
+        kCpuCount,
+        kWallNanos,
+        kCpuNanos,
+
+        kPeakMemoryBytes,
+        kNumMemoryAllocations,
+
+        // Spill.
+        kSpilledInputBytes,
+        kSpilledBytes,
+        kSpilledRows,
+        kSpilledPartitions,
+        kSpilledFiles,
+
+        // For BHJ/SHJ
+        kBuildInputRows,
+        kBuildNumInputVecBatches,
+        kBuildAddInputTime,
+        kBuildGetOutputTime,
+
+        kLookupInputRows,
+        kLookupNumInputVecBatches,
+        kLookupOutputRows,
+        kLookupNumOutputVecBatches,
+        kLookupAddInputTime,
+        kLookupGetOutputTime,
+
+        // Runtime OmniMetrics.
+        kNumDynamicFiltersProduced,
+        kNumDynamicFiltersAccepted,
+        kNumReplacedWithDynamicFilterRows,
+        kFlushRowCount,
+        kLoadedToValueHook,
+        kScanTime,
+        kSkippedSplits,
+        kProcessedSplits,
+        kSkippedStrides,
+        kProcessedStrides,
+        kRemainingFilterTime,
+        kIoWaitTime,
+        kStorageReadBytes,
+        kLocalReadBytes,
+        kRamReadBytes,
+        kPreloadSplits,
+
+        // Write OmniMetrics.
+        kPhysicalWrittenBytes,
+        kWriteIOTime,
+        kNumWrittenFiles,
+
+        // The end of enum items.
+        kEnd,
+        kNum = kEnd - kBegin
+    };
+
+    explicit OmniMetrics(const unsigned int numMetrics) : numMetrics(numMetrics), array(new long[numMetrics * kNum])
+    {
+        memset_s(array.get(), numMetrics * kNum * sizeof(long), 0, numMetrics * kNum * sizeof(long));
+        arrayRawPtr = array.get();
+    }
+
+    OmniMetrics(const OmniMetrics&) = delete;
+    OmniMetrics(OmniMetrics&&) = delete;
+    OmniMetrics& operator=(const OmniMetrics&) = delete;
+    OmniMetrics& operator=(OmniMetrics&&) = delete;
+
+    long* get(TYPE type)
+    {
+        auto offset = (static_cast(type) - static_cast(kBegin)) * numMetrics;
+        return &arrayRawPtr[offset];
+    }
+};
+} // omniruntime
diff --git a/core/src/operator/CMakeLists.txt b/core/src/operator/CMakeLists.txt
new file mode 100644
index 0000000..8a0b219
--- /dev/null
+++ b/core/src/operator/CMakeLists.txt
@@ -0,0 +1,27 @@
+file(GLOB_RECURSE OPERATOR_LIST ${CMAKE_CURRENT_LIST_DIR}/*.cpp ../simd/*.cpp util/*.cpp ../compute/*.cpp ../plannode/*.cpp)
+
+set(OP_TARGET ${OMNI_OPERATOR_SO})
+find_package(nlohmann_json 3.7.3 REQUIRED)
+# compile .a file
+
+add_library(${OP_TARGET} SHARED ${OPERATOR_LIST})
+
+#dependent library
+target_link_libraries(${OP_TARGET} PUBLIC expression ${OMNI_CODEGEN_SO} ${OMNI_VECTOR_SO})
+target_include_directories(${OP_TARGET} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/simd)
+install(TARGETS ${OP_TARGET} DESTINATION ${CMAKE_INSTALL_PREFIX})
+install(FILES ${SOURCE_ROOT}/src/operator/operator.h DESTINATION ${CMAKE_INSTALL_PREFIX}/include/operator)
+install(FILES ${SOURCE_ROOT}/src/operator/operator_factory.h DESTINATION ${CMAKE_INSTALL_PREFIX}/include/operator)
+install(FILES ${SOURCE_ROOT}/src/operator/execution_context.h DESTINATION ${CMAKE_INSTALL_PREFIX}/include/operator)
+install(FILES ${SOURCE_ROOT}/src/operator/memory_builder.h DESTINATION ${CMAKE_INSTALL_PREFIX}/include/operator)
+install(FILES ${SOURCE_ROOT}/src/operator/status.h DESTINATION ${CMAKE_INSTALL_PREFIX}/include/operator)
+install(FILES ${SOURCE_ROOT}/src/operator/config/operator_config.h DESTINATION ${CMAKE_INSTALL_PREFIX}/include/operator/config)
+install(FILES ${SOURCE_ROOT}/src/operator/hash_util.h DESTINATION ${CMAKE_INSTALL_PREFIX}/include/operator)
+install(FILES ${SOURCE_ROOT}/src/operator/util/function_type.h DESTINATION ${CMAKE_INSTALL_PREFIX}/include/operator/util/)
+install(FILES ${SOURCE_ROOT}/src/operator/util/operator_util.h DESTINATION ${CMAKE_INSTALL_PREFIX}/include/operator/util/)
+install(FILES ${SOURCE_ROOT}/src/operator/projection/projection.h DESTINATION ${CMAKE_INSTALL_PREFIX}/include/operator/projection/)
+install(FILES ${SOURCE_ROOT}/src/operator/omni_id_type_vector_traits.h DESTINATION ${CMAKE_INSTALL_PREFIX}/include/operator)
+install(FILES ${SOURCE_ROOT}/src/operator/filter/filter_and_project.h DESTINATION ${CMAKE_INSTALL_PREFIX}/include/operator/filter)
+install(FILES ${SOURCE_ROOT}/src/operator/window/window_frame.h DESTINATION ${CMAKE_INSTALL_PREFIX}/include/operator/window/)
+file(GLOB METRICS_HEAD_FILES ${SOURCE_ROOT}/src/metrics/*.h)
+install(FILES ${METRICS_HEAD_FILES} DESTINATION ${CMAKE_INSTALL_PREFIX}/include/metrics)
\ No newline at end of file
diff --git a/core/src/operator/aggregation/GROUPBY.MD b/core/src/operator/aggregation/GROUPBY.MD
new file mode 100644
index 0000000..39c8ca8
--- /dev/null
+++ b/core/src/operator/aggregation/GROUPBY.MD
@@ -0,0 +1,14 @@
+##### Requirements
+groupby is one of the building block of BI which is widely used to generate high level insights of the data.
+
+We aim to create a set of groupby algorithms which can taking into account the metadata before generating the ASM code. 
+The following metadata will be taken into account:
+
+`sorted`: an `O(n)` algorithm can be easily achieved via scan through the data
+`cardinality`: when cardinality lower than a threshold which allows the ptrs to fit in memory without impacting the system, we can use `perfect identify hash` and use an array to directory 
+store the aggregation result, this would also allow `O(n)` complexity
+`data dictionary`: which can be used to transform strings into numerical values
+
+##### Optimizations
+###### Sorted field groupby
+
diff --git a/core/src/operator/aggregation/agg_util.h b/core/src/operator/aggregation/agg_util.h
new file mode 100644
index 0000000..e90b98f
--- /dev/null
+++ b/core/src/operator/aggregation/agg_util.h
@@ -0,0 +1,112 @@
+/*
+ * @Copyright: Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved.
+ * @Description: Hash Aggregation WithExpr Header
+ */
+
+
+#ifndef OMNI_RUNTIME_AGG_UTIL_H
+#define OMNI_RUNTIME_AGG_UTIL_H
+
+#include "vector/vector_common.h"
+#include "operator/execution_context.h"
+#include "operator/filter/filter_and_project.h"
+#include "operator/util/operator_util.h"
+
+namespace omniruntime {
+namespace op {
+class AggUtil {
+public:
+    static bool IsAggPositionEligible(int32_t rowId, VectorBatch *inputVecBatch, SimpleFilter *aggSimpleFilter,
+        ExecutionContext *executionContext, DataTypes &originTypes)
+    {
+        const int32_t allColsCount = inputVecBatch->GetVectorCount();
+        auto originTypeIds = originTypes.GetIds();
+        int64_t values[allColsCount];
+        bool nulls[allColsCount];
+        int32_t lengths[allColsCount];
+        std::set &usedVectors = aggSimpleFilter->GetVectorIndexes();
+        for (auto iter = usedVectors.begin(); iter != usedVectors.end(); ++iter) {
+            auto vecIdx = *iter;
+            auto vector = inputVecBatch->Get(vecIdx);
+            nulls[vecIdx] = vector->IsNull(rowId);
+            values[vecIdx] = OperatorUtil::GetValuePtrAndLength(vector, rowId, lengths + vecIdx, originTypeIds[vecIdx]);
+        }
+
+        return aggSimpleFilter->Evaluate(values, nulls, lengths, reinterpret_cast(&executionContext));
+    }
+
+    static inline bool IsColumnFilter(VectorBatch *inputVecBatch, SimpleFilter *aggSimpleFilter)
+    {
+        return aggSimpleFilter->IsColumnFilter() &&
+               inputVecBatch->Get(*(aggSimpleFilter->GetVectorIndexes().begin()))->GetEncoding() == OMNI_FLAT;
+    }
+
+    static BaseVector *ColumnFilterFlatVectorSliceHelper(VectorBatch *inputVecBatch, SimpleFilter *aggSimpleFilter)
+    {
+        auto *sourceVector = inputVecBatch->Get(*aggSimpleFilter->GetVectorIndexes().begin());
+        return reinterpret_cast *>(sourceVector)->Slice(0, sourceVector->GetSize());
+    }
+
+    static VectorBatch *AggFilterRequiredVectors(VectorBatch *inputVecBatch, const DataTypes &originTypes,
+        const DataTypes &inputTypes, const std::vector> &projections,
+        ExecutionContext *executionContext)
+    {
+        int32_t rowCount = inputVecBatch->GetRowCount();
+        auto newInputVecBatch = std::make_unique(rowCount);
+        if (rowCount == 0) {
+            VectorHelper::AppendVectors(newInputVecBatch.get(), inputTypes, rowCount);
+            return newInputVecBatch.release();
+        }
+
+        int32_t originVecCount = inputVecBatch->GetVectorCount();
+        int64_t valueAddrs[originVecCount];
+        int64_t nullAddrs[originVecCount];
+        int64_t offsetAddrs[originVecCount];
+        int64_t dictionaryVectors[originVecCount];
+        GetAddr(*inputVecBatch, valueAddrs, nullAddrs, offsetAddrs, dictionaryVectors, originTypes);
+
+        for (auto &projection : projections) {
+            auto projectVec = projection->Project(inputVecBatch, valueAddrs, nullAddrs, offsetAddrs, executionContext,
+                dictionaryVectors, originTypes.GetIds());
+            if (executionContext->HasError()) {
+                executionContext->GetArena()->Reset();
+                std::string errorMessage = executionContext->GetError();
+                throw OmniException("OPERATOR_RUNTIME_ERROR", errorMessage);
+            }
+            newInputVecBatch->Append(projectVec);
+        }
+        return newInputVecBatch.release();
+    }
+
+    static void AddFilterColumn(VectorBatch *inputVecBatch, VectorBatch *newInputVecBatch,
+        std::vector &aggSimpleFilters, ExecutionContext *context, DataTypes &originTypes)
+    {
+        auto aggFilterNum = aggSimpleFilters.size();
+        auto rowCount = inputVecBatch->GetRowCount();
+
+        for (size_t i = 0; i < aggFilterNum; ++i) {
+            auto aggSimpleFilter = aggSimpleFilters[i];
+            if (aggSimpleFilter != nullptr) {
+                if (IsColumnFilter(inputVecBatch, aggSimpleFilter)) {
+                    newInputVecBatch->Append(ColumnFilterFlatVectorSliceHelper(inputVecBatch, aggSimpleFilter));
+                } else {
+                    // if the agg expression has filter, then append a boolean vector to vector batch
+                    auto *booleanVector = new Vector(rowCount);
+                    for (int32_t j = 0; j < rowCount; ++j) {
+                        if (AggUtil::IsAggPositionEligible(j, inputVecBatch, aggSimpleFilter, context, originTypes)) {
+                            booleanVector->SetValue(j, true);
+                        } else {
+                            booleanVector->SetValue(j, false);
+                        }
+                    }
+                    newInputVecBatch->Append(booleanVector);
+                }
+            }
+        }
+    }
+};
+}
+}
+
+
+#endif // OMNI_RUNTIME_AGG_UTIL_H
diff --git a/core/src/operator/aggregation/aggregation.cpp b/core/src/operator/aggregation/aggregation.cpp
new file mode 100644
index 0000000..9ce3782
--- /dev/null
+++ b/core/src/operator/aggregation/aggregation.cpp
@@ -0,0 +1,92 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved.
+ * Description: Aggregation Base Class
+ */
+
+#include "aggregation.h"
+#include "aggregator/aggregator_factory.h"
+namespace omniruntime {
+namespace op {
+template 
+void AggregationCommonOperatorFactory::CreateAggregatorFactory(
+    std::vector> &aggregatorFactories, int32_t maskCol)
+{
+    if (maskCol == Aggregator::INVALID_MASK_COL) {
+        aggregatorFactories.push_back(std::make_unique());
+    } else {
+        aggregatorFactories.push_back(std::make_unique>(maskCol));
+    }
+}
+
+OmniStatus AggregationCommonOperatorFactory::CreateAggregatorFactories(
+    std::vector> &aggregatorFactories, const std::vector &funcTypesContext,
+    const std::vector &maskCols)
+{
+    OmniStatus ret = OMNI_STATUS_NORMAL;
+
+    for (uint32_t i = 0; i < funcTypesContext.size(); ++i) {
+        switch (funcTypesContext[i]) {
+            case OMNI_AGGREGATION_TYPE_SUM: {
+                if (ConfigUtil::GetSupportContainerVecRule() == SupportContainerVecRule::NOT_SUPPORT) {
+                    CreateAggregatorFactory(aggregatorFactories, maskCols[i]);
+                } else {
+                    CreateAggregatorFactory(aggregatorFactories, maskCols[i]);
+                }
+                break;
+            }
+            case OMNI_AGGREGATION_TYPE_COUNT_COLUMN: {
+                CreateAggregatorFactory(aggregatorFactories, maskCols[i]);
+                break;
+            }
+            case OMNI_AGGREGATION_TYPE_COUNT_ALL: {
+                CreateAggregatorFactory(aggregatorFactories, maskCols[i]);
+                break;
+            }
+            case OMNI_AGGREGATION_TYPE_MAX: {
+                CreateAggregatorFactory(aggregatorFactories, maskCols[i]);
+                break;
+            }
+            case OMNI_AGGREGATION_TYPE_MIN: {
+                CreateAggregatorFactory(aggregatorFactories, maskCols[i]);
+                break;
+            }
+            case OMNI_AGGREGATION_TYPE_AVG: {
+                if (ConfigUtil::GetSupportContainerVecRule() == SupportContainerVecRule::NOT_SUPPORT) {
+                    CreateAggregatorFactory(aggregatorFactories, maskCols[i]);
+                } else {
+                    CreateAggregatorFactory(aggregatorFactories, maskCols[i]);
+                }
+                break;
+            }
+            case OMNI_AGGREGATION_TYPE_SAMP: {
+                CreateAggregatorFactory(aggregatorFactories, maskCols[i]);
+                break;
+            }
+            case OMNI_AGGREGATION_TYPE_FIRST_IGNORENULL: {
+                CreateAggregatorFactory(aggregatorFactories, maskCols[i]);
+                break;
+            }
+            case OMNI_AGGREGATION_TYPE_FIRST_INCLUDENULL: {
+                CreateAggregatorFactory(aggregatorFactories, maskCols[i]);
+                break;
+            }
+            case OMNI_AGGREGATION_TYPE_TRY_SUM: {
+                CreateAggregatorFactory(aggregatorFactories, maskCols[i]);
+                break;
+            }
+            case OMNI_AGGREGATION_TYPE_TRY_AVG: {
+                CreateAggregatorFactory(aggregatorFactories, maskCols[i]);
+                break;
+            }
+            default: {
+                std::string omniExceptionInfo = "In function CreateAggregatorFactories, No such agg func type " +
+                    std::to_string(funcTypesContext[i]);
+                throw omniruntime::exception::OmniException("UNSUPPORTED_ERROR", omniExceptionInfo);
+            }
+        }
+    }
+
+    return ret;
+}
+} // end of namespace op
+} // end of namespace omniruntime
\ No newline at end of file
diff --git a/core/src/operator/aggregation/aggregation.h b/core/src/operator/aggregation/aggregation.h
new file mode 100644
index 0000000..849c6a9
--- /dev/null
+++ b/core/src/operator/aggregation/aggregation.h
@@ -0,0 +1,77 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved.
+ * Description: Aggregation Base Class
+ */
+#ifndef AGGREGATION_H
+#define AGGREGATION_H
+
+#include 
+#include 
+#include "operator/operator_factory.h"
+#include "operator/aggregation/aggregator/only_aggregator_factory.h"
+#include "memory/memory_pool.h"
+#include "operator/status.h"
+
+namespace omniruntime {
+namespace op {
+class AggregationCommonOperator : public Operator {
+public:
+    explicit AggregationCommonOperator(std::vector> &&aggs, std::vector &inputRaws,
+        std::vector &outputPartials, bool isOverflowAsNull = false)
+        : inputRaws(inputRaws), outputPartials(outputPartials), isOverflowAsNull(isOverflowAsNull)
+    {
+        for (auto &agg : aggs) {
+            agg->SetExecutionContext(executionContext.get());
+        }
+        this->aggregators = std::move(aggs);
+    }
+
+    ~AggregationCommonOperator() override {};
+
+protected:
+    std::vector> aggregators;
+    std::vector inputRaws;
+    std::vector outputPartials;
+    bool isOverflowAsNull;
+};
+
+class AggregationCommonOperatorFactory : public OperatorFactory {
+public:
+    AggregationCommonOperatorFactory(std::vector &inputRaws, std::vector &outputPartials,
+        std::vector &maskColsContext, bool isOverflowAsNull = false, bool isStatisticalAggregate = false)
+        : inputRaws(inputRaws),
+          outputPartials(outputPartials),
+          isOverflowAsNull(isOverflowAsNull),
+          isStatisticalAggregate(isStatisticalAggregate)
+    {
+        for (size_t i = 0; i < maskColsContext.size(); ++i) {
+            maskCols.push_back(maskColsContext[i]);
+        }
+    }
+
+    ~AggregationCommonOperatorFactory() override {};
+
+    std::vector &GetMaskColumns()
+    {
+        return maskCols;
+    }
+
+    virtual OmniStatus Init() = 0;
+    virtual OmniStatus Close() = 0;
+
+    template 
+    void CreateAggregatorFactory(std::vector> &aggregatorFactories, int32_t maskCol);
+
+    OmniStatus CreateAggregatorFactories(std::vector> &aggregatorFactories,
+        const std::vector &funcTypesContext, const std::vector &maskCols);
+
+protected:
+    std::vector inputRaws;
+    std::vector outputPartials;
+    std::vector maskCols;
+    bool isOverflowAsNull;
+    bool isStatisticalAggregate;
+};
+}
+}
+#endif
\ No newline at end of file
diff --git a/core/src/operator/aggregation/aggregator/aggregator.h b/core/src/operator/aggregation/aggregator/aggregator.h
new file mode 100644
index 0000000..da7d863
--- /dev/null
+++ b/core/src/operator/aggregation/aggregator/aggregator.h
@@ -0,0 +1,333 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021-2024. All rights reserved.
+ * Description: Inner supported aggregators header
+ */
+#ifndef AGGREGATOR_H
+#define AGGREGATOR_H
+
+#include "operator/aggregation/definitions.h"
+#include "type/data_types.h"
+#include "type/data_type.h"
+#include "type/decimal128.h"
+#include "type/base_operations.h"
+#include "vector/vector.h"
+#include "vector/vector_common.h"
+#include "operator/execution_context.h"
+#include "operator/util/function_type.h"
+#include "operator/join/row_ref.h"
+#include "util/type_util.h"
+#include "util/config_util.h"
+#include "state_flag_operation.h"
+
+namespace omniruntime {
+namespace op {
+using namespace omniruntime::exception;
+using namespace omniruntime::vec;
+
+struct ColumnIndex {
+    int32_t idx;
+    type::DataTypePtr input;
+    type::DataTypePtr output;
+};
+
+using AggregateState = uint8_t;
+
+struct DecimalAverageState {
+    int64_t count;
+    int64_t overflow;
+    type::int128_t val;
+};
+
+struct DecimalSumState {
+    int64_t overflow;
+    type::int128_t val;
+};
+
+struct KeyValue {
+    char *keyAddr;
+    size_t keyLen;
+    AggregateState *value;
+};
+
+struct UnspillRowInfo {
+    AggregateState *state;
+    VectorBatch *batch;
+    int32_t rowIdx;
+};
+
+// Avg decimal and overflow is decode/encode in continuous memory
+static inline void DecodeAvgDecimal(op::DecimalAverageState *statePtr, type::int128_t &val, int64_t &overflow,
+    int64_t &count)
+{
+    count = statePtr->count;
+    overflow = statePtr->overflow;
+    val = statePtr->val;
+}
+
+static inline void EncodeAvgDecimal(op::DecimalAverageState *statePtr, const type::int128_t &val,
+    const int64_t &overflow, const int64_t &count)
+{
+    statePtr->count = count;
+    statePtr->overflow = overflow;
+    statePtr->val = val;
+}
+
+template  int32_t ALWAYS_INLINE Compare(const T &leftVal, const T &rightVal)
+{
+    return (leftVal > rightVal ? 1 : (leftVal < rightVal ? -1 : 0));
+}
+
+class Aggregator {
+public:
+    /* Initiate this aggregator, such as setting default values for states.
+     * @param aggregateType indicates which aggregate function this aggregator stands for
+     * @param inputTypes indicates this aggregator's input data types(support multi-input)
+     * it can use normal vector or container vector
+     * @param outputTypes indicates this aggregator's output data types(support multi-input)
+     * it can use normal vector or container vector
+     * @param channels indicates this aggregator's input channels for VectorBatch.
+     * @param inputRaw indicates this aggregator's input data type
+     * true for raw input, false for intermeidate input, default value as true.
+     * @param outputPartial indicates this aggregator's output data type.
+     * true for intermeidate output, false for final output, default value as false.
+     * @param isOverflowAsNull indicates aggregator handle overflow calculation result
+     * true overflow as null value, false throw exception, default value as false.
+     *
+     */
+    Aggregator(const FunctionType aggregateType, const type::DataTypes &inputTypes, const type::DataTypes &outputTypes,
+        const std::vector &channels, const bool inputRaw = true, const bool outputPartial = false,
+        const bool isOverflowAsNull = false)
+        : type(aggregateType),
+          inputTypes(inputTypes),
+          outputTypes(outputTypes),
+          inputRaw(inputRaw),
+          outputPartial(outputPartial),
+          isOverflowAsNull(isOverflowAsNull),
+          channels(channels)
+    {}
+
+    virtual ~Aggregator() = default;
+
+    virtual void SetExecutionContext(ExecutionContext *executionContext)
+    {
+        this->executionContext = executionContext;
+        this->arenaAllocator = executionContext->GetArena();
+    }
+
+    virtual void ProcessGroup(AggregateState *state, VectorBatch *vectorBatch, int32_t rowIndex)
+    {
+        throw OmniException("Not implemented",
+            "ProcessGroup(AggregateState &, VectorBatch *, int32_t) not implemented for " +
+            std::to_string(as_integer(type)));
+    }
+
+    virtual std::vector GetSpillType()
+    {
+        throw OmniException("UNSUPPORTED_ERROR",
+            "GetSpillType not implemented for " + std::to_string(as_integer(type)));
+    }
+
+    // for no groupby aggregation
+    virtual void ProcessGroup(AggregateState *state, VectorBatch *vectorBatch, const int32_t rowOffset,
+        const int32_t rowCount)
+    {
+#ifdef DEBUG
+        LogWarn("Using not-optimized aggregator api for aggregator %d", as_integer(type));
+#endif
+        int32_t end = rowOffset + rowCount;
+        for (int32_t i = rowOffset; i < end; ++i) {
+            ProcessGroup(state + aggStateOffset, vectorBatch, i);
+        }
+    }
+
+    virtual void AlignAggSchemaWithFilter(VectorBatch *result, VectorBatch *inputVecBatch,
+        const int32_t filterIndex) = 0;
+
+    virtual void AlignAggSchema(VectorBatch *result, VectorBatch *inputVecBatch) = 0;
+
+    static bool DoNeedHandleAggFilter(Vector *filterVec, const int32_t rowOffset, const int32_t size)
+    {
+        int32_t rowEnd = rowOffset + size;
+        for (int32_t start = rowOffset, end = rowEnd - 1; start <= end; ++start, --end) {
+            if (!filterVec->GetValue(start) || !filterVec->GetValue(end)) {
+                return true;
+            }
+        }
+        return false;
+    }
+
+    // for no groupby aggregation  with filter
+    virtual void ProcessGroupFilter(AggregateState *state, VectorBatch *vectorBatch, const int32_t rowOffset,
+        const int32_t filterIndex)
+    {
+#ifdef DEBUG
+        LogWarn("Using not-optimized aggregator api for aggregator %d", as_integer(type));
+#endif
+
+        int32_t rowEnd = rowOffset + vectorBatch->GetRowCount();
+        auto filterVec = static_cast *>(vectorBatch->Get(filterIndex));
+        bool needFilterJude = DoNeedHandleAggFilter(filterVec, rowOffset, vectorBatch->GetRowCount());
+        if (needFilterJude) {
+            for (int32_t i = rowOffset; i < rowEnd; ++i) {
+                if (filterVec->GetValue(i)) {
+                    ProcessGroup(state + aggStateOffset, vectorBatch, i);
+                }
+            }
+        } else {
+            for (int32_t i = rowOffset; i < rowEnd; ++i) {
+                ProcessGroup(state + aggStateOffset, vectorBatch, i);
+            }
+        }
+    }
+
+    // for groupby hash aggregation
+    virtual void ProcessGroup(std::vector &rowStates, VectorBatch *vectorBatch,
+        const int32_t rowOffset)
+    {
+#ifdef DEBUG
+        LogWarn("Using not-optimized aggregator api for aggregator %d", as_integer(type));
+#endif
+
+        int32_t rowIndex = rowOffset;
+        size_t rowCount = rowStates.size();
+
+        for (size_t i = 0; i < rowCount; ++i) {
+            ProcessGroup(rowStates[i] + aggStateOffset, vectorBatch, rowIndex++);
+        }
+    }
+
+    virtual void ProcessGroupUnspill(std::vector &unspillRows, int32_t rowCount, int32_t &vectorIndex)
+    {
+        throw OmniException("UNSUPPORTED_ERROR",
+            "ProcessGroupUnspill not implemented for " + std::to_string(as_integer(type)));
+    }
+
+    // for groupby hash aggregation
+    virtual void ProcessGroupFilter(std::vector &rowStates, const size_t aggIdx,
+        VectorBatch *vectorBatch, const int32_t filterOffset, const int32_t rowOffset)
+    {
+#ifdef DEBUG
+        LogWarn("Using not-optimized aggregator api for aggregator %d", as_integer(type));
+#endif
+        auto filterVecIdx = static_cast(filterOffset);
+        auto filterVec = static_cast *>(vectorBatch->Get(filterVecIdx));
+        auto rowCount = static_cast(rowStates.size());
+        bool needFilterJude = DoNeedHandleAggFilter(filterVec, rowOffset, rowCount);
+
+        int32_t rowIndex = rowOffset;
+        if (needFilterJude) {
+            for (int32_t i = 0; i < rowCount; ++i) {
+                if (filterVec->GetValue(i)) {
+                    ProcessGroup(rowStates[i] + aggStateOffset, vectorBatch, rowIndex++);
+                    continue;
+                }
+                rowIndex++;
+            }
+        } else {
+            for (int32_t i = 0; i < rowCount; ++i) {
+                ProcessGroup(rowStates[i] + aggStateOffset, vectorBatch, rowIndex++);
+            }
+        }
+    }
+
+    virtual void InitState(AggregateState *state)
+    {
+        throw OmniException("not implement", "InitState");
+    }
+
+    virtual void SetStateOffset(int32_t offset)
+    {
+        aggStateOffset = offset;
+    }
+
+    virtual size_t GetStateSize() = 0;
+
+    virtual void InitStates(std::vector &groupStates)
+    {
+        throw OmniException("not implement", "InitStates");
+    };
+
+    // set result to output vector
+    virtual void ExtractValues(const AggregateState *state, std::vector &vectors,
+        const int32_t rowIndex) = 0;
+
+    virtual void ExtractValuesBatch(std::vector &groupStates, std::vector &vectors,
+        int32_t rowOffset, int32_t rowCount) = 0;
+
+    virtual void ExtractValuesForSpill(std::vector &groupStates,
+        std::vector &vectors) = 0;
+
+    virtual bool IsTypedAggregator()
+    {
+        return false;
+    }
+
+    virtual bool IsInputRaw() const
+    {
+        return this->inputRaw;
+    }
+
+    virtual bool IsOutputPartial() const
+    {
+        return this->outputPartial;
+    }
+
+    virtual bool IsOverflowAsNull() const
+    {
+        return this->isOverflowAsNull;
+    }
+
+    virtual void SetStatisticalAggregate(bool statisticalAggregate)
+    {
+        this->isStatisticalAggregate = statisticalAggregate;
+    }
+
+    virtual bool IsStatisticalAggregate() const
+    {
+        return this->isStatisticalAggregate;
+    }
+
+    virtual FunctionType GetType() const
+    {
+        return type;
+    }
+
+    virtual const type::DataTypes &GetInputTypes() const
+    {
+        return inputTypes;
+    }
+
+    virtual const type::DataTypes &GetOutputTypes() const
+    {
+        return outputTypes;
+    }
+
+    virtual const std::vector &GetInputChannels() const
+    {
+        return channels;
+    }
+
+    const ExecutionContext *GetExecutionContext() const
+    {
+        return executionContext;
+    }
+
+public:
+    static constexpr int32_t INVALID_MASK_COL = -1;
+
+protected:
+    const FunctionType type;
+    type::DataTypes inputTypes;
+    type::DataTypes outputTypes;
+    const bool inputRaw;
+    const bool outputPartial;
+    const bool isOverflowAsNull;
+    bool isStatisticalAggregate = false;
+    const std::vector channels;
+    ExecutionContext *executionContext = nullptr;
+    SimpleArenaAllocator *arenaAllocator = nullptr;
+    int32_t aggStateOffset;
+};
+} // end of namespace op
+} // end of namespace omniruntime
+#endif
diff --git a/core/src/operator/aggregation/aggregator/aggregator_factory.cpp b/core/src/operator/aggregation/aggregator/aggregator_factory.cpp
new file mode 100644
index 0000000..5c6929c
--- /dev/null
+++ b/core/src/operator/aggregation/aggregator/aggregator_factory.cpp
@@ -0,0 +1,384 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved.
+ */
+
+#include "aggregator_factory.h"
+
+namespace omniruntime {
+namespace op {
+std::unique_ptr CreateAggregatorFactory(FunctionType aggType)
+{
+    switch (aggType) {
+        case OMNI_AGGREGATION_TYPE_SUM: {
+            if (ConfigUtil::GetSupportContainerVecRule() == SupportContainerVecRule::NOT_SUPPORT) {
+                return std::make_unique();
+            } else {
+                return std::make_unique();
+            }
+        }
+        case OMNI_AGGREGATION_TYPE_AVG: {
+            if (ConfigUtil::GetSupportContainerVecRule() == SupportContainerVecRule::NOT_SUPPORT) {
+                return std::make_unique();
+            } else {
+                return std::make_unique();
+            }
+        }
+        case OMNI_AGGREGATION_TYPE_SAMP: {
+            return std::make_unique();
+        }
+        case OMNI_AGGREGATION_TYPE_MIN: {
+            return std::make_unique();
+        }
+        case OMNI_AGGREGATION_TYPE_MAX: {
+            return std::make_unique();
+        }
+        case OMNI_AGGREGATION_TYPE_COUNT_COLUMN: {
+            return std::make_unique();
+        }
+        case OMNI_AGGREGATION_TYPE_COUNT_ALL: {
+            return std::make_unique();
+        }
+        case OMNI_AGGREGATION_TYPE_FIRST_IGNORENULL: {
+            return std::make_unique();
+        }
+        case OMNI_AGGREGATION_TYPE_FIRST_INCLUDENULL: {
+            return std::make_unique();
+        }
+        default: {
+            std::string omniExceptionInfo =
+                "In function CreateAggregatorFactory, no such aggregate type " + std::to_string(aggType);
+            throw omniruntime::exception::OmniException("UNSUPPORTED_ERROR", omniExceptionInfo);
+        }
+    }
+}
+
+template