diff --git a/presto-main/src/main/java/io/prestosql/SystemSessionProperties.java b/presto-main/src/main/java/io/prestosql/SystemSessionProperties.java index c812f921baee824f30fa1ff9d9375be687d54010..1679895fec7f19d2ccb5ef361a4d5392c19ed216 100644 --- a/presto-main/src/main/java/io/prestosql/SystemSessionProperties.java +++ b/presto-main/src/main/java/io/prestosql/SystemSessionProperties.java @@ -37,6 +37,9 @@ import java.util.Optional; import java.util.OptionalInt; import static com.google.common.base.Preconditions.checkArgument; +import static io.prestosql.spi.HetuConstant.EXTENSION_EXECUTION_PLANNER_CLASS_PATH; +import static io.prestosql.spi.HetuConstant.EXTENSION_EXECUTION_PLANNER_ENABLED; +import static io.prestosql.spi.HetuConstant.EXTENSION_EXECUTION_PLANNER_JAR_PATH; import static io.prestosql.spi.StandardErrorCode.INVALID_SESSION_PROPERTY; import static io.prestosql.spi.session.PropertyMetadata.booleanProperty; import static io.prestosql.spi.session.PropertyMetadata.dataSizeProperty; @@ -810,6 +813,22 @@ public final class SystemSessionProperties SKIP_NON_APPLICABLE_RULES_ENABLED, "Whether to skip applying some selected rules based on query pattern", featuresConfig.isSkipNonApplicableRulesEnabled(), + false), + // add extension execution planner and operator + stringProperty( + EXTENSION_EXECUTION_PLANNER_JAR_PATH, + "extension execution planner jar path", + hetuConfig.getExtensionExecutionPlannerJarPath(), + false), + stringProperty( + EXTENSION_EXECUTION_PLANNER_CLASS_PATH, + "extension execution planner class path", + hetuConfig.getExtensionExecutionPlannerClassPath(), + false), + booleanProperty( + EXTENSION_EXECUTION_PLANNER_ENABLED, + "extension execution planner enabled", + hetuConfig.getExtensionExecutionPlannerEnabled(), false)); } @@ -1418,4 +1437,19 @@ public final class SystemSessionProperties { return session.getSystemProperty(SKIP_NON_APPLICABLE_RULES_ENABLED, Boolean.class); } + + public static Boolean isExtensionExecutionPlannerEnabled(Session session) + { + return session.getSystemProperty(EXTENSION_EXECUTION_PLANNER_ENABLED, Boolean.class); + } + + public static String getExtensionExecutionPlannerJarPath(Session session) + { + return session.getSystemProperty(EXTENSION_EXECUTION_PLANNER_JAR_PATH, String.class); + } + + public static String getExtensionExecutionPlannerClassPath(Session session) + { + return session.getSystemProperty(EXTENSION_EXECUTION_PLANNER_CLASS_PATH, String.class); + } } diff --git a/presto-main/src/main/java/io/prestosql/execution/SqlTaskExecutionFactory.java b/presto-main/src/main/java/io/prestosql/execution/SqlTaskExecutionFactory.java index fb8c1fc153e4a4a1323d0406d9c106c810776770..2461df8a33f1d15c78bb74cf047180e3e096d468 100644 --- a/presto-main/src/main/java/io/prestosql/execution/SqlTaskExecutionFactory.java +++ b/presto-main/src/main/java/io/prestosql/execution/SqlTaskExecutionFactory.java @@ -14,6 +14,7 @@ package io.prestosql.execution; import io.airlift.concurrent.SetThreadName; +import io.airlift.log.Logger; import io.hetu.core.transport.execution.buffer.PagesSerdeFactory; import io.prestosql.Session; import io.prestosql.event.SplitMonitor; @@ -23,12 +24,19 @@ import io.prestosql.memory.QueryContext; import io.prestosql.metadata.Metadata; import io.prestosql.operator.CommonTableExecutionContext; import io.prestosql.operator.TaskContext; +import io.prestosql.spi.PrestoException; import io.prestosql.spi.plan.PlanNodeId; import io.prestosql.sql.planner.LocalExecutionPlanner; import io.prestosql.sql.planner.LocalExecutionPlanner.LocalExecutionPlan; import io.prestosql.sql.planner.PlanFragment; import io.prestosql.sql.planner.TypeProvider; +import javax.annotation.Nullable; + +import java.lang.reflect.Constructor; +import java.net.MalformedURLException; +import java.net.URL; +import java.net.URLClassLoader; import java.util.List; import java.util.Map; import java.util.Optional; @@ -36,12 +44,19 @@ import java.util.OptionalInt; import java.util.concurrent.Executor; import static com.google.common.base.Throwables.throwIfUnchecked; +import static io.prestosql.SystemSessionProperties.getExtensionExecutionPlannerClassPath; +import static io.prestosql.SystemSessionProperties.getExtensionExecutionPlannerJarPath; import static io.prestosql.SystemSessionProperties.isExchangeCompressionEnabled; +import static io.prestosql.SystemSessionProperties.isExtensionExecutionPlannerEnabled; import static io.prestosql.execution.SqlTaskExecution.createSqlTaskExecution; +import static io.prestosql.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static java.util.Objects.requireNonNull; public class SqlTaskExecutionFactory { + private static final Logger log = Logger.get(SqlTaskExecutionFactory.class); + private static LocalExecutionPlanner extensionPlanner; + private static boolean extensionPlannerInitialized; private final Executor taskNotificationExecutor; private final TaskExecutor taskExecutor; @@ -83,20 +98,32 @@ public class SqlTaskExecutionFactory consumer, new PagesSerdeFactory(metadata.getFunctionAndTypeManager().getBlockEncodingSerde(), isExchangeCompressionEnabled(session))); - LocalExecutionPlan localExecutionPlan; + LocalExecutionPlan localExecutionPlan = null; try (SetThreadName ignored = new SetThreadName("Task-%s", taskStateMachine.getTaskId())) { try { - localExecutionPlan = planner.plan( - taskContext, - fragment.getRoot(), - TypeProvider.copyOf(fragment.getSymbols()), - fragment.getPartitioningScheme(), - fragment.getStageExecutionDescriptor(), - fragment.getPartitionedSources(), - outputBuffer, - fragment.getFeederCTEId(), - fragment.getFeederCTEParentId(), - cteCtx); + if (isExtensionExecutionPlannerEnabled(session)) { + String jarPath = getExtensionExecutionPlannerJarPath(session); + String classPath = getExtensionExecutionPlannerClassPath(session); + if (jarPath != null && !jarPath.equals("") && classPath != null && !classPath.equals("")) { + localExecutionPlan = loadExtensionLocalExecutionPlan(outputBuffer, fragment, taskContext, cteCtx, jarPath, classPath); + } + else { + throw new PrestoException(GENERIC_INTERNAL_ERROR, "Extension execution planner jar path or class path isn't configured correctly"); + } + } + if (localExecutionPlan == null) { + localExecutionPlan = planner.plan( + taskContext, + fragment.getRoot(), + TypeProvider.copyOf(fragment.getSymbols()), + fragment.getPartitioningScheme(), + fragment.getStageExecutionDescriptor(), + fragment.getPartitionedSources(), + outputBuffer, + fragment.getFeederCTEId(), + fragment.getFeederCTEParentId(), + cteCtx); + } } catch (Throwable e) { // planning failed @@ -115,4 +142,46 @@ public class SqlTaskExecutionFactory taskNotificationExecutor, splitMonitor); } + + @Nullable + private LocalExecutionPlan loadExtensionLocalExecutionPlan(OutputBuffer outputBuffer, PlanFragment fragment, TaskContext taskContext, Map cteCtx, String jarPath, String classPath) + { + if (!extensionPlannerInitialized) { + try { + ExtensionClassLoader extensionClassLoader = new ExtensionClassLoader(jarPath, Thread.currentThread().getContextClassLoader()); + Thread.currentThread().setContextClassLoader(extensionClassLoader); + Class aClass = extensionClassLoader.loadClass(classPath); + Constructor constructor = aClass.getConstructor(LocalExecutionPlanner.class); + extensionPlanner = (LocalExecutionPlanner) constructor.newInstance(planner); + extensionPlannerInitialized = true; + } + catch (Throwable e) { + log.warn("get extension LocalExecutionPlanner failed: %s", e.toString()); + throw new PrestoException(GENERIC_INTERNAL_ERROR, e); + } + } + if (extensionPlanner != null) { + return extensionPlanner.plan( + taskContext, + fragment.getRoot(), + TypeProvider.copyOf(fragment.getSymbols()), + fragment.getPartitioningScheme(), + fragment.getStageExecutionDescriptor(), + fragment.getPartitionedSources(), + outputBuffer, + fragment.getFeederCTEId(), + fragment.getFeederCTEParentId(), + cteCtx); + } + return null; + } + + public static class ExtensionClassLoader + extends URLClassLoader + { + public ExtensionClassLoader(final String path, ClassLoader parent) throws MalformedURLException + { + super(new URL[] {new URL(path)}, parent); + } + } } diff --git a/presto-main/src/main/java/io/prestosql/execution/StateMachine.java b/presto-main/src/main/java/io/prestosql/execution/StateMachine.java index 8a6d064e33dfc556cefb873d7d7eaa05556837a4..24801797c96e7ac264c4e7b8ed57e72c498bbb26 100644 --- a/presto-main/src/main/java/io/prestosql/execution/StateMachine.java +++ b/presto-main/src/main/java/io/prestosql/execution/StateMachine.java @@ -49,6 +49,7 @@ public class StateMachine private final Executor executor; private final Object lock = new Object(); private final Set terminalStates; + private StateChangeListener tailStateChangeListener; @GuardedBy("lock") private volatile T state; @@ -279,6 +280,12 @@ public class StateMachine inTerminalState = isTerminalState(currentState); if (!inTerminalState) { stateChangeListeners.add(stateChangeListener); + if (tailStateChangeListener != null) { + if (stateChangeListeners.contains(tailStateChangeListener)) { + stateChangeListeners.remove(tailStateChangeListener); + } + stateChangeListeners.add(tailStateChangeListener); + } } } @@ -287,6 +294,12 @@ public class StateMachine safeExecute(() -> stateChangeListener.stateChanged(currentState)); } + public void addStateChangeListenerToTail(StateChangeListener stateChangeListener) + { + tailStateChangeListener = stateChangeListener; + addStateChangeListener(stateChangeListener); + } + @VisibleForTesting boolean isTerminalState(T state) { diff --git a/presto-main/src/main/java/io/prestosql/execution/TaskStateMachine.java b/presto-main/src/main/java/io/prestosql/execution/TaskStateMachine.java index aa56d93aec29f092e942785233e6dd44bc6950a7..ca02f8c80535ea05d1027c7fd335fea70fa84442 100644 --- a/presto-main/src/main/java/io/prestosql/execution/TaskStateMachine.java +++ b/presto-main/src/main/java/io/prestosql/execution/TaskStateMachine.java @@ -121,6 +121,15 @@ public class TaskStateMachine taskState.addStateChangeListener(stateChangeListener); } + /** + * Add listener to the tail, this listener will be notified at last when state changed. + * @param stateChangeListener listener of state change. + */ + public void addStateChangeListenerToTail(StateChangeListener stateChangeListener) + { + taskState.addStateChangeListenerToTail(stateChangeListener); + } + @Override public String toString() { diff --git a/presto-main/src/main/java/io/prestosql/metadata/InternalBlockEncodingSerde.java b/presto-main/src/main/java/io/prestosql/metadata/InternalBlockEncodingSerde.java index daf1ef743b90e918095b59e8741658ee7a48c7bd..eb66633f0e215a91afb600f01e790fb5422f71c5 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/InternalBlockEncodingSerde.java +++ b/presto-main/src/main/java/io/prestosql/metadata/InternalBlockEncodingSerde.java @@ -24,7 +24,7 @@ import java.util.Optional; import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Objects.requireNonNull; -final class InternalBlockEncodingSerde +public class InternalBlockEncodingSerde implements BlockEncodingSerde { private final FunctionAndTypeManager functionAndTypeManager; diff --git a/presto-main/src/main/java/io/prestosql/operator/EmptyLookupSource.java b/presto-main/src/main/java/io/prestosql/operator/EmptyLookupSource.java index 720d006c3d8091009d1c847876a00e4dedc3862d..eac8185d1cd567cf2ae82b1225b661dbc7f64fab 100644 --- a/presto-main/src/main/java/io/prestosql/operator/EmptyLookupSource.java +++ b/presto-main/src/main/java/io/prestosql/operator/EmptyLookupSource.java @@ -16,7 +16,7 @@ package io.prestosql.operator; import io.prestosql.spi.Page; import io.prestosql.spi.PageBuilder; -final class EmptyLookupSource +public final class EmptyLookupSource implements LookupSource { @Override diff --git a/presto-main/src/main/java/io/prestosql/operator/ExchangeClient.java b/presto-main/src/main/java/io/prestosql/operator/ExchangeClient.java index 29a9ae89e3562153545fe09d1362c567000047cd..ee04c6d52f3416da3f5b8e9d1c94e1b026ad8df0 100644 --- a/presto-main/src/main/java/io/prestosql/operator/ExchangeClient.java +++ b/presto-main/src/main/java/io/prestosql/operator/ExchangeClient.java @@ -24,6 +24,7 @@ import io.airlift.log.Logger; import io.airlift.units.DataSize; import io.airlift.units.Duration; import io.hetu.core.transport.execution.buffer.PageCodecMarker; +import io.hetu.core.transport.execution.buffer.PagesSerde; import io.hetu.core.transport.execution.buffer.SerializedPage; import io.prestosql.failuredetector.FailureDetector; import io.prestosql.memory.context.LocalMemoryContext; @@ -31,6 +32,7 @@ import io.prestosql.operator.HttpPageBufferClient.ClientCallback; import io.prestosql.operator.WorkProcessor.ProcessState; import io.prestosql.snapshot.MultiInputSnapshotState; import io.prestosql.snapshot.QuerySnapshotManager; +import io.prestosql.spi.Page; import io.prestosql.spi.snapshot.BlockEncodingSerdeProvider; import org.apache.commons.lang3.tuple.Pair; @@ -122,6 +124,8 @@ public class ExchangeClient @GuardedBy("this") private long averageBytesPerRequest; + private List pages = new ArrayList<>(); + private final AtomicBoolean closed = new AtomicBoolean(); private final AtomicReference failure = new AtomicReference<>(); @@ -343,6 +347,19 @@ public class ExchangeClient }); } + public List getPages(String target, PagesSerde pagesSerde) + { + SerializedPage serializedPage = pollPage(target).getLeft(); + if (serializedPage == null) { + if (isFinished()) { + return pages; + } + return null; + } + pages.add(pagesSerde.deserialize(serializedPage)); + return null; + } + @Nullable public Pair pollPage(String target) { diff --git a/presto-main/src/main/java/io/prestosql/operator/HashAggregationOperator.java b/presto-main/src/main/java/io/prestosql/operator/HashAggregationOperator.java index 43de68a8b884437de685f8d382fdfbd168e01f57..6f4674996ff70dd2ad50a4d2a540405742ccf9b6 100644 --- a/presto-main/src/main/java/io/prestosql/operator/HashAggregationOperator.java +++ b/presto-main/src/main/java/io/prestosql/operator/HashAggregationOperator.java @@ -117,7 +117,7 @@ public class HashAggregationOperator } @VisibleForTesting - HashAggregationOperatorFactory( + public HashAggregationOperatorFactory( int operatorId, PlanNodeId planNodeId, List groupByTypes, diff --git a/presto-main/src/main/java/io/prestosql/operator/Operator.java b/presto-main/src/main/java/io/prestosql/operator/Operator.java index 7f3da701338f67ea4e1fbe7ac21a75c16e73c2be..b6954ef8c9b66f3ada598716d81a61cac8f4827e 100644 --- a/presto-main/src/main/java/io/prestosql/operator/Operator.java +++ b/presto-main/src/main/java/io/prestosql/operator/Operator.java @@ -66,7 +66,10 @@ public interface Operator /** * For Snapshot - If next output is a marker page, then return it, otherwise return null */ - Page pollMarker(); + default Page pollMarker() + { + return null; + } /** * After calling this method operator should revoke all reserved revocable memory. diff --git a/presto-main/src/main/java/io/prestosql/operator/OperatorFactory.java b/presto-main/src/main/java/io/prestosql/operator/OperatorFactory.java index 5b6b81b5c8687ef68d9dc388482cd09090545184..81e60525372616181d33b528fdd987c233bd1666 100644 --- a/presto-main/src/main/java/io/prestosql/operator/OperatorFactory.java +++ b/presto-main/src/main/java/io/prestosql/operator/OperatorFactory.java @@ -13,7 +13,11 @@ */ package io.prestosql.operator; +import com.google.common.collect.ImmutableList; import io.prestosql.execution.Lifespan; +import io.prestosql.spi.type.Type; + +import java.util.List; public interface OperatorFactory { @@ -47,4 +51,14 @@ public interface OperatorFactory } OperatorFactory duplicate(); + + default boolean isExtensionOperatorFactory() + { + return false; + } + + default List getSourceTypes() + { + return ImmutableList.of(); + } } diff --git a/presto-main/src/main/java/io/prestosql/operator/PageUtils.java b/presto-main/src/main/java/io/prestosql/operator/PageUtils.java index 296c978eca6d32bd4399e8b3b4b5ad269f2bc09d..a32c6964382e72ec5c11b101b4d6f5bf8155a393 100644 --- a/presto-main/src/main/java/io/prestosql/operator/PageUtils.java +++ b/presto-main/src/main/java/io/prestosql/operator/PageUtils.java @@ -19,13 +19,13 @@ import io.prestosql.spi.block.LazyBlock; import java.util.function.LongConsumer; -final class PageUtils +public final class PageUtils { private PageUtils() { } - static Page recordMaterializedBytes(Page page, LongConsumer sizeInBytesConsumer) + public static Page recordMaterializedBytes(Page page, LongConsumer sizeInBytesConsumer) { // account processed bytes from lazy blocks only when they are loaded Block[] blocks = new Block[page.getChannelCount()]; diff --git a/presto-main/src/main/java/io/prestosql/operator/TaskContext.java b/presto-main/src/main/java/io/prestosql/operator/TaskContext.java index 8bda2d5551d0cd48e8d1795bafe6c1ef4760345d..68d90cdc215726dd9c51656cee4c42efeefb7226 100644 --- a/presto-main/src/main/java/io/prestosql/operator/TaskContext.java +++ b/presto-main/src/main/java/io/prestosql/operator/TaskContext.java @@ -40,7 +40,9 @@ import org.joda.time.DateTime; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.OptionalInt; import java.util.Set; import java.util.concurrent.CopyOnWriteArrayList; @@ -111,6 +113,8 @@ public class TaskContext private final PagesSerdeFactory serdeFactory; private final TaskSnapshotManager snapshotManager; + private final Map taskExtendProperties = new HashMap<>(); + public static TaskContext createTaskContext( QueryContext queryContext, TaskStateMachine taskStateMachine, @@ -185,6 +189,16 @@ public class TaskContext return snapshotManager; } + public TaskStateMachine getTaskStateMachine() + { + return taskStateMachine; + } + + public Map getTaskExtendProperties() + { + return taskExtendProperties; + } + public PipelineContext addPipelineContext(int pipelineId, boolean inputPipeline, boolean outputPipeline, boolean partitioned) { PipelineContext pipelineContext = new PipelineContext( diff --git a/presto-main/src/main/java/io/prestosql/operator/WindowFunctionDefinition.java b/presto-main/src/main/java/io/prestosql/operator/WindowFunctionDefinition.java index 821721986b0a76c58cc69aa57f38b87add5ff9f3..f0684b92100342f8f48abbcc628e678ef79bd446 100644 --- a/presto-main/src/main/java/io/prestosql/operator/WindowFunctionDefinition.java +++ b/presto-main/src/main/java/io/prestosql/operator/WindowFunctionDefinition.java @@ -64,6 +64,16 @@ public class WindowFunctionDefinition return type; } + public WindowFunctionSupplier getFunctionSupplier() + { + return functionSupplier; + } + + public List getArgumentChannels() + { + return argumentChannels; + } + public WindowFunction createWindowFunction() { return functionSupplier.createWindowFunction(argumentChannels); diff --git a/presto-main/src/main/java/io/prestosql/operator/exchange/LocalExchange.java b/presto-main/src/main/java/io/prestosql/operator/exchange/LocalExchange.java index 7daf270d0397e92802e4cb49602acaea5418a719..891eac7a748bf997768a8cbc609b3722ed3179e3 100644 --- a/presto-main/src/main/java/io/prestosql/operator/exchange/LocalExchange.java +++ b/presto-main/src/main/java/io/prestosql/operator/exchange/LocalExchange.java @@ -66,11 +66,11 @@ public class LocalExchange { private static final Logger LOG = Logger.get(LocalExchange.class); - private final Supplier exchangerSupplier; + protected Supplier exchangerSupplier; - private final List sources; + protected final List sources; - private final LocalExchangeMemoryManager memoryManager; + protected final LocalExchangeMemoryManager memoryManager; @GuardedBy("this") private boolean allSourcesFinished; @@ -352,27 +352,27 @@ public class LocalExchange @ThreadSafe public static class LocalExchangeFactory { - private final PartitioningHandle partitioning; - private final List types; - private final List partitionChannels; - private final Optional partitionHashChannel; - private final PipelineExecutionStrategy exchangeSourcePipelineExecutionStrategy; - private final DataSize maxBufferedBytes; - private final int bufferCount; - private final boolean isForMerge; - private final AggregationNode.AggregationType aggregationType; + protected final PartitioningHandle partitioning; + protected final List types; + protected final List partitionChannels; + protected final Optional partitionHashChannel; + protected final PipelineExecutionStrategy exchangeSourcePipelineExecutionStrategy; + protected final DataSize maxBufferedBytes; + protected final int bufferCount; + protected final boolean isForMerge; + protected final AggregationNode.AggregationType aggregationType; @GuardedBy("this") - private boolean noMoreSinkFactories; + protected boolean noMoreSinkFactories; // The number of total sink factories are tracked at planning time // so that the exact number of sink factory is known by the time execution starts. @GuardedBy("this") - private int numSinkFactories; + protected int numSinkFactories; @GuardedBy("this") - private final Map localExchangeMap = new HashMap<>(); + protected final Map localExchangeMap = new HashMap<>(); @GuardedBy("this") - private final List closedSinkFactories = new ArrayList<>(); + protected final List closedSinkFactories = new ArrayList<>(); public LocalExchangeFactory( PartitioningHandle partitioning, @@ -514,7 +514,7 @@ public class LocalExchange { private final LocalExchange exchange; - private LocalExchangeSinkFactory(LocalExchange exchange) + public LocalExchangeSinkFactory(LocalExchange exchange) { this.exchange = requireNonNull(exchange, "exchange is null"); } diff --git a/presto-main/src/main/java/io/prestosql/operator/exchange/LocalExchangeSinkOperator.java b/presto-main/src/main/java/io/prestosql/operator/exchange/LocalExchangeSinkOperator.java index 47721b356f1f75ab5884be4c9b698d020c3bb273..f2eef21953ba701580f1a681a3f9f93d843e856f 100644 --- a/presto-main/src/main/java/io/prestosql/operator/exchange/LocalExchangeSinkOperator.java +++ b/presto-main/src/main/java/io/prestosql/operator/exchange/LocalExchangeSinkOperator.java @@ -117,7 +117,7 @@ public class LocalExchangeSinkOperator private final Function pagePreprocessor; private final SingleInputSnapshotState snapshotState; - LocalExchangeSinkOperator(String id, OperatorContext operatorContext, LocalExchangeSink sink, Function pagePreprocessor) + public LocalExchangeSinkOperator(String id, OperatorContext operatorContext, LocalExchangeSink sink, Function pagePreprocessor) { this.id = id; this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); diff --git a/presto-main/src/main/java/io/prestosql/operator/exchange/LocalExchangeSource.java b/presto-main/src/main/java/io/prestosql/operator/exchange/LocalExchangeSource.java index 4036268b43149d257e6fe41fcbfba95894ef7a9b..65045f5aed8d7fc0cab8b804a722d9efdd6f652f 100644 --- a/presto-main/src/main/java/io/prestosql/operator/exchange/LocalExchangeSource.java +++ b/presto-main/src/main/java/io/prestosql/operator/exchange/LocalExchangeSource.java @@ -66,6 +66,8 @@ public class LocalExchangeSource private final Object lock = new Object(); + private List pages = new ArrayList<>(); + @GuardedBy("lock") private SettableFuture notEmptyFuture = NOT_EMPTY; @@ -99,7 +101,7 @@ public class LocalExchangeSource return Collections.unmodifiableSet(inputChannels); } - void addPage(PageReference pageReference, String origin) + public void addPage(PageReference pageReference, String origin) { checkNotHoldsLock(); @@ -238,6 +240,19 @@ public class LocalExchangeSource return Pair.of(page, origin.orElse(null)); } + public List getPages() + { + Page page = removePage().getLeft(); + if (page == null) { + if (isFinished()) { + return pages; + } + return null; + } + pages.add(page); + return null; + } + public ListenableFuture waitForReading() { checkNotHoldsLock(); diff --git a/presto-main/src/main/java/io/prestosql/operator/exchange/PageReference.java b/presto-main/src/main/java/io/prestosql/operator/exchange/PageReference.java index 0c85a3abd6233165645c6399638a4a2d6b025382..3b8cc44e9c8252e5f5973629c37a82ea06b064e6 100644 --- a/presto-main/src/main/java/io/prestosql/operator/exchange/PageReference.java +++ b/presto-main/src/main/java/io/prestosql/operator/exchange/PageReference.java @@ -23,7 +23,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; @ThreadSafe -class PageReference +public class PageReference { private final Page page; private final Runnable onFree; diff --git a/presto-main/src/main/java/io/prestosql/server/PluginManager.java b/presto-main/src/main/java/io/prestosql/server/PluginManager.java index 695ecba93920b2725007c39016a0735def219af8..ea281a49d0ae57d72c3b3acb2ee50389ad39d3ee 100644 --- a/presto-main/src/main/java/io/prestosql/server/PluginManager.java +++ b/presto-main/src/main/java/io/prestosql/server/PluginManager.java @@ -89,6 +89,7 @@ public class PluginManager .add("io.airlift.units.") .add("org.openjdk.jol.") .add("io.prestosql.sql.tree.") + .add("nova.hetu.omniruntime.vector.") .build(); private static final Logger log = Logger.get(PluginManager.class); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java index 2fae9cdfc4fab16e0f96e3894fa9a9533a8963a7..a3b4ecc50dae50f18539baafdfc0d2485d313ce4 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java @@ -342,37 +342,188 @@ public class LocalExecutionPlanner { private static final Logger log = Logger.get(LocalExecutionPlanner.class); - private final Metadata metadata; - private final TypeAnalyzer typeAnalyzer; - private final Optional explainAnalyzeContext; - private final PageSourceProvider pageSourceProvider; - private final IndexManager indexManager; - private final NodePartitioningManager nodePartitioningManager; - private final PageSinkManager pageSinkManager; - private final ExchangeClientSupplier exchangeClientSupplier; - private final ExpressionCompiler expressionCompiler; - private final PageFunctionCompiler pageFunctionCompiler; - private final JoinFilterFunctionCompiler joinFilterFunctionCompiler; - private final DataSize maxIndexMemorySize; - private final IndexJoinLookupStats indexJoinLookupStats; - private final DataSize maxPartialAggregationMemorySize; - private final DataSize maxPagePartitioningBufferSize; - private final DataSize maxLocalExchangeBufferSize; - private final SpillerFactory spillerFactory; - private final SingleStreamSpillerFactory singleStreamSpillerFactory; - private final PartitioningSpillerFactory partitioningSpillerFactory; - private final PagesIndex.Factory pagesIndexFactory; - private final JoinCompiler joinCompiler; - private final LookupJoinOperators lookupJoinOperators; - private final OrderingCompiler orderingCompiler; - private final StateStoreProvider stateStoreProvider; - private final NodeInfo nodeInfo; - private final CubeManager cubeManager; - private final StateStoreListenerManager stateStoreListenerManager; - private final DynamicFilterCacheManager dynamicFilterCacheManager; - private final HeuristicIndexerManager heuristicIndexerManager; - private final FunctionResolution functionResolution; - private final LogicalRowExpressions logicalRowExpressions; + protected final Metadata metadata; + protected final TypeAnalyzer typeAnalyzer; + protected final Optional explainAnalyzeContext; + protected final PageSourceProvider pageSourceProvider; + protected final IndexManager indexManager; + protected final NodePartitioningManager nodePartitioningManager; + protected final PageSinkManager pageSinkManager; + protected final ExchangeClientSupplier exchangeClientSupplier; + protected final ExpressionCompiler expressionCompiler; + protected final PageFunctionCompiler pageFunctionCompiler; + protected final JoinFilterFunctionCompiler joinFilterFunctionCompiler; + protected final DataSize maxIndexMemorySize; + protected final IndexJoinLookupStats indexJoinLookupStats; + protected final DataSize maxPartialAggregationMemorySize; + protected final DataSize maxPagePartitioningBufferSize; + protected final DataSize maxLocalExchangeBufferSize; + protected final SpillerFactory spillerFactory; + protected final SingleStreamSpillerFactory singleStreamSpillerFactory; + protected final PartitioningSpillerFactory partitioningSpillerFactory; + protected final PagesIndex.Factory pagesIndexFactory; + protected final JoinCompiler joinCompiler; + protected final LookupJoinOperators lookupJoinOperators; + protected final OrderingCompiler orderingCompiler; + protected final StateStoreProvider stateStoreProvider; + protected final NodeInfo nodeInfo; + protected final CubeManager cubeManager; + protected final StateStoreListenerManager stateStoreListenerManager; + protected final DynamicFilterCacheManager dynamicFilterCacheManager; + protected final HeuristicIndexerManager heuristicIndexerManager; + protected final FunctionResolution functionResolution; + protected final LogicalRowExpressions logicalRowExpressions; + protected final TaskManagerConfig taskManagerConfig; + + public Metadata getMetadata() + { + return metadata; + } + + public TypeAnalyzer getTypeAnalyzer() + { + return typeAnalyzer; + } + + public Optional getExplainAnalyzeContext() + { + return explainAnalyzeContext; + } + + public PageSourceProvider getPageSourceProvider() + { + return pageSourceProvider; + } + + public IndexManager getIndexManager() + { + return indexManager; + } + + public NodePartitioningManager getNodePartitioningManager() + { + return nodePartitioningManager; + } + + public PageSinkManager getPageSinkManager() + { + return pageSinkManager; + } + + public ExchangeClientSupplier getExchangeClientSupplier() + { + return exchangeClientSupplier; + } + + public ExpressionCompiler getExpressionCompiler() + { + return expressionCompiler; + } + + public PageFunctionCompiler getPageFunctionCompiler() + { + return pageFunctionCompiler; + } + + public JoinFilterFunctionCompiler getJoinFilterFunctionCompiler() + { + return joinFilterFunctionCompiler; + } + + public DataSize getMaxIndexMemorySize() + { + return maxIndexMemorySize; + } + + public IndexJoinLookupStats getIndexJoinLookupStats() + { + return indexJoinLookupStats; + } + + public DataSize getMaxPartialAggregationMemorySize() + { + return maxPartialAggregationMemorySize; + } + + public DataSize getMaxPagePartitioningBufferSize() + { + return maxPagePartitioningBufferSize; + } + + public DataSize getMaxLocalExchangeBufferSize() + { + return maxLocalExchangeBufferSize; + } + + public SpillerFactory getSpillerFactory() + { + return spillerFactory; + } + + public SingleStreamSpillerFactory getSingleStreamSpillerFactory() + { + return singleStreamSpillerFactory; + } + + public PartitioningSpillerFactory getPartitioningSpillerFactory() + { + return partitioningSpillerFactory; + } + + public PagesIndex.Factory getPagesIndexFactory() + { + return pagesIndexFactory; + } + + public JoinCompiler getJoinCompiler() + { + return joinCompiler; + } + + public LookupJoinOperators getLookupJoinOperators() + { + return lookupJoinOperators; + } + + public OrderingCompiler getOrderingCompiler() + { + return orderingCompiler; + } + + public StateStoreProvider getStateStoreProvider() + { + return stateStoreProvider; + } + + public NodeInfo getNodeInfo() + { + return nodeInfo; + } + + public CubeManager getCubeManager() + { + return cubeManager; + } + + public StateStoreListenerManager getStateStoreListenerManager() + { + return stateStoreListenerManager; + } + + public DynamicFilterCacheManager getDynamicFilterCacheManager() + { + return dynamicFilterCacheManager; + } + + public HeuristicIndexerManager getHeuristicIndexerManager() + { + return heuristicIndexerManager; + } + + public TaskManagerConfig getTaskManagerConfig() + { + return taskManagerConfig; + } @Inject public LocalExecutionPlanner( @@ -415,6 +566,7 @@ public class LocalExecutionPlanner this.pageFunctionCompiler = requireNonNull(pageFunctionCompiler, "pageFunctionCompiler is null"); this.joinFilterFunctionCompiler = requireNonNull(joinFilterFunctionCompiler, "compiler is null"); this.indexJoinLookupStats = requireNonNull(indexJoinLookupStats, "indexJoinLookupStats is null"); + this.taskManagerConfig = taskManagerConfig; this.maxIndexMemorySize = requireNonNull(taskManagerConfig, "taskManagerConfig is null").getMaxIndexMemoryUsage(); this.spillerFactory = requireNonNull(spillerFactory, "spillerFactory is null"); this.singleStreamSpillerFactory = requireNonNull(singleStreamSpillerFactory, "singleStreamSpillerFactory is null"); @@ -633,7 +785,7 @@ public class LocalExecutionPlanner return first instanceof LookupOuterOperatorFactory && isTableScanPipeline(context.outerToJoinMap.get(driverFactory)); } - private static void addLookupOuterDrivers(LocalExecutionPlanContext context) + protected static void addLookupOuterDrivers(LocalExecutionPlanContext context) { // For an outer join on the lookup side (RIGHT or FULL) add an additional // driver to output the unused rows in the lookup source @@ -664,29 +816,29 @@ public class LocalExecutionPlanner } } - private static class LocalExecutionPlanContext + public static class LocalExecutionPlanContext { - private final TaskContext taskContext; - private final TypeProvider types; - private final List driverFactories; - private final Optional indexSourceContext; + protected final TaskContext taskContext; + protected final TypeProvider types; + protected List driverFactories; + protected final Optional indexSourceContext; // the collector is shared with all subContexts to allow local dynamic filtering // with multiple table scans (e.g. co-located joins). - private final LocalDynamicFiltersCollector dynamicFiltersCollector; + protected final LocalDynamicFiltersCollector dynamicFiltersCollector; // this is shared with all subContexts - private final AtomicInteger nextPipelineId; + protected final AtomicInteger nextPipelineId; private int nextOperatorId; private boolean inputDriver = true; private OptionalInt driverInstanceCount = OptionalInt.empty(); private Map cteOperationMap = new HashMap<>(); - private Map cteCtx; + protected Map cteCtx; private static Map sourceInitialized = new ConcurrentHashMap<>(); private final PlanNodeId consumerId; - private final Optional feederCTEId; - private final Optional feederCTEParentId; + protected final Optional feederCTEId; + protected final Optional feederCTEParentId; // Snapshot: record pipeline that corresponds to the lookup-outer pipeline. // This is used to help determine if a lookup-outer pipeline should be treated as a tabel-scan pipeine. @@ -700,7 +852,7 @@ public class LocalExecutionPlanner this(taskContext, types, new ArrayList<>(), Optional.empty(), new LocalDynamicFiltersCollector(taskContext, Optional.of(metadata), dynamicFilterCacheManager), new AtomicInteger(0), feederCTEId, feederCTEParentId, cteCtx); } - private LocalExecutionPlanContext( + protected LocalExecutionPlanContext( TaskContext taskContext, TypeProvider types, List driverFactories, @@ -745,11 +897,16 @@ public class LocalExecutionPlanner return driverFactory; } - private List getDriverFactories() + public List getDriverFactories() { return ImmutableList.copyOf(driverFactories); } + public void setDriverFactories(List driverFactories) + { + this.driverFactories = driverFactories; + } + public Session getSession() { return taskContext.getSession(); @@ -780,22 +937,27 @@ public class LocalExecutionPlanner return indexSourceContext; } - private int getNextPipelineId() + private AtomicInteger getPipelineId() + { + return nextPipelineId; + } + + public int getNextPipelineId() { return nextPipelineId.getAndIncrement(); } - private int getNextOperatorId() + public int getNextOperatorId() { return nextOperatorId++; } - private boolean isInputDriver() + public boolean isInputDriver() { return inputDriver; } - private void setInputDriver(boolean inputDriver) + public void setInputDriver(boolean inputDriver) { this.inputDriver = inputDriver; } @@ -866,7 +1028,7 @@ public class LocalExecutionPlanner } } - private static class IndexSourceContext + public static class IndexSourceContext { private final SetMultimap indexLookupToProbeInput; @@ -918,13 +1080,13 @@ public class LocalExecutionPlanner } } - private class Visitor + public class Visitor extends InternalPlanVisitor { - private final Session session; - private final StageExecutionDescriptor stageExecutionDescriptor; + protected final Session session; + protected final StageExecutionDescriptor stageExecutionDescriptor; - private Visitor(Session session, StageExecutionDescriptor stageExecutionDescriptor) + public Visitor(Session session, StageExecutionDescriptor stageExecutionDescriptor) { this.session = session; this.stageExecutionDescriptor = stageExecutionDescriptor; @@ -1618,7 +1780,7 @@ public class LocalExecutionPlanner } } - private Supplier>> getDynamicFilterSupplier(Optional>> dynamicFilters, PlanNode sourceNode, LocalExecutionPlanContext context) + protected Supplier>> getDynamicFilterSupplier(Optional>> dynamicFilters, PlanNode sourceNode, LocalExecutionPlanContext context) { if (dynamicFilters.isPresent() && !dynamicFilters.get().isEmpty()) { log.debug("[TableScan] Dynamic filters: %s", dynamicFilters); @@ -1643,7 +1805,7 @@ public class LocalExecutionPlanner return null; } - private RowExpression bindChannels(RowExpression inputExpression, Map sourceLayout, TypeProvider types) + public RowExpression bindChannels(RowExpression inputExpression, Map sourceLayout, TypeProvider types) { RowExpression expression = inputExpression; Type type = expression.getType(); @@ -1864,12 +2026,12 @@ public class LocalExecutionPlanner stageExecutionDescriptor.isScanGroupedExecution(node.getId()) ? GROUPED_EXECUTION : UNGROUPED_EXECUTION); } - private ImmutableMap makeLayout(PlanNode node) + protected ImmutableMap makeLayout(PlanNode node) { return makeLayoutFromOutputSymbols(node.getOutputSymbols()); } - private ImmutableMap makeLayoutFromOutputSymbols(List outputSymbols) + protected ImmutableMap makeLayoutFromOutputSymbols(List outputSymbols) { ImmutableMap.Builder outputMappings = ImmutableMap.builder(); int channel = 0; @@ -2231,7 +2393,7 @@ public class LocalExecutionPlanner return symbols.stream().map(SymbolUtils::toSymbolReference).collect(toImmutableSet()); } - private PhysicalOperation createNestedLoopJoin(JoinNode node, LocalExecutionPlanContext context) + protected PhysicalOperation createNestedLoopJoin(JoinNode node, LocalExecutionPlanContext context) { PhysicalOperation probeSource = node.getLeft().accept(this, context); @@ -2459,7 +2621,7 @@ public class LocalExecutionPlanner return new PhysicalOperation(operator, outputMappings.build(), context, probeSource); } - private Optional createDynamicFilter(JoinNode node, LocalExecutionPlanContext context, int partitionCount) + protected Optional createDynamicFilter(JoinNode node, LocalExecutionPlanContext context, int partitionCount) { if (!isEnableDynamicFiltering(context.getSession())) { return Optional.empty(); @@ -2625,7 +2787,7 @@ public class LocalExecutionPlanner return lookupSourceFactoryManager; } - private JoinFilterFunctionFactory compileJoinFilterFunction( + protected JoinFilterFunctionFactory compileJoinFilterFunction( RowExpression filterExpression, Map probeLayout, Map buildLayout, @@ -2636,7 +2798,7 @@ public class LocalExecutionPlanner return joinFilterFunctionCompiler.compileJoinFilterFunction(bindChannels(filterExpression, joinSourcesLayout, types), buildLayout.size()); } - private int sortExpressionAsSortChannel( + public int sortExpressionAsSortChannel( RowExpression sortExpression, Map probeLayout, Map buildLayout, @@ -2682,7 +2844,7 @@ public class LocalExecutionPlanner } } - private Map createJoinSourcesLayout(Map lookupSourceLayout, Map probeSourceLayout) + protected Map createJoinSourcesLayout(Map lookupSourceLayout, Map probeSourceLayout) { ImmutableMap.Builder joinSourcesLayout = ImmutableMap.builder(); joinSourcesLayout.putAll(lookupSourceLayout); @@ -3203,19 +3365,19 @@ public class LocalExecutionPlanner throw new UnsupportedOperationException("not yet implemented"); } - private List getSourceOperatorTypes(PlanNode node, TypeProvider types) + protected List getSourceOperatorTypes(PlanNode node, TypeProvider types) { return getSymbolTypes(node.getOutputSymbols(), types); } - private List getSymbolTypes(List symbols, TypeProvider types) + protected List getSymbolTypes(List symbols, TypeProvider types) { return symbols.stream() .map(types::get) .collect(toImmutableList()); } - private AccumulatorFactory buildAccumulatorFactory( + protected AccumulatorFactory buildAccumulatorFactory( PhysicalOperation source, Aggregation aggregation) { @@ -3558,7 +3720,7 @@ public class LocalExecutionPlanner }; } - private static Function enforceLayoutProcessor(List expectedLayout, Map inputLayout) + protected static Function enforceLayoutProcessor(List expectedLayout, Map inputLayout) { int[] channels = expectedLayout.stream() .peek(symbol -> checkArgument(inputLayout.containsKey(symbol), "channel not found for symbol: %s", symbol)) @@ -3573,7 +3735,7 @@ public class LocalExecutionPlanner return new PageChannelSelector(channels); } - private static List getChannelsForSymbols(List symbols, Map layout) + protected static List getChannelsForSymbols(List symbols, Map layout) { ImmutableList.Builder builder = ImmutableList.builder(); for (Symbol symbol : symbols) { @@ -3582,7 +3744,7 @@ public class LocalExecutionPlanner return builder.build(); } - private static Function channelGetter(PhysicalOperation source) + protected static Function channelGetter(PhysicalOperation source) { return input -> { checkArgument(source.getLayout().containsKey(input)); @@ -3593,7 +3755,7 @@ public class LocalExecutionPlanner /** * Encapsulates an physical operator plus the mapping of logical symbols to channel/field */ - private static class PhysicalOperation + public static class PhysicalOperation { private final List operatorFactories; private final Map layout; @@ -3664,7 +3826,7 @@ public class LocalExecutionPlanner return layout; } - private List getOperatorFactories() + public List getOperatorFactories() { return operatorFactories; } @@ -3675,7 +3837,7 @@ public class LocalExecutionPlanner } } - private static class DriverFactoryParameters + protected static class DriverFactoryParameters { private final LocalExecutionPlanContext subContext; private final PhysicalOperation source; diff --git a/presto-main/src/main/java/io/prestosql/utils/HetuConfig.java b/presto-main/src/main/java/io/prestosql/utils/HetuConfig.java index 77cc9712db7d62742ff770899e5a9a6e92491463..61effbfb113dc368f3b7f15061f35578c6bb267c 100644 --- a/presto-main/src/main/java/io/prestosql/utils/HetuConfig.java +++ b/presto-main/src/main/java/io/prestosql/utils/HetuConfig.java @@ -62,10 +62,53 @@ public class HetuConfig private Duration splitCacheStateUpdateInterval = new Duration(2, TimeUnit.SECONDS); private boolean isTraceStackVisible; + private String extensionExecutionPlannerJarPath; + private String extensionExecutionPlannerClassPath; + private boolean extensionExecutionPlannerEnabled; + public HetuConfig() { } + public boolean getExtensionExecutionPlannerEnabled() + { + return extensionExecutionPlannerEnabled; + } + + @Config(HetuConstant.EXTENSION_EXECUTION_PLANNER_ENABLED) + @ConfigDescription("extension execution planner enable from config") + public HetuConfig setExtensionExecutionPlannerEnabled(boolean extensionExecutionPlannerEnabled) + { + this.extensionExecutionPlannerEnabled = extensionExecutionPlannerEnabled; + return this; + } + + public String getExtensionExecutionPlannerJarPath() + { + return extensionExecutionPlannerJarPath; + } + + @Config(HetuConstant.EXTENSION_EXECUTION_PLANNER_JAR_PATH) + @ConfigDescription("extension execution planner jar path from config") + public HetuConfig setExtensionExecutionPlannerJarPath(String extensionExecutionPlannerJarPath) + { + this.extensionExecutionPlannerJarPath = extensionExecutionPlannerJarPath; + return this; + } + + public String getExtensionExecutionPlannerClassPath() + { + return extensionExecutionPlannerClassPath; + } + + @Config(HetuConstant.EXTENSION_EXECUTION_PLANNER_CLASS_PATH) + @ConfigDescription("extension execution planner class path from config") + public HetuConfig setExtensionExecutionPlannerClassPath(String extensionExecutionPlannerClassPath) + { + this.extensionExecutionPlannerClassPath = extensionExecutionPlannerClassPath; + return this; + } + @NotNull public boolean isFilterEnabled() { diff --git a/presto-main/src/test/java/io/prestosql/operator/window/AbstractTestWindowFunction.java b/presto-main/src/test/java/io/prestosql/operator/window/AbstractTestWindowFunction.java index ba531f33144489db35c667f7a258373b6b510fce..e3c686709ab79794a3ee070d66b4a4e73790fbf9 100644 --- a/presto-main/src/test/java/io/prestosql/operator/window/AbstractTestWindowFunction.java +++ b/presto-main/src/test/java/io/prestosql/operator/window/AbstractTestWindowFunction.java @@ -30,7 +30,7 @@ public abstract class AbstractTestWindowFunction protected LocalQueryRunner queryRunner; @BeforeClass - public final void initTestWindowFunction() + public void initTestWindowFunction() { queryRunner = new LocalQueryRunner(TEST_SESSION); } diff --git a/presto-main/src/test/java/io/prestosql/utils/TestHetuConfig.java b/presto-main/src/test/java/io/prestosql/utils/TestHetuConfig.java index b23308f1b71c587aba9c0f423f8be0bae2363f81..b71d946ca6e99e0372442ae3b38793110385ebc9 100644 --- a/presto-main/src/test/java/io/prestosql/utils/TestHetuConfig.java +++ b/presto-main/src/test/java/io/prestosql/utils/TestHetuConfig.java @@ -54,7 +54,10 @@ public class TestHetuConfig .setSplitCacheMapEnabled(false) .setSplitCacheStateUpdateInterval(new Duration(2, TimeUnit.SECONDS)) .setTraceStackVisible(false) - .setIndexToPreload("")); + .setIndexToPreload("") + .setExtensionExecutionPlannerEnabled(false) + .setExtensionExecutionPlannerJarPath(null) + .setExtensionExecutionPlannerClassPath(null)); } @Test @@ -85,6 +88,9 @@ public class TestHetuConfig .put("hetu.split-cache-map.state-update-interval", "5s") .put("stack-trace-visible", "true") .put("hetu.heuristicindex.filter.cache.preload-indices", "idx1,idx2") + .put("extension_execution_planner_enabled", "true") + .put("extension_execution_planner_jar_path", "") + .put("extension_execution_planner_class_path", "") .build(); HetuConfig expected = new HetuConfig() @@ -111,7 +117,10 @@ public class TestHetuConfig .setSplitCacheMapEnabled(true) .setSplitCacheStateUpdateInterval(new Duration(5, TimeUnit.SECONDS)) .setTraceStackVisible(true) - .setIndexToPreload("idx1,idx2"); + .setIndexToPreload("idx1,idx2") + .setExtensionExecutionPlannerEnabled(true) + .setExtensionExecutionPlannerJarPath("") + .setExtensionExecutionPlannerClassPath(""); ConfigAssertions.assertFullMapping(properties, expected); } diff --git a/presto-spi/src/main/java/io/prestosql/spi/HetuConstant.java b/presto-spi/src/main/java/io/prestosql/spi/HetuConstant.java index 3b4940323380f318a6f520a0cb6232c3b15636ff..77db581f8545546a7b55445fde1680f02032e0a9 100644 --- a/presto-spi/src/main/java/io/prestosql/spi/HetuConstant.java +++ b/presto-spi/src/main/java/io/prestosql/spi/HetuConstant.java @@ -58,4 +58,9 @@ public class HetuConstant // error message public static final String HINDEX_CONFIG_ERROR_MSG = "Heuristic Index is not enabled in config.properties or is configured incorrectly."; + + // extension support message + public static final String EXTENSION_EXECUTION_PLANNER_ENABLED = "extension_execution_planner_enabled"; + public static final String EXTENSION_EXECUTION_PLANNER_JAR_PATH = "extension_execution_planner_jar_path"; + public static final String EXTENSION_EXECUTION_PLANNER_CLASS_PATH = "extension_execution_planner_class_path"; } diff --git a/presto-spi/src/main/java/io/prestosql/spi/Page.java b/presto-spi/src/main/java/io/prestosql/spi/Page.java index 840d09821b0cefd7e7c03a4cb7f41bb86ba1eb4e..b4bb9357ffeb3c9f8a793bcb0477f94f8c74db48 100644 --- a/presto-spi/src/main/java/io/prestosql/spi/Page.java +++ b/presto-spi/src/main/java/io/prestosql/spi/Page.java @@ -370,4 +370,9 @@ public class Page { pageMetadata.setProperty(key, value); } + + public Block[] getBlocks() + { + return blocks; + } } diff --git a/presto-spi/src/main/java/io/prestosql/spi/block/AbstractSingleRowBlock.java b/presto-spi/src/main/java/io/prestosql/spi/block/AbstractSingleRowBlock.java index 41a392236a2e1d9df816a0eb2fccc9050c081ff4..2ed3252ab5659d55ae86ce20c516703f0ce21844 100644 --- a/presto-spi/src/main/java/io/prestosql/spi/block/AbstractSingleRowBlock.java +++ b/presto-spi/src/main/java/io/prestosql/spi/block/AbstractSingleRowBlock.java @@ -70,6 +70,13 @@ public abstract class AbstractSingleRowBlock return getRawFieldBlock(position).getLong(rowIndex, offset); } + @Override + public double getDouble(int position, int offset) + { + checkFieldIndex(position); + return getRawFieldBlock(position).getDouble(rowIndex, offset); + } + @Override public Slice getSlice(int position, int offset, int length) { diff --git a/presto-spi/src/main/java/io/prestosql/spi/block/Block.java b/presto-spi/src/main/java/io/prestosql/spi/block/Block.java index a3bf3accb9df1af7d44c0908592e160f99023959..1934351002e3a6eb48de3f3491a82cfcb0d2aeb4 100644 --- a/presto-spi/src/main/java/io/prestosql/spi/block/Block.java +++ b/presto-spi/src/main/java/io/prestosql/spi/block/Block.java @@ -69,6 +69,14 @@ public interface Block throw new UnsupportedOperationException(getClass().getName()); } + /** + * Gets a little endian double at {@code offset} in the value at {@code position}. + */ + default double getDouble(int position, int offset) + { + throw new UnsupportedOperationException(getClass().getName()); + } + /** * Gets a slice at {@code offset} in the value at {@code position}. */ @@ -340,4 +348,32 @@ public interface Block System.arraycopy(positions, positionCount, matchedPositions, positionCount, positionCount); return positionCount; } + + default Object getValues() + { + throw new UnsupportedOperationException(); + } + + default int getBlockOffset() + { + throw new UnsupportedOperationException(); + } + + default boolean[] getValueNulls() + { + throw new UnsupportedOperationException(); + } + + default void close() + { + } + + default boolean isExtensionBlock() + { + return false; + } + + default void setClosable(boolean isClosable) + { + } } diff --git a/presto-spi/src/main/java/io/prestosql/spi/block/BlockBuilder.java b/presto-spi/src/main/java/io/prestosql/spi/block/BlockBuilder.java index f81e9d208b801307a8acf6c2d59b6177973d0fcd..66cf9597babdbcdafdd3c89655908738c00fcf6f 100644 --- a/presto-spi/src/main/java/io/prestosql/spi/block/BlockBuilder.java +++ b/presto-spi/src/main/java/io/prestosql/spi/block/BlockBuilder.java @@ -51,6 +51,14 @@ public interface BlockBuilder throw new UnsupportedOperationException(getClass().getName()); } + /** + * Write a double to the current entry; + */ + default BlockBuilder writeDouble(double value) + { + throw new UnsupportedOperationException(getClass().getName()); + } + /** * Write a byte sequences to the current entry; */ diff --git a/presto-spi/src/main/java/io/prestosql/spi/block/BlockUtil.java b/presto-spi/src/main/java/io/prestosql/spi/block/BlockUtil.java index 6d18d24969c042c83244530b8d9dce1c91b8098d..f6e0e2dcc2a0f406e189b648c347f91e3a804ff2 100644 --- a/presto-spi/src/main/java/io/prestosql/spi/block/BlockUtil.java +++ b/presto-spi/src/main/java/io/prestosql/spi/block/BlockUtil.java @@ -22,7 +22,7 @@ import static java.lang.Math.ceil; import static java.lang.String.format; import static java.util.Objects.requireNonNull; -final class BlockUtil +public final class BlockUtil { private static final double BLOCK_RESET_SKEW = 1.25; @@ -34,7 +34,7 @@ final class BlockUtil { } - static void checkArrayRange(int[] array, int offset, int length) + public static void checkArrayRange(int[] array, int offset, int length) { requireNonNull(array, "array is null"); if (offset < 0 || length < 0 || offset + length > array.length) { @@ -42,28 +42,28 @@ final class BlockUtil } } - static void checkValidRegion(int positionCount, int positionOffset, int length) + public static void checkValidRegion(int positionCount, int positionOffset, int length) { if (positionOffset < 0 || length < 0 || positionOffset + length > positionCount) { throw new IndexOutOfBoundsException(format("Invalid position %s and length %s in block with %s positions", positionOffset, length, positionCount)); } } - static void checkValidPositions(boolean[] positions, int positionCount) + public static void checkValidPositions(boolean[] positions, int positionCount) { if (positions.length != positionCount) { throw new IllegalArgumentException(format("Invalid positions array size %d, actual position count is %d", positions.length, positionCount)); } } - static void checkValidPosition(int position, int positionCount) + public static void checkValidPosition(int position, int positionCount) { if (position < 0 || position >= positionCount) { throw new IllegalArgumentException(format("Invalid position %s in block with %s positions", position, positionCount)); } } - static int calculateNewArraySize(int currentSize) + public static int calculateNewArraySize(int currentSize) { // grow array by 50% long newSize = (long) currentSize + (currentSize >> 1); @@ -81,7 +81,7 @@ final class BlockUtil return (int) newSize; } - static int calculateBlockResetSize(int currentSize) + public static int calculateBlockResetSize(int currentSize) { long newSize = (long) ceil(currentSize * BLOCK_RESET_SKEW); @@ -95,7 +95,7 @@ final class BlockUtil return (int) newSize; } - static int calculateBlockResetBytes(int currentBytes) + public static int calculateBlockResetBytes(int currentBytes) { long newBytes = (long) ceil(currentBytes * BLOCK_RESET_SKEW); if (newBytes > MAX_ARRAY_SIZE) { @@ -110,7 +110,7 @@ final class BlockUtil * with the first value set to 0. * If the range matches the entire offsets array, the input array will be returned. */ - static int[] compactOffsets(int[] offsets, int index, int length) + public static int[] compactOffsets(int[] offsets, int index, int length) { if (index == 0 && offsets.length == length + 1) { return offsets; @@ -128,7 +128,7 @@ final class BlockUtil * If the range matches the entire slice, the input slice will be returned. * Otherwise, a copy will be returned. */ - static Slice compactSlice(Slice slice, int index, int length) + public static Slice compactSlice(Slice slice, int index, int length) { if (slice.isCompact() && index == 0 && length == slice.length()) { return slice; @@ -141,7 +141,7 @@ final class BlockUtil * If the range matches the entire array, the input array will be returned. * Otherwise, a copy will be returned. */ - static boolean[] compactArray(boolean[] array, int index, int length) + public static boolean[] compactArray(boolean[] array, int index, int length) { if (index == 0 && length == array.length) { return array; @@ -149,7 +149,7 @@ final class BlockUtil return Arrays.copyOfRange(array, index, index + length); } - static byte[] compactArray(byte[] array, int index, int length) + public static byte[] compactArray(byte[] array, int index, int length) { if (index == 0 && length == array.length) { return array; @@ -157,7 +157,7 @@ final class BlockUtil return Arrays.copyOfRange(array, index, index + length); } - static short[] compactArray(short[] array, int index, int length) + public static short[] compactArray(short[] array, int index, int length) { if (index == 0 && length == array.length) { return array; @@ -165,7 +165,7 @@ final class BlockUtil return Arrays.copyOfRange(array, index, index + length); } - static int[] compactArray(int[] array, int index, int length) + public static int[] compactArray(int[] array, int index, int length) { if (index == 0 && length == array.length) { return array; @@ -173,7 +173,7 @@ final class BlockUtil return Arrays.copyOfRange(array, index, index + length); } - static long[] compactArray(long[] array, int index, int length) + public static long[] compactArray(long[] array, int index, int length) { if (index == 0 && length == array.length) { return array; @@ -181,7 +181,7 @@ final class BlockUtil return Arrays.copyOfRange(array, index, index + length); } - static int countUsedPositions(boolean[] positions) + public static int countUsedPositions(boolean[] positions) { int used = 0; for (boolean position : positions) { @@ -196,7 +196,7 @@ final class BlockUtil * Returns true if the two specified arrays contain the same object in every position. * Unlike the {@link Arrays#equals(Object[], Object[])} method, this method compares using reference equals. */ - static boolean arraySame(Object[] array1, Object[] array2) + public static boolean arraySame(Object[] array1, Object[] array2) { if (array1 == null || array2 == null || array1.length != array2.length) { throw new IllegalArgumentException("array1 and array2 cannot be null and should have same length"); diff --git a/presto-spi/src/main/java/io/prestosql/spi/block/ByteArrayBlock.java b/presto-spi/src/main/java/io/prestosql/spi/block/ByteArrayBlock.java index db69176a0b85a24775324e5e6aec2010633f5ec1..d0130ec7e0d11b9aec8967ea7186119ce26f0261 100644 --- a/presto-spi/src/main/java/io/prestosql/spi/block/ByteArrayBlock.java +++ b/presto-spi/src/main/java/io/prestosql/spi/block/ByteArrayBlock.java @@ -120,6 +120,24 @@ public class ByteArrayBlock return positionCount; } + @Override + public byte[] getValues() + { + return values; + } + + @Override + public int getBlockOffset() + { + return arrayOffset; + } + + @Override + public boolean[] getValueNulls() + { + return valueIsNull; + } + @Override public byte getByte(int position, int offset) { diff --git a/presto-spi/src/main/java/io/prestosql/spi/block/DictionaryBlock.java b/presto-spi/src/main/java/io/prestosql/spi/block/DictionaryBlock.java index 9c6cabdc983945431521d22e77299c35b3ac8806..57e1a2df5367cf576b78860c56068a13ccae1bee 100644 --- a/presto-spi/src/main/java/io/prestosql/spi/block/DictionaryBlock.java +++ b/presto-spi/src/main/java/io/prestosql/spi/block/DictionaryBlock.java @@ -397,11 +397,23 @@ public class DictionaryBlock return dictionary; } - Slice getIds() + public Slice getIds() { return Slices.wrappedIntArray(ids, idsOffset, positionCount); } + public int[] getIdsArray() + { + if (idsOffset == 0) { + return ids; + } + else { + int[] res = new int[positionCount]; + System.arraycopy(ids, idsOffset, res, 0, positionCount); + return res; + } + } + public int getId(int position) { checkValidPosition(position, positionCount); diff --git a/presto-spi/src/main/java/io/prestosql/spi/block/EncoderUtil.java b/presto-spi/src/main/java/io/prestosql/spi/block/EncoderUtil.java index 2e2541ae6d9f1feba818161e0b50657d157515dc..188a05836f9d0e2a35635f31012478c137aea0b5 100644 --- a/presto-spi/src/main/java/io/prestosql/spi/block/EncoderUtil.java +++ b/presto-spi/src/main/java/io/prestosql/spi/block/EncoderUtil.java @@ -18,7 +18,7 @@ import io.airlift.slice.SliceOutput; import java.util.Optional; -final class EncoderUtil +public final class EncoderUtil { private EncoderUtil() { diff --git a/presto-spi/src/main/java/io/prestosql/spi/block/Int128ArrayBlock.java b/presto-spi/src/main/java/io/prestosql/spi/block/Int128ArrayBlock.java index 6f31599b88c8d29f3bf986ee2da243b8576e3a36..351b57a6a26de2e8368d46dec6a74e67a2237958 100644 --- a/presto-spi/src/main/java/io/prestosql/spi/block/Int128ArrayBlock.java +++ b/presto-spi/src/main/java/io/prestosql/spi/block/Int128ArrayBlock.java @@ -123,6 +123,24 @@ public class Int128ArrayBlock return positionCount; } + @Override + public long[] getValues() + { + return values; + } + + @Override + public int getBlockOffset() + { + return positionOffset; + } + + @Override + public boolean[] getValueNulls() + { + return valueIsNull; + } + @Override public long getLong(int position, int offset) { diff --git a/presto-spi/src/main/java/io/prestosql/spi/block/IntArrayBlock.java b/presto-spi/src/main/java/io/prestosql/spi/block/IntArrayBlock.java index 2868fd050acfd95b9145ce56728cc87f9f50bbdc..55e109da45405a154acad6ec82261a4c47a234a8 100644 --- a/presto-spi/src/main/java/io/prestosql/spi/block/IntArrayBlock.java +++ b/presto-spi/src/main/java/io/prestosql/spi/block/IntArrayBlock.java @@ -120,6 +120,24 @@ public class IntArrayBlock return positionCount; } + @Override + public int[] getValues() + { + return values; + } + + @Override + public int getBlockOffset() + { + return arrayOffset; + } + + @Override + public boolean[] getValueNulls() + { + return valueIsNull; + } + @Override public int getInt(int position, int offset) { diff --git a/presto-spi/src/main/java/io/prestosql/spi/block/IntArrayList.java b/presto-spi/src/main/java/io/prestosql/spi/block/IntArrayList.java index 848f764b2edd065de851a5b3a0d0269ba116d864..bb947cf2dd93f277971e2e1cc3b02f37251dbd9d 100644 --- a/presto-spi/src/main/java/io/prestosql/spi/block/IntArrayList.java +++ b/presto-spi/src/main/java/io/prestosql/spi/block/IntArrayList.java @@ -21,7 +21,7 @@ import static java.lang.String.format; /** * A simplified version of fastutils IntArrayList for the purpose of positions copying. */ -class IntArrayList +public class IntArrayList { private static final int DEFAULT_INITIAL_CAPACITY = 16; private int[] array; diff --git a/presto-spi/src/main/java/io/prestosql/spi/block/LazyBlock.java b/presto-spi/src/main/java/io/prestosql/spi/block/LazyBlock.java index 1fc03a16953d9cc1632a7bf8ee1075ad70ab96e1..ad5bbd3a93abbec6980b12f384caa44f7156f32d 100644 --- a/presto-spi/src/main/java/io/prestosql/spi/block/LazyBlock.java +++ b/presto-spi/src/main/java/io/prestosql/spi/block/LazyBlock.java @@ -98,6 +98,11 @@ public class LazyBlock return block.getObject(position, clazz); } + public Block getBlock() + { + return block; + } + @Override public boolean bytesEqual(int position, int offset, Slice otherSlice, int otherOffset, int length) { diff --git a/presto-spi/src/main/java/io/prestosql/spi/block/LongArrayBlock.java b/presto-spi/src/main/java/io/prestosql/spi/block/LongArrayBlock.java index 2b47a8942a9f11bfb7e0753c338c198e3cd5019a..773b679a7ba676f9f06385d4c33b1a1ead9d740b 100644 --- a/presto-spi/src/main/java/io/prestosql/spi/block/LongArrayBlock.java +++ b/presto-spi/src/main/java/io/prestosql/spi/block/LongArrayBlock.java @@ -124,6 +124,24 @@ public class LongArrayBlock return positionCount; } + @Override + public long[] getValues() + { + return values; + } + + @Override + public int getBlockOffset() + { + return arrayOffset; + } + + @Override + public boolean[] getValueNulls() + { + return valueIsNull; + } + @Override public long getLong(int position, int offset) { diff --git a/presto-spi/src/main/java/io/prestosql/spi/block/SingleRowBlockWriter.java b/presto-spi/src/main/java/io/prestosql/spi/block/SingleRowBlockWriter.java index 616ab98db6215004527e667b07829c3c1151d5f5..aba745ec644e9c61fe0a8cd56284a80891a2bc5f 100644 --- a/presto-spi/src/main/java/io/prestosql/spi/block/SingleRowBlockWriter.java +++ b/presto-spi/src/main/java/io/prestosql/spi/block/SingleRowBlockWriter.java @@ -132,6 +132,14 @@ public class SingleRowBlockWriter return this; } + @Override + public BlockBuilder writeDouble(double value) + { + checkFieldIndexToWrite(); + fieldBlockBuilders[currentFieldIndexToWrite].writeDouble(value); + return this; + } + @Override public BlockBuilder writeBytes(Slice source, int sourceIndex, int length) { diff --git a/presto-spi/src/main/java/io/prestosql/spi/block/VariableWidthBlock.java b/presto-spi/src/main/java/io/prestosql/spi/block/VariableWidthBlock.java index d95992ebc1b490c24c0d8f1de508aa722ec06592..b23893c53b5dedab3b0cb539740e20a9542963e8 100644 --- a/presto-spi/src/main/java/io/prestosql/spi/block/VariableWidthBlock.java +++ b/presto-spi/src/main/java/io/prestosql/spi/block/VariableWidthBlock.java @@ -99,6 +99,11 @@ public class VariableWidthBlock this.isInitialized = true; } + public int[] getOffsets() + { + return offsets; + } + @Override protected final int getPositionOffset(int position) { @@ -130,6 +135,18 @@ public class VariableWidthBlock return positionCount; } + @Override + public int getBlockOffset() + { + return arrayOffset; + } + + @Override + public boolean[] getValueNulls() + { + return valueIsNull; + } + @Override public long getSizeInBytes() { @@ -203,7 +220,7 @@ public class VariableWidthBlock } @Override - protected Slice getRawSlice(int position) + public Slice getRawSlice(int position) { return slice; }