From 82d38c279d2eac1669fb2ada22ee473e1db8a0a6 Mon Sep 17 00:00:00 2001 From: PaddlePaddle-Gardener Date: Wed, 12 Jan 2022 14:44:45 +0800 Subject: [PATCH] mirgate_38688 --- .../elementwise/elementwise_op_broadcast.cu.h | 389 ++------ .../elementwise/elementwise_op_impl.cu.h | 3 +- .../datamover_primitives_xpu2.h | 567 ++++++++++++ .../kernel_primitives/kernel_primitives.h | 49 +- paddle/fluid/platform/hostdevice.h | 9 +- paddle/pten/kernels/gpu/elementwise.h | 863 ++++++++++++++++++ 6 files changed, 1538 insertions(+), 342 deletions(-) create mode 100644 paddle/pten/kernels/gpu/elementwise.h diff --git a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h index 549a6be0b4..e3d4607b71 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h @@ -22,355 +22,68 @@ namespace operators { namespace kps = paddle::operators::kernel_primitives; -struct DimensionsTransform { - using DimVector = std::vector; - typedef void (*MergeFunctor)(bool &, std::vector &, DimVector &, - int, int); - int64_t dim_size; - DimVector out_dims; - std::vector in_dims; - - private: - // To compensate the lackage of input_tensors` dimension with input variable - // 'axis' - void InputDimensionsExtend(int N, int axis) { - for (auto &in_dim : in_dims) { - int64_t in_idx = 0; - if (in_dim.size() < dim_size) { - DimVector tmp_dim(dim_size, 1); - do { - if (in_dim[in_idx] == out_dims[axis] || in_dim[in_idx] == 1) { - tmp_dim[axis] = in_dim[in_idx]; - in_idx++; - axis++; - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "The %d-th dimension of input tensor is expected to be equal " - "with the %d-th dimension of output tensor %d or 1, but " - "recieved %d.", - in_idx + 1, axis + 1, out_dims[axis], in_dim[in_idx])); - } - } while (in_idx < in_dim.size()); - in_dim.resize(dim_size); - std::copy(tmp_dim.begin(), tmp_dim.end(), in_dim.begin()); - } else { - do { - if (in_dim[in_idx] == out_dims[in_idx] || in_dim[in_idx] == 1) { - in_idx++; - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "The %d-th dimension of input tensor is expected to be equal " - "with the %d-th dimension of output tensor %d or 1, but " - "recieved %d.", - in_idx + 1, in_idx + 1, out_dims[in_idx], in_dim[in_idx])); - } - } while (in_idx < dim_size); - } - std::reverse(in_dim.begin(), in_dim.end()); - } - std::reverse(out_dims.begin(), out_dims.end()); - } - - template - __inline__ void MergeDimensions(MergeFunctor merge_func, int N) { - auto VectorReorganise = [](DimVector *vec, int l_idx, int m_idx) { - (*vec)[m_idx - 1] = - std::accumulate(vec->begin() + l_idx, vec->begin() + m_idx, 1, - std::multiplies()); - vec->erase(vec->begin() + l_idx, vec->begin() + m_idx - 1); - }; - - int64_t i = 0; - while (i < dim_size) { - int cnt = 0; - int low_idx = i; - bool equal = true; - do { - merge_func(equal, in_dims, out_dims, i, N); - if (equal) { - i++; - cnt++; - } else { - break; - } - } while (i < dim_size); - - if (cnt > 1) { - for (auto &in_dim : in_dims) { - VectorReorganise(&in_dim, low_idx, i); - } - VectorReorganise(&out_dims, low_idx, i); - dim_size -= --cnt; - i -= cnt; - } else if (cnt < 1) { - i++; - } - } - } - - public: - explicit DimensionsTransform( - const std::vector &ins, - const framework::DDim &dims, int axis) { - const int N = ins.size(); - dim_size = dims.size(); - out_dims = framework::vectorize(dims); - in_dims.resize(N); - for (int j = 0; j < N; ++j) { - in_dims[j] = framework::vectorize(ins[j]->dims()); - } - InputDimensionsExtend(N, axis); - - auto merge_sequential_dims = [](bool &equal, - std::vector &in_dims, - DimVector &out, int i, int num) { - for (int j = 1; j < num; ++j) { - equal = (in_dims[0][i] == in_dims[j][i]) ? true : false; - } - }; - auto merge_sequential_one_dims = [](bool &equal, - std::vector &in_dims, - DimVector &out, int i, int num) { - equal = in_dims[0][i] == 1; - if (equal) { - for (int j = 1; j < num; ++j) { - equal = in_dims[j][i] == out[i]; - } - } - }; - // To Merge the dimensions of input_tensors while the consequtive - // equal-dimensions appears. - MergeFunctor merge_ptr = merge_sequential_dims; - MergeDimensions(merge_ptr, N); - - int min_idx = 0; - int min_val = std::accumulate(in_dims[0].begin(), in_dims[0].end(), 1, - std::multiplies()); - for (int j = 1; j < N; ++j) { - int temp = std::accumulate(in_dims[j].begin(), in_dims[j].end(), 1, - std::multiplies()); - min_val = min_val > temp ? temp : min_val; - min_idx = min_val == temp ? j : min_idx; - } - std::swap(in_dims[0], in_dims[min_idx]); - - // To Merge the dimension of input_tensors while the consequtive - // 1-value-dimensions appears. - merge_ptr = merge_sequential_one_dims; - MergeDimensions(merge_ptr, N); - std::swap(in_dims[min_idx], in_dims[0]); +template +void LaunchBroadcastElementwiseCudaKernel( + const KPDevice &ctx, const std::vector &ins, + std::vector *outs, int axis, Functor func) { + std::vector pt_inputs; + std::vector pt_outputs; + // TODO(YuanRisheng) *_tmp for cache DenseTensor, because the temporary + // DenseTensor obj + // generated by MakePtenDenseTensor can be destroyed when exits loop. *_tmp + // can be deleted + // when DenseTensor support copy constructor. + std::vector> pt_inputs_tmp; + std::vector> pt_outputs_tmp; + for (auto in : ins) { + pt_inputs_tmp.emplace_back( + std::move(paddle::experimental::MakePtenDenseTensor(*in))); } -}; - -template -__device__ __forceinline__ void LoadData( - T *dst, const T *__restrict__ src, uint32_t block_offset, - const kps::details::BroadcastConfig &config, int numel, int num, - bool need_broadcast) { - // numel : whole num of output - // num: how many data will be deal with in this time - if (need_broadcast) { - kps::ReadDataBc(dst, src, block_offset, - config, numel); - } else { - kps::ReadData(dst, src + block_offset, num); + for (auto out : *outs) { + pt_outputs_tmp.emplace_back( + std::move(paddle::experimental::MakePtenDenseTensor(*out))); } -} - -template -__device__ void DealSegment( - const framework::Array &ins, OutT *out, - const framework::Array &use_broadcast, uint32_t numel, - const framework::Array, Arity> &configs, - int num, Functor func) { - InT args[Arity][VecSize]; - OutT result[VecSize]; - - int block_offset = blockIdx.x * blockDim.x * VecSize; - -#pragma unroll - for (int i = 0; i < Arity; i++) { - kps::Init(args[i], static_cast(1.0f)); - LoadData(args[i], ins[i], block_offset, - configs[i], numel, num, - use_broadcast[i]); + for (int i = 0; i < pt_inputs_tmp.size(); i++) { + pt_inputs.push_back(pt_inputs_tmp[i].get()); } - - const bool kCallElementwiseAny = - platform::FunctionTraits::has_pointer_args; - ElementwisePrimitiveCaller()(func, args, result); - kps::WriteData(out + block_offset, result, - num); -} - -template -__global__ void BroadcastKernel( - framework::Array ins, OutT *out, - framework::Array use_broadcast, uint32_t numel, - framework::Array, Arity> configs, - int main_tid, int tail_tid, Functor func) { - int block_offset = blockIdx.x * blockDim.x * VecSize; - // data offset of this block - if (blockIdx.x < main_tid) { - int num = blockDim.x * VecSize; // blockIdx.x < main_tid - DealSegment( - ins, out, use_broadcast, numel, configs, num, func); - } else { // reminder - int num = tail_tid; - DealSegment( - ins, out, use_broadcast, numel, configs, num, func); + for (int i = 0; i < pt_outputs_tmp.size(); i++) { + pt_outputs.push_back(pt_outputs_tmp[i].get()); } + pten::LaunchBroadcastElementwiseCudaKernel( + ctx, pt_inputs, &pt_outputs, axis, func); } -template -void LaunchKernel(const platform::CUDADeviceContext &ctx, - const std::vector &ins, - framework::Tensor *out, Functor func, - DimensionsTransform merge_dims) { - int numel = out->numel(); - const int threads = 256; - int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads; - - int main_tid = numel / (VecSize * threads); - int tail_tid = numel % (VecSize * threads); - auto stream = ctx.stream(); - OutT *out_data = out->data(); - - framework::Array, Arity> configs; - framework::Array use_broadcast; - framework::Array ins_data; - - for (int i = 0; i < Arity; i++) { - use_broadcast[i] = (ins[i]->numel() != numel); - ins_data[i] = ins[i]->data(); - if (use_broadcast[i]) { - // get the broadcast config, - // if data shape is[m, n], then you should set data_dim = {n, m} - // eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3} - configs[i] = kps::details::BroadcastConfig( - merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size); - } - } - - BroadcastKernel<<>>( - ins_data, out_data, use_broadcast, numel, configs, main_tid, tail_tid, - func); -} - -template -void LaunchBroadcastKernelForDifferentVecSize( - const platform::CUDADeviceContext &ctx, - const std::vector &ins, framework::Tensor *out, - int axis, Functor func) { - const auto merge_dims = DimensionsTransform(ins, out->dims(), axis); - -#define CALL_BROADCAST_FOR_DIM_SIZE(rank) \ - case rank: { \ - LaunchKernel(ctx, ins, out, \ - func, merge_dims); \ - } break; - - switch (merge_dims.dim_size) { - CALL_BROADCAST_FOR_DIM_SIZE(1); - CALL_BROADCAST_FOR_DIM_SIZE(2); - CALL_BROADCAST_FOR_DIM_SIZE(3); - CALL_BROADCAST_FOR_DIM_SIZE(4); - CALL_BROADCAST_FOR_DIM_SIZE(5); - CALL_BROADCAST_FOR_DIM_SIZE(6); - CALL_BROADCAST_FOR_DIM_SIZE(7); - CALL_BROADCAST_FOR_DIM_SIZE(8); - default: { - PADDLE_THROW(platform::errors::InvalidArgument( - "The maximum dimension of input tensor is expected to be less than " - "%d, but recieved %d.\n", - merge_dims.dim_size, framework::DDim::kMaxRank)); - } - } -#undef CALL_BROADCAST_FOR_DIM_SIZE -} - -template -void LaunchBroadcastElementwiseCudaKernel( - const platform::CUDADeviceContext &ctx, - const std::vector &ins, +template +void LaunchElementwiseCudaKernel( + const KPDevice &ctx, const std::vector &ins, std::vector *outs, int axis, Functor func) { - using Traits = platform::FunctionTraits; - const int kArity = - Traits::has_pointer_args ? static_cast(ET) : Traits::arity; - PADDLE_ENFORCE_EQ(ins.size(), kArity, - platform::errors::InvalidArgument( - "The number of inputs is expected to be equal to the " - "arity of functor. But recieved: the number of inputs " - "is %d, the arity of functor is %d.", - ins.size(), kArity)); - PADDLE_ENFORCE_EQ(kArity, 2, - platform::errors::InvalidArgument( - "Currently only broadcast of binary is supported and " - "verified, but received %d.", - kArity)); - - int in_vec_size = 4; - framework::Tensor *out = (*outs)[0]; - for (auto *in : ins) { - auto temp_size = platform::GetVectorizedSize(in->data()); - in_vec_size = in->dims() == out->dims() ? std::min(temp_size, in_vec_size) - : in_vec_size; + std::vector pt_inputs; + std::vector pt_outputs; + // TODO(YuanRisheng) *_tmp for cache DenseTensor, because the temporary + // DenseTensor obj + // generated by MakePtenDenseTensor can be destroyed when exits loop. *_tmp + // can be deleted + // when DenseTensor support copy constructor. + std::vector> pt_inputs_tmp; + std::vector> pt_outputs_tmp; + for (auto in : ins) { + pt_inputs_tmp.emplace_back( + std::move(paddle::experimental::MakePtenDenseTensor(*in))); } - int out_vec_size = platform::GetVectorizedSize(out->data()); - int vec_size = std::min(out_vec_size, in_vec_size); - - switch (vec_size) { - case 4: { - LaunchBroadcastKernelForDifferentVecSize( - ctx, ins, out, axis, func); - break; - } - case 2: { - LaunchBroadcastKernelForDifferentVecSize( - ctx, ins, out, axis, func); - break; - } - case 1: { - LaunchBroadcastKernelForDifferentVecSize( - ctx, ins, out, axis, func); - break; - } - default: { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported vectorized size: %d !", vec_size)); - break; - } + for (auto out : *outs) { + pt_outputs_tmp.emplace_back( + std::move(paddle::experimental::MakePtenDenseTensor(*out))); } -} - -template -void LaunchElementwiseCudaKernel( - const platform::CUDADeviceContext &cuda_ctx, - const std::vector &ins, - std::vector *outs, int axis, Functor func) { - std::vector dims_size; - bool no_broadcast_flag = true; - for (auto *in : ins) { - no_broadcast_flag = ins[0]->dims() == in->dims(); - dims_size.emplace_back(in->dims().size()); + for (int i = 0; i < pt_inputs_tmp.size(); i++) { + pt_inputs.push_back(pt_inputs_tmp[i].get()); } - - if (no_broadcast_flag) { - LaunchSameDimsElementwiseCudaKernel(cuda_ctx, ins, outs, - func); - } else { - axis = axis == -1 - ? *std::max_element(dims_size.begin(), dims_size.end()) - - *std::min_element(dims_size.begin(), dims_size.end()) - : axis; - LaunchBroadcastElementwiseCudaKernel(cuda_ctx, ins, outs, - axis, func); + for (int i = 0; i < pt_outputs_tmp.size(); i++) { + pt_outputs.push_back(pt_outputs_tmp[i].get()); } + pten::LaunchElementwiseCudaKernel( + ctx, pt_inputs, &pt_outputs, axis, func); } } // namespace operators diff --git a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h index 1d8acd5eca..36ff1ae254 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h @@ -35,8 +35,7 @@ using ElementwiseType = pten::ElementwiseType; template void LaunchSameDimsElementwiseCudaKernel( - const platform::CUDADeviceContext &ctx, - const std::vector &ins, + const KPDevice &ctx, const std::vector &ins, std::vector *outs, Functor func) { std::vector pt_inputs; std::vector pt_outputs; diff --git a/paddle/fluid/operators/kernel_primitives/datamover_primitives_xpu2.h b/paddle/fluid/operators/kernel_primitives/datamover_primitives_xpu2.h index e69de29bb2..3338995358 100644 --- a/paddle/fluid/operators/kernel_primitives/datamover_primitives_xpu2.h +++ b/paddle/fluid/operators/kernel_primitives/datamover_primitives_xpu2.h @@ -0,0 +1,567 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "xpu/kernel/cluster_header.h" +#include "xpu/kernel/debug.h" +#include "xpu/kernel/math.h" + +namespace paddle { +namespace operators { +namespace kernel_primitives { +namespace details { + +template +struct alignas(sizeof(T) * VecSize) VectorType { + T val[VecSize]; +}; + +/** + * Configuration of broadcast. Calculate the input data index according to the + * index of the output data. if input or output shape is [dim0, dim1] then dims + * must be [dim1, dim0]. + */ +#pragma pack(4) +template +struct BroadcastConfig { + int strides_in[framework::DDim::kMaxRank]; + int strides_out[framework::DDim::kMaxRank]; + int in_dim[framework::DDim::kMaxRank]; + + HOSTDEVICE BroadcastConfig() {} + + HOSTDEVICE BroadcastConfig(const std::vector& out_dims, + const std::vector& in_dims, + int dim_size) { + std::vector strides_in_tmp; + std::vector strides_out_tmp; + std::vector dim_tmp; + strides_in_tmp.resize(dim_size, 1); + strides_out_tmp.resize(dim_size, 1); + dim_tmp.resize(dim_size, 1); + for (int i = 1; i < dim_size; i++) { + strides_in_tmp[i] = strides_in_tmp[i - 1] * in_dims[i - 1]; + strides_out_tmp[i] = strides_out_tmp[i - 1] * out_dims[i - 1]; + } + + for (int i = 0; i < dim_size; i++) { + dim_tmp[i] = in_dims[i]; + } + + memcpy(strides_in, strides_in_tmp.data(), kDims * sizeof(int)); + memcpy(strides_out, strides_out_tmp.data(), kDims * sizeof(int)); + memcpy(in_dim, dim_tmp.data(), kDims * sizeof(int)); + } + + __device__ inline int operator()(int index_output) const { + int index_src = 0; +#pragma unroll + for (int i = kDims - 1; i >= 0; --i) { + int tmp_index = (index_output / strides_out[i]); + index_output = index_output - tmp_index * strides_out[i]; + index_src += (tmp_index % in_dim[i]) * strides_in[i]; + } + return index_src; + } +}; +#pragma pack() + +} // namespace details + +/** + * @brief Read 2D data from global memory to register according to Tx type, and + * store it as Ty type into register. + * + * @template paraments + * Tx: The type of data stored in the global memory. + * Ty: The type of data that needs to be stored in registers. + * NX: The number of data columns loaded by each thread. + * NY: The number of data rows loaded by each thread. + * BlockSize: Identifies the current device thread index method. For xpu, + * core_id() is used as the index. + * IsBoundary: Indicates whether to perform block access storage out-of-bounds + * judgment. When the number of data processed by the block is less than + * NX x NY x core_num(), boundary judgment is required to avoid memory access + * crossing the boundary. + * + * @param: + * dst: The register pointer of the thread, the size is NX * NY. + * src: The data pointer of the current block. + * size_nx: The maximum offset of the current block is size_nx elements in the + * lowest dimension. The parameters are only calculated when isboundary = true. + * size_ny: The maximum offset of the current block is size_ny elements in the + * first dimension. The parameters are only calculated when isboundary = true. + * stride_nx: Each read one element stride stride_nx elements in the last dim. + * stride_ny: Each read one element stride stride_ny elements in the first dim. + */ +template +__device__ __inline__ void ReadData(Ty* dst, const Tx _global_ptr_* src, + int size_nx, int size_ny, int stride_nx, + int stride_ny) { + int thread_offset = core_id(); + int left_size_nx = size_nx - thread_offset; + __local__ Tx in_temp[1]; + // Each branch is added for better performance + if (NX == 1 && NY == 1) { // for NX == 1 and NY == 1 + if (IsBoundary) { + if (left_size_nx > 0) { + GM2LM(src + thread_offset, in_temp, sizeof(Tx)); + dst[0] = static_cast(in_temp[0]); + } + } else { + GM2LM(src + thread_offset, in_temp, sizeof(Tx)); + dst[0] = static_cast(in_temp[0]); + } + } else if (NX == 1) { // for NX == 1 and NY != 1 +#pragma unroll + for (int idy = 0; idy < NY; ++idy) { + if (IsBoundary) { + if (idy * stride_ny >= size_ny) { + break; + } + } + GM2LM(src + thread_offset + idy * stride_ny, in_temp, sizeof(Tx)); + dst[idy] = static_cast(in_temp[0]); + } + } else if (NY == 1) { // for NY == 1 and NX != 1 +#pragma unroll + for (int idx = 0; idx < NX; ++idx) { + if (IsBoundary) { + if (idx * stride_nx >= left_size_nx) { + break; + } + } + GM2LM(src + thread_offset + idx * stride_nx, in_temp, sizeof(Tx)); + dst[idx] = static_cast(in_temp[0]); + } + } else { // for NX != 1 and NY != 1 +#pragma unroll + for (int idx = 0; idx < NX; ++idx) { +#pragma unroll + for (int idy = 0; idy < NY; ++idy) { + if (IsBoundary) { + if (idy * stride_ny >= size_ny || idx * stride_nx >= left_size_nx) { + break; + } + } + int fix = thread_offset + idx * stride_nx + idy * stride_ny; + GM2LM(src + fix, in_temp, sizeof(Tx)); + dst[idy * NX + idx] = static_cast(in_temp[0]); + } + } + } +} + +/** + * @brief Initialize register with init_data. + * + * @template paraments + * T: Data type of register. + * NX: Number of data to initialize. + * + * @param: + * dst: The register pointer of the thread, the size is NX. + * init_data: Initial value. + */ +template +__device__ __inline__ void Init(T* dst, T init_data) { +#pragma unroll + for (int i = 0; i < NX; i++) { + dst[i] = init_data; + } +} + +/** + * @brief Read 1D data from global memory to register. When IsBoundary = true + * and (NX % 4 == 0 or Nx % 2 == 0), vectorized load data will be used to + * improve memory access efficiency. + * + * @template paraments + * T: The type of data. + * NX: Each thread load NX data from global memory continuously. + * NY: Each thread need to load NY rows, only NY = 1 was supported. + * BlockSize: Identifies the current device thread index method. For xpu, + * core_id() is used as the index. + * IsBoundary: Whether to make an out-of-bounds judgment on access to memory. + * When the number of data processed by this block is less than + * NX x NY x core_num(), boundary judgment is required to avoid memory access + * crossing the boundary. + * + * @param: + * dst: The register pointer of the thread, the size is NX * NY. + * src: The data pointer of the current block. + * size: The current block needs to load size data continuously. + */ +template +__device__ __inline__ void ReadData(T* dst, const T _global_ptr_* src, + int num) { + int thread_offset = core_id() * NX; + __local__ T in_temp[1]; + if (IsBoundary) { // core_num() * NX > num +#pragma unroll + for (int idx = 0; idx < NX; ++idx) { + if (idx + thread_offset < num) { + GM2LM(src + thread_offset + idx, in_temp, sizeof(T)); + dst[idx] = in_temp[0]; + } + } + } else { // core_num() * NX < num + GM2LM(src + thread_offset, dst, NX * sizeof(T)); + } +} + +/** + * @brief Read 2D data from global memory to registers with broadcast form. + * + * @template paraments + * T: The type of data stored in the global memory. + * NX: The number of data columns loaded by each thread. + * NY: The number of data rows loaded by each thread. + * BlockSize: Identifies the current device thread index method. For xpu, + * core_id() is used as the index. + * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2. + * IsBoundary: Indicates whether to perform block access storage out-of-bounds + * judgment. When the number of data processed by the block is less than + * NX x NY x core_num(), boundary judgment is required to avoid memory access + * crossing the boundary. + * + * @param: + * dst: The register pointer of the thread, the size is NX * NY. + * src: Raw input data pointer of kernel. + * block_offset: Data offset of this block, core_num() * cluster_id() * NX; + * config: Calculation configuration of broadcast. It is used to calculate the + * coordinate mapping relationship between output data and input data. + * total_num_output: Total number of original output. + * stride_nx: Each read one element stride stride_nx elements in the last dim. + * stride_ny: Each read one element stride stride_ny elements in the first dim. + */ +template +__device__ __inline__ void ReadDataBc(T* dst, const T _global_ptr_* src, + uint32_t block_offset, + details::BroadcastConfig config, + int total_num_output, int stride_nx, + int stride_ny) { + uint32_t thread_offset = block_offset + core_id(); + uint32_t index_src = 0; + __local__ T in_temp[1]; + +#pragma unroll + for (int ny = 0; ny < NY; ++ny) { +#pragma unroll + for (uint32_t nx = 0; nx < NX; ++nx) { + uint32_t index_output = thread_offset + ny * stride_ny + nx * stride_nx; + index_src = 0; + if (IsBoundary) { + if (index_output >= (uint32_t)total_num_output) { + break; + } + } + index_src = config(index_output); + GM2LM(src + index_src, in_temp, sizeof(T)); + dst[nx + ny * NX] = in_temp[0]; + } + } +} + +/** + * @brief Read 2D data from global memory to register with reduce form. + * + * @template paraments + * T: The type of data. + * NX: The number of data columns loaded by each thread. + * NY: The number of data rows loaded by each thread. + * BlockSize: Identifies the current device thread index method. For xpu, + * core_id() is used as the index. + * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2. + * IsBoundary: Indicates whether to perform block access storage out-of-bounds + * judgment. When the number of data processed by the block is less than + * NX x NY x core_num(), boundary judgment is required to avoid memory access + * crossing the boundary. + * + * @param: + * dst: The register pointer of the thread, the size is NX * NY. + * src: The input data pointer of this block. + * block_offset: The data offset of this block, blockDim.x * cluster_id() * NX. + * index_cal: Calculation configuration of Reduce. It is used to calculate the + * coordinate mapping relationship between output data and input data. + * size_nx: The current block needs to load size_nx columns of data, this + * parameter will participate in the calculation when isboundary = true. + * size_ny: The current block needs to load size_ny rows of data, this parameter + * will participate in the calculation when isboundary = true. + * will be used when IsBoundary = true. + * stride_nx: Each read one element stride stride_nx columns. + * stride_ny: Each read one element stride stride_ny raws. + * reduce_last_dim: Used to indicate whether the dimension of reduce contains + * the lowest dimension. + */ +template +__device__ __inline__ void ReadDataReduce(T* dst, const T _global_ptr_* src, + int block_offset, + const IndexCal& index_cal, + int size_nx, int size_ny, + int stride_nx, int stride_ny, + bool reduce_last_dim) { + __local__ Tx in_temp[1]; + int thread_offset = 0; + int left_idx = 0; + if (reduce_last_dim) { + thread_offset = core_id(); + left_idx = 0; + } else { + thread_offset = 0; + left_idx = 0; + } + + if (NX == 1) { +#pragma unroll + for (int ny = 0; ny < NY; ++ny) { + if (IsBoundary) { + if (thread_offset >= size_ny) { + break; + } + } + uint32_t index_src = index_cal(thread_offset + block_offset); + GM2LM(src + index_src, in_temp, sizeof(Tx)); + dst[ny] = static_cast(func(in_temp[0])); + thread_offset += stride_ny; + } + } else { +#pragma unroll + for (int nx = 0; nx < NX; ++nx) { +#pragma unroll + for (int ny = 0; ny < NY; ++ny) { + if (IsBoundary) { + if ((thread_offset >= size_ny) || + (left_idx + nx * stride_nx >= size_nx)) { + break; + } + } + uint32_t index_src = index_cal(thread_offset + block_offset); + GM2LM(src + index_src, in_temp, sizeof(Tx)); + dst[nx + ny * NX] = static_cast(func(in_temp[0])); + thread_offset += stride_ny; + } + } + } +} +/** + * @brief Write 1D data from registers to global memory. When IsBoundary = true + * and (NX % 4 == 0 or Nx % 2 == 0), the data will be vectorized to improve the + * data loading efficiency + * + * @template paraments + * T: The type of data. + * NX: The number of data continuously writed by each thread. + * NY: The number of data rows loaded by each thread, only NY = 1 was supported. + * BlockSize: Identifies the current device thread index method. For xpu, + * core_id() is used as the index. + * IsBoundary: Indicates whether to perform block access storage out-of-bounds + * judgment. When the number of data processed by the block is less than + * NX x NY x core_num(), boundary judgment is required to avoid memory access + * crossing the boundary. + * + * @param: + * dst: The data pointer of the current block. + * src: The register pointer, the size is NX * NY. + * size: The current block needs to load size elements continuously. + */ + +template +__device__ void WriteData(T _global_ptr_* dst, const T* src, int num) { + int thread_offset = core_id() * NX; + __local__ T in_temp[1]; + if (IsBoundary) { // core_num() * NX > num +#pragma unroll + for (int idx = 0; idx < NX; ++idx) { + if (idx + thread_offset < num) { + in_temp[0] = src[idx]; + LM2GM(in_temp, dst + idx + thread_offset, sizeof(T)); + } + } + } else { // core_num() * NX < num + LM2GM(src, dst + thread_offset, NX * sizeof(T)); + } +} + +/** + * @brief Write 2D data from register to global memory according to Tx type, and + * store it as Ty type. + * + * @template paraments + * Tx: The type of data that needs to be stored in registers. + * Ty: The type of data stored in the global memory. + * NX: The number of data columns loaded by each thread. + * NY: The number of data rows loaded by each thread. + * BlockSize: Identifies the current device thread index method. For xpu, + * core_id() is used as the index. + * IsBoundary: Indicates whether to perform block access storage out-of-bounds + * judgment. When the number of data processed by the block is less than + * NX x NY x core_num(), boundary judgment is required to avoid memory access + * crossing the boundary. + * + * @param: + * dst: Data pointer of the current block. + * src: The register pointer of the thread, the size is NX * NY. + * size_nx: The current block needs to load size_nx columns of data, this + * parameter will be used when IsBoundary = true. + * size_ny: The current block needs to load size_ny rows of data. This parameter + * will be used when IsBoundary = true. + * stride_nx: Each read one element stride stride_nx elements in the last dim. + * stride_ny: Each read one element stride stride_ny elements in the first dim. + */ +template +__device__ __inline__ void WriteData(Ty _global_ptr_* dst, const Tx* src, + int size_nx, int size_ny, int stride_nx, + int stride_ny) { + int thread_offset = core_id(); + int left_size_nx = size_nx - thread_offset; + __local__ Ty in_temp[1]; + + // Each branch is added for better performance + if (NX == 1 && NY == 1) { + if (IsBoundary) { + if (left_size_nx > 0) { + in_temp[0] = static_cast(src[0]); + LM2GM(in_temp, dst + thread_offset, sizeof(Ty)); + } + } else { + in_temp[0] = static_cast(src[0]); + LM2GM(in_temp, dst + thread_offset, sizeof(Ty)); + } + } else if (NX == 1) { +#pragma unroll + for (int idy = 0; idy < NY; ++idy) { + if (IsBoundary) { + if (idy * stride_ny >= size_ny) { + break; + } + } + + in_temp[0] = static_cast(src[idy]); + LM2GM(in_temp, dst + thread_offset + idy * stride_ny, sizeof(Ty)); + } + } else if (NY == 1) { // for NY == 1 and NX != 1 +#pragma unroll + for (int idx = 0; idx < NX; ++idx) { + if (IsBoundary) { + if (idx * stride_nx >= left_size_nx) { + break; + } + } + + in_temp[0] = static_cast(src[idx]); + LM2GM(in_temp, dst + thread_offset + idx * stride_nx, sizeof(Ty)); + } + } else { // for NX != 1 and NY != 1 +#pragma unroll + for (int idx = 0; idx < NX; ++idx) { + if (IsBoundary) { + if (idx * stride_nx >= left_size_nx) { + break; + } + } +#pragma unroll + for (int idy = 0; idy < NY; ++idy) { + if (IsBoundary) { + if (idy * stride_ny >= size_ny) { + break; + } + } + in_temp[0] = static_cast(src[idx + idy * NX]); + LM2GM(in_temp, dst + thread_offset + idx * stride_nx + idy * stride_ny, + sizeof(Ty)); + } + } + } +} + +/** + * @brief Initialize register with init_data. + * + * @template paraments + * T: Data type of register. + * NX: Number of data to initialize. + * + * @param: + * dst: The register pointer of the thread, the size is NX. + * init_data: The register pointer of init data, the size is NX. + */ +template +__device__ __inline__ void Init(T* dst, T* init_data, int num) { +#pragma unroll + for (int i = 0; i < NX; i++) { + if (IsBoundary) { + if (i >= num) { + break; + } + } + dst[i] = init_data[i]; + } +} + +/** + * @brief Read 1D data from global memory to register with broadcast form. + * + * @template paraments + * T: The type of data stored in the global memory. + * NX: The number of data continuously loaded by each thread. + * NY: The number of data rows loaded by each thread, only NY = 1 was supported. + * BlockSize: Identifies the current device thread index method. For xpu, + * core_id() is used as the index. + * Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2. + * IsBoundary: Indicates whether to perform block access storage out-of-bounds + * judgment. When the number of data processed by the block is less than + * NX x NY x core_num(), boundary judgment is required to avoid memory access + * crossing the boundary. + * + * @param: + * dst: The register pointer of the thread, the size is NX * NY. + * src: The original input data pointer of kernel. + * block_offset: The data offset of this block, core_num() * blockIdx.x * NX; + * config: Calculation configuration of broadcast. It is used to calculate the + * coordinate mapping relationship between output data and input data. + * total_num_output: Total number of original output. + */ +template +__device__ __inline__ void ReadDataBc(T* dst, const T _global_ptr_* src, + uint32_t block_offset, + details::BroadcastConfig config, + int total_num_output) { + int thread_offset = block_offset + core_id() * NX; + int index_src = 0; + + __local__ T in_temp; +#pragma unroll + for (int nx = 0; nx < NX; ++nx) { + int index_output = thread_offset + nx; + index_src = 0; + if (IsBoundary) { + if (index_output >= total_num_output) { + break; + } + } + index_src = config(index_output); + GM2LM(src + index_src, &in_temp, sizeof(T)); + dst[nx] = in_temp; + } +} + +} // namespace kernel_primitives +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/kernel_primitives/kernel_primitives.h b/paddle/fluid/operators/kernel_primitives/kernel_primitives.h index 9a4f8bb026..558f8c81c6 100644 --- a/paddle/fluid/operators/kernel_primitives/kernel_primitives.h +++ b/paddle/fluid/operators/kernel_primitives/kernel_primitives.h @@ -13,11 +13,58 @@ // limitations under the License. #pragma once +#include "paddle/fluid/operators/kernel_primitives/helper_primitives.h" +#ifdef PADDLE_WITH_XPU2 +#include "paddle/fluid/operators/kernel_primitives/compute_primitives_xpu2.h" +#include "paddle/fluid/operators/kernel_primitives/datamover_primitives_xpu2.h" +#include "paddle/fluid/operators/kernel_primitives/functor_primitives_xpu2.h" + +#define KPStream XPUStream +#define KPDevice paddle::platform::XPUDeviceContext +#define _ptr_ _global_ptr_ +#define __forceinline__ __inline__ +#define __restrict__ + +#define THREAD_ID_X core_id() +#define THREAD_ID_Y 0 +#define THREAD_ID_Z 0 +#define BLOCK_NUM_X core_num() +#define BLOCK_NUM_Y 0 +#define BLOCK_NUM_Z 0 + +#define BLOCK_ID_X cluster_id() +#define BLOCK_ID_Y 0 +#define BLOCK_ID_Z 0 + +#define GRID_NUM_X cluster_num() +#define GRID_NUM_Y 0 +#define GRID_NUM_Z 0 +#else #include "paddle/fluid/operators/kernel_primitives/compute_primitives.h" #include "paddle/fluid/operators/kernel_primitives/datamover_primitives.h" #include "paddle/fluid/operators/kernel_primitives/functor_primitives.h" -#include "paddle/fluid/operators/kernel_primitives/helper_primitives.h" + +#define KPStream gpuStream_t +#define KPDevice paddle::platform::CUDADeviceContext +#define _ptr_ + +#define THREAD_ID_X threadIdx.x +#define THREAD_ID_Y threadIdx.y +#define THREAD_ID_Z threadIdx.z + +#define BLOCK_NUM_X blockDim.x +#define BLOCK_NUM_Y blockDim.y +#define BLOCK_NUM_Z blockDim.z + +#define BLOCK_ID_X blockIdx.x +#define BLOCK_ID_Y blockIdx.y +#define BLOCK_ID_Z blockIdx.z + +#define GRID_NUM_X gridDim.x +#define GRID_NUM_Y gridDim.y +#define GRID_NUM_Z gridDim.z +#endif namespace paddle { namespace operators { diff --git a/paddle/fluid/platform/hostdevice.h b/paddle/fluid/platform/hostdevice.h index 1ffbbc217e..65005a5adb 100644 --- a/paddle/fluid/platform/hostdevice.h +++ b/paddle/fluid/platform/hostdevice.h @@ -17,7 +17,14 @@ #include #endif -#if (defined(__CUDACC__) || defined(__HIPCC__)) +#ifdef __xpu_kp__ +#include +#include "xpu/kernel/cluster_header.h" +#include "xpu/kernel/debug.h" +#include "xpu/kernel/math.h" +#endif + +#if (defined(__CUDACC__) || defined(__HIPCC__) || defined(__xpu_kp__)) #define HOSTDEVICE __host__ __device__ #define DEVICE __device__ #define HOST __host__ diff --git a/paddle/pten/kernels/gpu/elementwise.h b/paddle/pten/kernels/gpu/elementwise.h new file mode 100644 index 0000000000..e4cc894e48 --- /dev/null +++ b/paddle/pten/kernels/gpu/elementwise.h @@ -0,0 +1,863 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" +#include "paddle/fluid/platform/aligned_vector.h" +#include "paddle/fluid/platform/function_traits.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/kernels/funcs/cuda_kernel_config.h" + +namespace pten { + +namespace kps = paddle::operators::kernel_primitives; +enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 }; + +/* Packing scalar type T(float, int etc.) into Array type + for supporting multiple-output feature in elementwise system.*/ +template +using ConditionalT = + typename std::conditional_t>; + +template +struct ElementwisePrimitiveCaller { + __device__ inline void operator()(Functor func, + InT (*args)[VecSize], + OutT *result); +}; + +template +struct ElementwisePrimitiveCaller { + __device__ inline void operator()(Functor func, + InT (*args)[VecSize], + OutT *result) { + kps::ElementwiseAny( + result, args, func); + } +}; + +template +struct ElementwisePrimitiveCaller { + __device__ inline void operator()(Functor func, + InT (*args)[VecSize], + OutT *result) { + kps::ElementwiseUnary( + result, args[0], func); + } +}; + +template +struct ElementwisePrimitiveCaller { + __device__ inline void operator()(Functor func, + InT (*args)[VecSize], + OutT *result) { + kps::ElementwiseBinary( + result, args[0], args[1], func); + } +}; + +template +struct ElementwisePrimitiveCaller { + __device__ inline void operator()(Functor func, + InT (*args)[VecSize], + OutT *result) { + kps::ElementwiseTernary( + result, args[0], args[1], args[2], func); + } +}; + +template +struct ElementwiseWriteDataCaller { + __device__ __forceinline__ void operator()( + paddle::framework::Array<_ptr_ OutT *, NumOuts> outs, + ConditionalT src[VecSize], + int block_offset, + int num) { + OutT dst[NumOuts][VecSize]; +#pragma unroll + for (int i = 0; i < VecSize; ++i) { +#pragma unroll + for (int j = 0; j < NumOuts; ++j) { + dst[j][i] = (src[i])[j]; + } + } +#pragma unroll + for (int i = 0; i < NumOuts; ++i) { + kps::WriteData( + outs[i] + block_offset, dst[i], num); + } + } +}; + +template +struct ElementwiseWriteDataCaller { + __device__ __forceinline__ void operator()( + paddle::framework::Array<_ptr_ OutT *, 1> outs, + OutT src[VecSize], + int block_offset, + int num) { + kps::WriteData( + outs[0] + block_offset, src, num); + } +}; + +template +__device__ void VectorizedElementwiseKernelImpl( + const paddle::framework::Array &in, + paddle::framework::Array<_ptr_ OutT *, NumOuts> outs, + int num, + int data_offset, + Functor func) { + InT args[Arity][VecSize]; + ConditionalT result[VecSize]; + +#pragma unroll + for (int i = 0; i < Arity; i++) { + kps::Init(args[i], static_cast(1.0f)); + kps::ReadData( + args[i], in[i] + data_offset, num); + } + + constexpr bool kCallElementwiseAny = + paddle::platform::FunctionTraits::has_pointer_args; + ElementwisePrimitiveCaller, + VecSize, + Functor, + Arity, + kCallElementwiseAny>()(func, args, result); + + ElementwiseWriteDataCaller()( + outs, result, data_offset, num); +} + +template +__global__ void VectorizedElementwiseKernel( + paddle::framework::Array ins, + paddle::framework::Array<_ptr_ OutT *, NumOuts> outs, + int size, + int main_offset, + Functor func) { + int data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize; + int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize; + for (; data_offset < main_offset; data_offset += stride) { + VectorizedElementwiseKernelImpl( + ins, outs, VecSize * BLOCK_NUM_X, data_offset, func); + } + + int num = size - data_offset; + if (num > 0) { + VectorizedElementwiseKernelImpl(ins, outs, num, data_offset, func); + } +} + +template +int GetVectorizedSizeForTensors(const std::vector &ins, + const std::vector &outs) { + int vec_size = 4; + for (auto iter = ins.begin(); iter != ins.end(); ++iter) { + vec_size = std::min( + vec_size, paddle::platform::GetVectorizedSize((*iter)->data())); + } + for (auto iter = outs.begin(); iter != outs.end(); ++iter) { + vec_size = std::min( + vec_size, paddle::platform::GetVectorizedSize((*iter)->data())); + } + return vec_size; +} + +template +void ElementwiseCudaKernel(const KPDevice &ctx, + const std::vector &ins, + std::vector *outs, + Functor func) { + auto numel = ins[0]->numel(); + paddle::framework::Array ins_data; + paddle::framework::Array<_ptr_ OutT *, NumOuts> outs_data; + + for (int i = 0; i < Arity; ++i) { + ins_data[i] = ins[i]->data(); + } + for (int i = 0; i < NumOuts; ++i) { + outs_data[i] = (*outs)[i]->mutable_data(); + } +#ifdef PADDLE_WITH_XPU2 + int block_size = 64; + int grid_size = 8; + auto stream = ctx.x_context()->xpu_stream; + int main_offset = (numel / (VecSize * block_size)) * VecSize * block_size; + VectorizedElementwiseKernel<<>>( + ins_data, outs_data, numel, main_offset, func); +#else + int block_size = funcs::GetThreadsConfig(ctx, numel, VecSize); + int grid_size = + ((numel + VecSize - 1) / VecSize + block_size - 1) / block_size; + int main_offset = (numel / (VecSize * block_size)) * VecSize * block_size; + auto stream = ctx.stream(); + VectorizedElementwiseKernel<<>>( + ins_data, outs_data, numel, main_offset, func); +#endif +} + +template +void LaunchSameDimsElementwiseCudaKernel( + const KPDevice &ctx, + const std::vector &ins, + std::vector *outs, + Functor func) { + using Traits = paddle::platform::FunctionTraits; + const int kArity = + Traits::has_pointer_args ? static_cast(ET) : Traits::arity; + PADDLE_ENFORCE_EQ(ins.size(), + kArity, + paddle::platform::errors::InvalidArgument( + "The number of inputs is expected to be equal to the " + "arity of functor. But recieved: the number of inputs " + "is %d, the arity of functor is %d.", + ins.size(), + kArity)); + PADDLE_ENFORCE_EQ(outs->size(), + NumOuts, + paddle::platform::errors::InvalidArgument( + "Number of outputs shall equal to number of functions, " + "but number of outputs is %d, of functions is %d.", + outs->size(), + NumOuts)); + + if (NumOuts > 1) { + for (int i = 1; i < NumOuts; ++i) { + PADDLE_ENFORCE_EQ( + (*outs)[i]->dims(), + (*outs)[0]->dims(), + paddle::platform::errors::InvalidArgument( + "The shape of each output tensor shall be identical yet, " + "but %dth output tensor`s shape is not.", + i)); + } + } + + // calculate the max vec_size for all ins and outs + int vec_size = GetVectorizedSizeForTensors(ins, *outs); + switch (vec_size) { + case 4: + ElementwiseCudaKernel( + ctx, ins, outs, func); + break; + case 2: + ElementwiseCudaKernel( + ctx, ins, outs, func); + break; + case 1: + ElementwiseCudaKernel( + ctx, ins, outs, func); + break; + default: { + PADDLE_THROW(paddle::platform::errors::Unimplemented( + "Unsupported vectorized size: %d !", vec_size)); + break; + } + } +} + +struct DimensionsTransform { + using DimVector = std::vector; + typedef void (*MergeFunctor)( + bool &, std::vector &, DimVector &, int, int); + int64_t dim_size; + DimVector out_dims; + std::vector in_dims; + + private: + // To compensate the lackage of input_tensors` dimension with input variable + // 'axis' + void InputDimensionsExtend(int N, int axis) { + for (auto &in_dim : in_dims) { + int64_t in_idx = 0; + if (in_dim.size() < dim_size) { + DimVector tmp_dim(dim_size, 1); + do { + if (in_dim[in_idx] == out_dims[axis] || in_dim[in_idx] == 1) { + tmp_dim[axis] = in_dim[in_idx]; + in_idx++; + axis++; + } else { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "The %d-th dimension of input tensor is expected to be equal " + "with the %d-th dimension of output tensor %d or 1, but " + "recieved %d.", + in_idx + 1, + axis + 1, + out_dims[axis], + in_dim[in_idx])); + } + } while (in_idx < in_dim.size()); + in_dim.resize(dim_size); + std::copy(tmp_dim.begin(), tmp_dim.end(), in_dim.begin()); + } else { + do { + if (in_dim[in_idx] == out_dims[in_idx] || in_dim[in_idx] == 1) { + in_idx++; + } else { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "The %d-th dimension of input tensor is expected to be equal " + "with the %d-th dimension of output tensor %d or 1, but " + "recieved %d.", + in_idx + 1, + in_idx + 1, + out_dims[in_idx], + in_dim[in_idx])); + } + } while (in_idx < dim_size); + } + std::reverse(in_dim.begin(), in_dim.end()); + } + std::reverse(out_dims.begin(), out_dims.end()); + } + + template + __inline__ void MergeDimensions(MergeFunctor merge_func, int N) { + auto VectorReorganise = [](DimVector *vec, int l_idx, int m_idx) { + (*vec)[m_idx - 1] = std::accumulate(vec->begin() + l_idx, + vec->begin() + m_idx, + 1, + std::multiplies()); + vec->erase(vec->begin() + l_idx, vec->begin() + m_idx - 1); + }; + + int64_t i = 0; + while (i < dim_size) { + int cnt = 0; + int low_idx = i; + bool equal = true; + do { + merge_func(equal, in_dims, out_dims, i, N); + if (equal) { + i++; + cnt++; + } else { + break; + } + } while (i < dim_size); + + if (cnt > 1) { + for (auto &in_dim : in_dims) { + VectorReorganise(&in_dim, low_idx, i); + } + VectorReorganise(&out_dims, low_idx, i); + dim_size -= --cnt; + i -= cnt; + } else if (cnt < 1) { + i++; + } + } + } + + public: + explicit DimensionsTransform(const std::vector &ins, + const paddle::framework::DDim &dims, + int axis) { + const int N = ins.size(); + dim_size = dims.size(); + out_dims = paddle::framework::vectorize(dims); + in_dims.resize(N); + for (int j = 0; j < N; ++j) { + in_dims[j] = paddle::framework::vectorize(ins[j]->dims()); + } + InputDimensionsExtend(N, axis); + + auto merge_sequential_dims = [](bool &equal, + std::vector &in_dims, + DimVector &out, + int i, + int num) { + for (int j = 1; j < num; ++j) { + equal &= (in_dims[0][i] == in_dims[j][i]) ? true : false; + } + }; + auto merge_sequential_one_dims = [](bool &equal, + std::vector &in_dims, + DimVector &out, + int i, + int num) { + equal = in_dims[0][i] == 1; + if (equal) { + for (int j = 1; j < num; ++j) { + equal &= in_dims[j][i] == out[i]; + } + } + }; + // To Merge the dimensions of input_tensors while the consequtive + // equal-dimensions appears. + MergeFunctor merge_ptr = merge_sequential_dims; + MergeDimensions(merge_ptr, N); + + int min_idx = 0; + int min_val = std::accumulate( + in_dims[0].begin(), in_dims[0].end(), 1, std::multiplies()); + for (int j = 1; j < N; ++j) { + int temp = std::accumulate( + in_dims[j].begin(), in_dims[j].end(), 1, std::multiplies()); + min_val = min_val > temp ? temp : min_val; + min_idx = min_val == temp ? j : min_idx; + } + std::swap(in_dims[0], in_dims[min_idx]); + + // To Merge the dimension of input_tensors while the consequtive + // 1-value-dimensions appears. + merge_ptr = merge_sequential_one_dims; + MergeDimensions(merge_ptr, N); + std::swap(in_dims[min_idx], in_dims[0]); + } +}; + +template +__device__ __forceinline__ void LoadData( + T *dst, + const _ptr_ T *src, + uint32_t block_offset, + const kps::details::BroadcastConfig &config, + int numel, + int num, + int need_broadcast) { + // numel : whole num of output + // num: how many data will be deal with in this time + if (need_broadcast) { + kps::ReadDataBc( + dst, src, block_offset, config, numel); + } else { + kps::ReadData(dst, src + block_offset, num); + } +} + +template +__device__ void ElementwiseBroadcastKernelImpl( + const paddle::framework::Array &ins, + paddle::framework::Array<_ptr_ OutT *, NumOuts> outs, + const paddle::framework::Array &use_broadcast, + uint32_t numel, + const paddle::framework::Array, Arity> + &configs, + int num, + int block_offset, + Functor func) { + InT args[Arity][VecSize]; + ConditionalT result[VecSize]; + +#pragma unroll + for (int i = 0; i < Arity; i++) { + kps::Init(args[i], static_cast(1.0f)); + LoadData(args[i], + ins[i], + block_offset, + configs[i], + numel, + num, + use_broadcast[i]); + } + constexpr bool kCallElementwiseAny = + paddle::platform::FunctionTraits::has_pointer_args; + ElementwisePrimitiveCaller, + VecSize, + Functor, + Arity, + kCallElementwiseAny>()(func, args, result); + + ElementwiseWriteDataCaller()( + outs, result, block_offset, num); +} + +template +__global__ void ElementwiseBroadcastKernel( + paddle::framework::Array ins, + paddle::framework::Array<_ptr_ OutT *, NumOuts> outs, + paddle::framework::Array use_broadcast, + uint32_t numel, + paddle::framework::Array, Arity> + configs, + int main_offset, + int tail_tid, + Functor func) { + int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize; + int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize; + +#ifdef PADDLE_WITH_XPU2 + for (; block_offset < main_offset; block_offset += stride) { + ElementwiseBroadcastKernelImpl(ins, + outs, + use_broadcast, + numel, + configs, + BLOCK_NUM_X * VecSize, + block_offset, + func); + } + int num = numel - block_offset; + if (num > 0) { + ElementwiseBroadcastKernelImpl( + ins, outs, use_broadcast, numel, configs, num, block_offset, func); + } +#else + if (block_offset < main_offset) { + ElementwiseBroadcastKernelImpl(ins, + outs, + use_broadcast, + numel, + configs, + BLOCK_NUM_X * VecSize, + block_offset, + func); + } else { + ElementwiseBroadcastKernelImpl( + ins, outs, use_broadcast, numel, configs, tail_tid, block_offset, func); + } +#endif +} + +template +void LaunchKernel(const KPDevice &ctx, + const std::vector &ins, + std::vector *outs, + Functor func, + DimensionsTransform merge_dims) { + int numel = (*outs)[0]->numel(); + paddle::framework::Array, Arity> configs; + paddle::framework::Array use_broadcast; + paddle::framework::Array ins_data; + paddle::framework::Array<_ptr_ OutT *, NumOuts> outs_data; + + for (int i = 0; i < NumOuts; ++i) { + outs_data[i] = (*outs)[i]->mutable_data(); + } + + for (int i = 0; i < Arity; i++) { + use_broadcast[i] = (ins[i]->numel() != numel); + ins_data[i] = (_ptr_ InT *)(ins[i]->data()); + if (use_broadcast[i]) { + // get the broadcast config, + // if data shape is[m, n], then you should set data_dim = {n, m} + // eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3} + configs[i] = kps::details::BroadcastConfig( + merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size); + } + } + +#ifdef PADDLE_WITH_XPU2 + const int threads = 64; + const int blocks = 8; + int main_offset = (numel / (VecSize * threads)) * VecSize * threads; + int tail_tid = numel % (VecSize * threads); + auto stream = ctx.x_context()->xpu_stream; + ElementwiseBroadcastKernel<<>>(ins_data, + outs_data, + use_broadcast, + numel, + configs, + main_offset, + tail_tid, + func); +#else + const int threads = 256; + int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads; + int main_offset = (numel / (VecSize * threads)) * VecSize * threads; + int tail_tid = numel % (VecSize * threads); + auto stream = ctx.stream(); + ElementwiseBroadcastKernel<<>>( + ins_data, + outs_data, + use_broadcast, + numel, + configs, + main_offset, + tail_tid, + func); +#endif +} + +template +void LaunchBroadcastKernelForDifferentVecSize( + const KPDevice &ctx, + const std::vector &ins, + std::vector *outs, + int axis, + Functor func) { + const auto merge_dims = DimensionsTransform(ins, (*outs)[0]->dims(), axis); + +#define CALL_BROADCAST_FOR_DIM_SIZE(rank) \ + case rank: { \ + LaunchKernel( \ + ctx, ins, outs, func, merge_dims); \ + } break; + + switch (merge_dims.dim_size) { + CALL_BROADCAST_FOR_DIM_SIZE(1); + CALL_BROADCAST_FOR_DIM_SIZE(2); + CALL_BROADCAST_FOR_DIM_SIZE(3); + CALL_BROADCAST_FOR_DIM_SIZE(4); + CALL_BROADCAST_FOR_DIM_SIZE(5); + CALL_BROADCAST_FOR_DIM_SIZE(6); + CALL_BROADCAST_FOR_DIM_SIZE(7); + CALL_BROADCAST_FOR_DIM_SIZE(8); + default: { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "The maximum dimension of input tensor is expected to be less than " + "%d, but recieved %d.\n", + merge_dims.dim_size, + paddle::framework::DDim::kMaxRank)); + } + } +#undef CALL_BROADCAST_FOR_DIM_SIZE +} + +template +void LaunchBroadcastElementwiseCudaKernel( + const KPDevice &ctx, + const std::vector &ins, + std::vector *outs, + int axis, + Functor func) { + using Traits = paddle::platform::FunctionTraits; + const int kArity = + Traits::has_pointer_args ? static_cast(ET) : Traits::arity; + PADDLE_ENFORCE_EQ(ins.size(), + kArity, + paddle::platform::errors::InvalidArgument( + "The number of inputs is expected to be equal to the " + "arity of functor. But recieved: the number of inputs " + "is %d, the arity of functor is %d.", + ins.size(), + kArity)); + PADDLE_ENFORCE_LE(kArity, + 3, + paddle::platform::errors::InvalidArgument( + "Currently only broadcast of ternary is supported " + "and verified, but received %d.", + kArity)); + PADDLE_ENFORCE_EQ(outs->size(), + NumOuts, + paddle::platform::errors::InvalidArgument( + "Number of outputs shall equal to number of functions, " + "but number of outputs is %d, of functions is %d.", + outs->size(), + NumOuts)); + int in_vec_size = 4; + int out_vec_size = 4; + if (NumOuts > 1) { + for (int i = 0; i < NumOuts; ++i) { + PADDLE_ENFORCE_EQ( + (*outs)[i]->dims(), + (*outs)[0]->dims(), + paddle::platform::errors::InvalidArgument( + "The shape of each output tensor shall be identical yet, but " + "%dth output tensor`s shape is not.", + i)); + out_vec_size = std::min( + paddle::platform::GetVectorizedSize((*outs)[i]->data()), + out_vec_size); + } + } else { + out_vec_size = + paddle::platform::GetVectorizedSize((*outs)[0]->data()); + } + + for (auto *in : ins) { + auto temp_size = paddle::platform::GetVectorizedSize(in->data()); + in_vec_size = in->dims() == (*outs)[0]->dims() + ? std::min(temp_size, in_vec_size) + : in_vec_size; + } + int vec_size = std::min(out_vec_size, in_vec_size); + + switch (vec_size) { + case 4: { + LaunchBroadcastKernelForDifferentVecSize(ctx, ins, outs, axis, func); + break; + } + case 2: { + LaunchBroadcastKernelForDifferentVecSize(ctx, ins, outs, axis, func); + break; + } + case 1: { + LaunchBroadcastKernelForDifferentVecSize(ctx, ins, outs, axis, func); + break; + } + default: { + PADDLE_THROW(paddle::platform::errors::Unimplemented( + "Unsupported vectorized size: %d !", vec_size)); + break; + } + } +} + +template +void LaunchElementwiseCudaKernel(const KPDevice &ctx, + const std::vector &ins, + std::vector *outs, + int axis, + Functor func) { + std::vector dims_size; + bool no_broadcast_flag = true; + for (auto *in : ins) { + no_broadcast_flag &= ins[0]->dims() == in->dims(); + dims_size.emplace_back(in->dims().size()); + } + if (no_broadcast_flag) { + LaunchSameDimsElementwiseCudaKernel( + ctx, ins, outs, func); + } else { + axis = axis == -1 + ? *std::max_element(dims_size.begin(), dims_size.end()) - + *std::min_element(dims_size.begin(), dims_size.end()) + : axis; + LaunchBroadcastElementwiseCudaKernel( + ctx, ins, outs, axis, func); + } +} + +} // namespace pten -- Gitee